Skip to main content

ai_lib_rust/plugins/
base.rs

1//! Base plugin types.
2
3use crate::Result;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9pub enum PluginPriority {
10    Highest = 0,
11    High = 25,
12    Normal = 50,
13    Low = 75,
14    Lowest = 100,
15}
16impl Default for PluginPriority {
17    fn default() -> Self {
18        PluginPriority::Normal
19    }
20}
21
22#[derive(Debug, Clone, Default)]
23pub struct PluginContext {
24    pub request: Option<serde_json::Value>,
25    pub response: Option<serde_json::Value>,
26    pub request_id: Option<String>,
27    pub model: Option<String>,
28    pub provider: Option<String>,
29    pub metadata: HashMap<String, serde_json::Value>,
30    pub error: Option<String>,
31    pub skip: bool,
32}
33
34impl PluginContext {
35    pub fn new() -> Self {
36        Self::default()
37    }
38    pub fn with_request(mut self, r: serde_json::Value) -> Self {
39        self.request = Some(r);
40        self
41    }
42    pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
43        self.request_id = Some(id.into());
44        self
45    }
46    pub fn with_model(mut self, m: impl Into<String>) -> Self {
47        self.model = Some(m.into());
48        self
49    }
50    pub fn skip(&mut self) {
51        self.skip = true;
52    }
53    pub fn should_skip(&self) -> bool {
54        self.skip
55    }
56    pub fn set_error(&mut self, e: impl Into<String>) {
57        self.error = Some(e.into());
58    }
59    pub fn has_error(&self) -> bool {
60        self.error.is_some()
61    }
62}
63
64#[async_trait]
65pub trait Plugin: Send + Sync {
66    fn name(&self) -> &str;
67    fn priority(&self) -> PluginPriority {
68        PluginPriority::Normal
69    }
70    async fn on_register(&self) -> Result<()> {
71        Ok(())
72    }
73    async fn on_unregister(&self) -> Result<()> {
74        Ok(())
75    }
76    async fn on_before_request(&self, _ctx: &mut PluginContext) -> Result<()> {
77        Ok(())
78    }
79    async fn on_after_response(&self, _ctx: &mut PluginContext) -> Result<()> {
80        Ok(())
81    }
82    async fn on_error(&self, _ctx: &mut PluginContext) -> Result<()> {
83        Ok(())
84    }
85    async fn on_stream_event(
86        &self,
87        _ctx: &mut PluginContext,
88        _event: &serde_json::Value,
89    ) -> Result<()> {
90        Ok(())
91    }
92}
93
94pub struct CompositePlugin {
95    name: String,
96    plugins: Vec<Arc<dyn Plugin>>,
97}
98impl CompositePlugin {
99    pub fn new(name: impl Into<String>) -> Self {
100        Self {
101            name: name.into(),
102            plugins: Vec::new(),
103        }
104    }
105    pub fn add(mut self, p: Arc<dyn Plugin>) -> Self {
106        self.plugins.push(p);
107        self
108    }
109    pub fn len(&self) -> usize {
110        self.plugins.len()
111    }
112    pub fn is_empty(&self) -> bool {
113        self.plugins.is_empty()
114    }
115}
116
117#[async_trait]
118impl Plugin for CompositePlugin {
119    fn name(&self) -> &str {
120        &self.name
121    }
122    async fn on_register(&self) -> Result<()> {
123        for p in &self.plugins {
124            p.on_register().await?;
125        }
126        Ok(())
127    }
128    async fn on_unregister(&self) -> Result<()> {
129        for p in &self.plugins {
130            p.on_unregister().await?;
131        }
132        Ok(())
133    }
134    async fn on_before_request(&self, ctx: &mut PluginContext) -> Result<()> {
135        for p in &self.plugins {
136            if ctx.should_skip() {
137                break;
138            }
139            p.on_before_request(ctx).await?;
140        }
141        Ok(())
142    }
143    async fn on_after_response(&self, ctx: &mut PluginContext) -> Result<()> {
144        for p in &self.plugins {
145            if ctx.should_skip() {
146                break;
147            }
148            p.on_after_response(ctx).await?;
149        }
150        Ok(())
151    }
152    async fn on_error(&self, ctx: &mut PluginContext) -> Result<()> {
153        for p in &self.plugins {
154            p.on_error(ctx).await?;
155        }
156        Ok(())
157    }
158}