Skip to main content

a2a_client/
middleware.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use a2a::A2AError;
4use async_trait::async_trait;
5
6use crate::transport::ServiceParams;
7
8/// Interceptor for modifying requests and responses at the client level.
9///
10/// Interceptors are called in order for `before`, and in reverse order for `after`.
11#[async_trait]
12pub trait CallInterceptor: Send + Sync {
13    /// Called before sending a request. Can modify params (e.g., add auth headers).
14    async fn before(&self, method: &str, params: &mut ServiceParams) -> Result<(), A2AError> {
15        let _ = (method, params);
16        Ok(())
17    }
18
19    /// Called after receiving a response.
20    async fn after(&self, method: &str, result: &Result<(), A2AError>) -> Result<(), A2AError> {
21        let _ = (method, result);
22        Ok(())
23    }
24}
25
26/// Logging interceptor using `tracing`.
27pub struct LoggingInterceptor;
28
29#[async_trait]
30impl CallInterceptor for LoggingInterceptor {
31    async fn before(&self, method: &str, _params: &mut ServiceParams) -> Result<(), A2AError> {
32        tracing::info!(method = method, "A2A client request");
33        Ok(())
34    }
35
36    async fn after(&self, method: &str, result: &Result<(), A2AError>) -> Result<(), A2AError> {
37        match result {
38            Ok(()) => tracing::info!(method = method, "A2A client response"),
39            Err(e) => tracing::warn!(method = method, error = %e, "A2A client error"),
40        }
41        Ok(())
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    struct NoopInterceptor;
50
51    #[async_trait]
52    impl CallInterceptor for NoopInterceptor {}
53
54    #[tokio::test]
55    async fn test_default_before() {
56        let interceptor = NoopInterceptor;
57        let mut params = ServiceParams::new();
58        let result = interceptor.before("test", &mut params).await;
59        assert!(result.is_ok());
60    }
61
62    #[tokio::test]
63    async fn test_default_after() {
64        let interceptor = NoopInterceptor;
65        let result = interceptor.after("test", &Ok(())).await;
66        assert!(result.is_ok());
67    }
68
69    #[tokio::test]
70    async fn test_default_after_with_error() {
71        let interceptor = NoopInterceptor;
72        let err = Err(A2AError::internal("fail"));
73        let result = interceptor.after("test", &err).await;
74        assert!(result.is_ok());
75    }
76
77    #[tokio::test]
78    async fn test_logging_interceptor_before() {
79        let interceptor = LoggingInterceptor;
80        let mut params = ServiceParams::new();
81        let result = interceptor
82            .before(a2a::jsonrpc::methods::SEND_MESSAGE, &mut params)
83            .await;
84        assert!(result.is_ok());
85    }
86
87    #[tokio::test]
88    async fn test_logging_interceptor_after_ok() {
89        let interceptor = LoggingInterceptor;
90        let result = interceptor
91            .after(a2a::jsonrpc::methods::SEND_MESSAGE, &Ok(()))
92            .await;
93        assert!(result.is_ok());
94    }
95
96    #[tokio::test]
97    async fn test_logging_interceptor_after_err() {
98        let interceptor = LoggingInterceptor;
99        let err = Err(A2AError::internal("boom"));
100        let result = interceptor
101            .after(a2a::jsonrpc::methods::SEND_MESSAGE, &err)
102            .await;
103        assert!(result.is_ok());
104    }
105}