ng_gateway_driver/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
pub mod macros;
pub mod modbus;

use crate::modbus::ModbusDriver;
use anyhow::{anyhow, Error};
use async_trait::async_trait;
use libloading::{Library, Symbol};
use node_grove_common_api::ng_proto::DriverMetadata;
use prost::Message;
use std::any::Any;
use std::collections::HashMap;
use std::path::Path;
use std::slice;
use std::sync::Arc;
use tokio::fs;
use tokio::sync::{Mutex, OnceCell};
use tracing::{info, warn};

#[async_trait]
pub trait NGDriver: Send + Sync + Any {
    async fn initialize(&self) -> Result<(), Error>;

    async fn run(&self) -> Result<(), Error>;

    async fn execute_command(
        &self,
        device_id: i32,
        command: &str,
        params: Box<dyn Any + Send>,
    ) -> Result<Box<dyn Any + Send>, Error>;

    async fn shutdown(&self) -> Result<(), Error>;

    fn converter(&self) -> Box<dyn NGConverter>;
}

pub trait NGConverter {
    fn convert_in(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
    fn convert_out(&self, input: Box<dyn Any + Send>) -> Result<Box<dyn Any + Send>, Error>;
}

pub struct DriverManager {
    drivers: Arc<Mutex<HashMap<String, Arc<dyn NGDriver>>>>,
}

static INSTANCE: OnceCell<Arc<Mutex<DriverManager>>> = OnceCell::const_new();

impl DriverManager {
    pub fn instance() -> Result<Arc<Mutex<Self>>, Error> {
        INSTANCE
            .get()
            .ok_or_else(|| anyhow!("NGControlCenter is not initialized"))
            .cloned()
    }

    pub async fn init(extension_paths: &[String]) -> Result<(), Error> {
        let mut dm = DriverManager {
            drivers: Arc::new(Mutex::new(HashMap::new())),
        };
        dm.register_builtin_drivers().await?;
        dm.load_custom_drivers(extension_paths).await;
        INSTANCE
            .set(Arc::new(Mutex::new(dm)))
            .map_err(|_| anyhow!("Failed to initialize driver manager"))
    }

    // register builtin drivers
    async fn register_builtin_drivers(&mut self) -> Result<(), Error> {
        self.register_driver("modbus", Arc::new(ModbusDriver::new()))
            .await?;
        Ok(())
    }

    // register driver(support builtin & custom)
    pub async fn register_driver(
        &mut self,
        name: &str,
        driver: Arc<dyn NGDriver>,
    ) -> Result<(), Error> {
        let mut drivers = self.drivers.lock().await;

        if drivers.contains_key(name) {
            return Err(anyhow!(format!("Driver '{}' already exists", name)));
        }

        driver.initialize().await.map_err(|err| {
            anyhow!(format!(
                "driver initialize failed with: {}",
                err.to_string()
            ))
        })?;
        info!("Driver '{}' registered", name);
        drivers.insert(name.to_string(), driver.clone());
        Ok(())
    }

    pub async fn get_driver(&self, name: &str) -> Option<Arc<dyn NGDriver>> {
        let drivers = Arc::clone(&self.drivers);
        let driver = drivers.lock().await.get(name).cloned();
        driver
    }

    async fn load_custom_drivers(&mut self, paths: &[String]) {
        let lib_extension = if cfg!(target_os = "macos") {
            "dylib"
        } else if cfg!(target_os = "linux") {
            "so"
        } else {
            warn!("Unsupported OS to load custom drivers");
            return;
        };

        for path in paths {
            let path = Path::new(path);

            if path.is_dir() {
                // 异步读取目录
                let mut entries = fs::read_dir(path).await.unwrap();
                while let Some(entry) = entries.next_entry().await.unwrap() {
                    let file_path = entry.path();
                    if file_path.is_file()
                        && file_path
                            .extension()
                            .map_or(false, |ext| ext == lib_extension)
                    {
                        match self.load_library(&file_path).await {
                            Err(err) => {
                                warn!(
                                    "Failed to load custom driver by path: {}, {}",
                                    file_path.display(),
                                    err
                                );
                            }
                            _ => {}
                        }
                    }
                }
            } else if path.is_file() && path.extension().map_or(false, |ext| ext == lib_extension) {
                match self.load_library(&path).await {
                    Err(err) => {
                        warn!(
                            "Failed to load custom driver by path: {}, {}",
                            path.display(),
                            err
                        );
                    }
                    _ => {}
                }
            }
        }
    }

    #[allow(unused_unsafe)]
    async fn load_library(&mut self, path: &Path) -> Result<(), Error> {
        let library = unsafe { Library::new(path) }
            .map_err(|err| anyhow!(format!("Failed to load library: {:?}", err)))?;
        let driver_symbol: Symbol<fn() -> Arc<dyn NGDriver>> = unsafe { library.get(b"driver") }
            .map_err(|err| anyhow!(format!("Failed to load driver symbol: {:?}", err)))?;
        let metadata_symbol: Symbol<fn() -> *const u8> = unsafe { library.get(b"metadata") }
            .map_err(|err| anyhow!(format!("Failed to load metadata symbol: {:?}", err)))?;
        let metadata_len_symbol: Symbol<fn() -> usize> = unsafe { library.get(b"metadata_length") }
            .map_err(|err| anyhow!(format!("Failed to load metadata_length symbol: {:?}", err)))?;
        let free_symbol: Symbol<fn()> = unsafe { library.get(b"free_metadata") }
            .map_err(|err| anyhow!(format!("Failed to load free symbol: {:?}", err)))?;

        let raw_metadata = unsafe { metadata_symbol() };
        let len = unsafe { metadata_len_symbol() };

        let metadata_slice = unsafe { slice::from_raw_parts(raw_metadata, len) };

        let metadata = DriverMetadata::decode(metadata_slice)?;
        let driver = driver_symbol();
        if !raw_metadata.is_null() {
            unsafe {
                free_symbol();
            }
        }
        self.register_driver(metadata.name.as_str(), driver).await?;
        Ok(())
    }
}