reasonkit-web 0.1.7

High-performance MCP server for browser automation, web capture, and content extraction. Rust-powered CDP client for AI agents.
Documentation
//! # Portal Middleware
//!
//! JWT authentication middleware for protecting routes.
//!
//! ## Usage
//!
//! ```ignore
//! use reasonkit_web::portal::middleware::RequireAuth;
//!
//! let protected_routes = Router::new()
//!     .route("/protected", get(handler))
//!     .layer(RequireAuth::new(auth_service));
//! ```

use axum::{
    extract::Request,
    http::{header, StatusCode},
    middleware::Next,
    response::Response,
    Json,
};
use serde::Serialize;

use crate::portal::auth::{AuthService, Claims};

/// Authentication error response
#[derive(Debug, Serialize)]
pub struct AuthErrorResponse {
    pub error: String,
    pub code: String,
}

/// Extract JWT token from Authorization header
fn extract_token(req: &Request) -> Option<&str> {
    req.headers()
        .get(header::AUTHORIZATION)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.strip_prefix("Bearer "))
}

/// JWT authentication middleware
///
/// Validates the JWT token and injects Claims into request extensions.
pub async fn require_auth(
    req: Request,
    next: Next,
) -> Result<Response, (StatusCode, Json<AuthErrorResponse>)> {
    // Extract token from header
    let token = extract_token(&req).ok_or_else(|| {
        (
            StatusCode::UNAUTHORIZED,
            Json(AuthErrorResponse {
                error: "Missing or invalid Authorization header".to_string(),
                code: "MISSING_TOKEN".to_string(),
            }),
        )
    })?;

    // Create auth service with default config
    let auth_service = AuthService::new(Default::default());

    // Validate token
    let claims = auth_service.validate_token(token).map_err(|e| {
        (
            StatusCode::UNAUTHORIZED,
            Json(AuthErrorResponse {
                error: e.to_string(),
                code: "INVALID_TOKEN".to_string(),
            }),
        )
    })?;

    // Check if it's an access token (not refresh)
    if claims.token_type != crate::portal::auth::TokenType::Access {
        return Err((
            StatusCode::UNAUTHORIZED,
            Json(AuthErrorResponse {
                error: "Invalid token type".to_string(),
                code: "WRONG_TOKEN_TYPE".to_string(),
            }),
        ));
    }

    // Insert claims into request extensions for handlers to access
    let mut req = req;
    req.extensions_mut().insert(claims);

    Ok(next.run(req).await)
}

/// Optional authentication middleware
///
/// Validates JWT if present but allows unauthenticated requests.
pub async fn optional_auth(req: Request, next: Next) -> Response {
    // Try to extract and validate token
    if let Some(token) = extract_token(&req) {
        let auth_service = AuthService::new(Default::default());
        if let Ok(claims) = auth_service.validate_token(token) {
            if claims.token_type == crate::portal::auth::TokenType::Access {
                let mut req = req;
                req.extensions_mut().insert(claims);
                return next.run(req).await;
            }
        }
    }

    // Continue without authentication
    next.run(req).await
}

/// Scope-checking middleware
///
/// Requires the authenticated user to have a specific scope.
#[allow(clippy::type_complexity)]
pub fn require_scope(
    required_scope: &'static str,
) -> impl Fn(
    Request,
    Next,
) -> std::pin::Pin<
    Box<
        dyn std::future::Future<Output = Result<Response, (StatusCode, Json<AuthErrorResponse>)>>
            + Send,
    >,
> + Clone {
    move |req: Request, next: Next| {
        Box::pin(async move {
            // Get claims from extensions (set by require_auth)
            let claims = req.extensions().get::<Claims>().ok_or_else(|| {
                (
                    StatusCode::UNAUTHORIZED,
                    Json(AuthErrorResponse {
                        error: "Not authenticated".to_string(),
                        code: "NOT_AUTHENTICATED".to_string(),
                    }),
                )
            })?;

            // Check if user has required scope
            if !claims
                .scopes
                .iter()
                .any(|s| s == required_scope || s == "admin")
            {
                return Err((
                    StatusCode::FORBIDDEN,
                    Json(AuthErrorResponse {
                        error: format!("Missing required scope: {}", required_scope),
                        code: "INSUFFICIENT_SCOPE".to_string(),
                    }),
                ));
            }

            Ok(next.run(req).await)
        })
    }
}

/// Extract authenticated user claims from request
pub fn get_claims(req: &Request) -> Option<&Claims> {
    req.extensions().get::<Claims>()
}

/// Axum extractor for authenticated claims
#[derive(Debug, Clone)]
pub struct AuthClaims(pub Claims);

#[axum::async_trait]
impl<S> axum::extract::FromRequestParts<S> for AuthClaims
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, Json<AuthErrorResponse>);

    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        _state: &S,
    ) -> Result<Self, Self::Rejection> {
        parts
            .extensions
            .get::<Claims>()
            .cloned()
            .map(AuthClaims)
            .ok_or_else(|| {
                (
                    StatusCode::UNAUTHORIZED,
                    Json(AuthErrorResponse {
                        error: "Not authenticated".to_string(),
                        code: "NOT_AUTHENTICATED".to_string(),
                    }),
                )
            })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::body::Body;
    use axum::http::Request;

    #[test]
    fn test_extract_token_valid() {
        let req = Request::builder()
            .header("Authorization", "Bearer test_token_123")
            .body(Body::empty())
            .unwrap();

        assert_eq!(extract_token(&req), Some("test_token_123"));
    }

    #[test]
    fn test_extract_token_missing() {
        let req = Request::builder().body(Body::empty()).unwrap();

        assert_eq!(extract_token(&req), None);
    }

    #[test]
    fn test_extract_token_invalid_format() {
        let req = Request::builder()
            .header("Authorization", "Basic user:pass")
            .body(Body::empty())
            .unwrap();

        assert_eq!(extract_token(&req), None);
    }
}