1use crate::AuthFramework;
6#[cfg(feature = "saml")]
7use crate::api::saml;
8use crate::api::{
9 ApiState, admin, advanced_protocols, auth, email_verification, health, mfa, middleware,
10 oauth_advanced, oauth2, openapi, users, webauthn,
11};
12use axum::{
13 Router,
14 extract::DefaultBodyLimit,
15 http::Method,
16 middleware as axum_middleware,
17 routing::{delete, get, post, put},
18};
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tower::ServiceBuilder;
22use tower_http::{cors::CorsLayer, trace::TraceLayer};
23use tracing::info;
24
25#[derive(Debug, Clone)]
41pub struct ApiServerConfig {
42 pub host: String,
44 pub port: u16,
46 pub cors: crate::config::CorsConfig,
50 pub max_body_size: usize,
52 pub enable_tracing: bool,
54}
55
56impl ApiServerConfig {
57 pub fn enable_cors(&self) -> bool {
59 self.cors.enabled
60 }
61}
62
63impl Default for ApiServerConfig {
64 fn default() -> Self {
65 Self {
66 host: "127.0.0.1".to_string(),
67 port: 8080,
68 cors: crate::config::CorsConfig::default(), max_body_size: 1024 * 1024, enable_tracing: true,
71 }
72 }
73}
74
75impl ApiServerConfig {
78 pub fn builder() -> ApiServerConfigBuilder {
80 ApiServerConfigBuilder::default()
81 }
82}
83
84pub struct ApiServerConfigBuilder {
89 config: ApiServerConfig,
90}
91
92impl Default for ApiServerConfigBuilder {
93 fn default() -> Self {
94 Self {
95 config: ApiServerConfig::default(),
96 }
97 }
98}
99
100impl ApiServerConfigBuilder {
101 pub fn host(mut self, host: impl Into<String>) -> Self {
103 self.config.host = host.into();
104 self
105 }
106
107 pub fn port(mut self, port: u16) -> Self {
109 self.config.port = port;
110 self
111 }
112
113 pub fn enable_cors(mut self, enable: bool) -> Self {
115 self.config.cors.enabled = enable;
116 self
117 }
118
119 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
121 self.config.cors.allowed_origins.push(origin.into());
122 self
123 }
124
125 pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
127 self.config.cors.allowed_origins = origins;
128 self
129 }
130
131 pub fn max_body_size(mut self, size: usize) -> Self {
133 self.config.max_body_size = size;
134 self
135 }
136
137 pub fn enable_tracing(mut self, enable: bool) -> Self {
139 self.config.enable_tracing = enable;
140 self
141 }
142
143 pub fn build(self) -> ApiServerConfig {
145 self.config
146 }
147}
148
149pub struct ApiServer {
159 config: ApiServerConfig,
160 auth_framework: Arc<AuthFramework>,
161}
162
163impl ApiServer {
164 pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
166 Self {
167 config: ApiServerConfig::default(),
168 auth_framework,
169 }
170 }
171
172 pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
174 Self {
175 config,
176 auth_framework,
177 }
178 }
179
180 pub async fn build_router(&self) -> crate::errors::Result<Router> {
182 let state = ApiState::new(self.auth_framework.clone()).await?;
183
184 let api_v1 = Router::new()
186 .route("/health", get(health::health_check))
188 .route("/health/detailed", get(health::detailed_health_check))
189 .route("/metrics", get(health::metrics))
190 .route("/readiness", get(health::readiness_check))
191 .route("/liveness", get(health::liveness_check))
192 .route("/auth/login", post(auth::login))
194 .route("/auth/register", post(auth::register))
195 .route("/auth/refresh", post(auth::refresh_token))
196 .route("/auth/logout", post(auth::logout))
197 .route("/auth/validate", get(auth::validate_token))
198 .route("/auth/providers", get(auth::list_providers))
199 .route("/api-keys", post(auth::create_api_key))
200 .route(
202 "/auth/verify-email/send",
203 post(email_verification::send_verification),
204 )
205 .route("/auth/verify-email", post(email_verification::verify_email))
206 .route(
207 "/auth/resend-verification",
208 post(email_verification::resend_verification),
209 )
210 .route("/oauth/authorize", get(oauth2::authorize))
212 .route("/oauth/token", post(oauth2::token))
213 .route("/oauth/revoke", post(oauth2::revoke))
214 .route("/oauth/introspect", post(oauth_advanced::introspect_token))
216 .route(
218 "/oauth/par",
219 post(oauth_advanced::pushed_authorization_request),
220 )
221 .route("/oauth/clients/{client_id}", get(oauth2::get_client_info))
222 .route("/oauth/device", post(oauth_advanced::device_authorization))
224 .route("/oauth/ciba", post(oauth_advanced::ciba_backchannel_auth))
226 .route("/oauth/userinfo", get(oauth2::userinfo))
228 .route("/oauth/end_session", get(oauth2::end_session))
230 .route("/oauth/register", post(oauth2::register_client))
232 .route(
234 "/.well-known/openid-configuration",
235 get(oauth2::openid_configuration),
236 )
237 .route("/.well-known/jwks.json", get(oauth2::jwks))
239 .route("/users/me", get(oauth2::users_me))
241 .route("/users/profile", get(users::get_profile))
242 .route("/users/profile", put(users::update_profile))
243 .route("/users/change-password", post(users::change_password))
244 .route("/users/sessions", get(users::get_sessions))
245 .route(
246 "/users/sessions/{session_id}",
247 delete(users::revoke_session),
248 )
249 .route("/users/{user_id}/profile", get(users::get_user_profile))
250 .route("/mfa/setup", post(mfa::setup_mfa))
252 .route("/mfa/verify", post(mfa::verify_mfa))
253 .route("/mfa/disable", post(mfa::disable_mfa))
254 .route("/mfa/status", get(mfa::get_mfa_status))
255 .route(
256 "/mfa/regenerate-backup-codes",
257 post(mfa::regenerate_backup_codes),
258 )
259 .route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
260 .route("/admin/users", get(admin::list_users))
262 .route("/admin/users", post(admin::create_user))
263 .route(
264 "/admin/users/{user_id}/roles",
265 put(admin::update_user_roles),
266 )
267 .route("/admin/users/{user_id}", delete(admin::delete_user))
268 .route("/admin/users/{user_id}/activate", put(admin::activate_user))
269 .route("/admin/stats", get(admin::get_system_stats))
270 .route("/admin/audit-logs", get(admin::get_audit_logs))
271 .route("/admin/audit-logs/stats", get(admin::get_audit_log_stats))
272 .route(
273 "/admin/config",
274 get(admin::get_config).put(admin::update_config),
275 )
276 .route(
278 "/webauthn/registration/init",
279 post(webauthn::webauthn_registration_init),
280 )
281 .route(
282 "/webauthn/registration/complete",
283 post(webauthn::webauthn_registration_complete),
284 )
285 .route(
286 "/webauthn/authentication/init",
287 post(webauthn::webauthn_authentication_init),
288 )
289 .route(
290 "/webauthn/authentication/complete",
291 post(webauthn::webauthn_authentication_complete),
292 )
293 .route(
294 "/webauthn/credentials/{username}",
295 get(webauthn::list_webauthn_credentials),
296 )
297 .route(
298 "/webauthn/credentials/{username}/{credential_id}",
299 delete(webauthn::delete_webauthn_credential),
300 );
301
302 let api_v1 = {
304 let router = api_v1;
305
306 #[cfg(feature = "saml")]
307 {
308 router
309 .route("/saml/metadata", get(saml::get_saml_metadata))
310 .route("/saml/sso", post(saml::initiate_saml_sso))
311 .route("/saml/acs", post(saml::handle_saml_acs))
312 .route("/saml/slo", post(saml::initiate_saml_slo))
313 .route("/saml/slo/response", get(saml::handle_saml_slo_response))
314 .route("/saml/assertion", post(saml::create_saml_assertion))
315 .route("/saml/idps", get(saml::list_saml_idps))
316 }
317
318 #[cfg(not(feature = "saml"))]
319 {
320 router
321 }
322 };
323
324 let router = Router::new()
326 .route("/api/openapi.json", get(openapi::serve_openapi_json))
327 .route("/docs", get(openapi::serve_swagger_ui))
328 .nest("/api/v1", api_v1)
329 .merge(advanced_protocols::router())
330 .with_state(state.clone());
331
332 let middleware_stack = ServiceBuilder::new()
334 .layer(axum_middleware::from_fn(middleware::timeout_middleware))
335 .layer(axum_middleware::from_fn(
336 middleware::security_headers_middleware,
337 ))
338 .layer(axum_middleware::from_fn({
339 let state = state.clone();
340 move |request, next| {
341 let state = state.clone();
342 async move {
343 middleware::rate_limit_middleware_with_state(state, request, next).await
344 }
345 }
346 }))
347 .layer(axum_middleware::from_fn(middleware::logging_middleware));
348
349 let router = if self.config.cors.enabled {
350 if self.config.cors.allowed_origins.is_empty() {
351 tracing::warn!(
352 "SECURITY/CORS: CORS is enabled but allowed_origins is empty. All cross-origin requests will be rejected! Disable CORS or add allowed origins."
353 );
354 }
355
356 let header_origins: Vec<axum::http::HeaderValue> = self
357 .config
358 .cors
359 .allowed_origins
360 .iter()
361 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
362 .collect();
363
364 if header_origins.is_empty() && !self.config.cors.allowed_origins.is_empty() {
365 tracing::warn!(
366 "CORS: none of the configured allowed_origins could be parsed as valid HTTP \
367 header values; cross-origin requests will be rejected"
368 );
369 }
370
371 let allow_origin = tower_http::cors::AllowOrigin::list(header_origins);
372
373 let allowed_methods: Vec<Method> = self
374 .config
375 .cors
376 .allowed_methods
377 .iter()
378 .filter_map(|m| m.parse::<Method>().ok())
379 .collect();
380
381 let allowed_headers: Vec<axum::http::HeaderName> = self
382 .config
383 .cors
384 .allowed_headers
385 .iter()
386 .filter_map(|h| h.parse::<axum::http::HeaderName>().ok())
387 .collect();
388
389 router.layer(
390 CorsLayer::new()
391 .allow_origin(allow_origin)
392 .allow_methods(allowed_methods)
393 .allow_headers(allowed_headers)
394 .max_age(std::time::Duration::from_secs(
395 self.config.cors.max_age_secs as u64,
396 )),
397 )
398 } else {
399 router
400 };
401
402 let router = if self.config.enable_tracing {
403 router.layer(TraceLayer::new_for_http())
404 } else {
405 router
406 };
407
408 Ok(router
409 .layer(middleware_stack)
410 .layer(DefaultBodyLimit::max(self.config.max_body_size)))
411 }
412
413 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
415 let app = self.build_router().await?;
416
417 let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
418
419 info!("🚀 AuthFramework API server starting on http://{}", addr);
420 info!("📖 API documentation available at http://{}/docs", addr);
421 info!("📘 OpenAPI JSON available at http://{}/api/openapi.json", addr);
422 info!("🏥 Health check available at http://{}/health", addr);
423 info!("📊 Metrics available at http://{}/metrics", addr);
424
425 let listener = tokio::net::TcpListener::bind(addr).await?;
426 axum::serve(listener, app).await?;
427
428 Ok(())
429 }
430
431 pub fn config(&self) -> &ApiServerConfig {
433 &self.config
434 }
435
436 pub fn address(&self) -> String {
438 format!("{}:{}", self.config.host, self.config.port)
439 }
440}
441
442pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
444 ApiServer::new(auth_framework)
445}
446
447pub async fn create_api_server_with_address(
449 auth_framework: Arc<AuthFramework>,
450 host: impl Into<String>,
451 port: u16,
452) -> ApiServer {
453 let config = ApiServerConfig {
454 host: host.into(),
455 port,
456 ..Default::default()
457 };
458 ApiServer::with_config(auth_framework, config)
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use crate::storage::memory::InMemoryStorage;
465 use crate::{AuthConfig, AuthFramework};
466 use axum::body::Body;
467 use axum::http::{Request, StatusCode};
468 use tower::ServiceExt;
469
470 async fn create_test_api_server() -> ApiServer {
471 let _storage = Arc::new(InMemoryStorage::new());
472 let config = AuthConfig::default();
473 let auth_framework = Arc::new(AuthFramework::new(config));
474 ApiServer::new(auth_framework)
475 }
476
477 #[tokio::test]
478 async fn test_health_endpoint() {
479 let api_server = create_test_api_server().await;
480 let router = api_server.build_router().await.unwrap();
481
482 let request = Request::builder()
483 .uri("/api/v1/health")
484 .method("GET")
485 .body(Body::empty())
486 .unwrap();
487
488 let response = router.oneshot(request).await.unwrap();
489 assert_eq!(response.status(), StatusCode::OK);
490 }
491
492 #[tokio::test]
493 async fn test_auth_required_endpoints() {
494 let api_server = create_test_api_server().await;
495 let router = api_server.build_router().await.unwrap();
496
497 let request = Request::builder()
498 .uri("/api/v1/users/profile")
499 .method("GET")
500 .body(Body::empty())
501 .unwrap();
502
503 let response = router.oneshot(request).await.unwrap();
504 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
506 }
507
508 #[tokio::test]
509 async fn test_cors_headers() {
510 let config = AuthConfig::default();
511 let auth_framework = Arc::new(AuthFramework::new(config));
512 let api_config = ApiServerConfig {
513 cors: crate::config::CorsConfig {
514 enabled: true,
515 allowed_origins: vec!["http://localhost:3000".to_string()],
516 ..crate::config::CorsConfig::default()
517 },
518 ..ApiServerConfig::default()
519 };
520 let api_server = ApiServer::with_config(auth_framework, api_config);
521 let router = api_server.build_router().await.unwrap();
522
523 let request = Request::builder()
524 .uri("/api/v1/health")
525 .method("GET")
526 .header("Origin", "http://localhost:3000")
527 .body(Body::empty())
528 .unwrap();
529
530 let response = router.oneshot(request).await.unwrap();
531 assert_eq!(response.status(), StatusCode::OK);
532
533 assert!(
535 response
536 .headers()
537 .contains_key("access-control-allow-origin")
538 );
539 }
540
541 #[tokio::test]
542 async fn test_readiness_endpoint() {
543 let api_server = create_test_api_server().await;
544 let router = api_server.build_router().await.unwrap();
545
546 let request = Request::builder()
547 .uri("/api/v1/readiness")
548 .method("GET")
549 .body(Body::empty())
550 .unwrap();
551
552 let response = router.oneshot(request).await.unwrap();
553 assert!(
555 response.status() == StatusCode::OK
556 || response.status() == StatusCode::SERVICE_UNAVAILABLE
557 );
558 }
559
560 #[tokio::test]
561 async fn test_liveness_endpoint() {
562 let api_server = create_test_api_server().await;
563 let router = api_server.build_router().await.unwrap();
564
565 let request = Request::builder()
566 .uri("/api/v1/liveness")
567 .method("GET")
568 .body(Body::empty())
569 .unwrap();
570
571 let response = router.oneshot(request).await.unwrap();
572 assert_eq!(response.status(), StatusCode::OK);
573 }
574
575 #[tokio::test]
576 async fn test_metrics_endpoint() {
577 let api_server = create_test_api_server().await;
578 let router = api_server.build_router().await.unwrap();
579
580 let request = Request::builder()
581 .uri("/api/v1/metrics")
582 .method("GET")
583 .body(Body::empty())
584 .unwrap();
585
586 let response = router.oneshot(request).await.unwrap();
587 assert_eq!(response.status(), StatusCode::OK);
588 }
589
590 #[tokio::test]
591 async fn test_nonexistent_route_returns_404() {
592 let api_server = create_test_api_server().await;
593 let router = api_server.build_router().await.unwrap();
594
595 let request = Request::builder()
596 .uri("/api/v1/this-does-not-exist")
597 .method("GET")
598 .body(Body::empty())
599 .unwrap();
600
601 let response = router.oneshot(request).await.unwrap();
602 assert_eq!(response.status(), StatusCode::NOT_FOUND);
603 }
604
605 #[tokio::test]
606 async fn test_login_with_empty_body() {
607 let api_server = create_test_api_server().await;
608 let router = api_server.build_router().await.unwrap();
609
610 let request = Request::builder()
611 .uri("/api/v1/auth/login")
612 .method("POST")
613 .header("Content-Type", "application/json")
614 .body(Body::from("{}"))
615 .unwrap();
616
617 let response = router.oneshot(request).await.unwrap();
618 assert_ne!(response.status(), StatusCode::OK);
620 }
621
622 #[tokio::test]
623 async fn test_register_endpoint_accessible() {
624 let api_server = create_test_api_server().await;
625 let router = api_server.build_router().await.unwrap();
626
627 let body = serde_json::json!({
628 "username": "newuser",
629 "password": "StrongP@ssw0rd123!",
630 "email": "test@example.com"
631 });
632
633 let request = Request::builder()
634 .uri("/api/v1/auth/register")
635 .method("POST")
636 .header("Content-Type", "application/json")
637 .body(Body::from(serde_json::to_string(&body).unwrap()))
638 .unwrap();
639
640 let response = router.oneshot(request).await.unwrap();
641 assert_ne!(response.status(), StatusCode::NOT_FOUND);
643 assert_ne!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
644 }
645
646 #[tokio::test]
647 async fn test_server_config_defaults() {
648 let config = ApiServerConfig::default();
649 assert_eq!(config.host, "127.0.0.1");
650 assert_eq!(config.port, 8080);
651 assert!(!config.enable_cors());
652 }
653
654 #[tokio::test]
655 async fn test_server_address() {
656 let api_server = create_test_api_server().await;
657 assert_eq!(api_server.address(), "127.0.0.1:8080");
658 }
659
660 #[tokio::test]
661 async fn test_create_api_server_with_address() {
662 let config = AuthConfig::default();
663 let auth_framework = Arc::new(AuthFramework::new(config));
664 let api_server = create_api_server_with_address(auth_framework, "0.0.0.0", 8080).await;
665 assert_eq!(api_server.address(), "0.0.0.0:8080");
666 }
667
668 #[tokio::test]
669 async fn test_admin_endpoints_require_auth() {
670 let api_server = create_test_api_server().await;
671 let router = api_server.build_router().await.unwrap();
672
673 let request = Request::builder()
674 .uri("/api/v1/admin/users")
675 .method("GET")
676 .body(Body::empty())
677 .unwrap();
678
679 let response = router.oneshot(request).await.unwrap();
680 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
681 }
682
683 #[tokio::test]
684 async fn test_security_headers_present() {
685 let api_server = create_test_api_server().await;
686 let router = api_server.build_router().await.unwrap();
687
688 let request = Request::builder()
689 .uri("/api/v1/health")
690 .method("GET")
691 .body(Body::empty())
692 .unwrap();
693
694 let response = router.oneshot(request).await.unwrap();
695 assert_eq!(response.status(), StatusCode::OK);
696
697 let headers = response.headers();
698 assert_eq!(
699 headers
700 .get("x-content-type-options")
701 .map(|v| v.to_str().unwrap()),
702 Some("nosniff")
703 );
704 assert_eq!(
705 headers.get("x-frame-options").map(|v| v.to_str().unwrap()),
706 Some("DENY")
707 );
708 }
709}