1use std::collections::HashMap;
29use std::sync::{Arc, RwLock};
30use serde::{Serialize, Deserialize};
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(untagged)]
39pub enum ModuleArg {
40 Null,
41 Bool(bool),
42 Int(i64),
43 Float(f64),
44 String(String),
45 Bytes(Vec<u8>),
46 Array(Vec<ModuleArg>),
47 Map(HashMap<String, ModuleArg>),
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(untagged)]
53pub enum ModuleValue {
54 Null,
55 Bool(bool),
56 Int(i64),
57 Float(f64),
58 String(String),
59 Bytes(Vec<u8>),
60 Array(Vec<ModuleValue>),
61 Map(HashMap<String, ModuleValue>),
62}
63
64pub type ModuleResult = Result<ModuleValue, ModuleError>;
66
67#[derive(Debug, thiserror::Error)]
69pub enum ModuleError {
70 #[error("Module not found: {0}")]
71 ModuleNotFound(String),
72
73 #[error("Method not found: {module}.{method}")]
74 MethodNotFound { module: String, method: String },
75
76 #[error("Invalid arguments: {0}")]
77 InvalidArgs(String),
78
79 #[error("Method {0} is async, use invoke_async instead")]
80 IsAsync(String),
81
82 #[error("Module error: {0}")]
83 Internal(String),
84}
85
86#[derive(Debug, Clone)]
92pub struct MethodDescriptor {
93 pub name: String,
94 pub description: String,
95 pub is_async: bool,
96}
97
98impl MethodDescriptor {
99 pub fn sync(name: impl Into<String>, description: impl Into<String>) -> Self {
100 Self { name: name.into(), description: description.into(), is_async: false }
101 }
102
103 pub fn async_method(name: impl Into<String>, description: impl Into<String>) -> Self {
104 Self { name: name.into(), description: description.into(), is_async: true }
105 }
106}
107
108pub trait NativeModule: Send + Sync {
116 fn name(&self) -> &str;
118
119 fn methods(&self) -> Vec<MethodDescriptor>;
121
122 fn invoke_sync(&self, method: &str, args: &[ModuleArg]) -> ModuleResult {
124 Err(ModuleError::MethodNotFound {
125 module: self.name().to_string(),
126 method: method.to_string(),
127 })
128 }
129
130 fn invoke_async(
133 &self, method: &str, args: &[ModuleArg],
134 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ModuleResult> + Send + '_>> {
135 let module = self.name().to_string();
136 let method = method.to_string();
137 Box::pin(async move {
138 Err(ModuleError::MethodNotFound { module, method })
139 })
140 }
141
142 fn on_init(&self) {}
144
145 fn on_destroy(&self) {}
147}
148
149pub struct ModuleRegistry {
155 modules: RwLock<HashMap<String, Arc<dyn NativeModule>>>,
156}
157
158impl ModuleRegistry {
159 pub fn new() -> Self {
160 Self { modules: RwLock::new(HashMap::new()) }
161 }
162
163 pub fn register(&self, module: Arc<dyn NativeModule>) -> Result<(), ModuleError> {
165 let name = module.name().to_string();
166 let mut modules = self.modules.write().unwrap();
167 if modules.contains_key(&name) {
168 return Err(ModuleError::Internal(format!("Module '{name}' already registered")));
169 }
170 module.on_init();
171 modules.insert(name, module);
172 Ok(())
173 }
174
175 pub fn unregister(&self, name: &str) -> Option<Arc<dyn NativeModule>> {
177 let mut modules = self.modules.write().unwrap();
178 let module = modules.remove(name);
179 if let Some(ref m) = module {
180 m.on_destroy();
181 }
182 module
183 }
184
185 pub fn has(&self, name: &str) -> bool {
187 self.modules.read().unwrap().contains_key(name)
188 }
189
190 pub fn module_names(&self) -> Vec<String> {
192 self.modules.read().unwrap().keys().cloned().collect()
193 }
194
195 pub fn module_methods(&self, name: &str) -> Result<Vec<MethodDescriptor>, ModuleError> {
197 let modules = self.modules.read().unwrap();
198 let module = modules.get(name)
199 .ok_or_else(|| ModuleError::ModuleNotFound(name.to_string()))?;
200 Ok(module.methods())
201 }
202
203 pub fn invoke_sync(
205 &self, module_name: &str, method: &str, args: &[ModuleArg],
206 ) -> ModuleResult {
207 let modules = self.modules.read().unwrap();
208 let module = modules.get(module_name)
209 .ok_or_else(|| ModuleError::ModuleNotFound(module_name.to_string()))?;
210
211 let methods = module.methods();
213 let desc = methods.iter().find(|m| m.name == method);
214 match desc {
215 None => Err(ModuleError::MethodNotFound {
216 module: module_name.to_string(),
217 method: method.to_string(),
218 }),
219 Some(d) if d.is_async => Err(ModuleError::IsAsync(method.to_string())),
220 Some(_) => module.invoke_sync(method, args),
221 }
222 }
223
224 pub async fn invoke_async(
226 &self, module_name: &str, method: &str, args: &[ModuleArg],
227 ) -> ModuleResult {
228 let module = {
229 let modules = self.modules.read().unwrap();
230 modules.get(module_name)
231 .ok_or_else(|| ModuleError::ModuleNotFound(module_name.to_string()))?
232 .clone()
233 };
234
235 module.invoke_async(method, args).await
236 }
237}
238
239impl Default for ModuleRegistry {
240 fn default() -> Self { Self::new() }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 struct TestModule;
248
249 impl NativeModule for TestModule {
250 fn name(&self) -> &str { "Test" }
251
252 fn methods(&self) -> Vec<MethodDescriptor> {
253 vec![
254 MethodDescriptor::sync("greet", "Returns a greeting"),
255 MethodDescriptor::sync("add", "Adds two numbers"),
256 MethodDescriptor::async_method("fetch", "Fetches data"),
257 ]
258 }
259
260 fn invoke_sync(&self, method: &str, args: &[ModuleArg]) -> ModuleResult {
261 match method {
262 "greet" => {
263 let name = match args.first() {
264 Some(ModuleArg::String(s)) => s.as_str(),
265 _ => "World",
266 };
267 Ok(ModuleValue::String(format!("Hello, {name}!")))
268 }
269 "add" => {
270 let a = match args.first() {
271 Some(ModuleArg::Int(n)) => *n,
272 _ => return Err(ModuleError::InvalidArgs("expected int".into())),
273 };
274 let b = match args.get(1) {
275 Some(ModuleArg::Int(n)) => *n,
276 _ => return Err(ModuleError::InvalidArgs("expected int".into())),
277 };
278 Ok(ModuleValue::Int(a + b))
279 }
280 _ => Err(ModuleError::MethodNotFound {
281 module: self.name().to_string(),
282 method: method.to_string(),
283 }),
284 }
285 }
286 }
287
288 #[test]
289 fn test_register_and_invoke() {
290 let registry = ModuleRegistry::new();
291 registry.register(Arc::new(TestModule)).unwrap();
292
293 assert!(registry.has("Test"));
294 assert!(!registry.has("Unknown"));
295
296 let result = registry.invoke_sync(
297 "Test", "greet", &[ModuleArg::String("Rust".into())]
298 ).unwrap();
299 assert!(matches!(result, ModuleValue::String(s) if s == "Hello, Rust!"));
300 }
301
302 #[test]
303 fn test_add_method() {
304 let registry = ModuleRegistry::new();
305 registry.register(Arc::new(TestModule)).unwrap();
306
307 let result = registry.invoke_sync(
308 "Test", "add", &[ModuleArg::Int(3), ModuleArg::Int(4)]
309 ).unwrap();
310 assert!(matches!(result, ModuleValue::Int(7)));
311 }
312
313 #[test]
314 fn test_module_not_found() {
315 let registry = ModuleRegistry::new();
316 let result = registry.invoke_sync("Missing", "greet", &[]);
317 assert!(matches!(result, Err(ModuleError::ModuleNotFound(_))));
318 }
319
320 #[test]
321 fn test_method_not_found() {
322 let registry = ModuleRegistry::new();
323 registry.register(Arc::new(TestModule)).unwrap();
324 let result = registry.invoke_sync("Test", "missing", &[]);
325 assert!(matches!(result, Err(ModuleError::MethodNotFound { .. })));
326 }
327
328 #[test]
329 fn test_async_method_guard() {
330 let registry = ModuleRegistry::new();
331 registry.register(Arc::new(TestModule)).unwrap();
332 let result = registry.invoke_sync("Test", "fetch", &[]);
333 assert!(matches!(result, Err(ModuleError::IsAsync(_))));
334 }
335
336 #[test]
337 fn test_duplicate_register() {
338 let registry = ModuleRegistry::new();
339 registry.register(Arc::new(TestModule)).unwrap();
340 let result = registry.register(Arc::new(TestModule));
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_unregister() {
346 let registry = ModuleRegistry::new();
347 registry.register(Arc::new(TestModule)).unwrap();
348 assert!(registry.has("Test"));
349
350 let removed = registry.unregister("Test");
351 assert!(removed.is_some());
352 assert!(!registry.has("Test"));
353 }
354
355 #[test]
356 fn test_module_names() {
357 let registry = ModuleRegistry::new();
358 registry.register(Arc::new(TestModule)).unwrap();
359 let names = registry.module_names();
360 assert_eq!(names, vec!["Test".to_string()]);
361 }
362
363 #[test]
364 fn test_module_methods_list() {
365 let registry = ModuleRegistry::new();
366 registry.register(Arc::new(TestModule)).unwrap();
367 let methods = registry.module_methods("Test").unwrap();
368 assert_eq!(methods.len(), 3);
369 assert_eq!(methods[0].name, "greet");
370 assert!(!methods[0].is_async);
371 assert_eq!(methods[2].name, "fetch");
372 assert!(methods[2].is_async);
373 }
374}