a2a_protocol_client/
interceptor.rs1use std::collections::HashMap;
38use std::future::Future;
39use std::pin::Pin;
40use std::sync::Arc;
41
42use crate::error::ClientResult;
43
44#[derive(Debug)]
51pub struct ClientRequest {
52 pub method: String,
54
55 pub params: serde_json::Value,
57
58 pub extra_headers: HashMap<String, String>,
62}
63
64impl ClientRequest {
65 #[must_use]
67 pub fn new(method: impl Into<String>, params: serde_json::Value) -> Self {
68 Self {
69 method: method.into(),
70 params,
71 extra_headers: HashMap::new(),
72 }
73 }
74}
75
76#[derive(Debug)]
80pub struct ClientResponse {
81 pub method: String,
83
84 pub result: serde_json::Value,
86
87 pub status_code: u16,
95}
96
97pub trait CallInterceptor: Send + Sync + 'static {
111 fn before<'a>(
115 &'a self,
116 req: &'a mut ClientRequest,
117 ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
118
119 fn after<'a>(
121 &'a self,
122 resp: &'a ClientResponse,
123 ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
124}
125
126pub(crate) trait CallInterceptorBoxed: Send + Sync + 'static {
132 fn before_boxed<'a>(
133 &'a self,
134 req: &'a mut ClientRequest,
135 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
136
137 fn after_boxed<'a>(
138 &'a self,
139 resp: &'a ClientResponse,
140 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
141}
142
143impl<T: CallInterceptor> CallInterceptorBoxed for T {
144 fn before_boxed<'a>(
145 &'a self,
146 req: &'a mut ClientRequest,
147 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
148 Box::pin(self.before(req))
149 }
150
151 fn after_boxed<'a>(
152 &'a self,
153 resp: &'a ClientResponse,
154 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
155 Box::pin(self.after(resp))
156 }
157}
158
159impl CallInterceptorBoxed for Box<dyn CallInterceptorBoxed> {
160 fn before_boxed<'a>(
161 &'a self,
162 req: &'a mut ClientRequest,
163 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
164 (**self).before_boxed(req)
165 }
166
167 fn after_boxed<'a>(
168 &'a self,
169 resp: &'a ClientResponse,
170 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
171 (**self).after_boxed(resp)
172 }
173}
174
175#[derive(Default)]
182pub struct InterceptorChain {
183 interceptors: Vec<Arc<dyn CallInterceptorBoxed>>,
184}
185
186impl InterceptorChain {
187 #[must_use]
189 pub fn new() -> Self {
190 Self::default()
191 }
192
193 pub fn push<I: CallInterceptor>(&mut self, interceptor: I) {
195 self.interceptors.push(Arc::new(interceptor));
196 }
197
198 #[must_use]
200 pub fn is_empty(&self) -> bool {
201 self.interceptors.is_empty()
202 }
203
204 pub async fn run_before(&self, req: &mut ClientRequest) -> ClientResult<()> {
210 for interceptor in &self.interceptors {
211 interceptor.before_boxed(req).await?;
212 }
213 Ok(())
214 }
215
216 pub async fn run_after(&self, resp: &ClientResponse) -> ClientResult<()> {
222 for interceptor in self.interceptors.iter().rev() {
223 interceptor.after_boxed(resp).await?;
224 }
225 Ok(())
226 }
227}
228
229impl std::fmt::Debug for InterceptorChain {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("InterceptorChain")
232 .field("count", &self.interceptors.len())
233 .finish()
234 }
235}
236
237#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::sync::atomic::{AtomicUsize, Ordering};
243
244 struct CountingInterceptor(Arc<AtomicUsize>);
245
246 impl CallInterceptor for CountingInterceptor {
247 #[allow(clippy::manual_async_fn)]
248 fn before<'a>(
249 &'a self,
250 _req: &'a mut ClientRequest,
251 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
252 async move {
253 self.0.fetch_add(1, Ordering::SeqCst);
254 Ok(())
255 }
256 }
257 #[allow(clippy::manual_async_fn)]
258 fn after<'a>(
259 &'a self,
260 _resp: &'a ClientResponse,
261 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
262 async move {
263 self.0.fetch_add(10, Ordering::SeqCst);
264 Ok(())
265 }
266 }
267 }
268
269 #[test]
270 fn chain_is_empty_when_new() {
271 let chain = InterceptorChain::new();
272 assert!(chain.is_empty(), "new chain should be empty");
273 }
274
275 #[test]
276 fn chain_is_not_empty_after_push() {
277 let counter = Arc::new(AtomicUsize::new(0));
278 let mut chain = InterceptorChain::new();
279 chain.push(CountingInterceptor(Arc::clone(&counter)));
280 assert!(
281 !chain.is_empty(),
282 "chain with one interceptor should not be empty"
283 );
284 }
285
286 #[tokio::test]
287 async fn chain_runs_before_in_order() {
288 let counter = Arc::new(AtomicUsize::new(0));
289 let mut chain = InterceptorChain::new();
290 chain.push(CountingInterceptor(Arc::clone(&counter)));
291 chain.push(CountingInterceptor(Arc::clone(&counter)));
292
293 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
294 chain.run_before(&mut req).await.unwrap();
295 assert_eq!(counter.load(Ordering::SeqCst), 2);
296 }
297
298 #[tokio::test]
299 async fn chain_runs_after_in_reverse_order() {
300 let counter = Arc::new(AtomicUsize::new(0));
301 let mut chain = InterceptorChain::new();
302 chain.push(CountingInterceptor(Arc::clone(&counter)));
303
304 let resp = ClientResponse {
305 method: "message/send".into(),
306 result: serde_json::Value::Null,
307 status_code: 200,
308 };
309 chain.run_after(&resp).await.unwrap();
310 assert_eq!(counter.load(Ordering::SeqCst), 10);
311 }
312
313 #[tokio::test]
316 async fn boxed_interceptor_delegates_before_and_after() {
317 let counter = Arc::new(AtomicUsize::new(0));
318 let interceptor = CountingInterceptor(Arc::clone(&counter));
319 let boxed: Box<dyn CallInterceptorBoxed> = Box::new(interceptor);
321
322 let mut req = ClientRequest::new("test", serde_json::Value::Null);
323 boxed.before_boxed(&mut req).await.unwrap();
324 assert_eq!(
325 counter.load(Ordering::SeqCst),
326 1,
327 "before_boxed should delegate"
328 );
329
330 let resp = ClientResponse {
331 method: "test".into(),
332 result: serde_json::Value::Null,
333 status_code: 200,
334 };
335 boxed.after_boxed(&resp).await.unwrap();
336 assert_eq!(
337 counter.load(Ordering::SeqCst),
338 11,
339 "after_boxed should delegate"
340 );
341
342 let double_boxed: Box<dyn CallInterceptorBoxed> = Box::new(boxed);
344 double_boxed.before_boxed(&mut req).await.unwrap();
345 assert_eq!(
346 counter.load(Ordering::SeqCst),
347 12,
348 "double-boxed before should delegate"
349 );
350 double_boxed.after_boxed(&resp).await.unwrap();
351 assert_eq!(
352 counter.load(Ordering::SeqCst),
353 22,
354 "double-boxed after should delegate"
355 );
356 }
357}