velocia 0.3.1

velocia – production-ready AI agent framework using ADK-Rust, A2A protocol, and AWS DynamoDB
use std::collections::HashSet;
use std::sync::Arc;

use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde_json::json;
use tracing::error;

use crate::a2a::types::AgentCard;
use crate::config::auth::{AuthConfig, AuthType};
use crate::error::AgentKitError;

use super::cognito::CognitoM2MCredentialService;
use super::noauth::NoAuthCredentialService;
use super::strategy::AuthStrategy;

pub struct AuthMiddlewareState {
    pub agent_card: AgentCard,
    pub auth_config: AuthConfig,
    pub public_paths: HashSet<String>,
    pub strategy: Arc<dyn AuthStrategy>,
}

impl AuthMiddlewareState {
    pub fn new(
        agent_card: AgentCard,
        auth_config: AuthConfig,
        public_paths: Vec<String>,
    ) -> Self {
        let strategy: Arc<dyn AuthStrategy> = match auth_config.auth_type {
            AuthType::Cognito => Arc::new(CognitoM2MCredentialService::from_env()),
            AuthType::NoAuth => Arc::new(NoAuthCredentialService),
        };
        Self {
            agent_card,
            auth_config,
            public_paths: public_paths.into_iter().collect(),
            strategy,
        }
    }
}

fn unauthorized(reason: &str) -> Response {
    (
        StatusCode::UNAUTHORIZED,
        Json(json!({"error": "unauthorized", "reason": reason})),
    )
        .into_response()
}

fn forbidden(reason: &str) -> Response {
    (
        StatusCode::FORBIDDEN,
        Json(json!({"error": "forbidden", "reason": reason})),
    )
        .into_response()
}

fn service_unavailable(reason: &str) -> Response {
    (
        StatusCode::SERVICE_UNAVAILABLE,
        Json(json!({"error": "service_unavailable", "reason": reason})),
    )
        .into_response()
}

/// Axum middleware that validates JWT bearer tokens for A2A endpoints.
pub async fn auth_middleware(
    State(state): State<Arc<AuthMiddlewareState>>,
    request: Request<Body>,
    next: Next,
) -> Response {
    let path = request.uri().path().to_string();

    // Allow public paths and no-auth mode to pass through.
    if state.public_paths.contains(&path) || state.auth_config.is_no_auth() {
        return next.run(request).await;
    }

    let headers = request.headers().clone();

    // Retrieve the first security scheme definition.
    let scheme = match state.agent_card.security_schemes.as_ref().and_then(|ss| {
        state
            .auth_config
            .first_scheme()
            .and_then(|(name, _)| ss.get(&name).cloned())
    }) {
        Some(v) => {
            // Convert raw JSON to SecurityScheme
            match serde_json::from_value::<crate::config::auth::SecurityScheme>(v) {
                Ok(s) => s,
                Err(e) => {
                    error!("Failed to parse security scheme: {e}");
                    return unauthorized("Invalid security scheme configuration");
                }
            }
        }
        None => return unauthorized("No security scheme configured"),
    };

    // Fetch JWKS keys.
    let keys = match state.strategy.get_keys(&scheme).await {
        Ok(k) => k,
        Err(AgentKitError::JwksFetch(msg)) => {
            error!("JWKS fetch error: {msg}");
            return service_unavailable("Unable to fetch JWKS");
        }
        Err(e) => {
            error!("Auth setup error: {e}");
            return unauthorized("Auth configuration error");
        }
    };

    // Extract bearer token.
    let token = match state.strategy.get_token(&headers) {
        Ok(t) => t,
        Err(AgentKitError::InvalidAuthHeader) => return unauthorized("Missing or malformed Authorization header"),
        Err(e) => {
            error!("Token extraction error: {e}");
            return unauthorized("Token extraction failed");
        }
    };

    // Validate the token.
    let claims = match state.strategy.validate_token(&token, &keys) {
        Ok(c) => c,
        Err(AgentKitError::JwtValidation(msg)) => {
            error!("JWT validation failed: {msg}");
            return unauthorized(&format!("Invalid JWT: {msg}"));
        }
        Err(e) => {
            error!("Unexpected validation error: {e}");
            return unauthorized("Token validation failed");
        }
    };

    // Check required scopes.
    if let Some((_, required_scopes)) = state.auth_config.first_scheme() {
        if !required_scopes.is_empty() {
            let token_scopes: HashSet<&str> = claims["scope"]
                .as_str()
                .unwrap_or("")
                .split_whitespace()
                .collect();

            let missing: Vec<&str> = required_scopes
                .iter()
                .filter(|s| !token_scopes.contains(s.as_str()))
                .map(|s| s.as_str())
                .collect();

            if !missing.is_empty() {
                error!("Missing required scopes: {missing:?}");
                return forbidden(&format!("Missing required scopes: {missing:?}"));
            }
        }
    }

    next.run(request).await
}