auth-framework 0.4.2

A comprehensive, production-ready authentication and authorization framework for Rust applications
Documentation
//! REST API Server Implementation
//!
//! Main server that hosts all API endpoints

use crate::AuthFramework;
use crate::api::{ApiState, admin, auth, health, mfa, middleware, oauth, users};
use axum::{
    Router,
    extract::DefaultBodyLimit,
    http::Method,
    middleware as axum_middleware,
    routing::{delete, get, post, put},
};
use std::net::SocketAddr;
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
    cors::{Any, CorsLayer},
    trace::TraceLayer,
};
use tracing::info;

/// API Server configuration
#[derive(Debug, Clone)]
pub struct ApiServerConfig {
    pub host: String,
    pub port: u16,
    pub enable_cors: bool,
    pub max_body_size: usize,
    pub enable_tracing: bool,
}

impl Default for ApiServerConfig {
    fn default() -> Self {
        Self {
            host: "127.0.0.1".to_string(),
            port: 8080,
            enable_cors: true,
            max_body_size: 1024 * 1024, // 1MB
            enable_tracing: true,
        }
    }
}

/// REST API Server
pub struct ApiServer {
    config: ApiServerConfig,
    auth_framework: Arc<AuthFramework>,
}

impl ApiServer {
    /// Create new API server
    pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
        Self {
            config: ApiServerConfig::default(),
            auth_framework,
        }
    }

    /// Create new API server with custom configuration
    pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
        Self {
            config,
            auth_framework,
        }
    }

    /// Build the router with all routes and middleware
    pub async fn build_router(&self) -> crate::errors::Result<Router> {
        let state = ApiState::new(self.auth_framework.clone()).await?;

        // Create the main router with all routes
        let router = Router::new()
            // Health and monitoring endpoints (public)
            .route("/health", get(health::health_check))
            .route("/health/detailed", get(health::detailed_health_check))
            .route("/metrics", get(health::metrics))
            .route("/readiness", get(health::readiness_check))
            .route("/liveness", get(health::liveness_check))
            // Authentication endpoints (public)
            .route("/auth/login", post(auth::login))
            .route("/auth/refresh", post(auth::refresh_token))
            .route("/auth/logout", post(auth::logout))
            .route("/auth/validate", get(auth::validate_token))
            .route("/auth/providers", get(auth::list_providers))
            // OAuth 2.0 endpoints (mostly public)
            .route("/oauth/authorize", get(oauth::authorize))
            .route("/oauth/token", post(oauth::token))
            .route("/oauth/revoke", post(oauth::revoke_token))
            .route("/oauth/introspect", post(oauth::introspect_token))
            .route("/oauth/clients/:client_id", get(oauth::get_client_info))
            // User management endpoints (authenticated)
            .route("/users/profile", get(users::get_profile))
            .route("/users/profile", put(users::update_profile))
            .route("/users/change-password", post(users::change_password))
            .route("/users/sessions", get(users::get_sessions))
            .route("/users/sessions/:session_id", delete(users::revoke_session))
            .route("/users/:user_id/profile", get(users::get_user_profile))
            // Multi-factor authentication endpoints (authenticated)
            .route("/mfa/setup", post(mfa::setup_mfa))
            .route("/mfa/verify", post(mfa::verify_mfa))
            .route("/mfa/disable", post(mfa::disable_mfa))
            .route("/mfa/status", get(mfa::get_mfa_status))
            .route(
                "/mfa/regenerate-backup-codes",
                post(mfa::regenerate_backup_codes),
            )
            .route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
            // Administrative endpoints (admin only)
            .route("/admin/users", get(admin::list_users))
            .route("/admin/users", post(admin::create_user))
            .route("/admin/users/:user_id/roles", put(admin::update_user_roles))
            .route("/admin/users/:user_id", delete(admin::delete_user))
            .route("/admin/users/:user_id/activate", put(admin::activate_user))
            .route("/admin/stats", get(admin::get_system_stats))
            .route("/admin/audit-logs", get(admin::get_audit_logs))
            // Set shared state
            .with_state(state.clone());

        // Add middleware layers
        let middleware_stack = ServiceBuilder::new()
            .layer(axum_middleware::from_fn(middleware::timeout_middleware))
            .layer(axum_middleware::from_fn(
                middleware::security_headers_middleware,
            ))
            .layer(axum_middleware::from_fn(middleware::rate_limit_middleware))
            .layer(axum_middleware::from_fn(middleware::logging_middleware));

        let router = if self.config.enable_cors {
            router.layer(
                CorsLayer::new()
                    .allow_origin(Any)
                    .allow_methods([
                        Method::GET,
                        Method::POST,
                        Method::PUT,
                        Method::DELETE,
                        Method::OPTIONS,
                    ])
                    .allow_headers(Any)
                    .max_age(std::time::Duration::from_secs(3600)),
            )
        } else {
            router
        };

        let router = if self.config.enable_tracing {
            router.layer(TraceLayer::new_for_http())
        } else {
            router
        };

        Ok(router
            .layer(middleware_stack)
            .layer(DefaultBodyLimit::max(self.config.max_body_size)))
    }

    /// Start the API server
    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
        let app = self.build_router().await?;

        let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);

        info!("🚀 AuthFramework API server starting on http://{}", addr);
        info!("📖 API documentation available at http://{}/docs", addr);
        info!("🏥 Health check available at http://{}/health", addr);
        info!("📊 Metrics available at http://{}/metrics", addr);

        let listener = tokio::net::TcpListener::bind(addr).await?;
        axum::serve(listener, app).await?;

        Ok(())
    }

    /// Get server configuration
    pub fn config(&self) -> &ApiServerConfig {
        &self.config
    }

    /// Get server address
    pub fn address(&self) -> String {
        format!("{}:{}", self.config.host, self.config.port)
    }
}

/// Create a basic API server with default configuration
pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
    ApiServer::new(auth_framework)
}

/// Create an API server with custom host and port
pub async fn create_api_server_with_address(
    auth_framework: Arc<AuthFramework>,
    host: impl Into<String>,
    port: u16,
) -> ApiServer {
    let config = ApiServerConfig {
        host: host.into(),
        port,
        ..Default::default()
    };
    ApiServer::with_config(auth_framework, config)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::storage::memory::InMemoryStorage;
    use crate::{AuthConfig, AuthFramework};
    use axum_test::TestServer;

    #[ignore = "TestServer compatibility issue - axum-test version mismatch"]
    async fn create_test_server() -> TestServer {
        let _storage = Arc::new(InMemoryStorage::new());
        let config = AuthConfig::default();
        let auth_framework = Arc::new(AuthFramework::new(config));

        let api_server = ApiServer::new(auth_framework);
        let _app = api_server.build_router().await.unwrap();

        // Note: TestServer compatibility issue with current axum-test version
        // The Router type doesn't properly implement IntoTransportLayer
        todo!("TestServer::new needs axum-test compatibility fix")
    }

    #[tokio::test]
    #[ignore = "TestServer compatibility issue"]
    async fn test_health_endpoint() {
        let server = create_test_server().await;

        let response = server.get("/health").await;
        response.assert_status_ok();

        let body: serde_json::Value = response.json();
        assert_eq!(body["status"], "healthy");
    }

    #[tokio::test]
    #[ignore = "TestServer compatibility issue"]
    async fn test_auth_required_endpoints() {
        let server = create_test_server().await;

        // Try to access protected endpoint without token
        let response = server.get("/users/profile").await;
        response.assert_status_unauthorized();
    }

    #[tokio::test]
    #[ignore = "TestServer compatibility issue"]
    async fn test_cors_headers() {
        let server = create_test_server().await;

        let response = server.get("/health").await;
        response.assert_status_ok();

        // Check CORS headers are present
        assert!(
            response
                .headers()
                .contains_key("access-control-allow-origin")
        );
    }
}