use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::Engine;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use agentix::server::AuthedUser;
use crate::tokens::TokenRegistry;
pub fn token_auth_layer(
registry: TokenRegistry,
) -> impl Clone
+ Send
+ Sync
+ 'static
+ Fn(
Request,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
move |mut req: Request, next: Next| {
let registry = registry.clone();
Box::pin(async move {
let token = extract_proxy_token(req.headers());
let entry = token.as_deref().and_then(|t| registry.lookup(t));
let Some(entry) = entry else {
return unauthorized_json("missing or unknown API key");
};
req.extensions_mut().insert(AuthedUser {
token: token.unwrap_or_default(),
user: entry.user.clone(),
});
next.run(req).await
})
}
}
fn extract_proxy_token(headers: &axum::http::HeaderMap) -> Option<String> {
if let Some(auth) = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
let trimmed = auth.trim();
let bare = trimmed
.strip_prefix("Bearer ")
.or_else(|| trimmed.strip_prefix("bearer "))
.unwrap_or(trimmed);
if !bare.is_empty() {
return Some(bare.to_string());
}
}
let xapi = headers.get("x-api-key").and_then(|v| v.to_str().ok())?;
let t = xapi.trim();
if t.is_empty() { None } else { Some(t.to_string()) }
}
fn unauthorized_json(message: &str) -> Response {
let body = json!({
"error": {
"message": message,
"type": "authentication_error",
"param": serde_json::Value::Null,
"code": serde_json::Value::Null,
}
});
(StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
}
pub fn admin_basic_auth_layer(
admin_password: String,
) -> impl Clone
+ Send
+ Sync
+ 'static
+ Fn(
Request,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
move |req: Request, next: Next| {
let admin_password = admin_password.clone();
Box::pin(async move {
let provided = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().strip_prefix("Basic "))
.and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok())
.and_then(|bytes| String::from_utf8(bytes).ok());
let ok = match provided {
Some(creds) => {
let (_user, pwd) = creds.split_once(':').unwrap_or(("", &creds));
pwd == admin_password
}
None => false,
};
if !ok {
let mut resp = (
StatusCode::UNAUTHORIZED,
"agentix admin: authentication required",
)
.into_response();
resp.headers_mut().insert(
header::WWW_AUTHENTICATE,
HeaderValue::from_static("Basic realm=\"agentix admin\""),
);
return resp;
}
next.run(req).await
})
}
}
pub fn quota_layer(
registry: TokenRegistry,
usage_log: PathBuf,
) -> impl Clone
+ Send
+ Sync
+ 'static
+ Fn(
Request,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
let registry = registry;
let usage_log = Arc::new(usage_log);
move |req: Request, next: Next| {
let registry = registry.clone();
let usage_log = usage_log.clone();
Box::pin(async move {
let Some(authed) = req.extensions().get::<AuthedUser>().cloned() else {
return next.run(req).await;
};
let budget = registry
.lookup(&authed.token)
.and_then(|e| e.monthly_token_budget);
let Some(budget) = budget else {
return next.run(req).await;
};
let used = match crate::aggregate::user_month_token_total(
usage_log.as_ref(),
&authed.user,
) {
Ok(n) => n,
Err(_) => {
return next.run(req).await;
}
};
if used >= budget {
let body = json!({
"error": {
"message": format!(
"monthly token budget exhausted: used {used} of {budget}",
),
"type": "rate_limit_error",
"param": serde_json::Value::Null,
"code": "monthly_budget_exhausted",
}
});
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
next.run(req).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn extracts_bearer_token() {
let mut h = axum::http::HeaderMap::new();
h.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer sk-x"),
);
assert_eq!(extract_proxy_token(&h).as_deref(), Some("sk-x"));
}
#[test]
fn extracts_x_api_key() {
let mut h = axum::http::HeaderMap::new();
h.insert("x-api-key", HeaderValue::from_static("sk-x"));
assert_eq!(extract_proxy_token(&h).as_deref(), Some("sk-x"));
}
#[test]
fn empty_returns_none() {
let h = axum::http::HeaderMap::new();
assert!(extract_proxy_token(&h).is_none());
}
}