use crate::AuthFramework;
use crate::api::{ApiState, admin, auth, health, mfa, middleware, oauth, users};
use axum::{
Router,
extract::DefaultBodyLimit,
http::Method,
middleware as axum_middleware,
routing::{delete, get, post, put},
};
use std::net::SocketAddr;
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing::info;
#[derive(Debug, Clone)]
pub struct ApiServerConfig {
pub host: String,
pub port: u16,
pub enable_cors: bool,
pub max_body_size: usize,
pub enable_tracing: bool,
}
impl Default for ApiServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
enable_cors: true,
max_body_size: 1024 * 1024, enable_tracing: true,
}
}
}
pub struct ApiServer {
config: ApiServerConfig,
auth_framework: Arc<AuthFramework>,
}
impl ApiServer {
pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
Self {
config: ApiServerConfig::default(),
auth_framework,
}
}
pub fn with_config(auth_framework: Arc<AuthFramework>, config: ApiServerConfig) -> Self {
Self {
config,
auth_framework,
}
}
pub async fn build_router(&self) -> crate::errors::Result<Router> {
let state = ApiState::new(self.auth_framework.clone()).await?;
let router = Router::new()
.route("/health", get(health::health_check))
.route("/health/detailed", get(health::detailed_health_check))
.route("/metrics", get(health::metrics))
.route("/readiness", get(health::readiness_check))
.route("/liveness", get(health::liveness_check))
.route("/auth/login", post(auth::login))
.route("/auth/refresh", post(auth::refresh_token))
.route("/auth/logout", post(auth::logout))
.route("/auth/validate", get(auth::validate_token))
.route("/auth/providers", get(auth::list_providers))
.route("/oauth/authorize", get(oauth::authorize))
.route("/oauth/token", post(oauth::token))
.route("/oauth/revoke", post(oauth::revoke_token))
.route("/oauth/introspect", post(oauth::introspect_token))
.route("/oauth/clients/:client_id", get(oauth::get_client_info))
.route("/users/profile", get(users::get_profile))
.route("/users/profile", put(users::update_profile))
.route("/users/change-password", post(users::change_password))
.route("/users/sessions", get(users::get_sessions))
.route("/users/sessions/:session_id", delete(users::revoke_session))
.route("/users/:user_id/profile", get(users::get_user_profile))
.route("/mfa/setup", post(mfa::setup_mfa))
.route("/mfa/verify", post(mfa::verify_mfa))
.route("/mfa/disable", post(mfa::disable_mfa))
.route("/mfa/status", get(mfa::get_mfa_status))
.route(
"/mfa/regenerate-backup-codes",
post(mfa::regenerate_backup_codes),
)
.route("/mfa/verify-backup-code", post(mfa::verify_backup_code))
.route("/admin/users", get(admin::list_users))
.route("/admin/users", post(admin::create_user))
.route("/admin/users/:user_id/roles", put(admin::update_user_roles))
.route("/admin/users/:user_id", delete(admin::delete_user))
.route("/admin/users/:user_id/activate", put(admin::activate_user))
.route("/admin/stats", get(admin::get_system_stats))
.route("/admin/audit-logs", get(admin::get_audit_logs))
.with_state(state.clone());
let middleware_stack = ServiceBuilder::new()
.layer(axum_middleware::from_fn(middleware::timeout_middleware))
.layer(axum_middleware::from_fn(
middleware::security_headers_middleware,
))
.layer(axum_middleware::from_fn(middleware::rate_limit_middleware))
.layer(axum_middleware::from_fn(middleware::logging_middleware));
let router = if self.config.enable_cors {
router.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers(Any)
.max_age(std::time::Duration::from_secs(3600)),
)
} else {
router
};
let router = if self.config.enable_tracing {
router.layer(TraceLayer::new_for_http())
} else {
router
};
Ok(router
.layer(middleware_stack)
.layer(DefaultBodyLimit::max(self.config.max_body_size)))
}
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
let app = self.build_router().await?;
let addr = SocketAddr::new(self.config.host.parse()?, self.config.port);
info!("🚀 AuthFramework API server starting on http://{}", addr);
info!("📖 API documentation available at http://{}/docs", addr);
info!("🏥 Health check available at http://{}/health", addr);
info!("📊 Metrics available at http://{}/metrics", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
pub fn config(&self) -> &ApiServerConfig {
&self.config
}
pub fn address(&self) -> String {
format!("{}:{}", self.config.host, self.config.port)
}
}
pub async fn create_api_server(auth_framework: Arc<AuthFramework>) -> ApiServer {
ApiServer::new(auth_framework)
}
pub async fn create_api_server_with_address(
auth_framework: Arc<AuthFramework>,
host: impl Into<String>,
port: u16,
) -> ApiServer {
let config = ApiServerConfig {
host: host.into(),
port,
..Default::default()
};
ApiServer::with_config(auth_framework, config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::InMemoryStorage;
use crate::{AuthConfig, AuthFramework};
use axum_test::TestServer;
#[ignore = "TestServer compatibility issue - axum-test version mismatch"]
async fn create_test_server() -> TestServer {
let _storage = Arc::new(InMemoryStorage::new());
let config = AuthConfig::default();
let auth_framework = Arc::new(AuthFramework::new(config));
let api_server = ApiServer::new(auth_framework);
let _app = api_server.build_router().await.unwrap();
todo!("TestServer::new needs axum-test compatibility fix")
}
#[tokio::test]
#[ignore = "TestServer compatibility issue"]
async fn test_health_endpoint() {
let server = create_test_server().await;
let response = server.get("/health").await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert_eq!(body["status"], "healthy");
}
#[tokio::test]
#[ignore = "TestServer compatibility issue"]
async fn test_auth_required_endpoints() {
let server = create_test_server().await;
let response = server.get("/users/profile").await;
response.assert_status_unauthorized();
}
#[tokio::test]
#[ignore = "TestServer compatibility issue"]
async fn test_cors_headers() {
let server = create_test_server().await;
let response = server.get("/health").await;
response.assert_status_ok();
assert!(
response
.headers()
.contains_key("access-control-allow-origin")
);
}
}