Skip to main content

ai_lib_rust/plugins/
registry.rs

1//! Plugin registry.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use super::base::{Plugin, PluginContext};
6use crate::Result;
7
8pub struct PluginRegistry { plugins: RwLock<HashMap<String, Arc<dyn Plugin>>>, enabled: RwLock<bool> }
9
10impl PluginRegistry {
11    pub fn new() -> Self { Self { plugins: RwLock::new(HashMap::new()), enabled: RwLock::new(true) } }
12
13    pub async fn register(&self, plugin: Arc<dyn Plugin>) -> Result<()> {
14        let name = plugin.name().to_string();
15        plugin.on_register().await?;
16        self.plugins.write().unwrap().insert(name, plugin);
17        Ok(())
18    }
19
20    pub async fn unregister(&self, name: &str) -> Result<Option<Arc<dyn Plugin>>> {
21        let plugin = self.plugins.write().unwrap().remove(name);
22        if let Some(ref p) = plugin { p.on_unregister().await?; }
23        Ok(plugin)
24    }
25
26    pub fn get(&self, name: &str) -> Option<Arc<dyn Plugin>> { self.plugins.read().unwrap().get(name).cloned() }
27    pub fn has(&self, name: &str) -> bool { self.plugins.read().unwrap().contains_key(name) }
28    pub fn list(&self) -> Vec<Arc<dyn Plugin>> { self.plugins.read().unwrap().values().cloned().collect() }
29    pub fn list_by_priority(&self) -> Vec<Arc<dyn Plugin>> { let mut p = self.list(); p.sort_by_key(|x| x.priority()); p }
30    pub fn count(&self) -> usize { self.plugins.read().unwrap().len() }
31    pub fn set_enabled(&self, e: bool) { *self.enabled.write().unwrap() = e; }
32    pub fn is_enabled(&self) -> bool { *self.enabled.read().unwrap() }
33
34    pub async fn trigger_before_request(&self, ctx: &mut PluginContext) -> Result<()> {
35        if !self.is_enabled() { return Ok(()); }
36        for p in self.list_by_priority() { if ctx.should_skip() { break; } p.on_before_request(ctx).await?; }
37        Ok(())
38    }
39
40    pub async fn trigger_after_response(&self, ctx: &mut PluginContext) -> Result<()> {
41        if !self.is_enabled() { return Ok(()); }
42        for p in self.list_by_priority() { if ctx.should_skip() { break; } p.on_after_response(ctx).await?; }
43        Ok(())
44    }
45
46    pub async fn trigger_on_error(&self, ctx: &mut PluginContext) -> Result<()> {
47        if !self.is_enabled() { return Ok(()); }
48        for p in self.list_by_priority() { p.on_error(ctx).await?; }
49        Ok(())
50    }
51
52    pub async fn clear(&self) -> Result<()> {
53        let plugins: HashMap<_, _> = std::mem::take(&mut *self.plugins.write().unwrap());
54        for (_, p) in plugins { let _ = p.on_unregister().await; }
55        Ok(())
56    }
57}
58impl Default for PluginRegistry { fn default() -> Self { Self::new() } }
59
60static GLOBAL_REGISTRY: once_cell::sync::Lazy<PluginRegistry> = once_cell::sync::Lazy::new(PluginRegistry::new);
61pub fn get_plugin_registry() -> &'static PluginRegistry { &GLOBAL_REGISTRY }