hyperinfer-server 0.1.1

High-performance LLM Gateway server built with Axum
//! Authentication module for user/password auth with JWT
//!
//! Provides endpoints for:
//! - POST /auth/login - Login with email/password, returns JWT
//! - GET /auth/me - Get current user info from JWT
//! - POST /auth/logout - Logout (client-side token removal)

use axum::{
    body::Body,
    extract::{FromRequestParts, Request, State},
    http::{request::Parts, HeaderMap, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
};
use hyperinfer_core::User;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use utoipa::ToSchema;

// ── JWT Claims ───────────────────────────────────────────────────────────────

/// Claims expected inside a user authentication JWT.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AuthClaims {
    /// User ID (UUID)
    pub sub: String,
    /// User email
    pub email: String,
    /// User role ("admin" or "member")
    pub role: String,
    /// Team ID
    pub team_id: String,
    /// Expiration timestamp
    pub exp: u64,
}

// ── Request/Response Types ────────────────────────────────────────────────────

#[derive(Debug, Deserialize, ToSchema)]
pub struct LoginRequest {
    pub email: String,
    pub password: String,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct LoginResponse {
    pub id: String,
    pub email: String,
    pub role: String,
    pub team_id: String,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct MeResponse {
    pub id: String,
    pub email: String,
    pub role: String,
    pub team_id: String,
}

fn cookie_secure() -> bool {
    std::env::var("AUTH_COOKIE_SECURE")
        .map(|v| v != "false" && v != "0")
        .unwrap_or(true)
}

/// Build a Set-Cookie header value for the auth token.
pub fn auth_cookie(token: &str) -> String {
    let secure_flag = if cookie_secure() { " Secure;" } else { "" };
    format!("auth_token={token};{secure_flag} HttpOnly; SameSite=Strict; Path=/; Max-Age=86400")
}

/// Build a Set-Cookie header value that clears the auth cookie.
pub fn clear_auth_cookie() -> String {
    let secure_flag = if cookie_secure() { " Secure;" } else { "" };
    format!("auth_token=;{secure_flag} HttpOnly; SameSite=Strict; Path=/; Max-Age=0")
}

// ── JWT Token Generation ────────────────────────────────────────────────────

pub fn create_auth_token(
    user: &User,
    jwt_secret: &str,
    expires_in_secs: u64,
) -> Result<String, jsonwebtoken::errors::Error> {
    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
    use std::time::{SystemTime, UNIX_EPOCH};

    let exp = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs()
        + expires_in_secs;

    let claims = AuthClaims {
        sub: user.id.clone(),
        email: user.email.clone(),
        role: user.role.clone(),
        team_id: user.team_id.clone(),
        exp,
    };

    encode(
        &Header::new(Algorithm::HS256),
        &claims,
        &EncodingKey::from_secret(jwt_secret.as_bytes()),
    )
}

pub fn validate_auth_token(
    token: &str,
    jwt_secret: &str,
) -> Result<AuthClaims, jsonwebtoken::errors::Error> {
    use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};

    let key = DecodingKey::from_secret(jwt_secret.as_bytes());
    let validation = Validation::new(Algorithm::HS256);
    let data = decode::<AuthClaims>(token, &key, &validation)?;
    Ok(data.claims)
}

// ── Middleware ────────────────────────────────────────────────────────────────

/// Extract JWT from Authorization header or auth_token cookie and validate it.
/// On success, adds AuthClaims to request extensions.
pub async fn auth_middleware(
    State(jwt_secret): State<Arc<String>>,
    mut req: Request<Body>,
    next: Next,
) -> Response {
    let token = extract_bearer_token(req.headers()).or_else(|| extract_cookie_token(req.headers()));

    let token = match token {
        Some(token) => token,
        None => {
            return (
                StatusCode::UNAUTHORIZED,
                "Missing or invalid Authorization header",
            )
                .into_response();
        }
    };

    let claims = match validate_auth_token(&token, &jwt_secret) {
        Ok(claims) => claims,
        Err(e) => {
            tracing::debug!("JWT validation failed: {:?}", e);
            return (StatusCode::UNAUTHORIZED, "Invalid or expired token").into_response();
        }
    };

    req.extensions_mut().insert(claims);
    next.run(req).await
}

fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
    let value = headers.get(axum::http::header::AUTHORIZATION)?;
    let s = value.to_str().ok()?;
    let mut parts = s.splitn(2, char::is_whitespace);
    let scheme = parts.next()?;
    if scheme.eq_ignore_ascii_case("bearer") {
        Some(parts.next()?.trim().to_string())
    } else {
        None
    }
}

fn extract_cookie_token(headers: &HeaderMap) -> Option<String> {
    let cookie_header = headers.get(axum::http::header::COOKIE)?.to_str().ok()?;
    for cookie in cookie_header.split(';') {
        let cookie = cookie.trim();
        if let Some(value) = cookie.strip_prefix("auth_token=") {
            if !value.is_empty() {
                return Some(value.to_string());
            }
        }
    }
    None
}

// ── Admin Extractor ───────────────────────────────────────────────────────────

/// Extractor that ensures the authenticated user has the "admin" role.
/// Must be used after `auth_middleware` has injected `AuthClaims`.
pub struct RequireAdmin(pub AuthClaims);

impl<S: Send + Sync> FromRequestParts<S> for RequireAdmin {
    type Rejection = Response;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        let claims = parts
            .extensions
            .get::<AuthClaims>()
            .cloned()
            .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Not authenticated").into_response())?;

        if claims.role != "admin" {
            return Err((StatusCode::FORBIDDEN, "Admin access required").into_response());
        }

        Ok(RequireAdmin(claims))
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::Utc;

    #[test]
    fn test_create_and_validate_token() {
        let user = User {
            id: "user-123".to_string(),
            team_id: "team-456".to_string(),
            email: "test@example.com".to_string(),
            role: "admin".to_string(),
            password_hash: None,
            created_at: Utc::now(),
        };
        let secret = "test-secret";

        let token = create_auth_token(&user, secret, 3600).unwrap();
        let claims = validate_auth_token(&token, secret).unwrap();

        assert_eq!(claims.sub, user.id);
        assert_eq!(claims.email, user.email);
        assert_eq!(claims.role, user.role);
        assert_eq!(claims.team_id, user.team_id);
    }

    #[test]
    fn test_validate_token_wrong_secret() {
        let user = User {
            id: "user-123".to_string(),
            team_id: "team-456".to_string(),
            email: "test@example.com".to_string(),
            role: "admin".to_string(),
            password_hash: None,
            created_at: Utc::now(),
        };

        let token = create_auth_token(&user, "correct-secret", 3600).unwrap();
        assert!(validate_auth_token(&token, "wrong-secret").is_err());
    }

    #[test]
    fn test_extract_bearer_token() {
        let mut headers = HeaderMap::new();
        headers.insert(
            axum::http::header::AUTHORIZATION,
            axum::http::HeaderValue::from_static("Bearer test-token"),
        );
        assert_eq!(
            extract_bearer_token(&headers),
            Some("test-token".to_string())
        );

        // Case insensitive
        let mut headers2 = HeaderMap::new();
        headers2.insert(
            axum::http::header::AUTHORIZATION,
            axum::http::HeaderValue::from_static("bearer test-token"),
        );
        assert_eq!(
            extract_bearer_token(&headers2),
            Some("test-token".to_string())
        );

        // Missing header
        let empty_headers = HeaderMap::new();
        assert_eq!(extract_bearer_token(&empty_headers), None);

        // Wrong scheme
        let mut headers3 = HeaderMap::new();
        headers3.insert(
            axum::http::header::AUTHORIZATION,
            axum::http::HeaderValue::from_static("Basic test-token"),
        );
        assert_eq!(extract_bearer_token(&headers3), None);
    }
}