Skip to main content

rs_adk/middleware/
mod.rs

1//! Middleware trait and chain — wraps agent execution at lifecycle points.
2
3pub mod latency;
4pub mod log;
5pub mod retry;
6
7pub use latency::*;
8pub use log::*;
9pub use retry::*;
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use rs_genai::prelude::FunctionCall;
16
17use crate::context::AgentEvent;
18use crate::context::InvocationContext;
19use crate::error::{AgentError, ToolError};
20use crate::llm::{LlmRequest, LlmResponse};
21
22/// Middleware hooks — all optional, implement only what you need.
23///
24/// # Examples
25///
26/// ```rust,ignore
27/// use async_trait::async_trait;
28/// use rs_adk::middleware::Middleware;
29/// use rs_adk::error::AgentError;
30/// use rs_genai::prelude::FunctionCall;
31///
32/// struct AuditMiddleware;
33///
34/// #[async_trait]
35/// impl Middleware for AuditMiddleware {
36///     fn name(&self) -> &str { "audit" }
37///
38///     async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
39///         println!("Calling tool: {}", call.name);
40///         Ok(())
41///     }
42/// }
43/// ```
44#[async_trait]
45pub trait Middleware: Send + Sync + 'static {
46    /// Unique name for this middleware (used in logging/debugging).
47    fn name(&self) -> &str;
48
49    /// Called before an agent begins execution.
50    async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
51        Ok(())
52    }
53    /// Called after an agent completes execution.
54    async fn after_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
55        Ok(())
56    }
57
58    /// Called before a tool is invoked.
59    async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
60        Ok(())
61    }
62    /// Called after a tool completes successfully.
63    async fn after_tool(
64        &self,
65        _call: &FunctionCall,
66        _result: &serde_json::Value,
67    ) -> Result<(), AgentError> {
68        Ok(())
69    }
70    /// Called when a tool execution fails.
71    async fn on_tool_error(
72        &self,
73        _call: &FunctionCall,
74        _err: &ToolError,
75    ) -> Result<(), AgentError> {
76        Ok(())
77    }
78
79    /// Called when an agent event is emitted.
80    async fn on_event(&self, _event: &AgentEvent) -> Result<(), AgentError> {
81        Ok(())
82    }
83
84    /// Called when an agent error occurs.
85    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
86        Ok(())
87    }
88
89    /// Called before an LLM model call is made. Return `Some(LlmResponse)` to skip the LLM call
90    /// and use the returned response instead (e.g., for caching, guardrails). Return `None` to proceed.
91    async fn before_model(&self, _request: &LlmRequest) -> Result<Option<LlmResponse>, AgentError> {
92        Ok(None)
93    }
94
95    /// Called after an LLM model call completes. Return `Some(LlmResponse)` to replace the model's
96    /// response (e.g., for output filtering, safety). Return `None` to use the original response.
97    async fn after_model(
98        &self,
99        _request: &LlmRequest,
100        _response: &LlmResponse,
101    ) -> Result<Option<LlmResponse>, AgentError> {
102        Ok(None)
103    }
104}
105
106/// Ordered chain of middleware.
107#[derive(Clone, Default)]
108pub struct MiddlewareChain {
109    layers: Vec<Arc<dyn Middleware>>,
110}
111
112impl MiddlewareChain {
113    /// Create a new empty middleware chain.
114    pub fn new() -> Self {
115        Self::default()
116    }
117
118    /// Append a middleware to the end of the chain.
119    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
120        self.layers.push(middleware);
121    }
122
123    /// Prepend a middleware to the front of the chain.
124    pub fn prepend(&mut self, middleware: Arc<dyn Middleware>) {
125        self.layers.insert(0, middleware);
126    }
127
128    /// Run all `before_agent` hooks in order.
129    pub async fn run_before_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
130        for m in &self.layers {
131            m.before_agent(ctx).await?;
132        }
133        Ok(())
134    }
135
136    /// Run all `after_agent` hooks in reverse order.
137    pub async fn run_after_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
138        for m in self.layers.iter().rev() {
139            m.after_agent(ctx).await?;
140        }
141        Ok(())
142    }
143
144    /// Run all `before_tool` hooks in order.
145    pub async fn run_before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
146        for m in &self.layers {
147            m.before_tool(call).await?;
148        }
149        Ok(())
150    }
151
152    /// Run all `after_tool` hooks in reverse order.
153    pub async fn run_after_tool(
154        &self,
155        call: &FunctionCall,
156        result: &serde_json::Value,
157    ) -> Result<(), AgentError> {
158        for m in self.layers.iter().rev() {
159            m.after_tool(call, result).await?;
160        }
161        Ok(())
162    }
163
164    /// Run all `on_tool_error` hooks in order.
165    pub async fn run_on_tool_error(
166        &self,
167        call: &FunctionCall,
168        err: &ToolError,
169    ) -> Result<(), AgentError> {
170        for m in &self.layers {
171            m.on_tool_error(call, err).await?;
172        }
173        Ok(())
174    }
175
176    /// Run all `on_event` hooks in order.
177    pub async fn run_on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
178        for m in &self.layers {
179            m.on_event(event).await?;
180        }
181        Ok(())
182    }
183
184    /// Run all `on_error` hooks in order.
185    pub async fn run_on_error(&self, err: &AgentError) -> Result<(), AgentError> {
186        for m in &self.layers {
187            m.on_error(err).await?;
188        }
189        Ok(())
190    }
191
192    /// Run all `before_model` hooks in order. Returns the first non-None override response.
193    pub async fn run_before_model(
194        &self,
195        request: &LlmRequest,
196    ) -> Result<Option<LlmResponse>, AgentError> {
197        for m in &self.layers {
198            if let Some(response) = m.before_model(request).await? {
199                return Ok(Some(response));
200            }
201        }
202        Ok(None)
203    }
204
205    /// Run all `after_model` hooks in reverse order. Returns the first non-None override response.
206    pub async fn run_after_model(
207        &self,
208        request: &LlmRequest,
209        response: &LlmResponse,
210    ) -> Result<Option<LlmResponse>, AgentError> {
211        for m in self.layers.iter().rev() {
212            if let Some(replacement) = m.after_model(request, response).await? {
213                return Ok(Some(replacement));
214            }
215        }
216        Ok(None)
217    }
218
219    /// Whether the chain has no middleware layers.
220    pub fn is_empty(&self) -> bool {
221        self.layers.is_empty()
222    }
223
224    /// Number of middleware layers in the chain.
225    pub fn len(&self) -> usize {
226        self.layers.len()
227    }
228}
229
230// ── Tests ────────────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use std::time::Duration;
236
237    // Helper: create a FunctionCall for testing.
238    fn test_call(name: &str) -> FunctionCall {
239        FunctionCall {
240            name: name.to_string(),
241            args: serde_json::json!({"key": "value"}),
242            id: None,
243        }
244    }
245
246    // ── Existing tests ──
247
248    struct CountingMiddleware {
249        call_count: Arc<std::sync::atomic::AtomicU32>,
250    }
251
252    #[async_trait]
253    impl Middleware for CountingMiddleware {
254        fn name(&self) -> &str {
255            "counter"
256        }
257
258        async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
259            self.call_count
260                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
261            Ok(())
262        }
263    }
264
265    #[test]
266    fn middleware_chain_ordering() {
267        let chain = MiddlewareChain::new();
268        assert!(chain.is_empty());
269        assert_eq!(chain.len(), 0);
270    }
271
272    #[test]
273    fn middleware_is_object_safe() {
274        fn _assert(_: &dyn Middleware) {}
275    }
276
277    #[test]
278    fn add_middleware_to_chain() {
279        let mut chain = MiddlewareChain::new();
280        let counter = Arc::new(CountingMiddleware {
281            call_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
282        });
283        chain.add(counter);
284        assert_eq!(chain.len(), 1);
285        assert!(!chain.is_empty());
286    }
287
288    #[test]
289    fn chain_is_clone() {
290        let mut chain = MiddlewareChain::new();
291        chain.add(Arc::new(LogMiddleware::new()));
292        let chain2 = chain.clone();
293        assert_eq!(chain2.len(), 1);
294    }
295
296    #[test]
297    fn log_middleware_defaults() {
298        let log = LogMiddleware::new();
299        assert_eq!(log.name(), "log");
300    }
301
302    #[test]
303    fn latency_middleware_defaults() {
304        let lat = LatencyMiddleware::new();
305        assert_eq!(lat.name(), "latency");
306    }
307
308    // ── LogMiddleware tests ──
309
310    #[tokio::test]
311    async fn logging_middleware_doesnt_panic() {
312        let log = LogMiddleware::new();
313        let call = test_call("my_tool");
314        let result = serde_json::json!({"ok": true});
315        let tool_err = ToolError::ExecutionFailed("boom".to_string());
316        let agent_err = AgentError::Other("oops".to_string());
317
318        // All hooks should complete without panic.
319        assert!(log.before_tool(&call).await.is_ok());
320        assert!(log.after_tool(&call, &result).await.is_ok());
321        assert!(log.on_tool_error(&call, &tool_err).await.is_ok());
322        assert!(log.on_error(&agent_err).await.is_ok());
323    }
324
325    // ── LatencyMiddleware tests ──
326
327    #[tokio::test]
328    async fn latency_middleware_records_timing() {
329        let lat = LatencyMiddleware::new();
330        let call = test_call("slow_tool");
331        let result = serde_json::json!("done");
332
333        // Simulate a tool call.
334        lat.before_tool(&call).await.unwrap();
335        // Small delay to ensure non-zero elapsed time.
336        tokio::time::sleep(Duration::from_millis(5)).await;
337        lat.after_tool(&call, &result).await.unwrap();
338
339        let records = lat.tool_latencies();
340        assert_eq!(records.len(), 1);
341        assert_eq!(records[0].name, "slow_tool");
342        assert!(records[0].success);
343        assert!(records[0].elapsed >= Duration::from_millis(1));
344    }
345
346    #[tokio::test]
347    async fn latency_middleware_records_failure() {
348        let lat = LatencyMiddleware::new();
349        let call = test_call("failing_tool");
350        let err = ToolError::ExecutionFailed("kaboom".to_string());
351
352        lat.before_tool(&call).await.unwrap();
353        lat.on_tool_error(&call, &err).await.unwrap();
354
355        let records = lat.tool_latencies();
356        assert_eq!(records.len(), 1);
357        assert_eq!(records[0].name, "failing_tool");
358        assert!(!records[0].success);
359    }
360
361    #[tokio::test]
362    async fn latency_middleware_clear() {
363        let lat = LatencyMiddleware::new();
364        let call = test_call("tool_a");
365        let result = serde_json::json!(null);
366
367        lat.before_tool(&call).await.unwrap();
368        lat.after_tool(&call, &result).await.unwrap();
369        assert_eq!(lat.tool_latencies().len(), 1);
370
371        lat.clear();
372        assert!(lat.tool_latencies().is_empty());
373    }
374
375    // ── RetryMiddleware tests ──
376
377    #[tokio::test]
378    async fn retry_middleware_tracks_retries() {
379        let retry = RetryMiddleware::new(3);
380        assert_eq!(retry.max_retries(), 3);
381        assert_eq!(retry.attempts(), 0);
382        assert!(!retry.should_retry(), "no error yet, should not retry");
383
384        // Simulate an error.
385        let err = AgentError::Other("transient".to_string());
386        retry.on_error(&err).await.unwrap();
387        assert!(retry.should_retry(), "error recorded, should retry");
388
389        // Record first attempt.
390        retry.record_attempt();
391        assert_eq!(retry.attempts(), 1);
392        assert!(!retry.should_retry(), "error was cleared by record_attempt");
393
394        // Another error + attempt cycle.
395        retry.on_error(&err).await.unwrap();
396        assert!(retry.should_retry());
397        retry.record_attempt();
398        assert_eq!(retry.attempts(), 2);
399
400        // Third error + attempt.
401        retry.on_error(&err).await.unwrap();
402        assert!(retry.should_retry());
403        retry.record_attempt();
404        assert_eq!(retry.attempts(), 3);
405
406        // Now at max — should not retry even with new error.
407        retry.on_error(&err).await.unwrap();
408        assert!(!retry.should_retry(), "at max retries, should not retry");
409    }
410
411    #[test]
412    fn retry_middleware_reset() {
413        let retry = RetryMiddleware::new(2);
414        retry
415            .error_count
416            .store(1, std::sync::atomic::Ordering::SeqCst);
417        retry.attempt.store(1, std::sync::atomic::Ordering::SeqCst);
418        retry.reset();
419        assert_eq!(retry.attempts(), 0);
420        assert!(!retry.should_retry());
421    }
422
423    // ── Chain integration test ──
424
425    #[test]
426    fn chain_with_all_builtin_middleware() {
427        let mut chain = MiddlewareChain::new();
428        chain.add(Arc::new(LogMiddleware::new()));
429        chain.add(Arc::new(LatencyMiddleware::new()));
430        chain.add(Arc::new(RetryMiddleware::new(3)));
431        assert_eq!(chain.len(), 3);
432    }
433}