use std::sync::Arc;
use axum::extract::Request;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::Response;
use crate::security::{self, RateLimiter, constant_time_eq, is_allowed_origin, is_localhost_host};
const BEARER_PREFIX_LEN: usize = "Bearer ".len();
#[derive(Clone)]
pub struct AuthState {
pub token: Option<String>,
}
pub async fn require_auth(
axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(expected) = &auth.token else {
return Ok(next.run(request).await);
};
let provided = request
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| {
let lower = v.to_lowercase();
if lower.starts_with("bearer ") {
Some(v[BEARER_PREFIX_LEN..].to_string())
} else {
None
}
});
match provided {
Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
Ok(next.run(request).await)
}
_ => {
tracing::warn!("Victauri: rejected request — invalid or missing auth token");
Err(StatusCode::UNAUTHORIZED)
}
}
}
#[must_use]
pub fn default_rate_limiter() -> Arc<RateLimiter> {
Arc::new(RateLimiter::new(security::DEFAULT_RATE_LIMIT))
}
pub async fn rate_limit(
axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
request: Request,
next: Next,
) -> Result<
Response,
(
StatusCode,
[(axum::http::HeaderName, axum::http::HeaderValue); 1],
),
> {
if limiter.try_acquire() {
Ok(next.run(request).await)
} else {
Err((
StatusCode::TOO_MANY_REQUESTS,
[(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_static("1"),
)],
))
}
}
pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
let host = request
.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !is_localhost_host(host) {
tracing::warn!("DNS rebinding attempt blocked: Host={host}");
return Err(StatusCode::FORBIDDEN);
}
Ok(next.run(request).await)
}
pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
if let Some(origin) = request
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
&& !is_allowed_origin(origin)
{
tracing::warn!("Cross-origin request blocked: Origin={origin}");
return Err(StatusCode::FORBIDDEN);
}
Ok(next.run(request).await)
}
pub async fn security_headers(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-store"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-frame-options"),
axum::http::HeaderValue::from_static("DENY"),
);
headers.insert(
axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
axum::http::HeaderValue::from_static("null"),
);
headers.insert(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'none'"),
);
response
}