use std::sync::Arc;
use std::time::Duration;
use axum::http::{HeaderValue, Method};
use axum::{
Router, middleware,
routing::{get, post},
};
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
use tower_http::{compression::CompressionLayer, cors::CorsLayer, trace::TraceLayer};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use crate::auth::require_api_key;
use crate::config::ServerConfig;
use crate::handlers::{datasets, export, harvest, health, portals, search, stats};
use crate::openapi::ApiDoc;
use crate::state::AppState;
pub fn create_router(state: AppState, config: &ServerConfig) -> Router {
let public_routes = Router::new()
.route("/health", get(health::health_check))
.route("/stats", get(stats::get_stats))
.route("/search", get(search::search))
.route("/portals", get(portals::list_portals))
.route("/portals/:name/stats", get(portals::get_portal_stats))
.route("/harvest/status", get(harvest::get_harvest_status))
.route("/datasets/:id", get(datasets::get_dataset_by_id));
let protected_routes = Router::new()
.route(
"/portals/:name/harvest",
post(portals::trigger_portal_harvest),
)
.route("/harvest", post(harvest::trigger_harvest_all))
.route("/export", get(export::export_datasets))
.layer(middleware::from_fn_with_state(
state.clone(),
require_api_key,
));
let api_routes = public_routes.merge(protected_routes);
let governor_config = Arc::new(
GovernorConfigBuilder::default()
.per_second(config.rate_limit_rps.into())
.burst_size(config.rate_limit_burst)
.finish()
.expect("Invalid rate limit configuration"),
);
let cors_layer = build_cors_layer(&config.cors_origins);
Router::new()
.nest("/api/v1", api_routes)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.layer(CompressionLayer::new())
.layer(TraceLayer::new_for_http())
.layer(cors_layer)
.layer(GovernorLayer {
config: governor_config,
})
.with_state(state)
}
fn build_cors_layer(origins: &str) -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
axum::http::header::ACCEPT,
])
.max_age(Duration::from_secs(3600));
if origins == "*" {
cors.allow_origin(tower_http::cors::Any)
} else {
let allowed: Vec<HeaderValue> = origins
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
cors.allow_origin(allowed)
}
}