use parking_lot::Mutex;
use std::sync::Arc;
use std::time::Instant;
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use crate::server::AppState;
#[derive(Debug)]
pub struct RateLimiter {
state: Arc<Mutex<RateLimiterState>>,
max_tokens: f64,
refill_rate: f64,
}
#[derive(Debug)]
struct RateLimiterState {
tokens: f64,
last_refill: Instant,
}
impl RateLimiter {
pub fn new(max_requests_per_minute: u32) -> Self {
let max_tokens = max_requests_per_minute as f64;
Self {
state: Arc::new(Mutex::new(RateLimiterState {
tokens: max_tokens,
last_refill: Instant::now(),
})),
max_tokens,
refill_rate: max_tokens / 60.0,
}
}
pub fn try_acquire(&self) -> bool {
let mut state = self.state.lock();
let now = Instant::now();
let elapsed = (now - state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed * self.refill_rate).min(self.max_tokens);
state.last_refill = now;
if state.tokens >= 1.0 {
state.tokens -= 1.0;
true
} else {
false
}
}
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
max_tokens: self.max_tokens,
refill_rate: self.refill_rate,
}
}
}
pub async fn rate_limit_layer(
State(limiter): State<RateLimiter>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if limiter.try_acquire() {
Ok(next.run(request).await)
} else {
Err(StatusCode::TOO_MANY_REQUESTS)
}
}
pub async fn require_auth(
State(state): State<Arc<AppState>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if !state.config.read().security.auth_enabled {
return Ok(next.run(request).await);
}
let path = request.uri().path();
if path == "/health" {
return Ok(next.run(request).await);
}
let static_prefixes = ["/assets/", "/dioxus/", "/favicon"];
let is_static =
static_prefixes.iter().any(|p| path.starts_with(p)) || path == "/" || path == "/index.html";
if is_static {
return Ok(next.run(request).await);
}
let auth_header = request
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or(StatusCode::UNAUTHORIZED)?;
let env_key = std::env::var("OXIOS_API_KEY")
.ok()
.filter(|k| !k.is_empty());
let config_key = state.config.read().api_key();
let is_valid = {
let key_valid = state.kernel.security.validate_token(token);
let env_valid = env_key.as_deref().map(|k| k == token).unwrap_or(false);
let config_valid = config_key.as_deref().map(|k| k == token).unwrap_or(false);
key_valid || env_valid || config_valid
}; if !is_valid {
tracing::warn!(path = %request.uri().path(), "Authentication failed");
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}