use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use a2a_protocol_types::error::A2aResult;
use crate::call_context::CallContext;
pub trait ServerInterceptor: Send + Sync + 'static {
fn before<'a>(
&'a self,
ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
fn after<'a>(
&'a self,
ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
}
#[derive(Default)]
pub struct ServerInterceptorChain {
interceptors: Vec<Arc<dyn ServerInterceptor>>,
}
impl ServerInterceptorChain {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, interceptor: Arc<dyn ServerInterceptor>) {
self.interceptors.push(interceptor);
}
pub async fn run_before(&self, ctx: &CallContext) -> A2aResult<()> {
for interceptor in &self.interceptors {
interceptor.before(ctx).await?;
}
Ok(())
}
pub async fn run_after(&self, ctx: &CallContext) -> A2aResult<()> {
for interceptor in self.interceptors.iter().rev() {
interceptor.after(ctx).await?;
}
Ok(())
}
}
impl fmt::Debug for ServerInterceptorChain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerInterceptorChain")
.field("count", &self.interceptors.len())
.finish()
}
}
use std::fmt;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn debug_shows_count() {
let chain = ServerInterceptorChain::new();
let debug = format!("{chain:?}");
assert!(debug.contains("ServerInterceptorChain"));
assert!(debug.contains("count"));
assert!(debug.contains('0'));
}
struct NoopInterceptor;
impl ServerInterceptor for NoopInterceptor {
fn before<'a>(
&'a self,
_ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
fn after<'a>(
&'a self,
_ctx: &'a CallContext,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
}
#[test]
fn debug_shows_correct_count_after_push() {
let mut chain = ServerInterceptorChain::new();
chain.push(Arc::new(NoopInterceptor));
chain.push(Arc::new(NoopInterceptor));
let debug = format!("{chain:?}");
assert!(debug.contains('2'), "expected count=2 in debug: {debug}");
}
#[tokio::test]
async fn run_before_calls_interceptors_in_order() {
let mut chain = ServerInterceptorChain::new();
chain.push(Arc::new(NoopInterceptor));
chain.push(Arc::new(NoopInterceptor));
let ctx = CallContext::new("test");
chain.run_before(&ctx).await.unwrap();
}
#[tokio::test]
async fn run_after_calls_interceptors_in_reverse() {
let mut chain = ServerInterceptorChain::new();
chain.push(Arc::new(NoopInterceptor));
chain.push(Arc::new(NoopInterceptor));
let ctx = CallContext::new("test");
chain.run_after(&ctx).await.unwrap();
}
#[tokio::test]
async fn empty_chain_succeeds() {
let chain = ServerInterceptorChain::new();
let ctx = CallContext::new("test");
chain.run_before(&ctx).await.unwrap();
chain.run_after(&ctx).await.unwrap();
}
}