versa_core/
lib.rs

1use dashmap::DashMap;
2use libloading::Library;
3use log::info;
4use std::any::Any;
5use std::collections::HashSet;
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::Arc;
8
9#[cfg(feature = "async")]
10use tokio::fs as async_fs;
11#[cfg(feature = "async")]
12use tokio::task;
13
14pub struct PluginValue(Box<dyn Any>);
15
16impl PluginValue {
17    pub fn new<T: 'static>(value: T) -> Self {
18        PluginValue(Box::new(value))
19    }
20
21    pub fn downcast<T: 'static>(self) -> Result<T, Self> {
22        self.0
23            .downcast::<T>()
24            .map(|boxed| *boxed)
25            .map_err(PluginValue)
26    }
27}
28
29#[derive(Clone)]
30pub struct PluginFunction {
31    func: fn(&PluginApi, Option<PluginValue>) -> PluginValue,
32}
33
34impl PluginFunction {
35    pub fn new(func: fn(&PluginApi, Option<PluginValue>) -> PluginValue) -> Self {
36        PluginFunction { func }
37    }
38
39    pub fn call(&self, api: &PluginApi, arg: Option<PluginValue>) -> PluginValue {
40        (self.func)(api, arg)
41    }
42}
43
44pub struct PluginApi {
45    pub registry: Arc<FunctionRegistry>,
46    pub mod_id: u32,
47}
48
49impl PluginApi {
50    pub fn new(registry: Arc<FunctionRegistry>) -> Self {
51        let mod_id = registry.assign_mod_id();
52        Self { registry, mod_id }
53    }
54
55    pub fn register_function(
56        &self,
57        name: &str,
58        func: fn(&PluginApi, Option<PluginValue>) -> PluginValue,
59    ) {
60        let wrapped_function = PluginFunction::new(func);
61        self.registry
62            .register_function(self.mod_id, name, wrapped_function);
63    }
64
65    pub fn call_function(&self, function_name: &str, arg: Option<PluginValue>) -> PluginValue {
66        self.registry
67            .call_function(function_name, self, arg)
68            .unwrap_or_else(|| PluginValue::new("Function not found".to_string()))
69    }
70
71    pub fn unregister_function(&self, name: &str) {
72        self.registry.unregister_function(self.mod_id, name);
73    }
74
75    pub fn list_functions(&self) -> DashMap<u32, HashSet<String>> {
76        self.registry.list_functions()
77    }
78
79    pub fn get_function(&self, name: &str) -> Option<DashMap<u32, PluginFunction>> {
80        self.registry.get_function(name)
81    }
82
83    #[cfg(not(feature = "async"))]
84    pub fn load_plugins(&self, mods_dir: &str) -> Result<(), String> {
85        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
86        let registry = Arc::clone(&self.registry);
87
88        let entries = std::fs::read_dir(mods_dir).map_err(|e| e.to_string())?;
89
90        for entry in entries {
91            let entry = entry.map_err(|e| e.to_string())?;
92            let path = entry.path();
93
94            if path.extension().and_then(|s| s.to_str()) == Some(LIB_EXTENSION) {
95                let lib = unsafe { Library::new(&path).map_err(|e| e.to_string())? };
96
97                unsafe {
98                    let plugin: libloading::Symbol<*const Plugin> =
99                        lib.get(b"PLUGIN\0").map_err(|e| e.to_string())?;
100                    let plugin = &**plugin;
101
102                    info!(
103                        "Loading plugin: {} v{} by {}",
104                        plugin.metadata.name, plugin.metadata.version, plugin.metadata.author
105                    );
106                    let api = PluginApi::new(registry.clone());
107                    (plugin.initialize)(&api);
108                }
109            }
110        }
111
112        Ok(())
113    }
114
115    #[cfg(feature = "async")]
116    pub async fn load_plugins(&self, mods_dir: &str) -> Result<(), String> {
117        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
118        let registry = Arc::clone(&self.registry);
119
120        let mut entries = async_fs::read_dir(mods_dir)
121            .await
122            .map_err(|e| format!("Could not read mods directory: {}", e))?;
123
124        while let Some(entry) = entries
125            .next_entry()
126            .await
127            .map_err(|e| format!("Failed to read entry: {}", e))?
128        {
129            let path = entry.path();
130
131            if path.extension().and_then(|s| s.to_str()) == Some(LIB_EXTENSION) {
132                let lib = unsafe {
133                    Library::new(&path).map_err(|e| format!("Failed to load library: {}", e))?
134                };
135
136                let registry_clone = Arc::clone(&registry);
137                task::spawn(async move {
138                    unsafe {
139                        let plugin: libloading::Symbol<*const Plugin> = lib
140                            .get(b"PLUGIN\0")
141                            .map_err(|e| format!("Failed to load PLUGIN symbol: {}", e))
142                            .unwrap();
143                        let plugin = &**plugin;
144
145                        info!(
146                            "Loading plugin: {} v{} by {}",
147                            plugin.metadata.name, plugin.metadata.version, plugin.metadata.author
148                        );
149                        let api = PluginApi::new(registry_clone);
150                        (plugin.initialize)(&api);
151                    }
152                })
153                .await
154                .expect("Failed to load plugin in async task");
155            }
156        }
157
158        Ok(())
159    }
160}
161
162pub struct Plugin {
163    pub initialize: fn(&PluginApi),
164    pub metadata: PluginMetadata,
165}
166
167pub struct PluginMetadata {
168    pub name: &'static str,
169    pub version: &'static str,
170    pub author: &'static str,
171    pub description: &'static str,
172}
173
174pub struct FunctionRegistry {
175    functions: DashMap<String, DashMap<u32, PluginFunction>>,
176    next_mod_id: AtomicU32,
177}
178
179impl Default for FunctionRegistry {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl FunctionRegistry {
186    pub fn new() -> Self {
187        Self {
188            functions: DashMap::new(),
189            next_mod_id: AtomicU32::new(1),
190        }
191    }
192
193    fn unregister_function(&self, mod_id: u32, name: &str) {
194        if let Some(entry) = self.functions.get_mut(name) {
195            if entry.remove(&mod_id).is_some() {
196                info!("Unregistered function: {}", name);
197            }
198        }
199    }
200
201    fn register_function(&self, mod_id: u32, original_name: &str, function: PluginFunction) {
202        let mut name = original_name.to_string();
203        let mut count = 2;
204
205        while self.functions.contains_key(&name) {
206            name = format!("{}_{}", original_name, count);
207            count += 1;
208        }
209
210        self.functions
211            .entry(name)
212            .or_default()
213            .insert(mod_id, function);
214    }
215
216    fn get_function(&self, name: &str) -> Option<DashMap<u32, PluginFunction>> {
217        self.functions.get(name).map(|entry| entry.clone())
218    }
219
220    fn assign_mod_id(&self) -> u32 {
221        self.next_mod_id.fetch_add(1, Ordering::SeqCst)
222    }
223
224    fn list_functions(&self) -> DashMap<u32, HashSet<String>> {
225        let mod_functions = DashMap::new();
226        self.functions.iter().for_each(|entry| {
227            let (name, mod_map) = entry.pair();
228            mod_map.iter().for_each(|mod_entry| {
229                let (mod_id, _) = mod_entry.pair();
230                mod_functions
231                    .entry(*mod_id)
232                    .or_insert_with(HashSet::new)
233                    .insert(name.clone());
234            });
235        });
236        mod_functions
237    }
238
239    fn call_function(
240        &self,
241        function_name: &str,
242        api: &PluginApi,
243        arg: Option<PluginValue>,
244    ) -> Option<PluginValue> {
245        self.functions.get(function_name).and_then(|mod_map| {
246            mod_map.iter().next().map(|entry| {
247                let (_, func) = entry.pair();
248                func.call(api, arg)
249            })
250        })
251    }
252}
253
254lazy_static::lazy_static! {
255    static ref FUNCTION_REGISTRY: Arc<FunctionRegistry> = Arc::new(FunctionRegistry::new());
256}
257
258pub const LIB_EXTENSION: &str = if cfg!(target_os = "windows") {
259    "dll"
260} else if cfg!(target_os = "macos") {
261    "dylib"
262} else {
263    "so"
264};