use std::sync::Arc;
use async_trait::async_trait;
use super::interceptor::{A2aDelegationContext, A2aError, A2aInterceptor, InterceptorDecision};
#[async_trait]
pub trait TokenValidator: Send + Sync {
async fn validate_token(&self, token: &str) -> Result<Option<String>, A2aError>;
}
#[derive(Clone)]
pub struct BearerAuthInterceptor {
pub validator: Arc<dyn TokenValidator>,
}
impl BearerAuthInterceptor {
pub fn new(validator: Arc<dyn TokenValidator>) -> Self {
Self { validator }
}
fn extract_bearer_token(auth_value: &str) -> Option<&str> {
let trimmed = auth_value.trim();
if trimmed.len() > 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
Some(&trimmed[7..])
} else {
None
}
}
}
#[async_trait]
impl A2aInterceptor for BearerAuthInterceptor {
async fn before_delegation(
&self,
ctx: &mut A2aDelegationContext,
) -> Result<InterceptorDecision, A2aError> {
let auth_header = match ctx.metadata.get("authorization") {
Some(value) => value.clone(),
None => {
return Ok(InterceptorDecision::Reject {
code: -32001,
message: "missing authorization header".to_string(),
});
}
};
let token = match Self::extract_bearer_token(&auth_header) {
Some(t) => t,
None => {
return Ok(InterceptorDecision::Reject {
code: -32001,
message: "invalid authorization header: expected Bearer scheme".to_string(),
});
}
};
match self.validator.validate_token(token).await {
Ok(caller_id) => {
ctx.caller_id = caller_id;
Ok(InterceptorDecision::Continue)
}
Err(err) => Ok(InterceptorDecision::Reject {
code: err.code().unwrap_or(-32001),
message: err.to_string(),
}),
}
}
async fn after_delegation(
&self,
_ctx: &A2aDelegationContext,
_response: &mut serde_json::Value,
) -> Result<(), A2aError> {
Ok(())
}
}
impl std::fmt::Debug for BearerAuthInterceptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerAuthInterceptor").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
struct AcceptAllValidator;
#[async_trait]
impl TokenValidator for AcceptAllValidator {
async fn validate_token(&self, _token: &str) -> Result<Option<String>, A2aError> {
Ok(Some("test-user".to_string()))
}
}
struct RejectAllValidator;
#[async_trait]
impl TokenValidator for RejectAllValidator {
async fn validate_token(&self, _token: &str) -> Result<Option<String>, A2aError> {
Err(A2aError::rejected(-32001, "token rejected"))
}
}
struct NoCaller;
#[async_trait]
impl TokenValidator for NoCaller {
async fn validate_token(&self, _token: &str) -> Result<Option<String>, A2aError> {
Ok(None)
}
}
fn make_ctx_with_auth(auth: &str) -> A2aDelegationContext {
A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: None,
metadata: HashMap::from([("authorization".to_string(), auth.to_string())]),
}
}
fn make_ctx_no_auth() -> A2aDelegationContext {
A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: None,
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_valid_bearer_token_sets_caller_id() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_with_auth("Bearer my-secret-token");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
assert_eq!(ctx.caller_id.as_deref(), Some("test-user"));
}
#[tokio::test]
async fn test_valid_bearer_token_no_caller_id() {
let interceptor = BearerAuthInterceptor::new(Arc::new(NoCaller));
let mut ctx = make_ctx_with_auth("Bearer some-token");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
assert_eq!(ctx.caller_id, None);
}
#[tokio::test]
async fn test_missing_authorization_header_rejects() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_no_auth();
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32001);
assert!(message.contains("missing authorization header"));
}
_ => panic!("expected Reject"),
}
}
#[tokio::test]
async fn test_non_bearer_scheme_rejects() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_with_auth("Basic dXNlcjpwYXNz");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32001);
assert!(message.contains("expected Bearer scheme"));
}
_ => panic!("expected Reject"),
}
}
#[tokio::test]
async fn test_invalid_token_rejects() {
let interceptor = BearerAuthInterceptor::new(Arc::new(RejectAllValidator));
let mut ctx = make_ctx_with_auth("Bearer bad-token");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32001);
assert!(message.contains("token rejected"));
}
_ => panic!("expected Reject"),
}
}
#[tokio::test]
async fn test_bearer_prefix_case_insensitive() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_with_auth("BEARER my-token");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
assert_eq!(ctx.caller_id.as_deref(), Some("test-user"));
}
#[tokio::test]
async fn test_bearer_prefix_with_leading_whitespace() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_with_auth(" Bearer my-token");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
assert!(matches!(decision, InterceptorDecision::Continue));
assert_eq!(ctx.caller_id.as_deref(), Some("test-user"));
}
#[tokio::test]
async fn test_after_delegation_is_noop() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let ctx = A2aDelegationContext {
method: "tasks/send".to_string(),
params: serde_json::json!({}),
caller_id: Some("user".to_string()),
metadata: HashMap::new(),
};
let mut response = serde_json::json!({"result": "ok"});
let result = interceptor.after_delegation(&ctx, &mut response).await;
assert!(result.is_ok());
assert_eq!(response, serde_json::json!({"result": "ok"}));
}
#[tokio::test]
async fn test_empty_bearer_value_rejects() {
let interceptor = BearerAuthInterceptor::new(Arc::new(AcceptAllValidator));
let mut ctx = make_ctx_with_auth("Bearer");
let decision = interceptor.before_delegation(&mut ctx).await.unwrap();
match decision {
InterceptorDecision::Reject { code, message } => {
assert_eq!(code, -32001);
assert!(message.contains("expected Bearer scheme"));
}
_ => panic!("expected Reject"),
}
}
}