Skip to main content

a2a_protocol_client/
interceptor.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Request/response interceptor infrastructure.
5//!
6//! Interceptors let callers inspect and modify every A2A request before it is
7//! sent and every response after it is received. Common uses include:
8//!
9//! - Adding `Authorization` headers (see [`crate::auth::AuthInterceptor`]).
10//! - Logging or tracing.
11//! - Injecting custom metadata.
12//!
13//! # Example
14//!
15//! ```rust
16//! use a2a_protocol_client::interceptor::{CallInterceptor, ClientRequest, ClientResponse};
17//! use a2a_protocol_client::error::ClientResult;
18//!
19//! struct LoggingInterceptor;
20//!
21//! impl CallInterceptor for LoggingInterceptor {
22//!     fn before<'a>(&'a self, req: &'a mut ClientRequest)
23//!         -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a
24//!     {
25//!         async move { let _ = req; Ok(()) }
26//!     }
27//!     fn after<'a>(&'a self, resp: &'a ClientResponse)
28//!         -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a
29//!     {
30//!         async move { let _ = resp; Ok(()) }
31//!     }
32//! }
33//! ```
34
35use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use crate::error::ClientResult;
41
42// ── ClientRequest ─────────────────────────────────────────────────────────────
43
44/// A logical A2A request as seen by interceptors.
45///
46/// Interceptors may mutate `params` and `extra_headers` before the request is
47/// dispatched to the transport layer.
48#[derive(Debug)]
49pub struct ClientRequest {
50    /// The A2A method name (e.g. `"message/send"`).
51    pub method: String,
52
53    /// Method parameters as a JSON value.
54    pub params: serde_json::Value,
55
56    /// Additional HTTP headers to include with this request.
57    ///
58    /// Auth interceptors use this to inject `Authorization` headers.
59    pub extra_headers: HashMap<String, String>,
60}
61
62impl ClientRequest {
63    /// Creates a new [`ClientRequest`] with the given method and params.
64    #[must_use]
65    pub fn new(method: impl Into<String>, params: serde_json::Value) -> Self {
66        Self {
67            method: method.into(),
68            params,
69            extra_headers: HashMap::new(),
70        }
71    }
72}
73
74// ── ClientResponse ────────────────────────────────────────────────────────────
75
76/// A logical A2A response as seen by interceptors.
77#[derive(Debug)]
78pub struct ClientResponse {
79    /// The A2A method name that produced this response.
80    pub method: String,
81
82    /// The JSON-decoded result value.
83    pub result: serde_json::Value,
84
85    /// The HTTP status code.
86    pub status_code: u16,
87}
88
89// ── CallInterceptor (public async-fn trait) ───────────────────────────────────
90
91/// Hooks called before every A2A request and after every response.
92///
93/// Implement this trait to add cross-cutting concerns such as authentication,
94/// logging, or metrics. Register interceptors via
95/// [`crate::ClientBuilder::with_interceptor`].
96///
97/// # Object-safety note
98///
99/// This trait uses `impl Future` return types with explicit lifetimes, which
100/// is not object-safe. Internally the SDK wraps implementations in a
101/// boxed-future shim. Callers implement the ergonomic trait API.
102pub trait CallInterceptor: Send + Sync + 'static {
103    /// Called before the request is sent.
104    ///
105    /// Mutate `req` to modify parameters or inject headers.
106    fn before<'a>(
107        &'a self,
108        req: &'a mut ClientRequest,
109    ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
110
111    /// Called after a successful response is received.
112    fn after<'a>(
113        &'a self,
114        resp: &'a ClientResponse,
115    ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
116}
117
118// ── Internal boxed trait for object-safe storage ──────────────────────────────
119
120/// Object-safe version of [`CallInterceptor`] used internally.
121///
122/// Not part of the public API; users implement [`CallInterceptor`].
123pub(crate) trait CallInterceptorBoxed: Send + Sync + 'static {
124    fn before_boxed<'a>(
125        &'a self,
126        req: &'a mut ClientRequest,
127    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
128
129    fn after_boxed<'a>(
130        &'a self,
131        resp: &'a ClientResponse,
132    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
133}
134
135impl<T: CallInterceptor> CallInterceptorBoxed for T {
136    fn before_boxed<'a>(
137        &'a self,
138        req: &'a mut ClientRequest,
139    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
140        Box::pin(self.before(req))
141    }
142
143    fn after_boxed<'a>(
144        &'a self,
145        resp: &'a ClientResponse,
146    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
147        Box::pin(self.after(resp))
148    }
149}
150
151impl CallInterceptorBoxed for Box<dyn CallInterceptorBoxed> {
152    fn before_boxed<'a>(
153        &'a self,
154        req: &'a mut ClientRequest,
155    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
156        (**self).before_boxed(req)
157    }
158
159    fn after_boxed<'a>(
160        &'a self,
161        resp: &'a ClientResponse,
162    ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
163        (**self).after_boxed(resp)
164    }
165}
166
167// ── InterceptorChain ──────────────────────────────────────────────────────────
168
169/// An ordered list of [`CallInterceptor`]s applied to every request.
170///
171/// Interceptors run in registration order for `before` and reverse order for
172/// `after` (outermost wraps innermost).
173#[derive(Default)]
174pub struct InterceptorChain {
175    interceptors: Vec<Arc<dyn CallInterceptorBoxed>>,
176}
177
178impl InterceptorChain {
179    /// Creates an empty [`InterceptorChain`].
180    #[must_use]
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    /// Adds an interceptor to the end of the chain.
186    pub fn push<I: CallInterceptor>(&mut self, interceptor: I) {
187        self.interceptors.push(Arc::new(interceptor));
188    }
189
190    /// Returns `true` if no interceptors have been registered.
191    #[must_use]
192    pub fn is_empty(&self) -> bool {
193        self.interceptors.is_empty()
194    }
195
196    /// Runs all `before` hooks in registration order.
197    ///
198    /// # Errors
199    ///
200    /// Returns the first error returned by any interceptor in the chain.
201    pub async fn run_before(&self, req: &mut ClientRequest) -> ClientResult<()> {
202        for interceptor in &self.interceptors {
203            interceptor.before_boxed(req).await?;
204        }
205        Ok(())
206    }
207
208    /// Runs all `after` hooks in reverse registration order.
209    ///
210    /// # Errors
211    ///
212    /// Returns the first error returned by any interceptor in the chain.
213    pub async fn run_after(&self, resp: &ClientResponse) -> ClientResult<()> {
214        for interceptor in self.interceptors.iter().rev() {
215            interceptor.after_boxed(resp).await?;
216        }
217        Ok(())
218    }
219}
220
221impl std::fmt::Debug for InterceptorChain {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        f.debug_struct("InterceptorChain")
224            .field("count", &self.interceptors.len())
225            .finish()
226    }
227}
228
229// ── Tests ─────────────────────────────────────────────────────────────────────
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::sync::atomic::{AtomicUsize, Ordering};
235
236    struct CountingInterceptor(Arc<AtomicUsize>);
237
238    impl CallInterceptor for CountingInterceptor {
239        #[allow(clippy::manual_async_fn)]
240        fn before<'a>(
241            &'a self,
242            _req: &'a mut ClientRequest,
243        ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
244            async move {
245                self.0.fetch_add(1, Ordering::SeqCst);
246                Ok(())
247            }
248        }
249        #[allow(clippy::manual_async_fn)]
250        fn after<'a>(
251            &'a self,
252            _resp: &'a ClientResponse,
253        ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
254            async move {
255                self.0.fetch_add(10, Ordering::SeqCst);
256                Ok(())
257            }
258        }
259    }
260
261    #[tokio::test]
262    async fn chain_runs_before_in_order() {
263        let counter = Arc::new(AtomicUsize::new(0));
264        let mut chain = InterceptorChain::new();
265        chain.push(CountingInterceptor(Arc::clone(&counter)));
266        chain.push(CountingInterceptor(Arc::clone(&counter)));
267
268        let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
269        chain.run_before(&mut req).await.unwrap();
270        assert_eq!(counter.load(Ordering::SeqCst), 2);
271    }
272
273    #[tokio::test]
274    async fn chain_runs_after_in_reverse_order() {
275        let counter = Arc::new(AtomicUsize::new(0));
276        let mut chain = InterceptorChain::new();
277        chain.push(CountingInterceptor(Arc::clone(&counter)));
278
279        let resp = ClientResponse {
280            method: "message/send".into(),
281            result: serde_json::Value::Null,
282            status_code: 200,
283        };
284        chain.run_after(&resp).await.unwrap();
285        assert_eq!(counter.load(Ordering::SeqCst), 10);
286    }
287}