use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use turul_a2a::middleware::{
A2aMiddleware, AuthFailureKind, AuthIdentity, MiddlewareError, RequestContext,
SecurityContribution,
};
#[async_trait]
pub trait ApiKeyLookup: Send + Sync {
async fn lookup(&self, key: &str) -> Option<String>;
}
pub struct StaticApiKeyLookup {
keys: HashMap<String, String>, }
impl StaticApiKeyLookup {
pub fn new(keys: HashMap<String, String>) -> Self {
Self { keys }
}
}
#[async_trait]
impl ApiKeyLookup for StaticApiKeyLookup {
async fn lookup(&self, key: &str) -> Option<String> {
self.keys.get(key).cloned()
}
}
pub struct RedactedApiKeyLookup {
keys: HashMap<String, String>,
}
impl RedactedApiKeyLookup {
pub fn new(keys: HashMap<String, String>) -> Self {
Self { keys }
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
}
impl fmt::Debug for RedactedApiKeyLookup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RedactedApiKeyLookup")
.field("len", &self.keys.len())
.finish()
}
}
#[async_trait]
impl ApiKeyLookup for RedactedApiKeyLookup {
async fn lookup(&self, key: &str) -> Option<String> {
self.keys.get(key).cloned()
}
}
pub struct ApiKeyMiddleware {
lookup: Arc<dyn ApiKeyLookup>,
header_name: String,
}
impl ApiKeyMiddleware {
pub fn new(lookup: Arc<dyn ApiKeyLookup>, header_name: impl Into<String>) -> Self {
Self {
lookup,
header_name: header_name.into(),
}
}
}
#[async_trait]
impl A2aMiddleware for ApiKeyMiddleware {
async fn before_request(&self, ctx: &mut RequestContext) -> Result<(), MiddlewareError> {
let key = ctx
.headers
.get(&self.header_name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let key = match key {
Some(k) if !k.is_empty() => k,
_ => {
return Err(MiddlewareError::Unauthenticated(
AuthFailureKind::MissingCredential,
));
}
};
let owner = self
.lookup
.lookup(&key)
.await
.ok_or(MiddlewareError::Unauthenticated(
AuthFailureKind::InvalidApiKey,
))?;
if owner.trim().is_empty() {
return Err(MiddlewareError::Unauthenticated(
AuthFailureKind::EmptyPrincipal,
));
}
ctx.identity = AuthIdentity::Authenticated {
owner,
claims: None, };
Ok(())
}
fn security_contribution(&self) -> SecurityContribution {
SecurityContribution::new().with_scheme(
"apiKey",
turul_a2a_proto::SecurityScheme {
scheme: Some(
turul_a2a_proto::security_scheme::Scheme::ApiKeySecurityScheme(
turul_a2a_proto::ApiKeySecurityScheme {
description: String::new(),
location: "header".into(),
name: self.header_name.clone(),
},
),
),
},
vec![],
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_lookup() -> Arc<dyn ApiKeyLookup> {
let mut keys = HashMap::new();
keys.insert("valid-key".to_string(), "user-from-key".to_string());
keys.insert("empty-owner-key".to_string(), "".to_string());
keys.insert("whitespace-key".to_string(), " ".to_string());
Arc::new(StaticApiKeyLookup::new(keys))
}
fn middleware() -> ApiKeyMiddleware {
ApiKeyMiddleware::new(test_lookup(), "X-API-Key")
}
#[tokio::test]
async fn valid_key_sets_authenticated_identity() {
let mw = middleware();
let mut ctx = RequestContext::new();
ctx.headers
.insert("x-api-key", "valid-key".parse().unwrap());
mw.before_request(&mut ctx).await.unwrap();
assert!(ctx.identity.is_authenticated());
assert_eq!(ctx.identity.owner(), "user-from-key");
assert!(ctx.identity.claims().is_none(), "API key has no claims");
}
#[tokio::test]
async fn missing_key_returns_unauthenticated() {
let mw = middleware();
let mut ctx = RequestContext::new();
let err = mw.before_request(&mut ctx).await.unwrap_err();
assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
}
#[tokio::test]
async fn invalid_key_returns_unauthenticated() {
let mw = middleware();
let mut ctx = RequestContext::new();
ctx.headers.insert("x-api-key", "bad-key".parse().unwrap());
let err = mw.before_request(&mut ctx).await.unwrap_err();
assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
}
#[tokio::test]
async fn empty_owner_from_lookup_rejected() {
let mw = middleware();
let mut ctx = RequestContext::new();
ctx.headers
.insert("x-api-key", "empty-owner-key".parse().unwrap());
let err = mw.before_request(&mut ctx).await.unwrap_err();
assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
}
#[tokio::test]
async fn whitespace_owner_from_lookup_rejected() {
let mw = middleware();
let mut ctx = RequestContext::new();
ctx.headers
.insert("x-api-key", "whitespace-key".parse().unwrap());
let err = mw.before_request(&mut ctx).await.unwrap_err();
assert!(matches!(err, MiddlewareError::Unauthenticated(_)));
}
#[test]
fn security_contribution_returns_api_key_scheme() {
let mw = middleware();
let contrib = mw.security_contribution();
assert_eq!(contrib.schemes.len(), 1);
assert_eq!(contrib.schemes[0].0, "apiKey");
assert_eq!(contrib.requirements.len(), 1);
}
}