use std::collections::HashSet;
use std::sync::Arc;
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::{
body::Body,
extract::{Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::config::{AuthConfig, AuthMode, Scope};
#[derive(Debug, Clone)]
pub struct Principal {
pub id: String,
pub scopes: HashSet<Scope>,
}
impl Principal {
#[must_use]
pub fn local_admin() -> Self {
Self {
id: "local".to_string(),
scopes: HashSet::from([
Scope::TownRead,
Scope::TownWrite,
Scope::AgentManage,
Scope::Admin,
]),
}
}
#[must_use]
pub fn with_scopes(id: impl Into<String>, scopes: &[Scope]) -> Self {
let scopes = if scopes.is_empty() {
HashSet::from([
Scope::TownRead,
Scope::TownWrite,
Scope::AgentManage,
Scope::Admin,
])
} else {
scopes.iter().copied().collect()
};
Self {
id: id.into(),
scopes,
}
}
#[must_use]
pub fn has_scope(&self, scope: Scope) -> bool {
self.scopes.contains(&scope) || self.scopes.contains(&Scope::Admin)
}
}
#[derive(Clone)]
pub struct AuthState {
pub config: Arc<AuthConfig>,
}
#[derive(Debug)]
pub struct AuthError {
status: StatusCode,
message: &'static str,
}
impl AuthError {
pub const UNAUTHORIZED: Self = Self {
status: StatusCode::UNAUTHORIZED,
message: "Authentication required",
};
pub const FORBIDDEN: Self = Self {
status: StatusCode::FORBIDDEN,
message: "Insufficient permissions",
};
pub const INVALID_CREDENTIALS: Self = Self {
status: StatusCode::UNAUTHORIZED,
message: "Invalid credentials",
};
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let body = serde_json::json!({
"error": self.message,
"status": self.status.as_u16()
});
(self.status, axum::Json(body)).into_response()
}
}
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
if let Some(auth) = headers.get("authorization")
&& let Ok(auth_str) = auth.to_str()
&& let Some(key) = auth_str.strip_prefix("Bearer ")
{
return Some(key.trim().to_string());
}
if let Some(key) = headers.get("x-api-key")
&& let Ok(key_str) = key.to_str()
{
return Some(key_str.trim().to_string());
}
None
}
fn verify_api_key(key: &str, hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(key.as_bytes(), &parsed_hash)
.is_ok()
}
pub async fn auth_middleware(
State(auth_state): State<AuthState>,
mut request: Request<Body>,
next: Next,
) -> Result<Response, AuthError> {
let config = &auth_state.config;
let principal = match config.mode {
AuthMode::None => Principal::local_admin(),
AuthMode::ApiKey => {
let key = extract_api_key(request.headers()).ok_or(AuthError::UNAUTHORIZED)?;
let hash = config
.api_key_hash
.as_ref()
.ok_or(AuthError::UNAUTHORIZED)?;
if !verify_api_key(&key, hash) {
return Err(AuthError::INVALID_CREDENTIALS);
}
Principal::with_scopes("api_key", &config.api_key_scopes)
}
AuthMode::Oidc => return Err(AuthError::UNAUTHORIZED), };
request.extensions_mut().insert(principal);
Ok(next.run(request).await)
}
pub async fn require_scope(
scope: Scope,
request: Request<Body>,
next: Next,
) -> Result<Response, AuthError> {
let principal = request
.extensions()
.get::<Principal>()
.ok_or(AuthError::UNAUTHORIZED)?;
if !principal.has_scope(scope) {
return Err(AuthError::FORBIDDEN);
}
Ok(next.run(request).await)
}
pub mod route_scopes {
pub use crate::config::Scope::{
Admin as ADMIN_OPS, AgentManage as AGENT_MGMT, TownRead as READ_OPS, TownWrite as WRITE_OPS,
};
}
#[must_use]
pub fn generate_api_key() -> (String, String) {
use argon2::{PasswordHasher, password_hash::SaltString};
let raw_key = format!(
"{}{}",
uuid::Uuid::new_v4().simple(),
uuid::Uuid::new_v4().simple()
);
let salt_uuid = uuid::Uuid::new_v4();
let salt = SaltString::encode_b64(salt_uuid.as_bytes()).expect("valid salt");
let hash = Argon2::default()
.hash_password(raw_key.as_bytes(), &salt)
.expect("failed to hash password")
.to_string();
(raw_key, hash)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_api_key_bearer() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer test-key-123".parse().unwrap());
assert_eq!(extract_api_key(&headers), Some("test-key-123".to_string()));
}
#[test]
fn test_extract_api_key_x_api_key() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "test-key-456".parse().unwrap());
assert_eq!(extract_api_key(&headers), Some("test-key-456".to_string()));
}
#[test]
fn test_extract_api_key_bearer_priority() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer bearer-key".parse().unwrap());
headers.insert("x-api-key", "x-api-key".parse().unwrap());
assert_eq!(extract_api_key(&headers), Some("bearer-key".to_string()));
}
#[test]
fn test_extract_api_key_none() {
let headers = HeaderMap::new();
assert_eq!(extract_api_key(&headers), None);
}
#[test]
fn test_generate_and_verify_api_key() {
let (raw_key, hash) = generate_api_key();
assert!(verify_api_key(&raw_key, &hash));
assert!(!verify_api_key("wrong-key", &hash));
}
#[test]
fn test_principal_has_scope() {
let admin = Principal::local_admin();
assert!(admin.has_scope(Scope::TownRead));
assert!(admin.has_scope(Scope::TownWrite));
assert!(admin.has_scope(Scope::AgentManage));
assert!(admin.has_scope(Scope::Admin));
}
#[test]
fn test_principal_admin_grants_all() {
let mut scopes = HashSet::new();
scopes.insert(Scope::Admin);
let admin = Principal {
id: "admin".to_string(),
scopes,
};
assert!(admin.has_scope(Scope::TownRead));
assert!(admin.has_scope(Scope::TownWrite));
assert!(admin.has_scope(Scope::AgentManage));
}
}