1use crate::AuthFramework;
6use crate::api::{
7 ApiState, admin, auth, health, mfa, middleware, oauth, oauth_advanced, saml, security, users,
8 webauthn,
9};
10use axum::{
11 Router,
12 extract::DefaultBodyLimit,
13 http::Method,
14 middleware as axum_middleware,
15 routing::{delete, get, post, put},
16};
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tower::ServiceBuilder;
20use tower_http::{
21 cors::{Any, CorsLayer},
22 trace::TraceLayer,
23};
24use tracing::info;
25
26#[derive(Debug, Clone)]
28pub struct ApiServerConfig {
29 pub host: String,
30 pub port: u16,
31 pub enable_cors: bool,
32 pub max_body_size: usize,
33 pub enable_tracing: bool,
34}
35
36impl Default for ApiServerConfig {
37 fn default() -> Self {
38 Self {
39 host: "127.0.0.1".to_string(),
40 port: 8080,
41 enable_cors: true,
42 max_body_size: 1024 * 1024, enable_tracing: true,
44 }
45 }
46}
47
48pub struct ApiServer {
50 config: ApiServerConfig,
51 auth_framework: Arc<AuthFramework>,
52}
53
54impl ApiServer {
55 pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
57 Self {
58 config: ApiServerConfig::default(),
59 auth_framework,
60 }
61 }
62
63 pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
65 Self {
66 config,
67 auth_framework,
68 }
69 }
70
71 pub async fn build_router(&self) -> crate::errors::Result<Router> {
73 let state = ApiState::new(self.auth_framework.clone()).await?;
74
75 let router = Router::new()
77 .route("/health", get(health::health_check))
79 .route("/health/detailed", get(health::detailed_health_check))
80 .route("/metrics", get(health::metrics))
81 .route("/readiness", get(health::readiness_check))
82 .route("/liveness", get(health::liveness_check))
83 .route("/api/v1/auth/register", post(auth::register))
85 .route("/api/v1/auth/login", post(auth::login))
86 .route("/api/v1/auth/authenticate", post(auth::authenticate))
87 .route("/api/v1/auth/refresh", post(auth::refresh_token))
88 .route("/api/v1/auth/logout", post(auth::logout))
89 .route("/api/v1/auth/validate", get(auth::validate_token))
90 .route("/api/v1/auth/providers", get(auth::list_providers))
91 .route("/api/v1/api-keys", post(auth::create_api_key))
93 .route("/api/v1/api-keys", get(auth::list_api_keys))
94 .route("/api/v1/api-keys/revoke", post(auth::revoke_api_key))
95 .route("/api/v1/oauth/authorize", get(oauth::authorize))
97 .route("/api/v1/oauth/token", post(oauth::token))
98 .route("/api/v1/oauth/revoke", post(oauth::revoke_token))
99 .route(
101 "/api/v1/oauth2/authorize",
102 get(crate::api::oauth2::authorize),
103 )
104 .route("/api/v1/oauth2/token", post(crate::api::oauth2::token))
105 .route("/api/v1/oauth2/revoke", post(crate::api::oauth2::revoke))
106 .route("/api/v1/oauth2/userinfo", get(crate::api::oauth2::userinfo))
107 .route("/api/v1/oauth/token-exchange", post(oauth::token_exchange))
109 .route(
110 "/api/v1/oauth/clients/{client_id}",
111 get(oauth::get_client_info),
112 )
113 .route(
115 "/api/v1/oauth/introspect",
116 post(oauth_advanced::introspect_token),
117 )
118 .route(
119 "/api/v1/oauth/par",
120 post(oauth_advanced::pushed_authorization_request),
121 )
122 .route(
123 "/api/v1/oauth/device_authorization",
124 post(oauth_advanced::device_authorization),
125 )
126 .route(
128 "/.well-known/openid-configuration",
129 get(oauth::oidc_discovery),
130 )
131 .route("/.well-known/jwks.json", get(oauth::jwks))
132 .route("/api/v1/oidc/userinfo", get(oauth::userinfo))
133 .route("/api/v1/users/me", get(users::get_profile)) .route("/api/v1/users/profile", get(users::get_profile))
136 .route("/api/v1/users/profile", put(users::update_profile))
137 .route(
138 "/api/v1/users/change-password",
139 post(users::change_password),
140 )
141 .route("/api/v1/users/sessions", get(users::get_sessions))
142 .route(
143 "/api/v1/users/sessions/{session_id}",
144 delete(users::revoke_session),
145 )
146 .route(
147 "/api/v1/users/{user_id}/profile",
148 get(users::get_user_profile),
149 )
150 .route("/api/v1/mfa/setup", post(mfa::setup_mfa))
152 .route("/api/v1/mfa/verify", post(mfa::verify_mfa))
153 .route("/api/v1/mfa/disable", post(mfa::disable_mfa))
154 .route("/api/v1/mfa/status", get(mfa::get_mfa_status))
155 .route(
156 "/api/v1/mfa/regenerate-backup-codes",
157 post(mfa::regenerate_backup_codes),
158 )
159 .route(
160 "/api/v1/mfa/verify-backup-code",
161 post(mfa::verify_backup_code),
162 )
163 .route("/api/v1/admin/users", get(admin::list_users))
165 .route("/api/v1/admin/users", post(admin::create_user))
166 .route(
167 "/api/v1/admin/users/{user_id}/roles",
168 put(admin::update_user_roles),
169 )
170 .route("/api/v1/admin/users/{user_id}", delete(admin::delete_user))
171 .route(
172 "/api/v1/admin/users/{user_id}/activate",
173 put(admin::activate_user),
174 )
175 .route("/api/v1/admin/stats", get(admin::get_system_stats))
176 .route("/api/v1/admin/audit-logs", get(admin::get_audit_logs))
177 .route(
179 "/api/v1/admin/security/blacklist",
180 post(security::blacklist_ip_endpoint),
181 )
182 .route(
183 "/api/v1/admin/security/unblock",
184 post(security::unblock_ip_endpoint),
185 )
186 .route(
187 "/api/v1/admin/security/stats",
188 get(security::stats_endpoint),
189 )
190 .route(
192 "/api/v1/webauthn/register/init",
193 post(webauthn::webauthn_registration_init),
194 )
195 .route(
196 "/api/v1/webauthn/register/complete",
197 post(webauthn::webauthn_registration_complete),
198 )
199 .route(
200 "/api/v1/webauthn/authenticate/init",
201 post(webauthn::webauthn_authentication_init),
202 )
203 .route(
204 "/api/v1/webauthn/authenticate/complete",
205 post(webauthn::webauthn_authentication_complete),
206 )
207 .route(
208 "/api/v1/webauthn/credentials/{username}",
209 get(webauthn::list_webauthn_credentials),
210 )
211 .route(
212 "/api/v1/webauthn/credentials/{username}/{credential_id}",
213 delete(webauthn::delete_webauthn_credential),
214 )
215 .route("/api/v1/saml/metadata", get(saml::get_saml_metadata))
217 .route("/api/v1/saml/sso/init", post(saml::initiate_saml_sso))
218 .route("/api/v1/saml/acs", post(saml::handle_saml_acs))
219 .route("/api/v1/saml/slo/init", post(saml::initiate_saml_slo))
220 .route(
221 "/api/v1/saml/slo/response",
222 get(saml::handle_saml_slo_response),
223 )
224 .route(
225 "/api/v1/saml/assertion/create",
226 post(saml::create_saml_assertion),
227 )
228 .route("/api/v1/saml/idps", get(saml::list_saml_idps));
229
230 #[cfg(feature = "enhanced-rbac")]
232 let router = {
233 use crate::api::rbac_endpoints;
234 router
235 .route("/api/v1/rbac/roles", post(rbac_endpoints::create_role))
236 .route("/api/v1/rbac/roles", get(rbac_endpoints::list_roles))
237 .route(
238 "/api/v1/rbac/roles/{role_id}",
239 get(rbac_endpoints::get_role),
240 )
241 .route(
242 "/api/v1/rbac/roles/{role_id}",
243 put(rbac_endpoints::update_role),
244 )
245 .route(
246 "/api/v1/rbac/roles/{role_id}",
247 delete(rbac_endpoints::delete_role),
248 )
249 .route(
250 "/api/v1/rbac/users/{user_id}/roles",
251 post(rbac_endpoints::assign_user_role),
252 )
253 .route(
254 "/api/v1/rbac/users/{user_id}/roles/{role_id}",
255 delete(rbac_endpoints::revoke_user_role),
256 )
257 .route(
258 "/api/v1/rbac/users/{user_id}/roles",
259 get(rbac_endpoints::get_user_roles),
260 )
261 .route(
262 "/api/v1/rbac/bulk/assign",
263 post(rbac_endpoints::bulk_assign_roles),
264 )
265 .route(
266 "/api/v1/rbac/check-permission",
267 post(rbac_endpoints::check_permission),
268 )
269 .route("/api/v1/rbac/elevate", post(rbac_endpoints::elevate_role))
270 .route("/api/v1/rbac/audit", get(rbac_endpoints::get_audit_logs))
271 };
272
273 let router = router.with_state(state.clone());
275
276 let middleware_stack = ServiceBuilder::new()
278 .layer(axum_middleware::from_fn(middleware::timeout_middleware))
279 .layer(axum_middleware::from_fn(
280 middleware::security_headers_middleware,
281 ))
282 .layer(axum_middleware::from_fn(middleware::rate_limit_middleware))
283 .layer(axum_middleware::from_fn(middleware::logging_middleware));
284
285 let router = if self.config.enable_cors {
286 router.layer(
287 CorsLayer::new()
288 .allow_origin(Any)
289 .allow_methods([
290 Method::GET,
291 Method::POST,
292 Method::PUT,
293 Method::DELETE,
294 Method::OPTIONS,
295 ])
296 .allow_headers(Any)
297 .max_age(std::time::Duration::from_secs(3600)),
298 )
299 } else {
300 router
301 };
302
303 let router = if self.config.enable_tracing {
304 router.layer(TraceLayer::new_for_http())
305 } else {
306 router
307 };
308
309 Ok(router
310 .layer(middleware_stack)
311 .layer(DefaultBodyLimit::max(self.config.max_body_size)))
312 }
313
314 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
316 let app = self.build_router().await?;
317
318 let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
319
320 info!("🚀 AuthFramework API server starting on http://{}", addr);
321 info!("📖 API documentation available at http://{}/docs", addr);
322 info!("🏥 Health check available at http://{}/health", addr);
323 info!("📊 Metrics available at http://{}/metrics", addr);
324
325 let listener = tokio::net::TcpListener::bind(addr).await?;
326 axum::serve(listener, app).await?;
327
328 Ok(())
329 }
330 pub fn config(&self) -> &ApiServerConfig {
332 &self.config
333 }
334
335 pub fn address(&self) -> String {
337 format!("{}:{}", self.config.host, self.config.port)
338 }
339}
340
341pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
343 ApiServer::new(auth_framework)
344}
345
346pub async fn create_api_server_with_address(
348 auth_framework: Arc<AuthFramework>,
349 host: impl Into<String>,
350 port: u16,
351) -> ApiServer {
352 let config = ApiServerConfig {
353 host: host.into(),
354 port,
355 ..Default::default()
356 };
357 ApiServer::with_config(auth_framework, config)
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::storage::memory::InMemoryStorage;
364 use crate::{AuthConfig, AuthFramework};
365
366 #[tokio::test]
367 async fn test_create_test_server() {
368 unsafe {
370 std::env::set_var("JWT_SECRET", "test-secret-key-at-least-32-characters-long");
371 }
372
373 let _storage = Arc::new(InMemoryStorage::new());
374 let config = AuthConfig::default();
375 let mut auth_framework = AuthFramework::new(config);
376 auth_framework.initialize().await.unwrap();
377 let auth_framework = Arc::new(auth_framework);
378
379 let api_server = ApiServer::new(auth_framework);
380 let app = api_server.build_router().await.unwrap();
381
382 let _server = axum_test::TestServer::new(app).unwrap();
384
385 assert!(true);
387 }
388
389 async fn create_test_server() -> axum_test::TestServer {
390 unsafe {
392 std::env::set_var("JWT_SECRET", "test-secret-key-at-least-32-characters-long");
393 }
394
395 let _storage = Arc::new(InMemoryStorage::new());
396 let config = AuthConfig::default();
397 let mut auth_framework = AuthFramework::new(config);
398 auth_framework.initialize().await.unwrap();
399 let auth_framework = Arc::new(auth_framework);
400
401 let api_server = ApiServer::new(auth_framework);
402 let app = api_server.build_router().await.unwrap();
403
404 axum_test::TestServer::new(app).unwrap()
405 }
406
407 #[tokio::test]
408 async fn test_health_endpoint() {
409 let server = create_test_server().await;
410 let response = server.get("/health").await;
411 response.assert_status_ok();
412
413 let body: serde_json::Value = response.json();
414 assert!(
416 body.get("success")
417 .and_then(|v| v.as_bool())
418 .unwrap_or(false)
419 );
420 }
421
422 #[tokio::test]
423 #[ignore = "TestServer compatibility issue"]
424 async fn test_auth_required_endpoints() {
425 let server = create_test_server().await;
426
427 let response = server.get("/users/profile").await;
429 response.assert_status_unauthorized();
430 }
431
432 #[tokio::test]
433 #[ignore = "TestServer compatibility issue"]
434 async fn test_cors_headers() {
435 let server = create_test_server().await;
436
437 let response = server.get("/health").await;
438 response.assert_status_ok();
439
440 assert!(
442 response
443 .headers()
444 .contains_key("access-control-allow-origin")
445 );
446 }
447}