use crate::{Config, Result, UsenetDownloader};
use axum::{
Router,
http::HeaderValue,
middleware,
routing::{delete, get, patch, post, put},
};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
pub mod auth;
pub mod error_response;
pub mod openapi;
pub mod rate_limit;
pub mod routes;
pub mod state;
pub use openapi::ApiDoc;
pub use state::AppState;
pub fn create_router(downloader: Arc<UsenetDownloader>, config: Arc<Config>) -> Router {
let state = AppState::new(downloader, config.clone());
let router = Router::new()
.route("/downloads", get(routes::list_downloads))
.route("/downloads", post(routes::add_download))
.route("/downloads/:id", get(routes::get_download))
.route("/downloads/:id", delete(routes::delete_download))
.route("/downloads/:id/pause", post(routes::pause_download))
.route("/downloads/:id/resume", post(routes::resume_download))
.route(
"/downloads/:id/priority",
patch(routes::set_download_priority),
)
.route("/downloads/:id/reprocess", post(routes::reprocess_download))
.route("/downloads/:id/reextract", post(routes::reextract_download))
.route("/downloads/url", post(routes::add_download_url))
.route("/queue/pause", post(routes::pause_queue))
.route("/queue/resume", post(routes::resume_queue))
.route("/queue/stats", get(routes::queue_stats))
.route("/history", get(routes::get_history))
.route("/history", delete(routes::clear_history))
.route("/servers/test", post(routes::test_server))
.route("/servers/test", get(routes::test_all_servers))
.route("/config", get(routes::get_config))
.route("/config", patch(routes::update_config))
.route("/config/speed-limit", get(routes::get_speed_limit))
.route("/config/speed-limit", put(routes::set_speed_limit))
.route("/categories", get(routes::list_categories))
.route("/categories/:name", put(routes::create_or_update_category))
.route("/categories/:name", delete(routes::delete_category))
.route("/capabilities", get(routes::get_capabilities))
.route("/health", get(routes::health_check))
.route("/openapi.json", get(routes::openapi_spec))
.route("/events", get(routes::event_stream))
.route("/shutdown", post(routes::shutdown))
.route("/rss", get(routes::list_rss_feeds))
.route("/rss", post(routes::add_rss_feed))
.route("/rss/:id", put(routes::update_rss_feed))
.route("/rss/:id", delete(routes::delete_rss_feed))
.route("/rss/:id/check", post(routes::check_rss_feed))
.route("/scheduler", get(routes::list_schedule_rules))
.route("/scheduler", post(routes::add_schedule_rule))
.route("/scheduler/:id", put(routes::update_schedule_rule))
.route("/scheduler/:id", delete(routes::delete_schedule_rule));
let router = if config.server.api.swagger_ui {
router.merge(SwaggerUi::new("/swagger-ui").url("/api/v1/openapi.json", ApiDoc::openapi()))
} else {
router
};
let router = router.with_state(state);
let router = if config.server.api.api_key.is_some() {
router.layer(middleware::from_fn_with_state(
config.server.api.api_key.clone(),
auth::require_api_key,
))
} else {
router
};
let router = if config.server.api.rate_limit.enabled {
let limiter = Arc::new(rate_limit::RateLimiter::new(
config.server.api.rate_limit.clone(),
));
router.layer(middleware::from_fn_with_state(
limiter,
rate_limit::rate_limit_middleware,
))
} else {
router
};
if config.server.api.cors_enabled {
let cors = build_cors_layer(&config.server.api.cors_origins);
router.layer(cors)
} else {
router
}
}
fn build_cors_layer(origins: &[String]) -> CorsLayer {
let allow_any = origins.iter().any(|o| o == "*");
if allow_any || origins.is_empty() {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
let allowed: Vec<HeaderValue> = origins.iter().filter_map(|o| o.parse().ok()).collect();
CorsLayer::new()
.allow_origin(AllowOrigin::list(allowed))
.allow_methods(Any)
.allow_headers(Any)
}
}
pub async fn start_api_server(
downloader: Arc<UsenetDownloader>,
config: Arc<Config>,
) -> Result<()> {
let bind_address = config.server.api.bind_address;
tracing::info!(
address = %bind_address,
"Starting API server"
);
let app = create_router(downloader, config);
let listener = TcpListener::bind(bind_address)
.await
.map_err(crate::error::Error::Io)?;
tracing::info!(
address = %bind_address,
"API server listening"
);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.map_err(|e| crate::error::Error::ApiServerError(e.to_string()))?;
tracing::info!("API server stopped");
Ok(())
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests;