use std::{net::IpAddr, sync::Arc};
use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, StatusCode},
};
use ipnet::IpNet;
use subtle::ConstantTimeEq;
use tracing::warn;
use crate::{config::ApiKeyConfig, AppState};
#[derive(Debug, Clone)]
pub struct AuthContext {
pub key_id: String,
pub client_ip: IpAddr,
pub key_rate_limit_per_min: Option<u32>,
pub key_burst: u32,
}
impl<S> FromRequestParts<S> for AuthContext
where
Arc<AppState>: axum::extract::FromRef<S>,
S: Send + Sync,
{
type Rejection = (StatusCode, axum::Json<serde_json::Value>);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let app_state = Arc::<AppState>::from_ref(state);
let cfg = app_state.config();
let security = &cfg.security;
let peer_ip = resolve_peer_ip(parts);
let client_ip = if security.trust_proxy_headers {
resolve_client_ip(parts, peer_ip, &security.trusted_source_cidrs)
} else {
peer_ip
};
if !security.allowed_source_cidrs.is_empty()
&& !ip_in_cidrs(client_ip, &security.allowed_source_cidrs)
{
warn!(
client_ip = %client_ip,
"auth: client IP not in allowed_source_cidrs"
);
return Err(forbidden());
}
let token = match extract_token(parts) {
Some(t) => t,
None => {
warn!(client_ip = %client_ip, "auth: missing or malformed token");
app_state.metrics.inc_auth_failure("missing_token");
app_state.metrics.inc_request("4xx");
return Err(unauthorized());
}
};
match find_matching_key(&security.api_keys, token) {
MatchResult::Matched(key_id, key_rate_limit_per_min, key_burst) => Ok(AuthContext { key_id, client_ip, key_rate_limit_per_min, key_burst }),
MatchResult::Disabled(key_id) => {
warn!(client_ip = %client_ip, key_id = %key_id, "auth: key is disabled");
app_state.metrics.inc_auth_failure("disabled_key");
app_state.metrics.inc_request("4xx");
Err(forbidden())
}
MatchResult::NotFound => {
warn!(client_ip = %client_ip, "auth: token not matched");
app_state.metrics.inc_auth_failure("invalid_token");
app_state.metrics.inc_request("4xx");
Err(forbidden())
}
}
}
}
fn extract_token(parts: &Parts) -> Option<&str> {
if let Some(auth) = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
return auth.strip_prefix("Bearer ");
}
parts
.headers
.get("x-api-key")
.and_then(|v| v.to_str().ok())
}
enum MatchResult {
Matched(String, Option<u32>, u32),
Disabled(String),
NotFound,
}
fn find_matching_key(keys: &[ApiKeyConfig], token: &str) -> MatchResult {
let token_bytes = token.as_bytes();
let mut matched_key: Option<&ApiKeyConfig> = None;
for key in keys {
let secret_bytes = key.secret.expose().as_bytes();
if ct_eq_bytes(token_bytes, secret_bytes) {
matched_key = Some(key);
}
}
match matched_key {
Some(k) if k.enabled => MatchResult::Matched(k.id.clone(), k.rate_limit_per_min, k.burst),
Some(k) => MatchResult::Disabled(k.id.clone()),
None => MatchResult::NotFound,
}
}
fn ct_eq_bytes(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
fn resolve_peer_ip(parts: &Parts) -> IpAddr {
parts
.extensions
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.0.ip())
.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST))
}
fn resolve_client_ip(parts: &Parts, peer_ip: IpAddr, trusted_cidrs: &[String]) -> IpAddr {
if !ip_in_cidrs(peer_ip, trusted_cidrs) {
return peer_ip;
}
if let Some(ip) = parse_forwarded_for(&parts.headers) {
return ip;
}
parts
.headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.and_then(|s| s.trim().parse::<IpAddr>().ok())
.unwrap_or(peer_ip)
}
fn parse_forwarded_for(headers: &axum::http::HeaderMap) -> Option<IpAddr> {
let value = headers.get("forwarded")?.to_str().ok()?;
let first_item = value.split(',').next()?;
for part in first_item.split(';') {
let part = part.trim();
let lower = part.to_ascii_lowercase();
if let Some(addr_part) = lower.strip_prefix("for=") {
let addr_str = part[part.len() - addr_part.len()..]
.trim_matches('"')
.trim_matches('[')
.trim_matches(']');
if let Ok(ip) = addr_str.parse::<IpAddr>() {
return Some(ip);
}
}
}
None
}
fn ip_in_cidrs(ip: IpAddr, cidrs: &[String]) -> bool {
cidrs
.iter()
.filter_map(|s| s.parse::<IpNet>().ok())
.any(|net| net.contains(&ip))
}
fn unauthorized() -> (StatusCode, axum::Json<serde_json::Value>) {
(
StatusCode::UNAUTHORIZED,
axum::Json(serde_json::json!({
"status": "error",
"code": "unauthorized",
"message": "Authentication required"
})),
)
}
fn forbidden() -> (StatusCode, axum::Json<serde_json::Value>) {
(
StatusCode::FORBIDDEN,
axum::Json(serde_json::json!({
"status": "error",
"code": "forbidden",
"message": "Access denied"
})),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ApiKeyConfig, SecretString};
fn make_key(id: &str, secret: &str, enabled: bool) -> ApiKeyConfig {
ApiKeyConfig {
id: id.to_string(),
secret: SecretString::new(secret),
enabled,
description: None,
allowed_recipient_domains: vec![],
rate_limit_per_min: None,
allowed_recipients: vec![],
burst: 0,
mask_recipient: None,
}
}
#[test]
fn matching_key_returns_key_id() {
let keys = vec![make_key("svc-a", "secret-a", true)];
match find_matching_key(&keys, "secret-a") {
MatchResult::Matched(id, _, _) => assert_eq!(id, "svc-a"),
_ => panic!("expected Matched"),
}
}
#[test]
fn wrong_token_returns_not_found() {
let keys = vec![make_key("svc-a", "secret-a", true)];
assert!(matches!(
find_matching_key(&keys, "wrong"),
MatchResult::NotFound
));
}
#[test]
fn disabled_key_returns_disabled() {
let keys = vec![make_key("svc-a", "secret-a", false)];
assert!(matches!(
find_matching_key(&keys, "secret-a"),
MatchResult::Disabled(_)
));
}
#[test]
fn multiple_keys_correct_one_matches() {
let keys = vec![
make_key("svc-a", "token-aaa", true),
make_key("svc-b", "token-bbb", true),
];
match find_matching_key(&keys, "token-bbb") {
MatchResult::Matched(id, _, _) => assert_eq!(id, "svc-b"),
_ => panic!("expected Matched for svc-b"),
}
}
#[test]
fn ip_in_cidrs_loopback() {
let cidrs = vec!["127.0.0.1/32".to_string()];
assert!(ip_in_cidrs("127.0.0.1".parse().unwrap(), &cidrs));
assert!(!ip_in_cidrs("10.0.0.1".parse().unwrap(), &cidrs));
}
#[test]
fn ip_in_cidrs_range() {
let cidrs = vec!["10.0.0.0/8".to_string()];
assert!(ip_in_cidrs("10.1.2.3".parse().unwrap(), &cidrs));
assert!(!ip_in_cidrs("192.168.1.1".parse().unwrap(), &cidrs));
}
#[test]
fn empty_cidr_list_returns_false() {
assert!(!ip_in_cidrs("127.0.0.1".parse().unwrap(), &[]));
}
#[test]
fn different_length_tokens_do_not_match() {
assert!(!ct_eq_bytes(b"short", b"longer-token"));
}
}