use std::collections::HashSet;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde_json::json;
use tracing::error;
use crate::a2a::types::AgentCard;
use crate::config::auth::{AuthConfig, AuthType};
use crate::error::AgentKitError;
use super::cognito::CognitoM2MCredentialService;
use super::noauth::NoAuthCredentialService;
use super::strategy::AuthStrategy;
pub struct AuthMiddlewareState {
pub agent_card: AgentCard,
pub auth_config: AuthConfig,
pub public_paths: HashSet<String>,
pub strategy: Arc<dyn AuthStrategy>,
}
impl AuthMiddlewareState {
pub fn new(
agent_card: AgentCard,
auth_config: AuthConfig,
public_paths: Vec<String>,
) -> Self {
let strategy: Arc<dyn AuthStrategy> = match auth_config.auth_type {
AuthType::Cognito => Arc::new(CognitoM2MCredentialService::from_env()),
AuthType::NoAuth => Arc::new(NoAuthCredentialService),
};
Self {
agent_card,
auth_config,
public_paths: public_paths.into_iter().collect(),
strategy,
}
}
}
fn unauthorized(reason: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
Json(json!({"error": "unauthorized", "reason": reason})),
)
.into_response()
}
fn forbidden(reason: &str) -> Response {
(
StatusCode::FORBIDDEN,
Json(json!({"error": "forbidden", "reason": reason})),
)
.into_response()
}
fn service_unavailable(reason: &str) -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "service_unavailable", "reason": reason})),
)
.into_response()
}
pub async fn auth_middleware(
State(state): State<Arc<AuthMiddlewareState>>,
request: Request<Body>,
next: Next,
) -> Response {
let path = request.uri().path().to_string();
if state.public_paths.contains(&path) || state.auth_config.is_no_auth() {
return next.run(request).await;
}
let headers = request.headers().clone();
let scheme = match state.agent_card.security_schemes.as_ref().and_then(|ss| {
state
.auth_config
.first_scheme()
.and_then(|(name, _)| ss.get(&name).cloned())
}) {
Some(v) => {
match serde_json::from_value::<crate::config::auth::SecurityScheme>(v) {
Ok(s) => s,
Err(e) => {
error!("Failed to parse security scheme: {e}");
return unauthorized("Invalid security scheme configuration");
}
}
}
None => return unauthorized("No security scheme configured"),
};
let keys = match state.strategy.get_keys(&scheme).await {
Ok(k) => k,
Err(AgentKitError::JwksFetch(msg)) => {
error!("JWKS fetch error: {msg}");
return service_unavailable("Unable to fetch JWKS");
}
Err(e) => {
error!("Auth setup error: {e}");
return unauthorized("Auth configuration error");
}
};
let token = match state.strategy.get_token(&headers) {
Ok(t) => t,
Err(AgentKitError::InvalidAuthHeader) => return unauthorized("Missing or malformed Authorization header"),
Err(e) => {
error!("Token extraction error: {e}");
return unauthorized("Token extraction failed");
}
};
let claims = match state.strategy.validate_token(&token, &keys) {
Ok(c) => c,
Err(AgentKitError::JwtValidation(msg)) => {
error!("JWT validation failed: {msg}");
return unauthorized(&format!("Invalid JWT: {msg}"));
}
Err(e) => {
error!("Unexpected validation error: {e}");
return unauthorized("Token validation failed");
}
};
if let Some((_, required_scopes)) = state.auth_config.first_scheme() {
if !required_scopes.is_empty() {
let token_scopes: HashSet<&str> = claims["scope"]
.as_str()
.unwrap_or("")
.split_whitespace()
.collect();
let missing: Vec<&str> = required_scopes
.iter()
.filter(|s| !token_scopes.contains(s.as_str()))
.map(|s| s.as_str())
.collect();
if !missing.is_empty() {
error!("Missing required scopes: {missing:?}");
return forbidden(&format!("Missing required scopes: {missing:?}"));
}
}
}
next.run(request).await
}