Skip to main content

auth_framework/api/
server.rs

1//! REST API Server Implementation
2//!
3//! Main server that hosts all API endpoints
4
5use crate::AuthFramework;
6#[cfg(feature = "saml")]
7use crate::api::saml;
8use crate::api::{
9    ApiState, admin, advanced_protocols, auth, email_verification, health, mfa, middleware,
10    oauth_advanced, oauth2, openapi, users, webauthn,
11};
12use axum::{
13    Router,
14    extract::DefaultBodyLimit,
15    http::Method,
16    middleware as axum_middleware,
17    routing::{delete, get, post, put},
18};
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tower::ServiceBuilder;
22use tower_http::{cors::CorsLayer, trace::TraceLayer};
23use tracing::info;
24
25/// REST API server configuration.
26///
27/// Use [`ApiServerConfig::builder()`] for ergonomic construction:
28///
29/// ```rust,ignore
30/// let config = ApiServerConfig::builder()
31///     .host("0.0.0.0")
32///     .port(3000)
33///     .enable_cors(true)
34///     .allow_origin("https://example.com")
35///     .build();
36/// ```
37///
38/// Default values bind to `127.0.0.1:8080` with tracing enabled, CORS disabled,
39/// and a 1 MB maximum request body.
40#[derive(Debug, Clone)]
41pub struct ApiServerConfig {
42    /// Address to bind the server to (default: `"127.0.0.1"`).
43    pub host: String,
44    /// TCP port to listen on (default: `8080`).
45    pub port: u16,
46    /// Centralized CORS configuration. Enable and set `allowed_origins` to
47    /// permit cross-origin requests. Origins are validated strictly — wildcard
48    /// (`"*"`) is never accepted.
49    pub cors: crate::config::CorsConfig,
50    /// Maximum allowed request body size in bytes (default: 1 MB).
51    pub max_body_size: usize,
52    /// Whether to attach a `tower_http::TraceLayer` for structured request/response logging.
53    pub enable_tracing: bool,
54}
55
56impl ApiServerConfig {
57    /// Convenience: is CORS enabled?
58    pub fn enable_cors(&self) -> bool {
59        self.cors.enabled
60    }
61}
62
63impl Default for ApiServerConfig {
64    fn default() -> Self {
65        Self {
66            host: "127.0.0.1".to_string(),
67            port: 8080,
68            cors: crate::config::CorsConfig::default(), // disabled by default
69            max_body_size: 1024 * 1024,                 // 1MB
70            enable_tracing: true,
71        }
72    }
73}
74
75/// REST API Server
76
77impl ApiServerConfig {
78    /// Create a new builder for `ApiServerConfig`
79    pub fn builder() -> ApiServerConfigBuilder {
80        ApiServerConfigBuilder::default()
81    }
82}
83
84/// Fluent builder for [`ApiServerConfig`].
85///
86/// Obtain via [`ApiServerConfig::builder()`].  All fields start with the same
87/// defaults as `ApiServerConfig::default()`.
88pub struct ApiServerConfigBuilder {
89    config: ApiServerConfig,
90}
91
92impl Default for ApiServerConfigBuilder {
93    fn default() -> Self {
94        Self {
95            config: ApiServerConfig::default(),
96        }
97    }
98}
99
100impl ApiServerConfigBuilder {
101    /// Set the address to bind to (e.g. `"0.0.0.0"`).
102    pub fn host(mut self, host: impl Into<String>) -> Self {
103        self.config.host = host.into();
104        self
105    }
106
107    /// Set the TCP port (e.g. `3000`).
108    pub fn port(mut self, port: u16) -> Self {
109        self.config.port = port;
110        self
111    }
112
113    /// Enable or disable CORS (default: disabled).
114    pub fn enable_cors(mut self, enable: bool) -> Self {
115        self.config.cors.enabled = enable;
116        self
117    }
118
119    /// Append a single allowed origin for CORS (e.g. `"https://example.com"`).
120    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
121        self.config.cors.allowed_origins.push(origin.into());
122        self
123    }
124
125    /// Replace the allowed origins list for CORS.
126    pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
127        self.config.cors.allowed_origins = origins;
128        self
129    }
130
131    /// Set the maximum request body size in bytes (default: 1 MB).
132    pub fn max_body_size(mut self, size: usize) -> Self {
133        self.config.max_body_size = size;
134        self
135    }
136
137    /// Enable or disable structured request/response tracing (default: enabled).
138    pub fn enable_tracing(mut self, enable: bool) -> Self {
139        self.config.enable_tracing = enable;
140        self
141    }
142
143    /// Consume the builder and return the finished [`ApiServerConfig`].
144    pub fn build(self) -> ApiServerConfig {
145        self.config
146    }
147}
148
149/// The REST API server that hosts all authentication, user-management,
150/// and health-check endpoints.
151///
152/// # Example
153///
154/// ```rust,ignore
155/// let server = ApiServer::with_config(auth.clone(), config);
156/// server.start().await?;
157/// ```
158pub struct ApiServer {
159    config: ApiServerConfig,
160    auth_framework: Arc<AuthFramework>,
161}
162
163impl ApiServer {
164    /// Create a server with the default [`ApiServerConfig`].
165    pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
166        Self {
167            config: ApiServerConfig::default(),
168            auth_framework,
169        }
170    }
171
172    /// Create a server with a custom [`ApiServerConfig`].
173    pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
174        Self {
175            config,
176            auth_framework,
177        }
178    }
179
180    /// Assemble the Axum [`Router`] with all route groups and middleware.
181    pub async fn build_router(&self) -> crate::errors::Result<Router> {
182        let state = ApiState::new(self.auth_framework.clone()).await?;
183
184        // Create nested router for API v1
185        let api_v1 = Router::new()
186            // Health and monitoring endpoints (public)
187            .route("/health", get(health::health_check))
188            .route("/health/detailed", get(health::detailed_health_check))
189            .route("/metrics", get(health::metrics))
190            .route("/readiness", get(health::readiness_check))
191            .route("/liveness", get(health::liveness_check))
192            // Authentication endpoints (public)
193            .route("/auth/login", post(auth::login))
194            .route("/auth/register", post(auth::register))
195            .route("/auth/refresh", post(auth::refresh_token))
196            .route("/auth/logout", post(auth::logout))
197            .route("/auth/validate", get(auth::validate_token))
198            .route("/auth/providers", get(auth::list_providers))
199            .route("/api-keys", post(auth::create_api_key))
200            // Email verification endpoints
201            .route(
202                "/auth/verify-email/send",
203                post(email_verification::send_verification),
204            )
205            .route("/auth/verify-email", post(email_verification::verify_email))
206            .route(
207                "/auth/resend-verification",
208                post(email_verification::resend_verification),
209            )
210            // OAuth 2.0 endpoints
211            .route("/oauth/authorize", get(oauth2::authorize))
212            .route("/oauth/token", post(oauth2::token))
213            .route("/oauth/revoke", post(oauth2::revoke))
214            // RFC 7662: Token Introspection (form-encoded, client auth required)
215            .route("/oauth/introspect", post(oauth_advanced::introspect_token))
216            // RFC 9126: Pushed Authorization Requests
217            .route(
218                "/oauth/par",
219                post(oauth_advanced::pushed_authorization_request),
220            )
221            .route("/oauth/clients/{client_id}", get(oauth2::get_client_info))
222            // RFC 8628: Device Authorization Grant
223            .route("/oauth/device", post(oauth_advanced::device_authorization))
224            // OpenID Connect CIBA (Client Initiated Backchannel Auth)
225            .route("/oauth/ciba", post(oauth_advanced::ciba_backchannel_auth))
226            // OIDC UserInfo endpoint
227            .route("/oauth/userinfo", get(oauth2::userinfo))
228            // OIDC RP-Initiated Logout
229            .route("/oauth/end_session", get(oauth2::end_session))
230            // RFC 7591: Dynamic Client Registration
231            .route("/oauth/register", post(oauth2::register_client))
232            // OpenID Connect Discovery
233            .route(
234                "/.well-known/openid-configuration",
235                get(oauth2::openid_configuration),
236            )
237            // JWKS endpoint
238            .route("/.well-known/jwks.json", get(oauth2::jwks))
239            // User management endpoints (authenticated)
240            .route("/users/me", get(oauth2::users_me))
241            .route("/users/profile", get(users::get_profile))
242            .route("/users/profile", put(users::update_profile))
243            .route("/users/change-password", post(users::change_password))
244            .route("/users/sessions", get(users::get_sessions))
245            .route(
246                "/users/sessions/{session_id}",
247                delete(users::revoke_session),
248            )
249            .route("/users/{user_id}/profile", get(users::get_user_profile))
250            // Multi-factor authentication endpoints (authenticated)
251            .route("/mfa/setup", post(mfa::setup_mfa))
252            .route("/mfa/verify", post(mfa::verify_mfa))
253            .route("/mfa/disable", post(mfa::disable_mfa))
254            .route("/mfa/status", get(mfa::get_mfa_status))
255            .route(
256                "/mfa/regenerate-backup-codes",
257                post(mfa::regenerate_backup_codes),
258            )
259            .route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
260            // Administrative endpoints (admin only)
261            .route("/admin/users", get(admin::list_users))
262            .route("/admin/users", post(admin::create_user))
263            .route(
264                "/admin/users/{user_id}/roles",
265                put(admin::update_user_roles),
266            )
267            .route("/admin/users/{user_id}", delete(admin::delete_user))
268            .route("/admin/users/{user_id}/activate", put(admin::activate_user))
269            .route("/admin/stats", get(admin::get_system_stats))
270            .route("/admin/audit-logs", get(admin::get_audit_logs))
271            .route("/admin/audit-logs/stats", get(admin::get_audit_log_stats))
272            .route(
273                "/admin/config",
274                get(admin::get_config).put(admin::update_config),
275            )
276            // WebAuthn endpoints
277            .route(
278                "/webauthn/registration/init",
279                post(webauthn::webauthn_registration_init),
280            )
281            .route(
282                "/webauthn/registration/complete",
283                post(webauthn::webauthn_registration_complete),
284            )
285            .route(
286                "/webauthn/authentication/init",
287                post(webauthn::webauthn_authentication_init),
288            )
289            .route(
290                "/webauthn/authentication/complete",
291                post(webauthn::webauthn_authentication_complete),
292            )
293            .route(
294                "/webauthn/credentials/{username}",
295                get(webauthn::list_webauthn_credentials),
296            )
297            .route(
298                "/webauthn/credentials/{username}/{credential_id}",
299                delete(webauthn::delete_webauthn_credential),
300            );
301
302        // Build the router with conditional SAML routes
303        let api_v1 = {
304            let router = api_v1;
305
306            #[cfg(feature = "saml")]
307            {
308                router
309                    .route("/saml/metadata", get(saml::get_saml_metadata))
310                    .route("/saml/sso", post(saml::initiate_saml_sso))
311                    .route("/saml/acs", post(saml::handle_saml_acs))
312                    .route("/saml/slo", post(saml::initiate_saml_slo))
313                    .route("/saml/slo/response", get(saml::handle_saml_slo_response))
314                    .route("/saml/assertion", post(saml::create_saml_assertion))
315                    .route("/saml/idps", get(saml::list_saml_idps))
316            }
317
318            #[cfg(not(feature = "saml"))]
319            {
320                router
321            }
322        };
323
324        // Create the main router with all routes
325        let router = Router::new()
326            .route("/api/openapi.json", get(openapi::serve_openapi_json))
327            .route("/docs", get(openapi::serve_swagger_ui))
328            .nest("/api/v1", api_v1)
329            .merge(advanced_protocols::router())
330            .with_state(state.clone());
331
332        // Add middleware layers
333        let middleware_stack = ServiceBuilder::new()
334            .layer(axum_middleware::from_fn(middleware::timeout_middleware))
335            .layer(axum_middleware::from_fn(
336                middleware::security_headers_middleware,
337            ))
338            .layer(axum_middleware::from_fn({
339                let state = state.clone();
340                move |request, next| {
341                    let state = state.clone();
342                    async move {
343                        middleware::rate_limit_middleware_with_state(state, request, next).await
344                    }
345                }
346            }))
347            .layer(axum_middleware::from_fn(middleware::logging_middleware));
348
349        let router = if self.config.cors.enabled {
350            if self.config.cors.allowed_origins.is_empty() {
351                tracing::warn!(
352                    "SECURITY/CORS: CORS is enabled but allowed_origins is empty. All cross-origin requests will be rejected! Disable CORS or add allowed origins."
353                );
354            }
355
356            let header_origins: Vec<axum::http::HeaderValue> = self
357                .config
358                .cors
359                .allowed_origins
360                .iter()
361                .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
362                .collect();
363
364            if header_origins.is_empty() && !self.config.cors.allowed_origins.is_empty() {
365                tracing::warn!(
366                    "CORS: none of the configured allowed_origins could be parsed as valid HTTP \
367                     header values; cross-origin requests will be rejected"
368                );
369            }
370
371            let allow_origin = tower_http::cors::AllowOrigin::list(header_origins);
372
373            let allowed_methods: Vec<Method> = self
374                .config
375                .cors
376                .allowed_methods
377                .iter()
378                .filter_map(|m| m.parse::<Method>().ok())
379                .collect();
380
381            let allowed_headers: Vec<axum::http::HeaderName> = self
382                .config
383                .cors
384                .allowed_headers
385                .iter()
386                .filter_map(|h| h.parse::<axum::http::HeaderName>().ok())
387                .collect();
388
389            router.layer(
390                CorsLayer::new()
391                    .allow_origin(allow_origin)
392                    .allow_methods(allowed_methods)
393                    .allow_headers(allowed_headers)
394                    .max_age(std::time::Duration::from_secs(
395                        self.config.cors.max_age_secs as u64,
396                    )),
397            )
398        } else {
399            router
400        };
401
402        let router = if self.config.enable_tracing {
403            router.layer(TraceLayer::new_for_http())
404        } else {
405            router
406        };
407
408        Ok(router
409            .layer(middleware_stack)
410            .layer(DefaultBodyLimit::max(self.config.max_body_size)))
411    }
412
413    /// Start the API server
414    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
415        let app = self.build_router().await?;
416
417        let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
418
419        info!("🚀 AuthFramework API server starting on http://{}", addr);
420        info!("📖 API documentation available at http://{}/docs", addr);
421        info!("📘 OpenAPI JSON available at http://{}/api/openapi.json", addr);
422        info!("🏥 Health check available at http://{}/health", addr);
423        info!("📊 Metrics available at http://{}/metrics", addr);
424
425        let listener = tokio::net::TcpListener::bind(addr).await?;
426        axum::serve(listener, app).await?;
427
428        Ok(())
429    }
430
431    /// Get server configuration
432    pub fn config(&self) -> &ApiServerConfig {
433        &self.config
434    }
435
436    /// Get server address
437    pub fn address(&self) -> String {
438        format!("{}:{}", self.config.host, self.config.port)
439    }
440}
441
442/// Create a basic API server with default configuration
443pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
444    ApiServer::new(auth_framework)
445}
446
447/// Create an API server with custom host and port
448pub async fn create_api_server_with_address(
449    auth_framework: Arc<AuthFramework>,
450    host: impl Into<String>,
451    port: u16,
452) -> ApiServer {
453    let config = ApiServerConfig {
454        host: host.into(),
455        port,
456        ..Default::default()
457    };
458    ApiServer::with_config(auth_framework, config)
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use crate::storage::memory::InMemoryStorage;
465    use crate::{AuthConfig, AuthFramework};
466    use axum::body::Body;
467    use axum::http::{Request, StatusCode};
468    use tower::ServiceExt;
469
470    async fn create_test_api_server() -> ApiServer {
471        let _storage = Arc::new(InMemoryStorage::new());
472        let config = AuthConfig::default();
473        let auth_framework = Arc::new(AuthFramework::new(config));
474        ApiServer::new(auth_framework)
475    }
476
477    #[tokio::test]
478    async fn test_health_endpoint() {
479        let api_server = create_test_api_server().await;
480        let router = api_server.build_router().await.unwrap();
481
482        let request = Request::builder()
483            .uri("/api/v1/health")
484            .method("GET")
485            .body(Body::empty())
486            .unwrap();
487
488        let response = router.oneshot(request).await.unwrap();
489        assert_eq!(response.status(), StatusCode::OK);
490    }
491
492    #[tokio::test]
493    async fn test_auth_required_endpoints() {
494        let api_server = create_test_api_server().await;
495        let router = api_server.build_router().await.unwrap();
496
497        let request = Request::builder()
498            .uri("/api/v1/users/profile")
499            .method("GET")
500            .body(Body::empty())
501            .unwrap();
502
503        let response = router.oneshot(request).await.unwrap();
504        // Protected endpoint should reject request without auth
505        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
506    }
507
508    #[tokio::test]
509    async fn test_cors_headers() {
510        let config = AuthConfig::default();
511        let auth_framework = Arc::new(AuthFramework::new(config));
512        let api_config = ApiServerConfig {
513            cors: crate::config::CorsConfig {
514                enabled: true,
515                allowed_origins: vec!["http://localhost:3000".to_string()],
516                ..crate::config::CorsConfig::default()
517            },
518            ..ApiServerConfig::default()
519        };
520        let api_server = ApiServer::with_config(auth_framework, api_config);
521        let router = api_server.build_router().await.unwrap();
522
523        let request = Request::builder()
524            .uri("/api/v1/health")
525            .method("GET")
526            .header("Origin", "http://localhost:3000")
527            .body(Body::empty())
528            .unwrap();
529
530        let response = router.oneshot(request).await.unwrap();
531        assert_eq!(response.status(), StatusCode::OK);
532
533        // Check CORS headers are present when a matching Origin is sent
534        assert!(
535            response
536                .headers()
537                .contains_key("access-control-allow-origin")
538        );
539    }
540
541    #[tokio::test]
542    async fn test_readiness_endpoint() {
543        let api_server = create_test_api_server().await;
544        let router = api_server.build_router().await.unwrap();
545
546        let request = Request::builder()
547            .uri("/api/v1/readiness")
548            .method("GET")
549            .body(Body::empty())
550            .unwrap();
551
552        let response = router.oneshot(request).await.unwrap();
553        // Should be OK or SERVICE_UNAVAILABLE, not a 404
554        assert!(
555            response.status() == StatusCode::OK
556                || response.status() == StatusCode::SERVICE_UNAVAILABLE
557        );
558    }
559
560    #[tokio::test]
561    async fn test_liveness_endpoint() {
562        let api_server = create_test_api_server().await;
563        let router = api_server.build_router().await.unwrap();
564
565        let request = Request::builder()
566            .uri("/api/v1/liveness")
567            .method("GET")
568            .body(Body::empty())
569            .unwrap();
570
571        let response = router.oneshot(request).await.unwrap();
572        assert_eq!(response.status(), StatusCode::OK);
573    }
574
575    #[tokio::test]
576    async fn test_metrics_endpoint() {
577        let api_server = create_test_api_server().await;
578        let router = api_server.build_router().await.unwrap();
579
580        let request = Request::builder()
581            .uri("/api/v1/metrics")
582            .method("GET")
583            .body(Body::empty())
584            .unwrap();
585
586        let response = router.oneshot(request).await.unwrap();
587        assert_eq!(response.status(), StatusCode::OK);
588    }
589
590    #[tokio::test]
591    async fn test_nonexistent_route_returns_404() {
592        let api_server = create_test_api_server().await;
593        let router = api_server.build_router().await.unwrap();
594
595        let request = Request::builder()
596            .uri("/api/v1/this-does-not-exist")
597            .method("GET")
598            .body(Body::empty())
599            .unwrap();
600
601        let response = router.oneshot(request).await.unwrap();
602        assert_eq!(response.status(), StatusCode::NOT_FOUND);
603    }
604
605    #[tokio::test]
606    async fn test_login_with_empty_body() {
607        let api_server = create_test_api_server().await;
608        let router = api_server.build_router().await.unwrap();
609
610        let request = Request::builder()
611            .uri("/api/v1/auth/login")
612            .method("POST")
613            .header("Content-Type", "application/json")
614            .body(Body::from("{}"))
615            .unwrap();
616
617        let response = router.oneshot(request).await.unwrap();
618        // Should return an error (400 or 422), not 200
619        assert_ne!(response.status(), StatusCode::OK);
620    }
621
622    #[tokio::test]
623    async fn test_register_endpoint_accessible() {
624        let api_server = create_test_api_server().await;
625        let router = api_server.build_router().await.unwrap();
626
627        let body = serde_json::json!({
628            "username": "newuser",
629            "password": "StrongP@ssw0rd123!",
630            "email": "test@example.com"
631        });
632
633        let request = Request::builder()
634            .uri("/api/v1/auth/register")
635            .method("POST")
636            .header("Content-Type", "application/json")
637            .body(Body::from(serde_json::to_string(&body).unwrap()))
638            .unwrap();
639
640        let response = router.oneshot(request).await.unwrap();
641        // It should process the request (not 404 or 405)
642        assert_ne!(response.status(), StatusCode::NOT_FOUND);
643        assert_ne!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
644    }
645
646    #[tokio::test]
647    async fn test_server_config_defaults() {
648        let config = ApiServerConfig::default();
649        assert_eq!(config.host, "127.0.0.1");
650        assert_eq!(config.port, 8080);
651        assert!(!config.enable_cors());
652    }
653
654    #[tokio::test]
655    async fn test_server_address() {
656        let api_server = create_test_api_server().await;
657        assert_eq!(api_server.address(), "127.0.0.1:8080");
658    }
659
660    #[tokio::test]
661    async fn test_create_api_server_with_address() {
662        let config = AuthConfig::default();
663        let auth_framework = Arc::new(AuthFramework::new(config));
664        let api_server = create_api_server_with_address(auth_framework, "0.0.0.0", 8080).await;
665        assert_eq!(api_server.address(), "0.0.0.0:8080");
666    }
667
668    #[tokio::test]
669    async fn test_admin_endpoints_require_auth() {
670        let api_server = create_test_api_server().await;
671        let router = api_server.build_router().await.unwrap();
672
673        let request = Request::builder()
674            .uri("/api/v1/admin/users")
675            .method("GET")
676            .body(Body::empty())
677            .unwrap();
678
679        let response = router.oneshot(request).await.unwrap();
680        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
681    }
682
683    #[tokio::test]
684    async fn test_security_headers_present() {
685        let api_server = create_test_api_server().await;
686        let router = api_server.build_router().await.unwrap();
687
688        let request = Request::builder()
689            .uri("/api/v1/health")
690            .method("GET")
691            .body(Body::empty())
692            .unwrap();
693
694        let response = router.oneshot(request).await.unwrap();
695        assert_eq!(response.status(), StatusCode::OK);
696
697        let headers = response.headers();
698        assert_eq!(
699            headers
700                .get("x-content-type-options")
701                .map(|v| v.to_str().unwrap()),
702            Some("nosniff")
703        );
704        assert_eq!(
705            headers.get("x-frame-options").map(|v| v.to_str().unwrap()),
706            Some("DENY")
707        );
708    }
709}