auth_framework/api/
server.rs

1//! REST API Server Implementation
2//!
3//! Main server that hosts all API endpoints
4
5use crate::AuthFramework;
6use crate::api::{ApiState, admin, auth, health, mfa, middleware, oauth, users};
7use axum::{
8    Router,
9    extract::DefaultBodyLimit,
10    http::Method,
11    middleware as axum_middleware,
12    routing::{delete, get, post, put},
13};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tower::ServiceBuilder;
17use tower_http::{
18    cors::{Any, CorsLayer},
19    trace::TraceLayer,
20};
21use tracing::info;
22
23/// API Server configuration
24#[derive(Debug, Clone)]
25pub struct ApiServerConfig {
26    pub host: String,
27    pub port: u16,
28    pub enable_cors: bool,
29    pub max_body_size: usize,
30    pub enable_tracing: bool,
31}
32
33impl Default for ApiServerConfig {
34    fn default() -> Self {
35        Self {
36            host: "127.0.0.1".to_string(),
37            port: 8080,
38            enable_cors: true,
39            max_body_size: 1024 * 1024, // 1MB
40            enable_tracing: true,
41        }
42    }
43}
44
45/// REST API Server
46pub struct ApiServer {
47    config: ApiServerConfig,
48    auth_framework: Arc<AuthFramework>,
49}
50
51impl ApiServer {
52    /// Create new API server
53    pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
54        Self {
55            config: ApiServerConfig::default(),
56            auth_framework,
57        }
58    }
59
60    /// Create new API server with custom configuration
61    pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
62        Self {
63            config,
64            auth_framework,
65        }
66    }
67
68    /// Build the router with all routes and middleware
69    pub async fn build_router(&self) -> crate::errors::Result<Router> {
70        let state = ApiState::new(self.auth_framework.clone()).await?;
71
72        // Create the main router with all routes
73        let router = Router::new()
74            // Health and monitoring endpoints (public)
75            .route("/health", get(health::health_check))
76            .route("/health/detailed", get(health::detailed_health_check))
77            .route("/metrics", get(health::metrics))
78            .route("/readiness", get(health::readiness_check))
79            .route("/liveness", get(health::liveness_check))
80            // Authentication endpoints (public)
81            .route("/auth/login", post(auth::login))
82            .route("/auth/refresh", post(auth::refresh_token))
83            .route("/auth/logout", post(auth::logout))
84            .route("/auth/validate", get(auth::validate_token))
85            .route("/auth/providers", get(auth::list_providers))
86            // OAuth 2.0 endpoints (mostly public)
87            .route("/oauth/authorize", get(oauth::authorize))
88            .route("/oauth/token", post(oauth::token))
89            .route("/oauth/revoke", post(oauth::revoke_token))
90            .route("/oauth/introspect", post(oauth::introspect_token))
91            .route("/oauth/clients/:client_id", get(oauth::get_client_info))
92            // User management endpoints (authenticated)
93            .route("/users/profile", get(users::get_profile))
94            .route("/users/profile", put(users::update_profile))
95            .route("/users/change-password", post(users::change_password))
96            .route("/users/sessions", get(users::get_sessions))
97            .route("/users/sessions/:session_id", delete(users::revoke_session))
98            .route("/users/:user_id/profile", get(users::get_user_profile))
99            // Multi-factor authentication endpoints (authenticated)
100            .route("/mfa/setup", post(mfa::setup_mfa))
101            .route("/mfa/verify", post(mfa::verify_mfa))
102            .route("/mfa/disable", post(mfa::disable_mfa))
103            .route("/mfa/status", get(mfa::get_mfa_status))
104            .route(
105                "/mfa/regenerate-backup-codes",
106                post(mfa::regenerate_backup_codes),
107            )
108            .route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
109            // Administrative endpoints (admin only)
110            .route("/admin/users", get(admin::list_users))
111            .route("/admin/users", post(admin::create_user))
112            .route("/admin/users/:user_id/roles", put(admin::update_user_roles))
113            .route("/admin/users/:user_id", delete(admin::delete_user))
114            .route("/admin/users/:user_id/activate", put(admin::activate_user))
115            .route("/admin/stats", get(admin::get_system_stats))
116            .route("/admin/audit-logs", get(admin::get_audit_logs))
117            // Set shared state
118            .with_state(state.clone());
119
120        // Add middleware layers
121        let middleware_stack = ServiceBuilder::new()
122            .layer(axum_middleware::from_fn(middleware::timeout_middleware))
123            .layer(axum_middleware::from_fn(
124                middleware::security_headers_middleware,
125            ))
126            .layer(axum_middleware::from_fn(middleware::rate_limit_middleware))
127            .layer(axum_middleware::from_fn(middleware::logging_middleware));
128
129        let router = if self.config.enable_cors {
130            router.layer(
131                CorsLayer::new()
132                    .allow_origin(Any)
133                    .allow_methods([
134                        Method::GET,
135                        Method::POST,
136                        Method::PUT,
137                        Method::DELETE,
138                        Method::OPTIONS,
139                    ])
140                    .allow_headers(Any)
141                    .max_age(std::time::Duration::from_secs(3600)),
142            )
143        } else {
144            router
145        };
146
147        let router = if self.config.enable_tracing {
148            router.layer(TraceLayer::new_for_http())
149        } else {
150            router
151        };
152
153        Ok(router
154            .layer(middleware_stack)
155            .layer(DefaultBodyLimit::max(self.config.max_body_size)))
156    }
157
158    /// Start the API server
159    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
160        let app = self.build_router().await?;
161
162        let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
163
164        info!("🚀 AuthFramework API server starting on http://{}", addr);
165        info!("📖 API documentation available at http://{}/docs", addr);
166        info!("🏥 Health check available at http://{}/health", addr);
167        info!("📊 Metrics available at http://{}/metrics", addr);
168
169        let listener = tokio::net::TcpListener::bind(addr).await?;
170        axum::serve(listener, app).await?;
171
172        Ok(())
173    }
174
175    /// Get server configuration
176    pub fn config(&self) -> &ApiServerConfig {
177        &self.config
178    }
179
180    /// Get server address
181    pub fn address(&self) -> String {
182        format!("{}:{}", self.config.host, self.config.port)
183    }
184}
185
186/// Create a basic API server with default configuration
187pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
188    ApiServer::new(auth_framework)
189}
190
191/// Create an API server with custom host and port
192pub async fn create_api_server_with_address(
193    auth_framework: Arc<AuthFramework>,
194    host: impl Into<String>,
195    port: u16,
196) -> ApiServer {
197    let config = ApiServerConfig {
198        host: host.into(),
199        port,
200        ..Default::default()
201    };
202    ApiServer::with_config(auth_framework, config)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::storage::memory::InMemoryStorage;
209    use crate::{AuthConfig, AuthFramework};
210    use axum_test::TestServer;
211
212    #[ignore = "TestServer compatibility issue - axum-test version mismatch"]
213    async fn create_test_server() -> TestServer {
214        let _storage = Arc::new(InMemoryStorage::new());
215        let config = AuthConfig::default();
216        let auth_framework = Arc::new(AuthFramework::new(config));
217
218        let api_server = ApiServer::new(auth_framework);
219        let _app = api_server.build_router().await.unwrap();
220
221        // Note: TestServer compatibility issue with current axum-test version
222        // The Router type doesn't properly implement IntoTransportLayer
223        todo!("TestServer::new needs axum-test compatibility fix")
224    }
225
226    #[tokio::test]
227    #[ignore = "TestServer compatibility issue"]
228    async fn test_health_endpoint() {
229        let server = create_test_server().await;
230
231        let response = server.get("/health").await;
232        response.assert_status_ok();
233
234        let body: serde_json::Value = response.json();
235        assert_eq!(body["status"], "healthy");
236    }
237
238    #[tokio::test]
239    #[ignore = "TestServer compatibility issue"]
240    async fn test_auth_required_endpoints() {
241        let server = create_test_server().await;
242
243        // Try to access protected endpoint without token
244        let response = server.get("/users/profile").await;
245        response.assert_status_unauthorized();
246    }
247
248    #[tokio::test]
249    #[ignore = "TestServer compatibility issue"]
250    async fn test_cors_headers() {
251        let server = create_test_server().await;
252
253        let response = server.get("/health").await;
254        response.assert_status_ok();
255
256        // Check CORS headers are present
257        assert!(
258            response
259                .headers()
260                .contains_key("access-control-allow-origin")
261        );
262    }
263}
264
265