rsketch_server/
http.rs

1// Copyright 2025 Crrow
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use axum::{
16    Router, extract::DefaultBodyLimit, http::StatusCode, response::IntoResponse, routing::get,
17};
18use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer};
19use rsketch_base::readable_size::ReadableSize;
20use rsketch_error::{ParseAddressSnafu, Result};
21use serde::{Deserialize, Serialize};
22use smart_default::SmartDefault;
23use snafu::ResultExt;
24use tokio::sync::oneshot;
25use tokio_util::sync::CancellationToken;
26use tower_http::cors::{Any, CorsLayer};
27use tracing::info;
28
29use super::ServiceHandler;
30
31/// Default maximum HTTP request body size (100 MB)
32pub const DEFAULT_MAX_HTTP_BODY_SIZE: ReadableSize = ReadableSize::mb(100);
33
34/// Configuration options for a REST server
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, SmartDefault, bon::Builder)]
36pub struct RestServerConfig {
37    /// The address to bind the REST server
38    #[default = "127.0.0.1:3000"]
39    pub bind_address:  String,
40    /// Maximum HTTP request body size
41    #[default(_code = "DEFAULT_MAX_HTTP_BODY_SIZE")]
42    pub max_body_size: ReadableSize,
43    /// Whether to enable CORS
44    #[default = true]
45    pub enable_cors:   bool,
46}
47
48/// Starts the REST server and returns a handle for managing its lifecycle.
49///
50/// This method:
51/// 1. Sets up the Axum router with middleware (CORS, body size limits)
52/// 2. Registers all provided route handlers
53/// 3. Parses and binds to the configured address
54/// 4. Spawns the server in a background task
55/// 5. Returns a handle for lifecycle management
56///
57/// The server will automatically register all provided route handlers and
58/// supports graceful shutdown through the returned handle.
59///
60/// # Arguments
61/// * `config` - Configuration for the REST server
62/// * `route_handlers` - Vector of functions that take a Router and return a
63///   modified Router
64///
65/// # Errors
66/// Returns an error if the bind address cannot be parsed.
67///
68/// # Example
69///
70/// ```rust,ignore
71/// use axum::{Router, routing::get};
72/// use rsketch_server::http::{RestServerConfig, start_rest_server};
73///
74/// fn my_routes(router: Router) -> Router {
75///     router.route("/api/v1/hello", get(|| async { "Hello, World!" }))
76/// }
77///
78/// #[tokio::main]
79/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
80///     let config = RestServerConfig::default();
81///     let handlers = vec![my_routes];
82///     let handle = start_rest_server(config, handlers).await?;
83///     Ok(())
84/// }
85/// ```
86///
87/// # Errors
88///
89/// Returns an error if server binding fails or graceful shutdown encounters
90/// issues.
91///
92/// # Panics
93///
94/// May panic if TcpListener binding fails within the spawn context.
95#[allow(clippy::unused_async)]
96pub async fn start_rest_server<F>(
97    config: RestServerConfig,
98    route_handlers: Vec<F>,
99) -> Result<ServiceHandler>
100where
101    F: Fn(Router) -> Router + Send + Sync + 'static,
102{
103    // Parse bind address
104    let bind_addr = config
105        .bind_address
106        .parse::<std::net::SocketAddr>()
107        .context(ParseAddressSnafu {
108            addr: config.bind_address.clone(),
109        })?;
110
111    // Build the router with middleware
112    let mut router = Router::new()
113        .route("/health", get(health_check))
114        .layer(OtelInResponseLayer)
115        .layer(OtelAxumLayer::default())
116        .layer({
117            #[allow(clippy::cast_possible_truncation)]
118            DefaultBodyLimit::max(config.max_body_size.as_bytes() as usize)
119        });
120
121    // Add CORS if enabled
122    if config.enable_cors {
123        let cors = CorsLayer::new()
124            .allow_origin(Any)
125            .allow_methods(Any)
126            .allow_headers(Any);
127        router = router.layer(cors);
128    }
129
130    // Register route handlers
131    for handler in &route_handlers {
132        info!("Registering REST route handler");
133        router = handler(router);
134    }
135
136    // Spawn the server task
137    let cancellation_token = CancellationToken::new();
138    let (join_handle, started_rx) = {
139        let (started_tx, started_rx) = oneshot::channel::<()>();
140        let cancellation_token_clone = cancellation_token.clone();
141        let join_handle = tokio::spawn(async move {
142            let listener = tokio::net::TcpListener::bind(bind_addr).await.unwrap();
143            let result = axum::serve(listener, router)
144                .with_graceful_shutdown(async move {
145                    info!("REST server (on {}) starting", bind_addr);
146                    let _ = started_tx.send(());
147                    info!("REST server (on {}) started", bind_addr);
148                    cancellation_token_clone.cancelled().await;
149                    info!("REST server (on {}) received shutdown signal", bind_addr);
150                })
151                .await;
152
153            info!(
154                "REST server (on {}) task completed: {:?}",
155                bind_addr, result
156            );
157        });
158        (join_handle, started_rx)
159    };
160
161    Ok(ServiceHandler {
162        join_handle,
163        cancellation_token,
164        started_rx: Some(started_rx),
165        reporter_handles: Vec::new(), // No readiness reporting for simple route handlers
166    })
167}
168
169/// Health check endpoint for the REST server
170async fn health_check() -> impl IntoResponse { (StatusCode::OK, "OK") }
171
172/// Health check handler that returns detailed health information
173async fn api_health_handler() -> axum::Json<serde_json::Value> {
174    axum::Json(serde_json::json!({
175        "status": "healthy",
176        "timestamp": chrono::Utc::now().to_rfc3339(),
177        "service": "rsketch",
178        "version": env!("CARGO_PKG_VERSION")
179    }))
180}
181
182/// Add health routes to the router
183///
184/// This function adds health check endpoints for API monitoring and readiness
185/// checks. It provides both simple health check and detailed health information
186/// endpoints.
187pub fn health_routes(router: Router) -> Router {
188    router
189        .route("/api/v1/health", get(api_health_handler))
190        .route("/api/health", get(api_health_handler))
191}
192
193#[cfg(test)]
194mod tests {
195    use axum::{Json, routing::get};
196
197    use super::*;
198
199    fn init_test_logging() {
200        let _ = tracing_subscriber::fmt()
201            .with_env_filter("debug")
202            .try_init();
203    }
204
205    /// Helper function to get an available port by binding to port 0
206    async fn get_available_port() -> u16 {
207        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
208        let port = listener.local_addr().unwrap().port();
209        drop(listener); // Release the port
210        port
211    }
212
213    #[tokio::test]
214    async fn test_rest_server_lifecycle() {
215        init_test_logging();
216
217        let port = get_available_port().await;
218        let config = RestServerConfig {
219            bind_address: format!("127.0.0.1:{port}"),
220            ..RestServerConfig::default()
221        };
222        let handlers: Vec<fn(Router) -> Router> = vec![health_routes];
223
224        let mut handler = start_rest_server(config, handlers).await.unwrap();
225
226        // Wait for server to start
227        handler.wait_for_start().await.unwrap();
228
229        // Test that the server is running by making a request
230        let client = reqwest::Client::new();
231        let response = client
232            .get(format!("http://127.0.0.1:{port}/health"))
233            .send()
234            .await
235            .unwrap();
236        assert_eq!(response.status(), 200);
237
238        let response = client
239            .get(format!("http://127.0.0.1:{port}/api/v1/health"))
240            .send()
241            .await
242            .unwrap();
243        assert_eq!(response.status(), 200);
244
245        // Shutdown the server
246        handler.shutdown();
247        handler.wait_for_stop().await.unwrap();
248    }
249
250    #[tokio::test]
251    async fn test_rest_server_without_cors() {
252        init_test_logging();
253
254        let port = get_available_port().await;
255        let config = RestServerConfig {
256            bind_address: format!("127.0.0.1:{port}"),
257            enable_cors: false,
258            ..RestServerConfig::default()
259        };
260        let handlers = vec![health_routes];
261
262        let mut handler = start_rest_server(config, handlers).await.unwrap();
263        handler.wait_for_start().await.unwrap();
264
265        // Test that the server is running
266        let client = reqwest::Client::new();
267        let response = client
268            .get(format!("http://127.0.0.1:{port}/health"))
269            .send()
270            .await
271            .unwrap();
272        assert_eq!(response.status(), 200);
273
274        handler.shutdown();
275        handler.wait_for_stop().await.unwrap();
276    }
277
278    #[tokio::test]
279    async fn test_multiple_route_handlers() {
280        init_test_logging();
281
282        async fn goodbye_handler() -> Json<&'static str> { Json("Goodbye, World!") }
283
284        fn goodbye_routes(router: Router) -> Router {
285            router.route("/api/v1/goodbye", get(goodbye_handler))
286        }
287
288        let port = get_available_port().await;
289        let config = RestServerConfig {
290            bind_address: format!("127.0.0.1:{port}"),
291            ..RestServerConfig::default()
292        };
293        let handlers = vec![health_routes, goodbye_routes];
294
295        let mut handler = start_rest_server(config, handlers).await.unwrap();
296        handler.wait_for_start().await.unwrap();
297
298        // Test both routes
299        let client = reqwest::Client::new();
300        let response = client
301            .get(format!("http://127.0.0.1:{port}/api/v1/health"))
302            .send()
303            .await
304            .unwrap();
305        assert_eq!(response.status(), 200);
306
307        let response = client
308            .get(format!("http://127.0.0.1:{port}/api/v1/goodbye"))
309            .send()
310            .await
311            .unwrap();
312        assert_eq!(response.status(), 200);
313
314        handler.shutdown();
315        handler.wait_for_stop().await.unwrap();
316    }
317}