a2a_protocol_client/
interceptor.rs1use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use crate::error::ClientResult;
41
42#[derive(Debug)]
49pub struct ClientRequest {
50 pub method: String,
52
53 pub params: serde_json::Value,
55
56 pub extra_headers: HashMap<String, String>,
60}
61
62impl ClientRequest {
63 #[must_use]
65 pub fn new(method: impl Into<String>, params: serde_json::Value) -> Self {
66 Self {
67 method: method.into(),
68 params,
69 extra_headers: HashMap::new(),
70 }
71 }
72}
73
74#[derive(Debug)]
78pub struct ClientResponse {
79 pub method: String,
81
82 pub result: serde_json::Value,
84
85 pub status_code: u16,
87}
88
89pub trait CallInterceptor: Send + Sync + 'static {
103 fn before<'a>(
107 &'a self,
108 req: &'a mut ClientRequest,
109 ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
110
111 fn after<'a>(
113 &'a self,
114 resp: &'a ClientResponse,
115 ) -> impl Future<Output = ClientResult<()>> + Send + 'a;
116}
117
118pub(crate) trait CallInterceptorBoxed: Send + Sync + 'static {
124 fn before_boxed<'a>(
125 &'a self,
126 req: &'a mut ClientRequest,
127 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
128
129 fn after_boxed<'a>(
130 &'a self,
131 resp: &'a ClientResponse,
132 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>>;
133}
134
135impl<T: CallInterceptor> CallInterceptorBoxed for T {
136 fn before_boxed<'a>(
137 &'a self,
138 req: &'a mut ClientRequest,
139 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
140 Box::pin(self.before(req))
141 }
142
143 fn after_boxed<'a>(
144 &'a self,
145 resp: &'a ClientResponse,
146 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
147 Box::pin(self.after(resp))
148 }
149}
150
151impl CallInterceptorBoxed for Box<dyn CallInterceptorBoxed> {
152 fn before_boxed<'a>(
153 &'a self,
154 req: &'a mut ClientRequest,
155 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
156 (**self).before_boxed(req)
157 }
158
159 fn after_boxed<'a>(
160 &'a self,
161 resp: &'a ClientResponse,
162 ) -> Pin<Box<dyn Future<Output = ClientResult<()>> + Send + 'a>> {
163 (**self).after_boxed(resp)
164 }
165}
166
167#[derive(Default)]
174pub struct InterceptorChain {
175 interceptors: Vec<Arc<dyn CallInterceptorBoxed>>,
176}
177
178impl InterceptorChain {
179 #[must_use]
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 pub fn push<I: CallInterceptor>(&mut self, interceptor: I) {
187 self.interceptors.push(Arc::new(interceptor));
188 }
189
190 #[must_use]
192 pub fn is_empty(&self) -> bool {
193 self.interceptors.is_empty()
194 }
195
196 pub async fn run_before(&self, req: &mut ClientRequest) -> ClientResult<()> {
202 for interceptor in &self.interceptors {
203 interceptor.before_boxed(req).await?;
204 }
205 Ok(())
206 }
207
208 pub async fn run_after(&self, resp: &ClientResponse) -> ClientResult<()> {
214 for interceptor in self.interceptors.iter().rev() {
215 interceptor.after_boxed(resp).await?;
216 }
217 Ok(())
218 }
219}
220
221impl std::fmt::Debug for InterceptorChain {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("InterceptorChain")
224 .field("count", &self.interceptors.len())
225 .finish()
226 }
227}
228
229#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::sync::atomic::{AtomicUsize, Ordering};
235
236 struct CountingInterceptor(Arc<AtomicUsize>);
237
238 impl CallInterceptor for CountingInterceptor {
239 #[allow(clippy::manual_async_fn)]
240 fn before<'a>(
241 &'a self,
242 _req: &'a mut ClientRequest,
243 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
244 async move {
245 self.0.fetch_add(1, Ordering::SeqCst);
246 Ok(())
247 }
248 }
249 #[allow(clippy::manual_async_fn)]
250 fn after<'a>(
251 &'a self,
252 _resp: &'a ClientResponse,
253 ) -> impl std::future::Future<Output = ClientResult<()>> + Send + 'a {
254 async move {
255 self.0.fetch_add(10, Ordering::SeqCst);
256 Ok(())
257 }
258 }
259 }
260
261 #[tokio::test]
262 async fn chain_runs_before_in_order() {
263 let counter = Arc::new(AtomicUsize::new(0));
264 let mut chain = InterceptorChain::new();
265 chain.push(CountingInterceptor(Arc::clone(&counter)));
266 chain.push(CountingInterceptor(Arc::clone(&counter)));
267
268 let mut req = ClientRequest::new("message/send", serde_json::Value::Null);
269 chain.run_before(&mut req).await.unwrap();
270 assert_eq!(counter.load(Ordering::SeqCst), 2);
271 }
272
273 #[tokio::test]
274 async fn chain_runs_after_in_reverse_order() {
275 let counter = Arc::new(AtomicUsize::new(0));
276 let mut chain = InterceptorChain::new();
277 chain.push(CountingInterceptor(Arc::clone(&counter)));
278
279 let resp = ClientResponse {
280 method: "message/send".into(),
281 result: serde_json::Value::Null,
282 status_code: 200,
283 };
284 chain.run_after(&resp).await.unwrap();
285 assert_eq!(counter.load(Ordering::SeqCst), 10);
286 }
287}