ng_gateway_driver/
lib.rs

1pub mod macros;
2pub mod modbus;
3
4use crate::modbus::ModbusDriver;
5use anyhow::{anyhow, Error};
6use async_trait::async_trait;
7use libloading::{Library, Symbol};
8use node_grove_common_api::ng_proto::DriverMetadata;
9use prost::Message;
10use std::any::Any;
11use std::collections::HashMap;
12use std::path::Path;
13use std::slice;
14use std::sync::Arc;
15use tokio::fs;
16use tokio::sync::{Mutex, OnceCell};
17use tracing::{info, warn};
18
19#[async_trait]
20pub trait NGDriver: Send + Sync + Any {
21    async fn initialize(&self) -> Result<(), Error>;
22
23    async fn run(&self) -> Result<(), Error>;
24
25    async fn execute_command(
26        &self,
27        device_id: i32,
28        command: &str,
29        params: Box<dyn Any + Send>,
30    ) -> Result<Box<dyn Any + Send>, Error>;
31
32    async fn shutdown(&self) -> Result<(), Error>;
33
34    fn converter(&self) -> Box<dyn NGConverter>;
35}
36
37pub trait NGConverter {
38    fn convert_in(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
39    fn convert_out(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
40}
41
42pub struct DriverManager {
43    drivers: Arc<Mutex<HashMap<String, Arc<dyn NGDriver>>>>,
44}
45
46static INSTANCE: OnceCell<Arc<Mutex<DriverManager>>> = OnceCell::const_new();
47
48impl DriverManager {
49    pub fn instance() -> Result<Arc<Mutex<Self>>, Error> {
50        INSTANCE
51            .get()
52            .ok_or_else(|| anyhow!("NGControlCenter is not initialized"))
53            .cloned()
54    }
55
56    pub async fn init(extension_paths: &[String]) -> Result<(), Error> {
57        let mut dm = DriverManager {
58            drivers: Arc::new(Mutex::new(HashMap::new())),
59        };
60        dm.register_builtin_drivers().await?;
61        dm.load_custom_drivers(extension_paths).await;
62        INSTANCE
63            .set(Arc::new(Mutex::new(dm)))
64            .map_err(|_| anyhow!("Failed to initialize driver manager"))
65    }
66
67    // register builtin drivers
68    async fn register_builtin_drivers(&mut self) -> Result<(), Error> {
69        self.register_driver("modbus", Arc::new(ModbusDriver::new()))
70            .await?;
71        Ok(())
72    }
73
74    // register driver(support builtin & custom)
75    pub async fn register_driver(
76        &mut self,
77        name: &str,
78        driver: Arc<dyn NGDriver>,
79    ) -> Result<(), Error> {
80        let mut drivers = self.drivers.lock().await;
81
82        if drivers.contains_key(name) {
83            return Err(anyhow!(format!("Driver '{}' already exists", name)));
84        }
85
86        driver.initialize().await.map_err(|err| {
87            anyhow!(format!(
88                "driver initialize failed with: {}",
89                err.to_string()
90            ))
91        })?;
92        info!("Driver '{}' registered", name);
93        drivers.insert(name.to_string(), driver.clone());
94        Ok(())
95    }
96
97    pub async fn get_driver(&self, name: &str) -> Option<Arc<dyn NGDriver>> {
98        let drivers = Arc::clone(&self.drivers);
99        let driver = drivers.lock().await.get(name).cloned();
100        driver
101    }
102
103    async fn load_custom_drivers(&mut self, paths: &[String]) {
104        let lib_extension = if cfg!(target_os = "macos") {
105            "dylib"
106        } else if cfg!(target_os = "linux") {
107            "so"
108        } else {
109            warn!("Unsupported OS to load custom drivers");
110            return;
111        };
112
113        for path in paths {
114            let path = Path::new(path);
115
116            if path.is_dir() {
117                // 异步读取目录
118                let mut entries = fs::read_dir(path).await.unwrap();
119                while let Some(entry) = entries.next_entry().await.unwrap() {
120                    let file_path = entry.path();
121                    if file_path.is_file()
122                        && file_path
123                            .extension()
124                            .map_or(false, |ext| ext == lib_extension)
125                    {
126                        match self.load_library(&file_path).await {
127                            Err(err) => {
128                                warn!(
129                                    "Failed to load custom driver by path: {}, {}",
130                                    file_path.display(),
131                                    err
132                                );
133                            }
134                            _ => {}
135                        }
136                    }
137                }
138            } else if path.is_file() && path.extension().map_or(false, |ext| ext == lib_extension) {
139                match self.load_library(&path).await {
140                    Err(err) => {
141                        warn!(
142                            "Failed to load custom driver by path: {}, {}",
143                            path.display(),
144                            err
145                        );
146                    }
147                    _ => {}
148                }
149            }
150        }
151    }
152
153    #[allow(unused_unsafe)]
154    async fn load_library(&mut self, path: &Path) -> Result<(), Error> {
155        let library = unsafe { Library::new(path) }
156            .map_err(|err| anyhow!(format!("Failed to load library: {:?}", err)))?;
157        let driver_symbol: Symbol<fn() -> Arc<dyn NGDriver>> = unsafe { library.get(b"driver") }
158            .map_err(|err| anyhow!(format!("Failed to load driver symbol: {:?}", err)))?;
159        let metadata_symbol: Symbol<fn() -> *const u8> = unsafe { library.get(b"metadata") }
160            .map_err(|err| anyhow!(format!("Failed to load metadata symbol: {:?}", err)))?;
161        let metadata_len_symbol: Symbol<fn() -> usize> = unsafe { library.get(b"metadata_length") }
162            .map_err(|err| anyhow!(format!("Failed to load metadata_length symbol: {:?}", err)))?;
163        let free_symbol: Symbol<fn()> = unsafe { library.get(b"free_metadata") }
164            .map_err(|err| anyhow!(format!("Failed to load free symbol: {:?}", err)))?;
165
166        let raw_metadata = unsafe { metadata_symbol() };
167        let len = unsafe { metadata_len_symbol() };
168
169        let metadata_slice = unsafe { slice::from_raw_parts(raw_metadata, len) };
170
171        let metadata = DriverMetadata::decode(metadata_slice)?;
172        let driver = driver_symbol();
173        if !raw_metadata.is_null() {
174            unsafe {
175                free_symbol();
176            }
177        }
178        self.register_driver(metadata.name.as_str(), driver).await?;
179        Ok(())
180    }
181}