use axum::{
Router, extract::DefaultBodyLimit, http::StatusCode, response::IntoResponse, routing::get,
};
use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer};
use rsketch_base::readable_size::ReadableSize;
use rsketch_error::{ParseAddressSnafu, Result};
use serde::{Deserialize, Serialize};
use smart_default::SmartDefault;
use snafu::ResultExt;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
use super::ServiceHandler;
pub const DEFAULT_MAX_HTTP_BODY_SIZE: ReadableSize = ReadableSize::mb(100);
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, SmartDefault, bon::Builder)]
pub struct RestServerConfig {
#[default = "127.0.0.1:3000"]
pub bind_address: String,
#[default(_code = "DEFAULT_MAX_HTTP_BODY_SIZE")]
pub max_body_size: ReadableSize,
#[default = true]
pub enable_cors: bool,
}
#[allow(clippy::unused_async)]
pub async fn start_rest_server<F>(
config: RestServerConfig,
route_handlers: Vec<F>,
) -> Result<ServiceHandler>
where
F: Fn(Router) -> Router + Send + Sync + 'static,
{
let bind_addr = config
.bind_address
.parse::<std::net::SocketAddr>()
.context(ParseAddressSnafu {
addr: config.bind_address.clone(),
})?;
let mut router = Router::new()
.route("/health", get(health_check))
.layer(OtelInResponseLayer)
.layer(OtelAxumLayer::default())
.layer({
#[allow(clippy::cast_possible_truncation)]
DefaultBodyLimit::max(config.max_body_size.as_bytes() as usize)
});
if config.enable_cors {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
router = router.layer(cors);
}
for handler in &route_handlers {
info!("Registering REST route handler");
router = handler(router);
}
let cancellation_token = CancellationToken::new();
let (join_handle, started_rx) = {
let (started_tx, started_rx) = oneshot::channel::<()>();
let cancellation_token_clone = cancellation_token.clone();
let join_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(bind_addr).await.unwrap();
let result = axum::serve(listener, router)
.with_graceful_shutdown(async move {
info!("REST server (on {}) starting", bind_addr);
let _ = started_tx.send(());
info!("REST server (on {}) started", bind_addr);
cancellation_token_clone.cancelled().await;
info!("REST server (on {}) received shutdown signal", bind_addr);
})
.await;
info!(
"REST server (on {}) task completed: {:?}",
bind_addr, result
);
});
(join_handle, started_rx)
};
Ok(ServiceHandler {
join_handle,
cancellation_token,
started_rx: Some(started_rx),
reporter_handles: Vec::new(), })
}
async fn health_check() -> impl IntoResponse { (StatusCode::OK, "OK") }
async fn api_health_handler() -> axum::Json<serde_json::Value> {
axum::Json(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"service": "rsketch",
"version": env!("CARGO_PKG_VERSION")
}))
}
pub fn health_routes(router: Router) -> Router {
router
.route("/api/v1/health", get(api_health_handler))
.route("/api/health", get(api_health_handler))
}
#[cfg(test)]
mod tests {
use axum::{Json, routing::get};
use super::*;
fn init_test_logging() {
let _ = tracing_subscriber::fmt()
.with_env_filter("debug")
.try_init();
}
async fn get_available_port() -> u16 {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener); port
}
#[tokio::test]
async fn test_rest_server_lifecycle() {
init_test_logging();
let port = get_available_port().await;
let config = RestServerConfig {
bind_address: format!("127.0.0.1:{port}"),
..RestServerConfig::default()
};
let handlers: Vec<fn(Router) -> Router> = vec![health_routes];
let mut handler = start_rest_server(config, handlers).await.unwrap();
handler.wait_for_start().await.unwrap();
let client = reqwest::Client::new();
let response = client
.get(format!("http://127.0.0.1:{port}/health"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
let response = client
.get(format!("http://127.0.0.1:{port}/api/v1/health"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
handler.shutdown();
handler.wait_for_stop().await.unwrap();
}
#[tokio::test]
async fn test_rest_server_without_cors() {
init_test_logging();
let port = get_available_port().await;
let config = RestServerConfig {
bind_address: format!("127.0.0.1:{port}"),
enable_cors: false,
..RestServerConfig::default()
};
let handlers = vec![health_routes];
let mut handler = start_rest_server(config, handlers).await.unwrap();
handler.wait_for_start().await.unwrap();
let client = reqwest::Client::new();
let response = client
.get(format!("http://127.0.0.1:{port}/health"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
handler.shutdown();
handler.wait_for_stop().await.unwrap();
}
#[tokio::test]
async fn test_multiple_route_handlers() {
init_test_logging();
async fn goodbye_handler() -> Json<&'static str> { Json("Goodbye, World!") }
fn goodbye_routes(router: Router) -> Router {
router.route("/api/v1/goodbye", get(goodbye_handler))
}
let port = get_available_port().await;
let config = RestServerConfig {
bind_address: format!("127.0.0.1:{port}"),
..RestServerConfig::default()
};
let handlers = vec![health_routes, goodbye_routes];
let mut handler = start_rest_server(config, handlers).await.unwrap();
handler.wait_for_start().await.unwrap();
let client = reqwest::Client::new();
let response = client
.get(format!("http://127.0.0.1:{port}/api/v1/health"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
let response = client
.get(format!("http://127.0.0.1:{port}/api/v1/goodbye"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
handler.shutdown();
handler.wait_for_stop().await.unwrap();
}
}