use std::sync::Arc;
use async_trait::async_trait;
use super::context::RequestContext;
use super::error::MiddlewareError;
use super::traits::A2aMiddleware;
pub struct MiddlewareStack {
middleware: Vec<Arc<dyn A2aMiddleware>>,
}
impl MiddlewareStack {
pub fn new(middleware: Vec<Arc<dyn A2aMiddleware>>) -> Self {
Self { middleware }
}
pub fn is_empty(&self) -> bool {
self.middleware.is_empty()
}
pub async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
for mw in &self.middleware {
mw.before_request(ctx).await?;
}
Ok(())
}
}
pub struct AnyOfMiddleware {
children: Vec<Arc<dyn A2aMiddleware>>,
}
impl AnyOfMiddleware {
pub fn new(children: Vec<Arc<dyn A2aMiddleware>>) -> Self {
assert!(
!children.is_empty(),
"AnyOfMiddleware requires at least one child"
);
Self { children }
}
}
#[async_trait]
impl A2aMiddleware for AnyOfMiddleware {
async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
let mut errors: Vec<MiddlewareError> = Vec::new();
for child in &self.children {
let mut attempt_ctx = RequestContext {
bearer_token: ctx.bearer_token.clone(),
headers: ctx.headers.clone(),
identity: ctx.identity.clone(),
extensions: ctx.extensions.clone(),
};
match child.before_request(&mut attempt_ctx).await {
Ok(()) => {
ctx.identity = attempt_ctx.identity;
ctx.extensions = attempt_ctx.extensions;
return Ok(());
}
Err(MiddlewareError::Internal(msg)) => {
return Err(MiddlewareError::Internal(msg));
}
Err(e) => {
errors.push(e);
}
}
}
let selected = errors
.into_iter()
.reduce(|champion, challenger| {
if challenger.precedence() > champion.precedence() {
challenger
} else {
champion
}
})
.expect("AnyOfMiddleware has at least one child");
Err(selected)
}
fn security_contribution(&self) -> super::traits::SecurityContribution {
let mut contribution = super::traits::SecurityContribution::new();
for child in &self.children {
let child_contrib = child.security_contribution();
for (name, scheme) in child_contrib.schemes {
contribution.schemes.push((name, scheme));
}
for req in child_contrib.requirements {
contribution.requirements.push(req);
}
}
contribution
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::context::AuthIdentity;
struct SucceedingMiddleware {
owner: String,
}
#[async_trait]
impl A2aMiddleware for SucceedingMiddleware {
async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
ctx.identity = AuthIdentity::Authenticated {
owner: self.owner.clone(),
claims: None,
};
Ok(())
}
}
use crate::middleware::error::AuthFailureKind;
struct FailUnauthenticated {
kind: AuthFailureKind,
}
impl FailUnauthenticated {
fn new(kind: AuthFailureKind) -> Self {
Self { kind }
}
}
#[async_trait]
impl A2aMiddleware for FailUnauthenticated {
async fn before_request(&self, _ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
Err(MiddlewareError::Unauthenticated(self.kind))
}
}
struct FailHttpChallenge {
kind: AuthFailureKind,
}
impl FailHttpChallenge {
fn new(kind: AuthFailureKind) -> Self {
Self { kind }
}
}
#[async_trait]
impl A2aMiddleware for FailHttpChallenge {
async fn before_request(&self, _ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
Err(MiddlewareError::HttpChallenge(self.kind))
}
}
struct FailForbidden {
kind: AuthFailureKind,
}
impl FailForbidden {
fn new(kind: AuthFailureKind) -> Self {
Self { kind }
}
}
#[async_trait]
impl A2aMiddleware for FailForbidden {
async fn before_request(&self, _ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
Err(MiddlewareError::Forbidden(self.kind))
}
}
struct FailInternal {
message: String,
}
#[async_trait]
impl A2aMiddleware for FailInternal {
async fn before_request(&self, _ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
Err(MiddlewareError::Internal(self.message.clone()))
}
}
struct CallTracker {
called: std::sync::Arc<std::sync::atomic::AtomicBool>,
inner: Box<dyn A2aMiddleware>,
}
#[async_trait]
impl A2aMiddleware for CallTracker {
async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
self.called.store(true, std::sync::atomic::Ordering::SeqCst);
self.inner.before_request(ctx).await
}
}
#[tokio::test]
async fn empty_stack_passes_through() {
let stack = MiddlewareStack::new(vec![]);
let mut ctx = RequestContext::new();
assert!(stack.before_request(&mut ctx).await.is_ok());
assert!(!ctx.identity.is_authenticated());
}
#[tokio::test]
async fn stack_error_halts_chain() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let stack = MiddlewareStack::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::MissingCredential)),
Arc::new(CallTracker {
called: called.clone(),
inner: Box::new(SucceedingMiddleware {
owner: "user".into(),
}),
}),
]);
let mut ctx = RequestContext::new();
assert!(stack.before_request(&mut ctx).await.is_err());
assert!(
!called.load(std::sync::atomic::Ordering::SeqCst),
"Second middleware should not be called after first fails"
);
}
#[tokio::test]
async fn anyof_first_success_wins() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::InvalidApiKey)),
Arc::new(SucceedingMiddleware {
owner: "user-b".into(),
}),
]);
let mut ctx = RequestContext::new();
assert!(any.before_request(&mut ctx).await.is_ok());
assert_eq!(ctx.identity.owner(), "user-b");
}
#[tokio::test]
async fn anyof_first_child_succeeds_skips_rest() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let any = AnyOfMiddleware::new(vec![
Arc::new(SucceedingMiddleware {
owner: "user-a".into(),
}),
Arc::new(CallTracker {
called: called.clone(),
inner: Box::new(SucceedingMiddleware {
owner: "user-b".into(),
}),
}),
]);
let mut ctx = RequestContext::new();
any.before_request(&mut ctx).await.unwrap();
assert_eq!(ctx.identity.owner(), "user-a");
assert!(!called.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn anyof_internal_short_circuits() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let any = AnyOfMiddleware::new(vec![
Arc::new(FailInternal {
message: "db down".into(),
}),
Arc::new(CallTracker {
called: called.clone(),
inner: Box::new(SucceedingMiddleware {
owner: "user".into(),
}),
}),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
assert!(matches!(err, MiddlewareError::Internal(_)));
assert!(
!called.load(std::sync::atomic::Ordering::SeqCst),
"Internal error should short-circuit, not try next child"
);
}
#[tokio::test]
async fn anyof_forbidden_beats_unauthenticated() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::MissingCredential)),
Arc::new(FailForbidden::new(AuthFailureKind::InsufficientScope)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
assert!(
matches!(err, MiddlewareError::Forbidden(_)),
"Forbidden should win over Unauthenticated"
);
}
#[tokio::test]
async fn anyof_forbidden_beats_http_challenge() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailHttpChallenge::new(AuthFailureKind::InvalidToken)),
Arc::new(FailForbidden::new(AuthFailureKind::InsufficientScope)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
assert!(
matches!(err, MiddlewareError::Forbidden(_)),
"Forbidden should win over HttpChallenge"
);
}
#[tokio::test]
async fn anyof_http_challenge_beats_unauthenticated() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::InvalidApiKey)),
Arc::new(FailHttpChallenge::new(AuthFailureKind::InvalidToken)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
assert!(
matches!(err, MiddlewareError::HttpChallenge(_)),
"HttpChallenge should win over Unauthenticated"
);
}
#[tokio::test]
async fn anyof_all_unauthenticated_returns_first() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::MissingCredential)),
Arc::new(FailUnauthenticated::new(AuthFailureKind::InvalidApiKey)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
match err {
MiddlewareError::Unauthenticated(kind) => {
assert_eq!(
kind,
AuthFailureKind::MissingCredential,
"Tie should go to first-registered"
);
}
_ => panic!("Expected Unauthenticated"),
}
}
#[test]
fn middleware_error_http_status_mapping() {
assert_eq!(
MiddlewareError::Unauthenticated(AuthFailureKind::MissingCredential).http_status(),
401
);
assert_eq!(
MiddlewareError::HttpChallenge(AuthFailureKind::InvalidToken).http_status(),
401
);
assert_eq!(
MiddlewareError::Forbidden(AuthFailureKind::InsufficientScope).http_status(),
403
);
assert_eq!(MiddlewareError::Internal("x".into()).http_status(), 500);
}
#[tokio::test]
async fn anyof_multiple_http_challenges_selects_first_by_precedence() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailHttpChallenge::new(AuthFailureKind::InvalidToken)),
Arc::new(FailHttpChallenge::new(AuthFailureKind::InvalidToken)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
match err {
MiddlewareError::HttpChallenge(kind) => {
assert_eq!(kind, AuthFailureKind::InvalidToken);
}
other => panic!("Expected HttpChallenge, got: {other:?}"),
}
}
#[tokio::test]
async fn anyof_unauthenticated_and_http_challenge_selects_challenge() {
let any = AnyOfMiddleware::new(vec![
Arc::new(FailUnauthenticated::new(AuthFailureKind::MissingCredential)),
Arc::new(FailHttpChallenge::new(AuthFailureKind::InvalidToken)),
]);
let mut ctx = RequestContext::new();
let err = any.before_request(&mut ctx).await.unwrap_err();
match err {
MiddlewareError::HttpChallenge(kind) => {
assert_eq!(kind, AuthFailureKind::InvalidToken);
}
other => panic!("Expected HttpChallenge, got: {other:?}"),
}
}
#[test]
#[should_panic(expected = "at least one child")]
fn anyof_empty_children_panics() {
AnyOfMiddleware::new(vec![]);
}
#[test]
fn anonymous_is_not_authenticated() {
let id = AuthIdentity::Anonymous;
assert!(!id.is_authenticated());
assert_eq!(id.owner(), "anonymous");
assert!(id.claims().is_none());
}
#[test]
fn authenticated_is_authenticated() {
let id = AuthIdentity::Authenticated {
owner: "user-1".into(),
claims: Some(serde_json::json!({"sub": "user-1"})),
};
assert!(id.is_authenticated());
assert_eq!(id.owner(), "user-1");
assert!(id.claims().is_some());
}
#[test]
fn authenticated_with_literal_anonymous_owner_is_still_authenticated() {
let id = AuthIdentity::Authenticated {
owner: "anonymous".into(),
claims: None,
};
assert!(id.is_authenticated());
}
#[test]
fn api_key_auth_has_no_claims_but_is_authenticated() {
let id = AuthIdentity::Authenticated {
owner: "api-key-user".into(),
claims: None,
};
assert!(id.is_authenticated());
assert!(id.claims().is_none());
assert_eq!(id.owner(), "api-key-user");
}
}