Skip to main content

ai_lib_rust/plugins/
registry.rs

1//! Plugin registry.
2
3use super::base::{Plugin, PluginContext};
4use crate::Result;
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7
8pub struct PluginRegistry {
9    plugins: RwLock<HashMap<String, Arc<dyn Plugin>>>,
10    enabled: RwLock<bool>,
11}
12
13impl PluginRegistry {
14    pub fn new() -> Self {
15        Self {
16            plugins: RwLock::new(HashMap::new()),
17            enabled: RwLock::new(true),
18        }
19    }
20
21    pub async fn register(&self, plugin: Arc<dyn Plugin>) -> Result<()> {
22        let name = plugin.name().to_string();
23        plugin.on_register().await?;
24        self.plugins.write().unwrap().insert(name, plugin);
25        Ok(())
26    }
27
28    pub async fn unregister(&self, name: &str) -> Result<Option<Arc<dyn Plugin>>> {
29        let plugin = self.plugins.write().unwrap().remove(name);
30        if let Some(ref p) = plugin {
31            p.on_unregister().await?;
32        }
33        Ok(plugin)
34    }
35
36    pub fn get(&self, name: &str) -> Option<Arc<dyn Plugin>> {
37        self.plugins.read().unwrap().get(name).cloned()
38    }
39    pub fn has(&self, name: &str) -> bool {
40        self.plugins.read().unwrap().contains_key(name)
41    }
42    pub fn list(&self) -> Vec<Arc<dyn Plugin>> {
43        self.plugins.read().unwrap().values().cloned().collect()
44    }
45    pub fn list_by_priority(&self) -> Vec<Arc<dyn Plugin>> {
46        let mut p = self.list();
47        p.sort_by_key(|x| x.priority());
48        p
49    }
50    pub fn count(&self) -> usize {
51        self.plugins.read().unwrap().len()
52    }
53    pub fn set_enabled(&self, e: bool) {
54        *self.enabled.write().unwrap() = e;
55    }
56    pub fn is_enabled(&self) -> bool {
57        *self.enabled.read().unwrap()
58    }
59
60    pub async fn trigger_before_request(&self, ctx: &mut PluginContext) -> Result<()> {
61        if !self.is_enabled() {
62            return Ok(());
63        }
64        for p in self.list_by_priority() {
65            if ctx.should_skip() {
66                break;
67            }
68            p.on_before_request(ctx).await?;
69        }
70        Ok(())
71    }
72
73    pub async fn trigger_after_response(&self, ctx: &mut PluginContext) -> Result<()> {
74        if !self.is_enabled() {
75            return Ok(());
76        }
77        for p in self.list_by_priority() {
78            if ctx.should_skip() {
79                break;
80            }
81            p.on_after_response(ctx).await?;
82        }
83        Ok(())
84    }
85
86    pub async fn trigger_on_error(&self, ctx: &mut PluginContext) -> Result<()> {
87        if !self.is_enabled() {
88            return Ok(());
89        }
90        for p in self.list_by_priority() {
91            p.on_error(ctx).await?;
92        }
93        Ok(())
94    }
95
96    pub async fn clear(&self) -> Result<()> {
97        let plugins: HashMap<_, _> = std::mem::take(&mut *self.plugins.write().unwrap());
98        for (_, p) in plugins {
99            let _ = p.on_unregister().await;
100        }
101        Ok(())
102    }
103}
104impl Default for PluginRegistry {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110static GLOBAL_REGISTRY: once_cell::sync::Lazy<PluginRegistry> =
111    once_cell::sync::Lazy::new(PluginRegistry::new);
112pub fn get_plugin_registry() -> &'static PluginRegistry {
113    &GLOBAL_REGISTRY
114}