1pub mod macros;
2pub mod modbus;
3
4use crate::modbus::ModbusDriver;
5use anyhow::{anyhow, Error};
6use async_trait::async_trait;
7use libloading::{Library, Symbol};
8use node_grove_common_api::ng_proto::DriverMetadata;
9use prost::Message;
10use std::any::Any;
11use std::collections::HashMap;
12use std::path::Path;
13use std::slice;
14use std::sync::Arc;
15use tokio::fs;
16use tokio::sync::{Mutex, OnceCell};
17use tracing::{info, warn};
18
19#[async_trait]
20pub trait NGDriver: Send + Sync + Any {
21 async fn initialize(&self) -> Result<(), Error>;
22
23 async fn run(&self) -> Result<(), Error>;
24
25 async fn execute_command(
26 &self,
27 device_id: i32,
28 command: &str,
29 params: Box<dyn Any + Send>,
30 ) -> Result<Box<dyn Any + Send>, Error>;
31
32 async fn shutdown(&self) -> Result<(), Error>;
33
34 fn converter(&self) -> Box<dyn NGConverter>;
35}
36
37pub trait NGConverter {
38 fn convert_in(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
39 fn convert_out(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
40}
41
42pub struct DriverManager {
43 drivers: Arc<Mutex<HashMap<String, Arc<dyn NGDriver>>>>,
44}
45
46static INSTANCE: OnceCell<Arc<Mutex<DriverManager>>> = OnceCell::const_new();
47
48impl DriverManager {
49 pub fn instance() -> Result<Arc<Mutex<Self>>, Error> {
50 INSTANCE
51 .get()
52 .ok_or_else(|| anyhow!("NGControlCenter is not initialized"))
53 .cloned()
54 }
55
56 pub async fn init(extension_paths: &[String]) -> Result<(), Error> {
57 let mut dm = DriverManager {
58 drivers: Arc::new(Mutex::new(HashMap::new())),
59 };
60 dm.register_builtin_drivers().await?;
61 dm.load_custom_drivers(extension_paths).await;
62 INSTANCE
63 .set(Arc::new(Mutex::new(dm)))
64 .map_err(|_| anyhow!("Failed to initialize driver manager"))
65 }
66
67 async fn register_builtin_drivers(&mut self) -> Result<(), Error> {
69 self.register_driver("modbus", Arc::new(ModbusDriver::new()))
70 .await?;
71 Ok(())
72 }
73
74 pub async fn register_driver(
76 &mut self,
77 name: &str,
78 driver: Arc<dyn NGDriver>,
79 ) -> Result<(), Error> {
80 let mut drivers = self.drivers.lock().await;
81
82 if drivers.contains_key(name) {
83 return Err(anyhow!(format!("Driver '{}' already exists", name)));
84 }
85
86 driver.initialize().await.map_err(|err| {
87 anyhow!(format!(
88 "driver initialize failed with: {}",
89 err.to_string()
90 ))
91 })?;
92 info!("Driver '{}' registered", name);
93 drivers.insert(name.to_string(), driver.clone());
94 Ok(())
95 }
96
97 pub async fn get_driver(&self, name: &str) -> Option<Arc<dyn NGDriver>> {
98 let drivers = Arc::clone(&self.drivers);
99 let driver = drivers.lock().await.get(name).cloned();
100 driver
101 }
102
103 async fn load_custom_drivers(&mut self, paths: &[String]) {
104 let lib_extension = if cfg!(target_os = "macos") {
105 "dylib"
106 } else if cfg!(target_os = "linux") {
107 "so"
108 } else {
109 warn!("Unsupported OS to load custom drivers");
110 return;
111 };
112
113 for path in paths {
114 let path = Path::new(path);
115
116 if path.is_dir() {
117 let mut entries = fs::read_dir(path).await.unwrap();
119 while let Some(entry) = entries.next_entry().await.unwrap() {
120 let file_path = entry.path();
121 if file_path.is_file()
122 && file_path
123 .extension()
124 .map_or(false, |ext| ext == lib_extension)
125 {
126 match self.load_library(&file_path).await {
127 Err(err) => {
128 warn!(
129 "Failed to load custom driver by path: {}, {}",
130 file_path.display(),
131 err
132 );
133 }
134 _ => {}
135 }
136 }
137 }
138 } else if path.is_file() && path.extension().map_or(false, |ext| ext == lib_extension) {
139 match self.load_library(&path).await {
140 Err(err) => {
141 warn!(
142 "Failed to load custom driver by path: {}, {}",
143 path.display(),
144 err
145 );
146 }
147 _ => {}
148 }
149 }
150 }
151 }
152
153 #[allow(unused_unsafe)]
154 async fn load_library(&mut self, path: &Path) -> Result<(), Error> {
155 let library = unsafe { Library::new(path) }
156 .map_err(|err| anyhow!(format!("Failed to load library: {:?}", err)))?;
157 let driver_symbol: Symbol<fn() -> Arc<dyn NGDriver>> = unsafe { library.get(b"driver") }
158 .map_err(|err| anyhow!(format!("Failed to load driver symbol: {:?}", err)))?;
159 let metadata_symbol: Symbol<fn() -> *const u8> = unsafe { library.get(b"metadata") }
160 .map_err(|err| anyhow!(format!("Failed to load metadata symbol: {:?}", err)))?;
161 let metadata_len_symbol: Symbol<fn() -> usize> = unsafe { library.get(b"metadata_length") }
162 .map_err(|err| anyhow!(format!("Failed to load metadata_length symbol: {:?}", err)))?;
163 let free_symbol: Symbol<fn()> = unsafe { library.get(b"free_metadata") }
164 .map_err(|err| anyhow!(format!("Failed to load free symbol: {:?}", err)))?;
165
166 let raw_metadata = unsafe { metadata_symbol() };
167 let len = unsafe { metadata_len_symbol() };
168
169 let metadata_slice = unsafe { slice::from_raw_parts(raw_metadata, len) };
170
171 let metadata = DriverMetadata::decode(metadata_slice)?;
172 let driver = driver_symbol();
173 if !raw_metadata.is_null() {
174 unsafe {
175 free_symbol();
176 }
177 }
178 self.register_driver(metadata.name.as_str(), driver).await?;
179 Ok(())
180 }
181}