velocia 0.3.0

velocia – production-ready AI agent framework using ADK-Rust, A2A protocol, and AWS DynamoDB
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use async_trait::async_trait;
use axum::http::HeaderMap;
use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{error, info};

use crate::config::auth::SecurityScheme;
use crate::error::{AgentKitError, Result};

use super::strategy::AuthStrategy;

struct TokenCache {
    access_token: String,
    expires_at: u64,
}

/// AWS Cognito M2M credential service using the OAuth2 client-credentials flow.
///
/// Tokens are cached until expiration and refreshed automatically.
pub struct CognitoM2MCredentialService {
    client_id: String,
    client_secret: String,
    token_url: String,
    scope: Option<String>,
    cache: Mutex<Option<TokenCache>>,
    http: reqwest::Client,
}

impl CognitoM2MCredentialService {
    pub fn from_env() -> Self {
        let domain = std::env::var("COGNITO_USER_POOL_DOMAIN")
            .expect("COGNITO_USER_POOL_DOMAIN must be set");
        Self {
            client_id: std::env::var("COGNITO_CLIENT_ID").unwrap_or_default(),
            client_secret: std::env::var("COGNITO_CLIENT_SECRET").unwrap_or_default(),
            token_url: format!("https://{domain}/oauth2/token"),
            scope: std::env::var("COGNITO_SCOPE").ok(),
            cache: Mutex::new(None),
            http: reqwest::Client::new(),
        }
    }

    pub fn new(
        client_id: impl Into<String>,
        client_secret: impl Into<String>,
        user_pool_domain: impl Into<String>,
        scope: Option<String>,
    ) -> Self {
        let domain: String = user_pool_domain.into();
        Self {
            client_id: client_id.into(),
            client_secret: client_secret.into(),
            token_url: format!("https://{domain}/oauth2/token"),
            scope,
            cache: Mutex::new(None),
            http: reqwest::Client::new(),
        }
    }

    fn now_secs() -> u64 {
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or(Duration::ZERO)
            .as_secs()
    }

    async fn fetch_token(&self) -> Result<String> {
        let mut form = vec![
            ("grant_type", "client_credentials"),
            ("client_id", &self.client_id),
            ("client_secret", &self.client_secret),
        ];
        let scope_ref;
        if let Some(s) = &self.scope {
            scope_ref = s.as_str();
            form.push(("scope", scope_ref));
        }

        let resp = self
            .http
            .post(&self.token_url)
            .form(&form)
            .send()
            .await?
            .error_for_status()?
            .json::<serde_json::Value>()
            .await?;

        let token = resp["access_token"]
            .as_str()
            .ok_or_else(|| AgentKitError::Auth("Missing access_token in response".into()))?
            .to_string();
        let expires_in = resp["expires_in"].as_u64().unwrap_or(3600);
        let expires_at = Self::now_secs() + expires_in - 60; // 60 s buffer

        let mut cache = self.cache.lock().await;
        *cache = Some(TokenCache { access_token: token.clone(), expires_at });

        Ok(token)
    }

    /// Return a valid access token, refreshing if necessary.
    pub async fn get_credentials(&self) -> Result<String> {
        let cache = self.cache.lock().await;
        if let Some(ref c) = *cache {
            if Self::now_secs() < c.expires_at {
                return Ok(c.access_token.clone());
            }
        }
        drop(cache);
        info!("Fetching new Cognito access token");
        self.fetch_token().await
    }
}

#[async_trait]
impl AuthStrategy for CognitoM2MCredentialService {
    async fn get_keys(&self, scheme: &SecurityScheme) -> Result<serde_json::Value> {
        let jwks_url = scheme
            .description
            .as_deref()
            .ok_or_else(|| AgentKitError::JwksFetch("Missing JWKS URL in security scheme description".into()))?;

        let jwks = self
            .http
            .get(jwks_url)
            .send()
            .await
            .map_err(|e| AgentKitError::JwksFetch(e.to_string()))?
            .error_for_status()
            .map_err(|e| AgentKitError::JwksFetch(e.to_string()))?
            .json::<serde_json::Value>()
            .await
            .map_err(|e| AgentKitError::JwksFetch(e.to_string()))?;

        Ok(jwks)
    }

    fn get_token(&self, headers: &HeaderMap) -> Result<String> {
        let auth = headers
            .get(axum::http::header::AUTHORIZATION)
            .and_then(|v| v.to_str().ok())
            .ok_or(AgentKitError::InvalidAuthHeader)?;

        if !auth.starts_with("Bearer ") {
            return Err(AgentKitError::InvalidAuthHeader);
        }
        Ok(auth["Bearer ".len()..].to_string())
    }

    fn validate_token(&self, token: &str, keys: &Value) -> Result<serde_json::Value> {
        let header = decode_header(token)
            .map_err(|e| AgentKitError::JwtValidation(e.to_string()))?;

        let kid = header.kid.as_deref().unwrap_or("");
        let key_arr = keys["keys"]
            .as_array()
            .ok_or_else(|| AgentKitError::JwtValidation("Invalid JWKS: missing 'keys' array".into()))?;

        let jwk = key_arr
            .iter()
            .find(|k| k["kid"].as_str() == Some(kid))
            .ok_or_else(|| AgentKitError::JwtValidation(format!("Key ID '{kid}' not found in JWKS")))?;

        let decoding_key = DecodingKey::from_rsa_components(
            jwk["n"].as_str().unwrap_or(""),
            jwk["e"].as_str().unwrap_or(""),
        )
        .map_err(|e| AgentKitError::JwtValidation(e.to_string()))?;

        let mut validation = Validation::new(header.alg);
        validation.validate_aud = false; // Cognito M2M tokens may omit `aud`

        let data = decode::<serde_json::Value>(token, &decoding_key, &validation)
            .map_err(|e| {
                error!("JWT decode error: {e}");
                AgentKitError::JwtValidation(e.to_string())
            })?;

        Ok(data.claims)
    }
}