use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use axum::http::HeaderMap;
use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{error, info};
use crate::config::auth::SecurityScheme;
use crate::error::{AgentKitError, Result};
use super::strategy::AuthStrategy;
struct TokenCache {
access_token: String,
expires_at: u64,
}
pub struct CognitoM2MCredentialService {
client_id: String,
client_secret: String,
token_url: String,
scope: Option<String>,
cache: Mutex<Option<TokenCache>>,
http: reqwest::Client,
}
impl CognitoM2MCredentialService {
pub fn from_env() -> Self {
let domain = std::env::var("COGNITO_USER_POOL_DOMAIN")
.expect("COGNITO_USER_POOL_DOMAIN must be set");
Self {
client_id: std::env::var("COGNITO_CLIENT_ID").unwrap_or_default(),
client_secret: std::env::var("COGNITO_CLIENT_SECRET").unwrap_or_default(),
token_url: format!("https://{domain}/oauth2/token"),
scope: std::env::var("COGNITO_SCOPE").ok(),
cache: Mutex::new(None),
http: reqwest::Client::new(),
}
}
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
user_pool_domain: impl Into<String>,
scope: Option<String>,
) -> Self {
let domain: String = user_pool_domain.into();
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
token_url: format!("https://{domain}/oauth2/token"),
scope,
cache: Mutex::new(None),
http: reqwest::Client::new(),
}
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_secs()
}
async fn fetch_token(&self) -> Result<String> {
let mut form = vec![
("grant_type", "client_credentials"),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
];
let scope_ref;
if let Some(s) = &self.scope {
scope_ref = s.as_str();
form.push(("scope", scope_ref));
}
let resp = self
.http
.post(&self.token_url)
.form(&form)
.send()
.await?
.error_for_status()?
.json::<serde_json::Value>()
.await?;
let token = resp["access_token"]
.as_str()
.ok_or_else(|| AgentKitError::Auth("Missing access_token in response".into()))?
.to_string();
let expires_in = resp["expires_in"].as_u64().unwrap_or(3600);
let expires_at = Self::now_secs() + expires_in - 60;
let mut cache = self.cache.lock().await;
*cache = Some(TokenCache { access_token: token.clone(), expires_at });
Ok(token)
}
pub async fn get_credentials(&self) -> Result<String> {
let cache = self.cache.lock().await;
if let Some(ref c) = *cache {
if Self::now_secs() < c.expires_at {
return Ok(c.access_token.clone());
}
}
drop(cache);
info!("Fetching new Cognito access token");
self.fetch_token().await
}
}
#[async_trait]
impl AuthStrategy for CognitoM2MCredentialService {
async fn get_keys(&self, scheme: &SecurityScheme) -> Result<serde_json::Value> {
let jwks_url = scheme
.description
.as_deref()
.ok_or_else(|| AgentKitError::JwksFetch("Missing JWKS URL in security scheme description".into()))?;
let jwks = self
.http
.get(jwks_url)
.send()
.await
.map_err(|e| AgentKitError::JwksFetch(e.to_string()))?
.error_for_status()
.map_err(|e| AgentKitError::JwksFetch(e.to_string()))?
.json::<serde_json::Value>()
.await
.map_err(|e| AgentKitError::JwksFetch(e.to_string()))?;
Ok(jwks)
}
fn get_token(&self, headers: &HeaderMap) -> Result<String> {
let auth = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or(AgentKitError::InvalidAuthHeader)?;
if !auth.starts_with("Bearer ") {
return Err(AgentKitError::InvalidAuthHeader);
}
Ok(auth["Bearer ".len()..].to_string())
}
fn validate_token(&self, token: &str, keys: &Value) -> Result<serde_json::Value> {
let header = decode_header(token)
.map_err(|e| AgentKitError::JwtValidation(e.to_string()))?;
let kid = header.kid.as_deref().unwrap_or("");
let key_arr = keys["keys"]
.as_array()
.ok_or_else(|| AgentKitError::JwtValidation("Invalid JWKS: missing 'keys' array".into()))?;
let jwk = key_arr
.iter()
.find(|k| k["kid"].as_str() == Some(kid))
.ok_or_else(|| AgentKitError::JwtValidation(format!("Key ID '{kid}' not found in JWKS")))?;
let decoding_key = DecodingKey::from_rsa_components(
jwk["n"].as_str().unwrap_or(""),
jwk["e"].as_str().unwrap_or(""),
)
.map_err(|e| AgentKitError::JwtValidation(e.to_string()))?;
let mut validation = Validation::new(header.alg);
validation.validate_aud = false;
let data = decode::<serde_json::Value>(token, &decoding_key, &validation)
.map_err(|e| {
error!("JWT decode error: {e}");
AgentKitError::JwtValidation(e.to_string())
})?;
Ok(data.claims)
}
}