1use axum::{
16 Router, extract::DefaultBodyLimit, http::StatusCode, response::IntoResponse, routing::get,
17};
18use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer};
19use rsketch_base::readable_size::ReadableSize;
20use rsketch_error::{ParseAddressSnafu, Result};
21use serde::{Deserialize, Serialize};
22use smart_default::SmartDefault;
23use snafu::ResultExt;
24use tokio::sync::oneshot;
25use tokio_util::sync::CancellationToken;
26use tower_http::cors::{Any, CorsLayer};
27use tracing::info;
28
29use super::ServiceHandler;
30
31pub const DEFAULT_MAX_HTTP_BODY_SIZE: ReadableSize = ReadableSize::mb(100);
33
34#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, SmartDefault, bon::Builder)]
36pub struct RestServerConfig {
37 #[default = "127.0.0.1:3000"]
39 pub bind_address: String,
40 #[default(_code = "DEFAULT_MAX_HTTP_BODY_SIZE")]
42 pub max_body_size: ReadableSize,
43 #[default = true]
45 pub enable_cors: bool,
46}
47
48#[allow(clippy::unused_async)]
96pub async fn start_rest_server<F>(
97 config: RestServerConfig,
98 route_handlers: Vec<F>,
99) -> Result<ServiceHandler>
100where
101 F: Fn(Router) -> Router + Send + Sync + 'static,
102{
103 let bind_addr = config
105 .bind_address
106 .parse::<std::net::SocketAddr>()
107 .context(ParseAddressSnafu {
108 addr: config.bind_address.clone(),
109 })?;
110
111 let mut router = Router::new()
113 .route("/health", get(health_check))
114 .layer(OtelInResponseLayer)
115 .layer(OtelAxumLayer::default())
116 .layer({
117 #[allow(clippy::cast_possible_truncation)]
118 DefaultBodyLimit::max(config.max_body_size.as_bytes() as usize)
119 });
120
121 if config.enable_cors {
123 let cors = CorsLayer::new()
124 .allow_origin(Any)
125 .allow_methods(Any)
126 .allow_headers(Any);
127 router = router.layer(cors);
128 }
129
130 for handler in &route_handlers {
132 info!("Registering REST route handler");
133 router = handler(router);
134 }
135
136 let cancellation_token = CancellationToken::new();
138 let (join_handle, started_rx) = {
139 let (started_tx, started_rx) = oneshot::channel::<()>();
140 let cancellation_token_clone = cancellation_token.clone();
141 let join_handle = tokio::spawn(async move {
142 let listener = tokio::net::TcpListener::bind(bind_addr).await.unwrap();
143 let result = axum::serve(listener, router)
144 .with_graceful_shutdown(async move {
145 info!("REST server (on {}) starting", bind_addr);
146 let _ = started_tx.send(());
147 info!("REST server (on {}) started", bind_addr);
148 cancellation_token_clone.cancelled().await;
149 info!("REST server (on {}) received shutdown signal", bind_addr);
150 })
151 .await;
152
153 info!(
154 "REST server (on {}) task completed: {:?}",
155 bind_addr, result
156 );
157 });
158 (join_handle, started_rx)
159 };
160
161 Ok(ServiceHandler {
162 join_handle,
163 cancellation_token,
164 started_rx: Some(started_rx),
165 reporter_handles: Vec::new(), })
167}
168
169async fn health_check() -> impl IntoResponse { (StatusCode::OK, "OK") }
171
172async fn api_health_handler() -> axum::Json<serde_json::Value> {
174 axum::Json(serde_json::json!({
175 "status": "healthy",
176 "timestamp": chrono::Utc::now().to_rfc3339(),
177 "service": "rsketch",
178 "version": env!("CARGO_PKG_VERSION")
179 }))
180}
181
182pub fn health_routes(router: Router) -> Router {
188 router
189 .route("/api/v1/health", get(api_health_handler))
190 .route("/api/health", get(api_health_handler))
191}
192
193#[cfg(test)]
194mod tests {
195 use axum::{Json, routing::get};
196
197 use super::*;
198
199 fn init_test_logging() {
200 let _ = tracing_subscriber::fmt()
201 .with_env_filter("debug")
202 .try_init();
203 }
204
205 async fn get_available_port() -> u16 {
207 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
208 let port = listener.local_addr().unwrap().port();
209 drop(listener); port
211 }
212
213 #[tokio::test]
214 async fn test_rest_server_lifecycle() {
215 init_test_logging();
216
217 let port = get_available_port().await;
218 let config = RestServerConfig {
219 bind_address: format!("127.0.0.1:{port}"),
220 ..RestServerConfig::default()
221 };
222 let handlers: Vec<fn(Router) -> Router> = vec![health_routes];
223
224 let mut handler = start_rest_server(config, handlers).await.unwrap();
225
226 handler.wait_for_start().await.unwrap();
228
229 let client = reqwest::Client::new();
231 let response = client
232 .get(format!("http://127.0.0.1:{port}/health"))
233 .send()
234 .await
235 .unwrap();
236 assert_eq!(response.status(), 200);
237
238 let response = client
239 .get(format!("http://127.0.0.1:{port}/api/v1/health"))
240 .send()
241 .await
242 .unwrap();
243 assert_eq!(response.status(), 200);
244
245 handler.shutdown();
247 handler.wait_for_stop().await.unwrap();
248 }
249
250 #[tokio::test]
251 async fn test_rest_server_without_cors() {
252 init_test_logging();
253
254 let port = get_available_port().await;
255 let config = RestServerConfig {
256 bind_address: format!("127.0.0.1:{port}"),
257 enable_cors: false,
258 ..RestServerConfig::default()
259 };
260 let handlers = vec![health_routes];
261
262 let mut handler = start_rest_server(config, handlers).await.unwrap();
263 handler.wait_for_start().await.unwrap();
264
265 let client = reqwest::Client::new();
267 let response = client
268 .get(format!("http://127.0.0.1:{port}/health"))
269 .send()
270 .await
271 .unwrap();
272 assert_eq!(response.status(), 200);
273
274 handler.shutdown();
275 handler.wait_for_stop().await.unwrap();
276 }
277
278 #[tokio::test]
279 async fn test_multiple_route_handlers() {
280 init_test_logging();
281
282 async fn goodbye_handler() -> Json<&'static str> { Json("Goodbye, World!") }
283
284 fn goodbye_routes(router: Router) -> Router {
285 router.route("/api/v1/goodbye", get(goodbye_handler))
286 }
287
288 let port = get_available_port().await;
289 let config = RestServerConfig {
290 bind_address: format!("127.0.0.1:{port}"),
291 ..RestServerConfig::default()
292 };
293 let handlers = vec![health_routes, goodbye_routes];
294
295 let mut handler = start_rest_server(config, handlers).await.unwrap();
296 handler.wait_for_start().await.unwrap();
297
298 let client = reqwest::Client::new();
300 let response = client
301 .get(format!("http://127.0.0.1:{port}/api/v1/health"))
302 .send()
303 .await
304 .unwrap();
305 assert_eq!(response.status(), 200);
306
307 let response = client
308 .get(format!("http://127.0.0.1:{port}/api/v1/goodbye"))
309 .send()
310 .await
311 .unwrap();
312 assert_eq!(response.status(), 200);
313
314 handler.shutdown();
315 handler.wait_for_stop().await.unwrap();
316 }
317}