1mod context_editing;
2mod human_in_the_loop;
3mod model_call_limit;
4mod model_fallback;
5mod security;
6mod summarization;
7mod todo_list;
8mod tool_call_limit;
9mod tool_retry;
10
11pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
12pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
13pub use model_call_limit::ModelCallLimitMiddleware;
14pub use model_fallback::ModelFallbackMiddleware;
15pub use security::{
16 ConfirmationPolicy, RiskLevel, RuleBasedAnalyzer, SecurityAnalyzer,
17 SecurityConfirmationCallback, SecurityMiddleware, ThresholdConfirmationPolicy,
18};
19pub use summarization::SummarizationMiddleware;
20pub use todo_list::TodoListMiddleware;
21pub use tool_call_limit::ToolCallLimitMiddleware;
22pub use tool_retry::ToolRetryMiddleware;
23
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use serde_json::Value;
28use synaptic_core::{
29 ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
30 ToolDefinition,
31};
32
33#[derive(Debug, Clone)]
42pub struct ModelRequest {
43 pub messages: Vec<Message>,
44 pub tools: Vec<ToolDefinition>,
45 pub tool_choice: Option<ToolChoice>,
46 pub system_prompt: Option<String>,
47}
48
49impl ModelRequest {
50 pub fn to_chat_request(&self) -> ChatRequest {
52 let mut messages = Vec::new();
53 if let Some(ref prompt) = self.system_prompt {
54 messages.push(Message::system(prompt));
55 }
56 messages.extend(self.messages.clone());
57 let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
58 if let Some(ref choice) = self.tool_choice {
59 req = req.with_tool_choice(choice.clone());
60 }
61 req
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ModelResponse {
68 pub message: Message,
69 pub usage: Option<TokenUsage>,
70}
71
72impl From<ChatResponse> for ModelResponse {
73 fn from(resp: ChatResponse) -> Self {
74 Self {
75 message: resp.message,
76 usage: resp.usage,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
87pub struct ToolCallRequest {
88 pub call: ToolCall,
89}
90
91#[async_trait]
100pub trait ModelCaller: Send + Sync {
101 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
102}
103
104#[async_trait]
106pub trait ToolCaller: Send + Sync {
107 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
108}
109
110#[async_trait]
130pub trait AgentMiddleware: Send + Sync {
131 async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
133 Ok(())
134 }
135
136 async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
138 Ok(())
139 }
140
141 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
143 Ok(())
144 }
145
146 async fn after_model(
148 &self,
149 _request: &ModelRequest,
150 _response: &mut ModelResponse,
151 ) -> Result<(), SynapticError> {
152 Ok(())
153 }
154
155 async fn wrap_model_call(
157 &self,
158 request: ModelRequest,
159 next: &dyn ModelCaller,
160 ) -> Result<ModelResponse, SynapticError> {
161 next.call(request).await
162 }
163
164 async fn wrap_tool_call(
166 &self,
167 request: ToolCallRequest,
168 next: &dyn ToolCaller,
169 ) -> Result<Value, SynapticError> {
170 next.call(request).await
171 }
172}
173
174pub struct MiddlewareChain {
180 middlewares: Vec<Arc<dyn AgentMiddleware>>,
181}
182
183impl MiddlewareChain {
184 pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
185 Self { middlewares }
186 }
187
188 pub fn is_empty(&self) -> bool {
189 self.middlewares.is_empty()
190 }
191
192 pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
193 for mw in &self.middlewares {
194 mw.before_agent(messages).await?;
195 }
196 Ok(())
197 }
198
199 pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
200 for mw in self.middlewares.iter().rev() {
201 mw.after_agent(messages).await?;
202 }
203 Ok(())
204 }
205
206 pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
207 for mw in &self.middlewares {
208 mw.before_model(request).await?;
209 }
210 Ok(())
211 }
212
213 pub async fn run_after_model(
214 &self,
215 request: &ModelRequest,
216 response: &mut ModelResponse,
217 ) -> Result<(), SynapticError> {
218 for mw in self.middlewares.iter().rev() {
219 mw.after_model(request, response).await?;
220 }
221 Ok(())
222 }
223
224 pub async fn call_model(
229 &self,
230 mut request: ModelRequest,
231 base: &dyn ModelCaller,
232 ) -> Result<ModelResponse, SynapticError> {
233 self.run_before_model(&mut request).await?;
235
236 let mut response = if self.middlewares.is_empty() {
238 base.call(request.clone()).await?
239 } else {
240 let chain = WrapModelChain {
241 middlewares: &self.middlewares,
242 index: 0,
243 base,
244 };
245 chain.call(request.clone()).await?
246 };
247
248 self.run_after_model(&request, &mut response).await?;
250
251 Ok(response)
252 }
253
254 pub async fn call_tool(
256 &self,
257 request: ToolCallRequest,
258 base: &dyn ToolCaller,
259 ) -> Result<Value, SynapticError> {
260 if self.middlewares.is_empty() {
261 base.call(request).await
262 } else {
263 let chain = WrapToolChain {
264 middlewares: &self.middlewares,
265 index: 0,
266 base,
267 };
268 chain.call(request).await
269 }
270 }
271}
272
273struct WrapModelChain<'a> {
276 middlewares: &'a [Arc<dyn AgentMiddleware>],
277 index: usize,
278 base: &'a dyn ModelCaller,
279}
280
281#[async_trait]
282impl ModelCaller for WrapModelChain<'_> {
283 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
284 if self.index >= self.middlewares.len() {
285 self.base.call(request).await
286 } else {
287 let next = WrapModelChain {
288 middlewares: self.middlewares,
289 index: self.index + 1,
290 base: self.base,
291 };
292 self.middlewares[self.index]
293 .wrap_model_call(request, &next)
294 .await
295 }
296 }
297}
298
299struct WrapToolChain<'a> {
300 middlewares: &'a [Arc<dyn AgentMiddleware>],
301 index: usize,
302 base: &'a dyn ToolCaller,
303}
304
305#[async_trait]
306impl ToolCaller for WrapToolChain<'_> {
307 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
308 if self.index >= self.middlewares.len() {
309 self.base.call(request).await
310 } else {
311 let next = WrapToolChain {
312 middlewares: self.middlewares,
313 index: self.index + 1,
314 base: self.base,
315 };
316 self.middlewares[self.index]
317 .wrap_tool_call(request, &next)
318 .await
319 }
320 }
321}
322
323pub struct BaseChatModelCaller {
329 model: Arc<dyn ChatModel>,
330}
331
332impl BaseChatModelCaller {
333 pub fn new(model: Arc<dyn ChatModel>) -> Self {
334 Self { model }
335 }
336}
337
338#[async_trait]
339impl ModelCaller for BaseChatModelCaller {
340 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
341 let chat_request = request.to_chat_request();
342 let response = self.model.chat(chat_request).await?;
343 Ok(response.into())
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use std::sync::atomic::{AtomicUsize, Ordering};
351
352 struct CountingMiddleware {
353 before_count: AtomicUsize,
354 after_count: AtomicUsize,
355 }
356
357 impl CountingMiddleware {
358 fn new() -> Self {
359 Self {
360 before_count: AtomicUsize::new(0),
361 after_count: AtomicUsize::new(0),
362 }
363 }
364 }
365
366 #[async_trait]
367 impl AgentMiddleware for CountingMiddleware {
368 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
369 self.before_count.fetch_add(1, Ordering::SeqCst);
370 Ok(())
371 }
372
373 async fn after_model(
374 &self,
375 _request: &ModelRequest,
376 _response: &mut ModelResponse,
377 ) -> Result<(), SynapticError> {
378 self.after_count.fetch_add(1, Ordering::SeqCst);
379 Ok(())
380 }
381 }
382
383 #[test]
384 fn middleware_chain_creation() {
385 let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
386 let chain = MiddlewareChain::new(vec![mw]);
387 assert!(!chain.is_empty());
388 }
389
390 #[test]
391 fn empty_middleware_chain() {
392 let chain = MiddlewareChain::new(vec![]);
393 assert!(chain.is_empty());
394 }
395
396 #[test]
397 fn model_request_to_chat_request() {
398 let req = ModelRequest {
399 messages: vec![Message::human("hello")],
400 tools: vec![],
401 tool_choice: None,
402 system_prompt: Some("You are helpful.".to_string()),
403 };
404 let chat_req = req.to_chat_request();
405 assert_eq!(chat_req.messages.len(), 2);
406 assert!(chat_req.messages[0].is_system());
407 assert!(chat_req.messages[1].is_human());
408 }
409
410 #[test]
411 fn model_request_without_system_prompt() {
412 let req = ModelRequest {
413 messages: vec![Message::human("hello")],
414 tools: vec![],
415 tool_choice: None,
416 system_prompt: None,
417 };
418 let chat_req = req.to_chat_request();
419 assert_eq!(chat_req.messages.len(), 1);
420 }
421}