ai_lib_rust/plugins/
base.rs1use crate::Result;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9pub enum PluginPriority {
10 Highest = 0,
11 High = 25,
12 Normal = 50,
13 Low = 75,
14 Lowest = 100,
15}
16impl Default for PluginPriority {
17 fn default() -> Self {
18 PluginPriority::Normal
19 }
20}
21
22#[derive(Debug, Clone, Default)]
23pub struct PluginContext {
24 pub request: Option<serde_json::Value>,
25 pub response: Option<serde_json::Value>,
26 pub request_id: Option<String>,
27 pub model: Option<String>,
28 pub provider: Option<String>,
29 pub metadata: HashMap<String, serde_json::Value>,
30 pub error: Option<String>,
31 pub skip: bool,
32}
33
34impl PluginContext {
35 pub fn new() -> Self {
36 Self::default()
37 }
38 pub fn with_request(mut self, r: serde_json::Value) -> Self {
39 self.request = Some(r);
40 self
41 }
42 pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
43 self.request_id = Some(id.into());
44 self
45 }
46 pub fn with_model(mut self, m: impl Into<String>) -> Self {
47 self.model = Some(m.into());
48 self
49 }
50 pub fn skip(&mut self) {
51 self.skip = true;
52 }
53 pub fn should_skip(&self) -> bool {
54 self.skip
55 }
56 pub fn set_error(&mut self, e: impl Into<String>) {
57 self.error = Some(e.into());
58 }
59 pub fn has_error(&self) -> bool {
60 self.error.is_some()
61 }
62}
63
64#[async_trait]
65pub trait Plugin: Send + Sync {
66 fn name(&self) -> &str;
67 fn priority(&self) -> PluginPriority {
68 PluginPriority::Normal
69 }
70 async fn on_register(&self) -> Result<()> {
71 Ok(())
72 }
73 async fn on_unregister(&self) -> Result<()> {
74 Ok(())
75 }
76 async fn on_before_request(&self, _ctx: &mut PluginContext) -> Result<()> {
77 Ok(())
78 }
79 async fn on_after_response(&self, _ctx: &mut PluginContext) -> Result<()> {
80 Ok(())
81 }
82 async fn on_error(&self, _ctx: &mut PluginContext) -> Result<()> {
83 Ok(())
84 }
85 async fn on_stream_event(
86 &self,
87 _ctx: &mut PluginContext,
88 _event: &serde_json::Value,
89 ) -> Result<()> {
90 Ok(())
91 }
92}
93
94pub struct CompositePlugin {
95 name: String,
96 plugins: Vec<Arc<dyn Plugin>>,
97}
98impl CompositePlugin {
99 pub fn new(name: impl Into<String>) -> Self {
100 Self {
101 name: name.into(),
102 plugins: Vec::new(),
103 }
104 }
105 pub fn add(mut self, p: Arc<dyn Plugin>) -> Self {
106 self.plugins.push(p);
107 self
108 }
109 pub fn len(&self) -> usize {
110 self.plugins.len()
111 }
112 pub fn is_empty(&self) -> bool {
113 self.plugins.is_empty()
114 }
115}
116
117#[async_trait]
118impl Plugin for CompositePlugin {
119 fn name(&self) -> &str {
120 &self.name
121 }
122 async fn on_register(&self) -> Result<()> {
123 for p in &self.plugins {
124 p.on_register().await?;
125 }
126 Ok(())
127 }
128 async fn on_unregister(&self) -> Result<()> {
129 for p in &self.plugins {
130 p.on_unregister().await?;
131 }
132 Ok(())
133 }
134 async fn on_before_request(&self, ctx: &mut PluginContext) -> Result<()> {
135 for p in &self.plugins {
136 if ctx.should_skip() {
137 break;
138 }
139 p.on_before_request(ctx).await?;
140 }
141 Ok(())
142 }
143 async fn on_after_response(&self, ctx: &mut PluginContext) -> Result<()> {
144 for p in &self.plugins {
145 if ctx.should_skip() {
146 break;
147 }
148 p.on_after_response(ctx).await?;
149 }
150 Ok(())
151 }
152 async fn on_error(&self, ctx: &mut PluginContext) -> Result<()> {
153 for p in &self.plugins {
154 p.on_error(ctx).await?;
155 }
156 Ok(())
157 }
158}