1use crate::error::AiError;
4use crate::provider::TextStream;
5use crate::types::*;
6use async_trait::async_trait;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PluginPhase {
13 Pre,
15 Normal,
17 Post,
19}
20
21#[async_trait]
27pub trait Plugin: Send + Sync + Debug + 'static {
28 fn name(&self) -> &str;
30
31 fn enforce(&self) -> PluginPhase {
33 PluginPhase::Normal
34 }
35
36 async fn resolve_model(
44 &self,
45 _model_id: &str,
46 _ctx: &RequestContext,
47 ) -> Result<Option<String>, AiError> {
48 Ok(None)
49 }
50
51 async fn load_template(
55 &self,
56 _template_name: &str,
57 _ctx: &RequestContext,
58 ) -> Result<Option<Vec<Message>>, AiError> {
59 Ok(None)
60 }
61
62 async fn transform_params(
70 &self,
71 params: TextParams,
72 _ctx: &RequestContext,
73 ) -> Result<TextParams, AiError> {
74 Ok(params)
75 }
76
77 async fn transform_result(
81 &self,
82 result: TextResult,
83 _ctx: &RequestContext,
84 ) -> Result<TextResult, AiError> {
85 Ok(result)
86 }
87
88 async fn on_request_start(&self, _ctx: &RequestContext) -> Result<(), AiError> {
93 Ok(())
94 }
95
96 async fn on_request_end(
98 &self,
99 _ctx: &RequestContext,
100 _result: &TextResult,
101 ) -> Result<(), AiError> {
102 Ok(())
103 }
104
105 async fn on_error(&self, _error: &AiError, _ctx: &RequestContext) -> Result<(), AiError> {
107 Ok(())
108 }
109
110 fn transform_stream(&self, stream: Box<TextStream>) -> Box<TextStream> {
117 stream
118 }
119}
120
121#[derive(Debug, Clone)]
125pub struct PluginEngine {
126 plugins: Vec<Arc<dyn Plugin>>,
127}
128
129impl PluginEngine {
130 pub fn new(mut plugins: Vec<Arc<dyn Plugin>>) -> Self {
132 plugins.sort_by_key(|p| match p.enforce() {
134 PluginPhase::Pre => 0,
135 PluginPhase::Normal => 1,
136 PluginPhase::Post => 2,
137 });
138
139 Self { plugins }
140 }
141
142 pub fn plugins(&self) -> &[Arc<dyn Plugin>] {
144 &self.plugins
145 }
146
147 pub async fn resolve_model(
151 &self,
152 model_id: &str,
153 ctx: &RequestContext,
154 ) -> Result<String, AiError> {
155 for plugin in &self.plugins {
156 if let Some(resolved) = plugin.resolve_model(model_id, ctx).await? {
157 return Ok(resolved);
158 }
159 }
160 Ok(model_id.to_string())
161 }
162
163 pub async fn load_template(
165 &self,
166 template_name: &str,
167 ctx: &RequestContext,
168 ) -> Result<Option<Vec<Message>>, AiError> {
169 for plugin in &self.plugins {
170 if let Some(messages) = plugin.load_template(template_name, ctx).await? {
171 return Ok(Some(messages));
172 }
173 }
174 Ok(None)
175 }
176
177 pub async fn transform_params(
181 &self,
182 mut params: TextParams,
183 ctx: &RequestContext,
184 ) -> Result<TextParams, AiError> {
185 for plugin in &self.plugins {
186 params = plugin.transform_params(params, ctx).await?;
187 }
188 Ok(params)
189 }
190
191 pub async fn transform_result(
193 &self,
194 mut result: TextResult,
195 ctx: &RequestContext,
196 ) -> Result<TextResult, AiError> {
197 for plugin in &self.plugins {
198 result = plugin.transform_result(result, ctx).await?;
199 }
200 Ok(result)
201 }
202
203 pub async fn on_request_start(&self, ctx: &RequestContext) -> Result<(), AiError> {
207 use futures::future::try_join_all;
208
209 let futures = self
210 .plugins
211 .iter()
212 .map(|p| p.on_request_start(ctx))
213 .collect::<Vec<_>>();
214
215 try_join_all(futures).await?;
216 Ok(())
217 }
218
219 pub async fn on_request_end(
221 &self,
222 ctx: &RequestContext,
223 result: &TextResult,
224 ) -> Result<(), AiError> {
225 use futures::future::try_join_all;
226
227 let futures = self
228 .plugins
229 .iter()
230 .map(|p| p.on_request_end(ctx, result))
231 .collect::<Vec<_>>();
232
233 try_join_all(futures).await?;
234 Ok(())
235 }
236
237 pub async fn on_error(&self, error: &AiError, ctx: &RequestContext) -> Result<(), AiError> {
239 use futures::future::try_join_all;
240
241 let futures = self
242 .plugins
243 .iter()
244 .map(|p| p.on_error(error, ctx))
245 .collect::<Vec<_>>();
246
247 try_join_all(futures).await?;
248 Ok(())
249 }
250
251 pub fn apply_stream_transforms(&self, stream: Box<TextStream>) -> Box<TextStream> {
255 self.plugins
256 .iter()
257 .fold(stream, |stream, plugin| plugin.transform_stream(stream))
258 }
259}
260
261impl Default for PluginEngine {
262 fn default() -> Self {
263 Self::new(Vec::new())
264 }
265}