use crate::auth::AuthMethod;
use crate::core::models::{ApiKey, user::types::User};
use crate::core::types::context::RequestContext;
use crate::server::middleware::auth_rate_limiter::get_auth_rate_limiter;
use crate::server::middleware::helpers::{
extract_auth_method_with_api_key_header, is_public_route,
};
use crate::server::state::AppState;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
use actix_web::http::header::{HeaderMap, HeaderName};
use actix_web::{HttpMessage, HttpRequest, web};
use futures::future::{Ready, ready};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use tracing::{debug, warn};
pub struct AuthMiddleware;
impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type InitError = ();
type Transform = AuthMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(AuthMiddlewareService {
service: Rc::new(service),
}))
}
}
pub struct AuthMiddlewareService<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for AuthMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let service = Rc::clone(&self.service);
Box::pin(async move {
let is_public = is_public_route(req.path());
let app_state = match req.app_data::<web::Data<AppState>>().cloned() {
Some(state) => state,
None => {
return Err(actix_web::error::ErrorInternalServerError(
"Missing application state",
));
}
};
let cfg = app_state.config.load();
let enable_jwt = cfg.auth().enable_jwt;
let enable_api_key = cfg.auth().enable_api_key;
let api_key_header = cfg.auth().api_key_header.clone();
let context = build_request_context(&mut req);
let auth_method =
extract_auth_method_with_api_key_header(req.headers(), api_key_header.as_str());
let client_id = get_client_identifier(&req, api_key_header.as_str());
let rate_limiter = get_auth_rate_limiter();
if is_public {
req.extensions_mut().insert(context);
return service.call(req).await;
}
let auth_enabled = enable_jwt || enable_api_key;
if !auth_enabled {
req.extensions_mut().insert(context);
return service.call(req).await;
}
if let Err(wait_seconds) = rate_limiter.check_allowed(&client_id) {
return Err(actix_web::error::ErrorTooManyRequests(format!(
"Too many failed attempts. Try again in {} seconds",
wait_seconds
)));
}
let auth_method = match auth_method {
AuthMethod::Jwt(_) if !enable_jwt => {
rate_limiter.record_failure(&client_id);
return Err(actix_web::error::ErrorUnauthorized(
"JWT authentication disabled",
));
}
AuthMethod::ApiKey(_) if !enable_api_key => {
rate_limiter.record_failure(&client_id);
return Err(actix_web::error::ErrorUnauthorized(
"API key authentication disabled",
));
}
other => other,
};
if matches!(auth_method, AuthMethod::None) {
rate_limiter.record_failure(&client_id);
return Err(actix_web::error::ErrorUnauthorized(
"Missing authentication",
));
}
match app_state.auth.authenticate(auth_method, context).await {
Ok(result) if result.success => {
rate_limiter.record_success(&client_id);
debug!("Authentication succeeded");
req.extensions_mut().insert(result.context.clone());
if let Some(user) = result.user {
req.extensions_mut().insert::<User>(user);
}
if let Some(api_key) = result.api_key {
req.extensions_mut().insert::<ApiKey>(api_key);
}
service.call(req).await
}
Ok(result) => {
rate_limiter.record_failure(&client_id);
warn!(
"Authentication failed: {}",
result
.error
.clone()
.unwrap_or_else(|| "unauthorized".to_string())
);
Err(actix_web::error::ErrorUnauthorized(
result.error.unwrap_or_else(|| "Unauthorized".to_string()),
))
}
Err(err) => {
rate_limiter.record_failure(&client_id);
Err(actix_web::error::ErrorInternalServerError(format!(
"Authentication error: {}",
err
)))
}
}
})
}
}
pub fn get_request_context(req: &HttpRequest) -> Result<RequestContext, actix_web::Error> {
req.extensions()
.get::<RequestContext>()
.cloned()
.ok_or_else(|| actix_web::error::ErrorInternalServerError("Missing request context"))
}
fn get_client_identifier(req: &ServiceRequest, api_key_header: &str) -> String {
let ip = req
.connection_info()
.peer_addr()
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string());
let api_key = get_header_value(req.headers(), api_key_header)
.or_else(|| {
if api_key_header.eq_ignore_ascii_case("x-api-key") {
None
} else {
get_header_value(req.headers(), "x-api-key")
}
})
.or_else(|| get_header_value(req.headers(), "authorization"));
if let Some(api_key) = api_key {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(api_key.as_bytes());
format!("{}:{:x}", ip, hash)
} else {
format!("ip:{}", ip)
}
}
fn get_header_value(headers: &HeaderMap, header_name: &str) -> Option<String> {
let header_name = HeaderName::try_from(header_name.trim()).ok()?;
headers
.get(&header_name)
.and_then(|value| value.to_str().ok())
.map(str::to_string)
}
fn build_request_context(req: &mut ServiceRequest) -> RequestContext {
let mut context = RequestContext::new();
if let Some(id) = req
.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.filter(|s| !s.is_empty())
{
context.request_id = id.to_string();
}
context.user_agent = req
.headers()
.get("user-agent")
.and_then(|value| value.to_str().ok())
.map(str::to_string);
context.client_ip = req.connection_info().peer_addr().map(|ip| ip.to_string());
let mut headers = HashMap::new();
for (name, value) in req.headers().iter() {
if name.as_str().eq_ignore_ascii_case("authorization")
|| name.as_str().eq_ignore_ascii_case("x-api-key")
{
continue;
}
if let Ok(value) = value.to_str() {
headers.insert(name.as_str().to_string(), value.to_string());
}
}
context.headers = headers;
context
}