1use crate::AuthFramework;
6use crate::api::{ApiState, admin, auth, health, mfa, middleware, oauth, users};
7use axum::{
8 Router,
9 extract::DefaultBodyLimit,
10 http::Method,
11 middleware as axum_middleware,
12 routing::{delete, get, post, put},
13};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tower::ServiceBuilder;
17use tower_http::{
18 cors::{Any, CorsLayer},
19 trace::TraceLayer,
20};
21use tracing::info;
22
23#[derive(Debug, Clone)]
25pub struct ApiServerConfig {
26 pub host: String,
27 pub port: u16,
28 pub enable_cors: bool,
29 pub max_body_size: usize,
30 pub enable_tracing: bool,
31}
32
33impl Default for ApiServerConfig {
34 fn default() -> Self {
35 Self {
36 host: "127.0.0.1".to_string(),
37 port: 8080,
38 enable_cors: true,
39 max_body_size: 1024 * 1024, enable_tracing: true,
41 }
42 }
43}
44
45pub struct ApiServer {
47 config: ApiServerConfig,
48 auth_framework: Arc<AuthFramework>,
49}
50
51impl ApiServer {
52 pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
54 Self {
55 config: ApiServerConfig::default(),
56 auth_framework,
57 }
58 }
59
60 pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
62 Self {
63 config,
64 auth_framework,
65 }
66 }
67
68 pub async fn build_router(&self) -> crate::errors::Result<Router> {
70 let state = ApiState::new(self.auth_framework.clone()).await?;
71
72 let router = Router::new()
74 .route("/health", get(health::health_check))
76 .route("/health/detailed", get(health::detailed_health_check))
77 .route("/metrics", get(health::metrics))
78 .route("/readiness", get(health::readiness_check))
79 .route("/liveness", get(health::liveness_check))
80 .route("/auth/login", post(auth::login))
82 .route("/auth/refresh", post(auth::refresh_token))
83 .route("/auth/logout", post(auth::logout))
84 .route("/auth/validate", get(auth::validate_token))
85 .route("/auth/providers", get(auth::list_providers))
86 .route("/oauth/authorize", get(oauth::authorize))
88 .route("/oauth/token", post(oauth::token))
89 .route("/oauth/revoke", post(oauth::revoke_token))
90 .route("/oauth/introspect", post(oauth::introspect_token))
91 .route("/oauth/clients/:client_id", get(oauth::get_client_info))
92 .route("/users/profile", get(users::get_profile))
94 .route("/users/profile", put(users::update_profile))
95 .route("/users/change-password", post(users::change_password))
96 .route("/users/sessions", get(users::get_sessions))
97 .route("/users/sessions/:session_id", delete(users::revoke_session))
98 .route("/users/:user_id/profile", get(users::get_user_profile))
99 .route("/mfa/setup", post(mfa::setup_mfa))
101 .route("/mfa/verify", post(mfa::verify_mfa))
102 .route("/mfa/disable", post(mfa::disable_mfa))
103 .route("/mfa/status", get(mfa::get_mfa_status))
104 .route(
105 "/mfa/regenerate-backup-codes",
106 post(mfa::regenerate_backup_codes),
107 )
108 .route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
109 .route("/admin/users", get(admin::list_users))
111 .route("/admin/users", post(admin::create_user))
112 .route("/admin/users/:user_id/roles", put(admin::update_user_roles))
113 .route("/admin/users/:user_id", delete(admin::delete_user))
114 .route("/admin/users/:user_id/activate", put(admin::activate_user))
115 .route("/admin/stats", get(admin::get_system_stats))
116 .route("/admin/audit-logs", get(admin::get_audit_logs))
117 .with_state(state.clone());
119
120 let middleware_stack = ServiceBuilder::new()
122 .layer(axum_middleware::from_fn(middleware::timeout_middleware))
123 .layer(axum_middleware::from_fn(
124 middleware::security_headers_middleware,
125 ))
126 .layer(axum_middleware::from_fn(middleware::rate_limit_middleware))
127 .layer(axum_middleware::from_fn(middleware::logging_middleware));
128
129 let router = if self.config.enable_cors {
130 router.layer(
131 CorsLayer::new()
132 .allow_origin(Any)
133 .allow_methods([
134 Method::GET,
135 Method::POST,
136 Method::PUT,
137 Method::DELETE,
138 Method::OPTIONS,
139 ])
140 .allow_headers(Any)
141 .max_age(std::time::Duration::from_secs(3600)),
142 )
143 } else {
144 router
145 };
146
147 let router = if self.config.enable_tracing {
148 router.layer(TraceLayer::new_for_http())
149 } else {
150 router
151 };
152
153 Ok(router
154 .layer(middleware_stack)
155 .layer(DefaultBodyLimit::max(self.config.max_body_size)))
156 }
157
158 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
160 let app = self.build_router().await?;
161
162 let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
163
164 info!("🚀 AuthFramework API server starting on http://{}", addr);
165 info!("📖 API documentation available at http://{}/docs", addr);
166 info!("🏥 Health check available at http://{}/health", addr);
167 info!("📊 Metrics available at http://{}/metrics", addr);
168
169 let listener = tokio::net::TcpListener::bind(addr).await?;
170 axum::serve(listener, app).await?;
171
172 Ok(())
173 }
174
175 pub fn config(&self) -> &ApiServerConfig {
177 &self.config
178 }
179
180 pub fn address(&self) -> String {
182 format!("{}:{}", self.config.host, self.config.port)
183 }
184}
185
186pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
188 ApiServer::new(auth_framework)
189}
190
191pub async fn create_api_server_with_address(
193 auth_framework: Arc<AuthFramework>,
194 host: impl Into<String>,
195 port: u16,
196) -> ApiServer {
197 let config = ApiServerConfig {
198 host: host.into(),
199 port,
200 ..Default::default()
201 };
202 ApiServer::with_config(auth_framework, config)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::storage::memory::InMemoryStorage;
209 use crate::{AuthConfig, AuthFramework};
210 use axum_test::TestServer;
211
212 #[ignore = "TestServer compatibility issue - axum-test version mismatch"]
213 async fn create_test_server() -> TestServer {
214 let _storage = Arc::new(InMemoryStorage::new());
215 let config = AuthConfig::default();
216 let auth_framework = Arc::new(AuthFramework::new(config));
217
218 let api_server = ApiServer::new(auth_framework);
219 let _app = api_server.build_router().await.unwrap();
220
221 todo!("TestServer::new needs axum-test compatibility fix")
224 }
225
226 #[tokio::test]
227 #[ignore = "TestServer compatibility issue"]
228 async fn test_health_endpoint() {
229 let server = create_test_server().await;
230
231 let response = server.get("/health").await;
232 response.assert_status_ok();
233
234 let body: serde_json::Value = response.json();
235 assert_eq!(body["status"], "healthy");
236 }
237
238 #[tokio::test]
239 #[ignore = "TestServer compatibility issue"]
240 async fn test_auth_required_endpoints() {
241 let server = create_test_server().await;
242
243 let response = server.get("/users/profile").await;
245 response.assert_status_unauthorized();
246 }
247
248 #[tokio::test]
249 #[ignore = "TestServer compatibility issue"]
250 async fn test_cors_headers() {
251 let server = create_test_server().await;
252
253 let response = server.get("/health").await;
254 response.assert_status_ok();
255
256 assert!(
258 response
259 .headers()
260 .contains_key("access-control-allow-origin")
261 );
262 }
263}
264
265