use std::sync::Arc;
use axum::{
Router,
extract::DefaultBodyLimit,
routing::{delete, get, post},
};
use tower_http::{
cors::{AllowOrigin, Any, CorsLayer},
limit::RequestBodyLimitLayer,
trace::TraceLayer,
};
use crate::{ExtractionConfig, core::ServerConfig};
use super::{
handlers::{
cache_clear_handler, cache_manifest_handler, cache_stats_handler, cache_warm_handler, chunk_handler,
detect_handler, embed_handler, extract_handler, formats_handler, health_handler, info_handler, version_handler,
},
types::{ApiSizeLimits, ApiState},
};
pub fn create_router(config: ExtractionConfig) -> Router {
create_router_with_limits(config, ApiSizeLimits::default())
}
pub fn create_router_with_limits(config: ExtractionConfig, limits: ApiSizeLimits) -> Router {
create_router_with_limits_and_server_config(config, limits, ServerConfig::default())
}
pub fn create_router_with_limits_and_server_config(
config: ExtractionConfig,
limits: ApiSizeLimits,
server_config: ServerConfig,
) -> Router {
let state = ApiState {
default_config: Arc::new(config),
};
let cors_layer = if server_config.cors_allows_all() {
tracing::warn!(
"CORS configured to allow all origins (default). This permits CSRF attacks. \
For production, set KREUZBERG_CORS_ORIGINS environment variable to comma-separated \
list of allowed origins (e.g., 'https://app.example.com,https://api.example.com')"
);
CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any)
} else {
let origins: Vec<_> = server_config
.cors_origins
.iter()
.filter_map(|s| s.trim().parse::<axum::http::HeaderValue>().ok())
.collect();
if !origins.is_empty() {
tracing::info!("CORS configured with {} explicit allowed origin(s)", origins.len());
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods(Any)
.allow_headers(Any)
} else {
tracing::warn!(
"CORS origins configured but empty/invalid - falling back to permissive CORS. \
This allows CSRF attacks. Set explicit origins for production."
);
CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any)
}
};
let mut router = Router::new()
.route("/extract", post(extract_handler))
.route("/detect", post(detect_handler))
.route("/embed", post(embed_handler))
.route("/chunk", post(chunk_handler))
.route("/formats", get(formats_handler))
.route("/health", get(health_handler))
.route("/info", get(info_handler))
.route("/version", get(version_handler))
.route("/cache/stats", get(cache_stats_handler))
.route("/cache/clear", delete(cache_clear_handler))
.route("/cache/manifest", get(cache_manifest_handler))
.route("/cache/warm", post(cache_warm_handler));
#[cfg(feature = "api")]
{
router = router.route("/openapi.json", get(openapi_schema_handler));
}
router
.layer(DefaultBodyLimit::max(limits.max_request_body_bytes))
.layer(RequestBodyLimitLayer::new(limits.max_request_body_bytes))
.layer(cors_layer)
.layer(TraceLayer::new_for_http())
.with_state(state)
}
#[cfg(feature = "api")]
async fn openapi_schema_handler() -> axum::Json<serde_json::Value> {
use crate::api::openapi::openapi_json;
let schema_str = openapi_json();
let schema: serde_json::Value = serde_json::from_str(&schema_str)
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to generate OpenAPI schema"}));
axum::Json(schema)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_router() {
let config = ExtractionConfig::default();
let _router = create_router(config);
}
#[test]
fn test_router_has_routes() {
use std::mem::size_of_val;
let config = ExtractionConfig::default();
let router = create_router(config);
assert!(size_of_val(&router) > 0);
}
#[test]
fn test_create_router_with_limits() {
let config = ExtractionConfig::default();
let limits = ApiSizeLimits::from_mb(50, 50);
let _router = create_router_with_limits(config, limits);
}
#[test]
fn test_create_router_with_server_config() {
let extraction_config = ExtractionConfig::default();
let limits = ApiSizeLimits::from_mb(100, 100);
let server_config = ServerConfig::default();
let _router = create_router_with_limits_and_server_config(extraction_config, limits, server_config);
}
#[test]
fn test_server_config_cors_handling() {
let extraction_config = ExtractionConfig::default();
let limits = ApiSizeLimits::default();
let server_config = ServerConfig {
cors_origins: vec!["https://example.com".to_string()],
..Default::default()
};
let _router = create_router_with_limits_and_server_config(extraction_config, limits, server_config);
}
}