1pub mod latency;
4pub mod log;
5pub mod retry;
6
7pub use latency::*;
8pub use log::*;
9pub use retry::*;
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use rs_genai::prelude::FunctionCall;
16
17use crate::context::AgentEvent;
18use crate::context::InvocationContext;
19use crate::error::{AgentError, ToolError};
20use crate::llm::{LlmRequest, LlmResponse};
21
22#[async_trait]
45pub trait Middleware: Send + Sync + 'static {
46 fn name(&self) -> &str;
48
49 async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
51 Ok(())
52 }
53 async fn after_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
55 Ok(())
56 }
57
58 async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
60 Ok(())
61 }
62 async fn after_tool(
64 &self,
65 _call: &FunctionCall,
66 _result: &serde_json::Value,
67 ) -> Result<(), AgentError> {
68 Ok(())
69 }
70 async fn on_tool_error(
72 &self,
73 _call: &FunctionCall,
74 _err: &ToolError,
75 ) -> Result<(), AgentError> {
76 Ok(())
77 }
78
79 async fn on_event(&self, _event: &AgentEvent) -> Result<(), AgentError> {
81 Ok(())
82 }
83
84 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
86 Ok(())
87 }
88
89 async fn before_model(&self, _request: &LlmRequest) -> Result<Option<LlmResponse>, AgentError> {
92 Ok(None)
93 }
94
95 async fn after_model(
98 &self,
99 _request: &LlmRequest,
100 _response: &LlmResponse,
101 ) -> Result<Option<LlmResponse>, AgentError> {
102 Ok(None)
103 }
104}
105
106#[derive(Clone, Default)]
108pub struct MiddlewareChain {
109 layers: Vec<Arc<dyn Middleware>>,
110}
111
112impl MiddlewareChain {
113 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
120 self.layers.push(middleware);
121 }
122
123 pub fn prepend(&mut self, middleware: Arc<dyn Middleware>) {
125 self.layers.insert(0, middleware);
126 }
127
128 pub async fn run_before_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
130 for m in &self.layers {
131 m.before_agent(ctx).await?;
132 }
133 Ok(())
134 }
135
136 pub async fn run_after_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
138 for m in self.layers.iter().rev() {
139 m.after_agent(ctx).await?;
140 }
141 Ok(())
142 }
143
144 pub async fn run_before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
146 for m in &self.layers {
147 m.before_tool(call).await?;
148 }
149 Ok(())
150 }
151
152 pub async fn run_after_tool(
154 &self,
155 call: &FunctionCall,
156 result: &serde_json::Value,
157 ) -> Result<(), AgentError> {
158 for m in self.layers.iter().rev() {
159 m.after_tool(call, result).await?;
160 }
161 Ok(())
162 }
163
164 pub async fn run_on_tool_error(
166 &self,
167 call: &FunctionCall,
168 err: &ToolError,
169 ) -> Result<(), AgentError> {
170 for m in &self.layers {
171 m.on_tool_error(call, err).await?;
172 }
173 Ok(())
174 }
175
176 pub async fn run_on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
178 for m in &self.layers {
179 m.on_event(event).await?;
180 }
181 Ok(())
182 }
183
184 pub async fn run_on_error(&self, err: &AgentError) -> Result<(), AgentError> {
186 for m in &self.layers {
187 m.on_error(err).await?;
188 }
189 Ok(())
190 }
191
192 pub async fn run_before_model(
194 &self,
195 request: &LlmRequest,
196 ) -> Result<Option<LlmResponse>, AgentError> {
197 for m in &self.layers {
198 if let Some(response) = m.before_model(request).await? {
199 return Ok(Some(response));
200 }
201 }
202 Ok(None)
203 }
204
205 pub async fn run_after_model(
207 &self,
208 request: &LlmRequest,
209 response: &LlmResponse,
210 ) -> Result<Option<LlmResponse>, AgentError> {
211 for m in self.layers.iter().rev() {
212 if let Some(replacement) = m.after_model(request, response).await? {
213 return Ok(Some(replacement));
214 }
215 }
216 Ok(None)
217 }
218
219 pub fn is_empty(&self) -> bool {
221 self.layers.is_empty()
222 }
223
224 pub fn len(&self) -> usize {
226 self.layers.len()
227 }
228}
229
230#[cfg(test)]
233mod tests {
234 use super::*;
235 use std::time::Duration;
236
237 fn test_call(name: &str) -> FunctionCall {
239 FunctionCall {
240 name: name.to_string(),
241 args: serde_json::json!({"key": "value"}),
242 id: None,
243 }
244 }
245
246 struct CountingMiddleware {
249 call_count: Arc<std::sync::atomic::AtomicU32>,
250 }
251
252 #[async_trait]
253 impl Middleware for CountingMiddleware {
254 fn name(&self) -> &str {
255 "counter"
256 }
257
258 async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
259 self.call_count
260 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
261 Ok(())
262 }
263 }
264
265 #[test]
266 fn middleware_chain_ordering() {
267 let chain = MiddlewareChain::new();
268 assert!(chain.is_empty());
269 assert_eq!(chain.len(), 0);
270 }
271
272 #[test]
273 fn middleware_is_object_safe() {
274 fn _assert(_: &dyn Middleware) {}
275 }
276
277 #[test]
278 fn add_middleware_to_chain() {
279 let mut chain = MiddlewareChain::new();
280 let counter = Arc::new(CountingMiddleware {
281 call_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
282 });
283 chain.add(counter);
284 assert_eq!(chain.len(), 1);
285 assert!(!chain.is_empty());
286 }
287
288 #[test]
289 fn chain_is_clone() {
290 let mut chain = MiddlewareChain::new();
291 chain.add(Arc::new(LogMiddleware::new()));
292 let chain2 = chain.clone();
293 assert_eq!(chain2.len(), 1);
294 }
295
296 #[test]
297 fn log_middleware_defaults() {
298 let log = LogMiddleware::new();
299 assert_eq!(log.name(), "log");
300 }
301
302 #[test]
303 fn latency_middleware_defaults() {
304 let lat = LatencyMiddleware::new();
305 assert_eq!(lat.name(), "latency");
306 }
307
308 #[tokio::test]
311 async fn logging_middleware_doesnt_panic() {
312 let log = LogMiddleware::new();
313 let call = test_call("my_tool");
314 let result = serde_json::json!({"ok": true});
315 let tool_err = ToolError::ExecutionFailed("boom".to_string());
316 let agent_err = AgentError::Other("oops".to_string());
317
318 assert!(log.before_tool(&call).await.is_ok());
320 assert!(log.after_tool(&call, &result).await.is_ok());
321 assert!(log.on_tool_error(&call, &tool_err).await.is_ok());
322 assert!(log.on_error(&agent_err).await.is_ok());
323 }
324
325 #[tokio::test]
328 async fn latency_middleware_records_timing() {
329 let lat = LatencyMiddleware::new();
330 let call = test_call("slow_tool");
331 let result = serde_json::json!("done");
332
333 lat.before_tool(&call).await.unwrap();
335 tokio::time::sleep(Duration::from_millis(5)).await;
337 lat.after_tool(&call, &result).await.unwrap();
338
339 let records = lat.tool_latencies();
340 assert_eq!(records.len(), 1);
341 assert_eq!(records[0].name, "slow_tool");
342 assert!(records[0].success);
343 assert!(records[0].elapsed >= Duration::from_millis(1));
344 }
345
346 #[tokio::test]
347 async fn latency_middleware_records_failure() {
348 let lat = LatencyMiddleware::new();
349 let call = test_call("failing_tool");
350 let err = ToolError::ExecutionFailed("kaboom".to_string());
351
352 lat.before_tool(&call).await.unwrap();
353 lat.on_tool_error(&call, &err).await.unwrap();
354
355 let records = lat.tool_latencies();
356 assert_eq!(records.len(), 1);
357 assert_eq!(records[0].name, "failing_tool");
358 assert!(!records[0].success);
359 }
360
361 #[tokio::test]
362 async fn latency_middleware_clear() {
363 let lat = LatencyMiddleware::new();
364 let call = test_call("tool_a");
365 let result = serde_json::json!(null);
366
367 lat.before_tool(&call).await.unwrap();
368 lat.after_tool(&call, &result).await.unwrap();
369 assert_eq!(lat.tool_latencies().len(), 1);
370
371 lat.clear();
372 assert!(lat.tool_latencies().is_empty());
373 }
374
375 #[tokio::test]
378 async fn retry_middleware_tracks_retries() {
379 let retry = RetryMiddleware::new(3);
380 assert_eq!(retry.max_retries(), 3);
381 assert_eq!(retry.attempts(), 0);
382 assert!(!retry.should_retry(), "no error yet, should not retry");
383
384 let err = AgentError::Other("transient".to_string());
386 retry.on_error(&err).await.unwrap();
387 assert!(retry.should_retry(), "error recorded, should retry");
388
389 retry.record_attempt();
391 assert_eq!(retry.attempts(), 1);
392 assert!(!retry.should_retry(), "error was cleared by record_attempt");
393
394 retry.on_error(&err).await.unwrap();
396 assert!(retry.should_retry());
397 retry.record_attempt();
398 assert_eq!(retry.attempts(), 2);
399
400 retry.on_error(&err).await.unwrap();
402 assert!(retry.should_retry());
403 retry.record_attempt();
404 assert_eq!(retry.attempts(), 3);
405
406 retry.on_error(&err).await.unwrap();
408 assert!(!retry.should_retry(), "at max retries, should not retry");
409 }
410
411 #[test]
412 fn retry_middleware_reset() {
413 let retry = RetryMiddleware::new(2);
414 retry
415 .error_count
416 .store(1, std::sync::atomic::Ordering::SeqCst);
417 retry.attempt.store(1, std::sync::atomic::Ordering::SeqCst);
418 retry.reset();
419 assert_eq!(retry.attempts(), 0);
420 assert!(!retry.should_retry());
421 }
422
423 #[test]
426 fn chain_with_all_builtin_middleware() {
427 let mut chain = MiddlewareChain::new();
428 chain.add(Arc::new(LogMiddleware::new()));
429 chain.add(Arc::new(LatencyMiddleware::new()));
430 chain.add(Arc::new(RetryMiddleware::new(3)));
431 assert_eq!(chain.len(), 3);
432 }
433}