use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::extract::{Query, State};
use axum::http::{Request, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Json, Response};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
const TOKEN_TTL: Duration = Duration::from_secs(24 * 60 * 60);
const MAX_LOGIN_ATTEMPTS: u32 = 5;
const LOGIN_WINDOW: Duration = Duration::from_secs(60);
pub struct AuthState {
pub username: String,
pub password: String,
pub tokens: RwLock<HashMap<String, Instant>>,
login_attempts: RwLock<Vec<Instant>>,
}
impl AuthState {
pub fn new(username: String, password: String) -> Self {
Self {
username,
password,
tokens: RwLock::new(HashMap::new()),
login_attempts: RwLock::new(Vec::new()),
}
}
pub async fn issue_token(&self) -> String {
let bytes: [u8; 32] = rand::random();
let token =
base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, bytes);
let now = Instant::now();
let mut tokens = self.tokens.write().await;
tokens.retain(|_, issued_at| now.duration_since(*issued_at) < TOKEN_TTL);
tokens.insert(token.clone(), now);
token
}
pub async fn validate_token(&self, token: &str) -> bool {
let tokens = self.tokens.read().await;
match tokens.get(token) {
Some(issued_at) => issued_at.elapsed() < TOKEN_TTL,
None => false,
}
}
async fn check_rate_limit(&self) -> bool {
let now = Instant::now();
let mut attempts = self.login_attempts.write().await;
attempts.retain(|t| now.duration_since(*t) < LOGIN_WINDOW);
if attempts.len() >= MAX_LOGIN_ATTEMPTS as usize {
return false;
}
attempts.push(now);
true
}
}
#[derive(Deserialize)]
pub struct LoginRequest {
username: String,
password: String,
}
#[derive(Serialize)]
pub struct LoginResponse {
token: Option<String>,
error: Option<String>,
}
pub async fn login(
State(auth): State<Arc<AuthState>>,
Json(req): Json<LoginRequest>,
) -> (StatusCode, Json<LoginResponse>) {
if !auth.check_rate_limit().await {
return (
StatusCode::TOO_MANY_REQUESTS,
Json(LoginResponse {
token: None,
error: Some("Too many login attempts. Try again later.".into()),
}),
);
}
if constant_time_eq(req.username.as_bytes(), auth.username.as_bytes())
&& constant_time_eq(req.password.as_bytes(), auth.password.as_bytes())
{
let token = auth.issue_token().await;
(
StatusCode::OK,
Json(LoginResponse {
token: Some(token),
error: None,
}),
)
} else {
(
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
token: None,
error: Some("Invalid password".into()),
}),
)
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
#[derive(Clone, Serialize)]
pub struct AuthStatusResponse {
pub auth_required: bool,
}
#[derive(Deserialize)]
pub struct TokenQuery {
pub token: Option<String>,
}
pub async fn require_auth(
State(auth): State<Arc<AuthState>>,
Query(query): Query<TokenQuery>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let path = request.uri().path();
if path == "/api/health" || path == "/api/auth/status" || path == "/api/auth/login" {
return next.run(request).await;
}
let token = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|t| t.to_string())
.or(query.token);
match token {
Some(t) if auth.validate_token(&t).await => next.run(request).await,
_ => (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
token: None,
error: Some("Invalid or missing token".into()),
}),
)
.into_response(),
}
}