use std::collections::HashMap;
use std::fmt;
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct A2aError {
message: String,
code: Option<i32>,
}
impl A2aError {
pub fn new(message: impl Into<String>) -> Self {
Self { message: message.into(), code: None }
}
pub fn rejected(code: i32, message: impl Into<String>) -> Self {
Self { message: message.into(), code: Some(code) }
}
pub fn code(&self) -> Option<i32> {
self.code
}
}
impl fmt::Display for A2aError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(code) = self.code {
write!(f, "A2A interceptor error (code {code}): {}", self.message)
} else {
write!(f, "A2A interceptor error: {}", self.message)
}
}
}
impl std::error::Error for A2aError {}
impl From<A2aError> for adk_core::AdkError {
fn from(err: A2aError) -> Self {
adk_core::AdkError::agent(err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct A2aDelegationContext {
pub method: String,
pub params: serde_json::Value,
pub caller_id: Option<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub enum InterceptorDecision {
Continue,
ShortCircuit(serde_json::Value),
Reject {
code: i32,
message: String,
},
}
#[async_trait]
pub trait A2aInterceptor: Send + Sync {
async fn before_delegation(
&self,
ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError>;
async fn after_delegation(
&self,
ctx: &A2aDelegationContext,
response: &mut serde_json::Value,
) -> Result<(), A2aError>;
}
pub struct InterceptorChain {
interceptors: Vec<Box<dyn A2aInterceptor>>,
}
impl InterceptorChain {
pub fn new() -> Self {
Self { interceptors: Vec::new() }
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, interceptor: impl A2aInterceptor + 'static) -> Self {
self.interceptors.push(Box::new(interceptor));
self
}
pub fn len(&self) -> usize {
self.interceptors.len()
}
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
pub async fn run_before(
&self,
ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
for interceptor in &self.interceptors {
match interceptor.before_delegation(ctx).await? {
InterceptorDecision::Continue => continue,
decision => return Ok(decision),
}
}
Ok(InterceptorDecision::Continue)
}
pub async fn run_after(
&self,
ctx: &A2aDelegationContext,
response: &mut serde_json::Value,
) -> Result<(), A2aError> {
for interceptor in self.interceptors.iter().rev() {
interceptor.after_delegation(ctx, response).await?;
}
Ok(())
}
}
impl Default for InterceptorChain {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for InterceptorChain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InterceptorChain")
.field("interceptor_count", &self.interceptors.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct CountingInterceptor {
id: usize,
order: std::sync::Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
}
#[async_trait]
impl A2aInterceptor for CountingInterceptor {
async fn before_delegation(
&self,
_ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
self.order.lock().unwrap().push((self.id, "before"));
Ok(InterceptorDecision::Continue)
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
self.order.lock().unwrap().push((self.id, "after"));
Ok(())
}
}
struct RejectingInterceptor;
#[async_trait]
impl A2aInterceptor for RejectingInterceptor {
async fn before_delegation(
&self,
_ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
Ok(InterceptorDecision::Reject { code: -32001, message: "denied".to_string() })
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
Ok(())
}
}
struct ShortCircuitInterceptor;
#[async_trait]
impl A2aInterceptor for ShortCircuitInterceptor {
async fn before_delegation(
&self,
_ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
Ok(InterceptorDecision::ShortCircuit(serde_json::json!({"cached": true})))
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
Ok(())
}
}
fn make_ctx() -> A2aDelegationContext {
A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: None,
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_empty_chain_returns_continue() {
let chain = InterceptorChain::new();
let mut ctx = make_ctx();
let decision = chain.run_before(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
}
#[tokio::test]
async fn test_before_executes_in_registration_order() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let chain = InterceptorChain::new()
.add(CountingInterceptor { id: 1, order: order.clone() })
.add(CountingInterceptor { id: 2, order: order.clone() })
.add(CountingInterceptor { id: 3, order: order.clone() });
let mut ctx = make_ctx();
chain.run_before(&mut ctx).await.unwrap();
let recorded = order.lock().unwrap();
assert_eq!(recorded.as_slice(), &[(1, "before"), (2, "before"), (3, "before")]);
}
#[tokio::test]
async fn test_after_executes_in_reverse_order() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let chain = InterceptorChain::new()
.add(CountingInterceptor { id: 1, order: order.clone() })
.add(CountingInterceptor { id: 2, order: order.clone() })
.add(CountingInterceptor { id: 3, order: order.clone() });
let ctx = make_ctx();
let mut response = serde_json::json!({"result": "ok"});
chain.run_after(&ctx, &mut response).await.unwrap();
let recorded = order.lock().unwrap();
assert_eq!(recorded.as_slice(), &[(3, "after"), (2, "after"), (1, "after")]);
}
#[tokio::test]
async fn test_reject_stops_chain() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let chain = InterceptorChain::new()
.add(CountingInterceptor { id: 1, order: order.clone() })
.add(RejectingInterceptor)
.add(CountingInterceptor { id: 3, order: order.clone() });
let mut ctx = make_ctx();
let decision = chain.run_before(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Reject { code: -32001, .. }));
let recorded = order.lock().unwrap();
assert_eq!(recorded.as_slice(), &[(1, "before")]);
}
#[tokio::test]
async fn test_short_circuit_stops_chain() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let chain = InterceptorChain::new()
.add(CountingInterceptor { id: 1, order: order.clone() })
.add(ShortCircuitInterceptor)
.add(CountingInterceptor { id: 3, order: order.clone() });
let mut ctx = make_ctx();
let decision = chain.run_before(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::ShortCircuit(val) => {
assert_eq!(val, serde_json::json!({"cached": true}));
}
_ => panic!("expected ShortCircuit"),
}
let recorded = order.lock().unwrap();
assert_eq!(recorded.as_slice(), &[(1, "before")]);
}
#[tokio::test]
async fn test_chain_len_and_is_empty() {
let chain = InterceptorChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
let chain = chain.add(ShortCircuitInterceptor);
assert!(!chain.is_empty());
assert_eq!(chain.len(), 1);
}
#[tokio::test]
async fn test_context_mutation_propagates() {
struct MutatingInterceptor;
#[async_trait]
impl A2aInterceptor for MutatingInterceptor {
async fn before_delegation(
&self,
ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
ctx.metadata.insert("enriched".to_string(), "true".to_string());
Ok(InterceptorDecision::Continue)
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
Ok(())
}
}
let chain = InterceptorChain::new().add(MutatingInterceptor);
let mut ctx = make_ctx();
chain.run_before(&mut ctx).await.unwrap();
assert_eq!(ctx.metadata.get("enriched"), Some(&"true".to_string()));
}
#[tokio::test]
async fn test_a2a_error_display() {
let err = A2aError::new("something failed");
assert_eq!(err.to_string(), "A2A interceptor error: something failed");
let err = A2aError::rejected(-32001, "rate limited");
assert_eq!(err.to_string(), "A2A interceptor error (code -32001): rate limited");
assert_eq!(err.code(), Some(-32001));
}
}