a2a_protocol_server/
interceptor.rs1use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use a2a_protocol_types::error::A2aResult;
17
18use crate::call_context::CallContext;
19
20pub trait ServerInterceptor: Send + Sync + 'static {
29 fn before<'a>(
37 &'a self,
38 ctx: &'a CallContext,
39 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
40
41 fn after<'a>(
50 &'a self,
51 ctx: &'a CallContext,
52 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
53}
54
55#[derive(Default)]
60pub struct ServerInterceptorChain {
61 interceptors: Vec<Arc<dyn ServerInterceptor>>,
62}
63
64impl ServerInterceptorChain {
65 #[must_use]
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 pub fn push(&mut self, interceptor: Arc<dyn ServerInterceptor>) {
73 self.interceptors.push(interceptor);
74 }
75
76 pub async fn run_before(&self, ctx: &CallContext) -> A2aResult<()> {
84 for interceptor in &self.interceptors {
85 interceptor.before(ctx).await?;
86 }
87 Ok(())
88 }
89
90 pub async fn run_after(&self, ctx: &CallContext) -> A2aResult<()> {
98 for interceptor in self.interceptors.iter().rev() {
99 interceptor.after(ctx).await?;
100 }
101 Ok(())
102 }
103}
104
105impl fmt::Debug for ServerInterceptorChain {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 f.debug_struct("ServerInterceptorChain")
108 .field("count", &self.interceptors.len())
109 .finish()
110 }
111}
112
113use std::fmt;
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn debug_shows_count() {
121 let chain = ServerInterceptorChain::new();
122 let debug = format!("{chain:?}");
123 assert!(debug.contains("ServerInterceptorChain"));
124 assert!(debug.contains("count"));
125 assert!(debug.contains('0'));
126 }
127
128 struct NoopInterceptor;
129 impl ServerInterceptor for NoopInterceptor {
130 fn before<'a>(
131 &'a self,
132 _ctx: &'a CallContext,
133 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
134 Box::pin(async { Ok(()) })
135 }
136 fn after<'a>(
137 &'a self,
138 _ctx: &'a CallContext,
139 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
140 Box::pin(async { Ok(()) })
141 }
142 }
143
144 #[test]
145 fn debug_shows_correct_count_after_push() {
146 let mut chain = ServerInterceptorChain::new();
147 chain.push(Arc::new(NoopInterceptor));
148 chain.push(Arc::new(NoopInterceptor));
149 let debug = format!("{chain:?}");
150 assert!(debug.contains('2'), "expected count=2 in debug: {debug}");
151 }
152
153 #[tokio::test]
154 async fn run_before_calls_interceptors_in_order() {
155 let mut chain = ServerInterceptorChain::new();
156 chain.push(Arc::new(NoopInterceptor));
157 chain.push(Arc::new(NoopInterceptor));
158 let ctx = CallContext::new("test");
159 chain.run_before(&ctx).await.unwrap();
160 }
161
162 #[tokio::test]
163 async fn run_after_calls_interceptors_in_reverse() {
164 let mut chain = ServerInterceptorChain::new();
165 chain.push(Arc::new(NoopInterceptor));
166 chain.push(Arc::new(NoopInterceptor));
167 let ctx = CallContext::new("test");
168 chain.run_after(&ctx).await.unwrap();
169 }
170
171 #[tokio::test]
172 async fn empty_chain_succeeds() {
173 let chain = ServerInterceptorChain::new();
174 let ctx = CallContext::new("test");
175 chain.run_before(&ctx).await.unwrap();
176 chain.run_after(&ctx).await.unwrap();
177 }
178}