use axum::{
body::Body,
extract::{FromRequestParts, Request, State},
http::{request::Parts, HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use hyperinfer_core::User;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use utoipa::ToSchema;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AuthClaims {
pub sub: String,
pub email: String,
pub role: String,
pub team_id: String,
pub exp: u64,
}
#[derive(Debug, Deserialize, ToSchema)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct LoginResponse {
pub id: String,
pub email: String,
pub role: String,
pub team_id: String,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct MeResponse {
pub id: String,
pub email: String,
pub role: String,
pub team_id: String,
}
fn cookie_secure() -> bool {
std::env::var("AUTH_COOKIE_SECURE")
.map(|v| v != "false" && v != "0")
.unwrap_or(true)
}
pub fn auth_cookie(token: &str) -> String {
let secure_flag = if cookie_secure() { " Secure;" } else { "" };
format!("auth_token={token};{secure_flag} HttpOnly; SameSite=Strict; Path=/; Max-Age=86400")
}
pub fn clear_auth_cookie() -> String {
let secure_flag = if cookie_secure() { " Secure;" } else { "" };
format!("auth_token=;{secure_flag} HttpOnly; SameSite=Strict; Path=/; Max-Age=0")
}
pub fn create_auth_token(
user: &User,
jwt_secret: &str,
expires_in_secs: u64,
) -> Result<String, jsonwebtoken::errors::Error> {
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use std::time::{SystemTime, UNIX_EPOCH};
let exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ expires_in_secs;
let claims = AuthClaims {
sub: user.id.clone(),
email: user.email.clone(),
role: user.role.clone(),
team_id: user.team_id.clone(),
exp,
};
encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(jwt_secret.as_bytes()),
)
}
pub fn validate_auth_token(
token: &str,
jwt_secret: &str,
) -> Result<AuthClaims, jsonwebtoken::errors::Error> {
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
let key = DecodingKey::from_secret(jwt_secret.as_bytes());
let validation = Validation::new(Algorithm::HS256);
let data = decode::<AuthClaims>(token, &key, &validation)?;
Ok(data.claims)
}
pub async fn auth_middleware(
State(jwt_secret): State<Arc<String>>,
mut req: Request<Body>,
next: Next,
) -> Response {
let token = extract_bearer_token(req.headers()).or_else(|| extract_cookie_token(req.headers()));
let token = match token {
Some(token) => token,
None => {
return (
StatusCode::UNAUTHORIZED,
"Missing or invalid Authorization header",
)
.into_response();
}
};
let claims = match validate_auth_token(&token, &jwt_secret) {
Ok(claims) => claims,
Err(e) => {
tracing::debug!("JWT validation failed: {:?}", e);
return (StatusCode::UNAUTHORIZED, "Invalid or expired token").into_response();
}
};
req.extensions_mut().insert(claims);
next.run(req).await
}
fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
let value = headers.get(axum::http::header::AUTHORIZATION)?;
let s = value.to_str().ok()?;
let mut parts = s.splitn(2, char::is_whitespace);
let scheme = parts.next()?;
if scheme.eq_ignore_ascii_case("bearer") {
Some(parts.next()?.trim().to_string())
} else {
None
}
}
fn extract_cookie_token(headers: &HeaderMap) -> Option<String> {
let cookie_header = headers.get(axum::http::header::COOKIE)?.to_str().ok()?;
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix("auth_token=") {
if !value.is_empty() {
return Some(value.to_string());
}
}
}
None
}
pub struct RequireAdmin(pub AuthClaims);
impl<S: Send + Sync> FromRequestParts<S> for RequireAdmin {
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claims = parts
.extensions
.get::<AuthClaims>()
.cloned()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "Not authenticated").into_response())?;
if claims.role != "admin" {
return Err((StatusCode::FORBIDDEN, "Admin access required").into_response());
}
Ok(RequireAdmin(claims))
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn test_create_and_validate_token() {
let user = User {
id: "user-123".to_string(),
team_id: "team-456".to_string(),
email: "test@example.com".to_string(),
role: "admin".to_string(),
password_hash: None,
created_at: Utc::now(),
};
let secret = "test-secret";
let token = create_auth_token(&user, secret, 3600).unwrap();
let claims = validate_auth_token(&token, secret).unwrap();
assert_eq!(claims.sub, user.id);
assert_eq!(claims.email, user.email);
assert_eq!(claims.role, user.role);
assert_eq!(claims.team_id, user.team_id);
}
#[test]
fn test_validate_token_wrong_secret() {
let user = User {
id: "user-123".to_string(),
team_id: "team-456".to_string(),
email: "test@example.com".to_string(),
role: "admin".to_string(),
password_hash: None,
created_at: Utc::now(),
};
let token = create_auth_token(&user, "correct-secret", 3600).unwrap();
assert!(validate_auth_token(&token, "wrong-secret").is_err());
}
#[test]
fn test_extract_bearer_token() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
axum::http::HeaderValue::from_static("Bearer test-token"),
);
assert_eq!(
extract_bearer_token(&headers),
Some("test-token".to_string())
);
let mut headers2 = HeaderMap::new();
headers2.insert(
axum::http::header::AUTHORIZATION,
axum::http::HeaderValue::from_static("bearer test-token"),
);
assert_eq!(
extract_bearer_token(&headers2),
Some("test-token".to_string())
);
let empty_headers = HeaderMap::new();
assert_eq!(extract_bearer_token(&empty_headers), None);
let mut headers3 = HeaderMap::new();
headers3.insert(
axum::http::header::AUTHORIZATION,
axum::http::HeaderValue::from_static("Basic test-token"),
);
assert_eq!(extract_bearer_token(&headers3), None);
}
}