Skip to main content

entrenar/server/
api.rs

1//! API router and server setup
2//!
3//! Configures axum routes and runs the HTTP server.
4
5use crate::server::{
6    handlers::{
7        create_experiment, create_run, get_experiment, get_run, health_check, list_experiments,
8        log_metrics, log_params, update_run,
9    },
10    state::AppState,
11    Result, ServerConfig, ServerError,
12};
13use axum::{
14    routing::{get, patch, post},
15    Router,
16};
17use std::net::SocketAddr;
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::trace::TraceLayer;
20
21/// Tracking server for experiment management
22pub struct TrackingServer {
23    config: ServerConfig,
24    state: AppState,
25}
26
27impl TrackingServer {
28    /// Create a new tracking server
29    pub fn new(config: ServerConfig) -> Self {
30        let state = AppState::new(config.clone());
31        Self { config, state }
32    }
33
34    /// Build the router
35    pub fn router(&self) -> Router {
36        let mut app = Router::new()
37            // Health check
38            .route("/health", get(health_check))
39            // Experiments
40            .route("/api/v1/experiments", post(create_experiment))
41            .route("/api/v1/experiments", get(list_experiments))
42            .route("/api/v1/experiments/{id}", get(get_experiment))
43            // Runs
44            .route("/api/v1/runs", post(create_run))
45            .route("/api/v1/runs/{id}", get(get_run))
46            .route("/api/v1/runs/{id}", patch(update_run))
47            .route("/api/v1/runs/{id}/params", post(log_params))
48            .route("/api/v1/runs/{id}/metrics", post(log_metrics))
49            // State
50            .with_state(self.state.clone())
51            // Tracing
52            .layer(TraceLayer::new_for_http());
53
54        // Add CORS if enabled
55        if self.config.cors_enabled {
56            let cors = CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any);
57            app = app.layer(cors);
58        }
59
60        app
61    }
62
63    /// Run the server
64    pub async fn run(&self) -> Result<()> {
65        let addr = self.config.address;
66        let listener = tokio::net::TcpListener::bind(addr)
67            .await
68            .map_err(|e| ServerError::Bind(e.to_string()))?;
69
70        println!("🚀 Entrenar tracking server running on http://{addr}");
71
72        axum::serve(listener, self.router()).await.map_err(ServerError::Io)?;
73
74        Ok(())
75    }
76
77    /// Get the configured address
78    pub fn address(&self) -> SocketAddr {
79        self.config.address
80    }
81
82    /// Get the current state (for testing)
83    pub fn state(&self) -> &AppState {
84        &self.state
85    }
86}
87
88// =============================================================================
89// Tests
90// =============================================================================
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use axum::body::Body;
96    use axum::http::{Request, StatusCode};
97    use tower::ServiceExt;
98
99    fn test_server() -> TrackingServer {
100        TrackingServer::new(ServerConfig::default())
101    }
102
103    #[tokio::test]
104    async fn test_tracking_server_new() {
105        let server = test_server();
106        assert_eq!(server.address().port(), 5000);
107    }
108
109    #[tokio::test]
110    async fn test_health_endpoint() {
111        let server = test_server();
112        let app = server.router();
113
114        let response = app
115            .oneshot(
116                Request::builder()
117                    .uri("/health")
118                    .body(Body::empty())
119                    .expect("operation should succeed"),
120            )
121            .await
122            .expect("operation should succeed");
123
124        assert_eq!(response.status(), StatusCode::OK);
125    }
126
127    #[tokio::test]
128    async fn test_create_experiment_endpoint() {
129        let server = test_server();
130        let app = server.router();
131
132        let body = r#"{"name": "test-experiment"}"#;
133        let response = app
134            .oneshot(
135                Request::builder()
136                    .method("POST")
137                    .uri("/api/v1/experiments")
138                    .header("Content-Type", "application/json")
139                    .body(Body::from(body))
140                    .expect("operation should succeed"),
141            )
142            .await
143            .expect("operation should succeed");
144
145        assert_eq!(response.status(), StatusCode::CREATED);
146    }
147
148    #[tokio::test]
149    async fn test_list_experiments_endpoint() {
150        let server = test_server();
151        let app = server.router();
152
153        let response = app
154            .oneshot(
155                Request::builder()
156                    .uri("/api/v1/experiments")
157                    .body(Body::empty())
158                    .expect("operation should succeed"),
159            )
160            .await
161            .expect("operation should succeed");
162
163        assert_eq!(response.status(), StatusCode::OK);
164    }
165
166    #[tokio::test]
167    async fn test_get_experiment_not_found() {
168        let server = test_server();
169        let app = server.router();
170
171        let response = app
172            .oneshot(
173                Request::builder()
174                    .uri("/api/v1/experiments/nonexistent")
175                    .body(Body::empty())
176                    .expect("operation should succeed"),
177            )
178            .await
179            .expect("operation should succeed");
180
181        assert_eq!(response.status(), StatusCode::NOT_FOUND);
182    }
183
184    #[tokio::test]
185    async fn test_create_run_endpoint() {
186        let server = test_server();
187
188        // First create an experiment
189        let exp = server
190            .state
191            .storage
192            .create_experiment("test", None, None)
193            .expect("operation should succeed");
194
195        let app = server.router();
196        let body = format!(r#"{{"experiment_id": "{}"}}"#, exp.id);
197
198        let response = app
199            .oneshot(
200                Request::builder()
201                    .method("POST")
202                    .uri("/api/v1/runs")
203                    .header("Content-Type", "application/json")
204                    .body(Body::from(body))
205                    .expect("operation should succeed"),
206            )
207            .await
208            .expect("operation should succeed");
209
210        assert_eq!(response.status(), StatusCode::CREATED);
211    }
212
213    #[tokio::test]
214    async fn test_log_params_endpoint() {
215        let server = test_server();
216
217        // Create experiment and run
218        let exp = server
219            .state
220            .storage
221            .create_experiment("test", None, None)
222            .expect("operation should succeed");
223        let run =
224            server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
225
226        let app = server.router();
227        let body = r#"{"params": {"lr": 0.001, "batch_size": 32}}"#;
228
229        let response = app
230            .oneshot(
231                Request::builder()
232                    .method("POST")
233                    .uri(format!("/api/v1/runs/{}/params", run.id))
234                    .header("Content-Type", "application/json")
235                    .body(Body::from(body))
236                    .expect("operation should succeed"),
237            )
238            .await
239            .expect("operation should succeed");
240
241        assert_eq!(response.status(), StatusCode::OK);
242    }
243
244    #[tokio::test]
245    async fn test_log_metrics_endpoint() {
246        let server = test_server();
247
248        // Create experiment and run
249        let exp = server
250            .state
251            .storage
252            .create_experiment("test", None, None)
253            .expect("operation should succeed");
254        let run =
255            server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
256
257        let app = server.router();
258        let body = r#"{"metrics": {"loss": 0.5, "accuracy": 0.9}, "step": 100}"#;
259
260        let response = app
261            .oneshot(
262                Request::builder()
263                    .method("POST")
264                    .uri(format!("/api/v1/runs/{}/metrics", run.id))
265                    .header("Content-Type", "application/json")
266                    .body(Body::from(body))
267                    .expect("operation should succeed"),
268            )
269            .await
270            .expect("operation should succeed");
271
272        assert_eq!(response.status(), StatusCode::OK);
273    }
274
275    #[tokio::test]
276    async fn test_update_run_endpoint() {
277        let server = test_server();
278
279        // Create experiment and run
280        let exp = server
281            .state
282            .storage
283            .create_experiment("test", None, None)
284            .expect("operation should succeed");
285        let run =
286            server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
287
288        let app = server.router();
289        let body = r#"{"status": "completed"}"#;
290
291        let response = app
292            .oneshot(
293                Request::builder()
294                    .method("PATCH")
295                    .uri(format!("/api/v1/runs/{}", run.id))
296                    .header("Content-Type", "application/json")
297                    .body(Body::from(body))
298                    .expect("operation should succeed"),
299            )
300            .await
301            .expect("operation should succeed");
302
303        assert_eq!(response.status(), StatusCode::OK);
304    }
305
306    #[tokio::test]
307    async fn test_cors_enabled() {
308        let config = ServerConfig::default();
309        assert!(config.cors_enabled);
310
311        let server = TrackingServer::new(config);
312        let _app = server.router();
313        // Router builds successfully with CORS
314    }
315
316    #[tokio::test]
317    async fn test_cors_disabled() {
318        let config = ServerConfig::default().without_cors();
319        assert!(!config.cors_enabled);
320
321        let server = TrackingServer::new(config);
322        let _app = server.router();
323        // Router builds successfully without CORS
324    }
325}