1use crate::server::{
6 handlers::{
7 create_experiment, create_run, get_experiment, get_run, health_check, list_experiments,
8 log_metrics, log_params, update_run,
9 },
10 state::AppState,
11 Result, ServerConfig, ServerError,
12};
13use axum::{
14 routing::{get, patch, post},
15 Router,
16};
17use std::net::SocketAddr;
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::trace::TraceLayer;
20
21pub struct TrackingServer {
23 config: ServerConfig,
24 state: AppState,
25}
26
27impl TrackingServer {
28 pub fn new(config: ServerConfig) -> Self {
30 let state = AppState::new(config.clone());
31 Self { config, state }
32 }
33
34 pub fn router(&self) -> Router {
36 let mut app = Router::new()
37 .route("/health", get(health_check))
39 .route("/api/v1/experiments", post(create_experiment))
41 .route("/api/v1/experiments", get(list_experiments))
42 .route("/api/v1/experiments/{id}", get(get_experiment))
43 .route("/api/v1/runs", post(create_run))
45 .route("/api/v1/runs/{id}", get(get_run))
46 .route("/api/v1/runs/{id}", patch(update_run))
47 .route("/api/v1/runs/{id}/params", post(log_params))
48 .route("/api/v1/runs/{id}/metrics", post(log_metrics))
49 .with_state(self.state.clone())
51 .layer(TraceLayer::new_for_http());
53
54 if self.config.cors_enabled {
56 let cors = CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any);
57 app = app.layer(cors);
58 }
59
60 app
61 }
62
63 pub async fn run(&self) -> Result<()> {
65 let addr = self.config.address;
66 let listener = tokio::net::TcpListener::bind(addr)
67 .await
68 .map_err(|e| ServerError::Bind(e.to_string()))?;
69
70 println!("🚀 Entrenar tracking server running on http://{addr}");
71
72 axum::serve(listener, self.router()).await.map_err(ServerError::Io)?;
73
74 Ok(())
75 }
76
77 pub fn address(&self) -> SocketAddr {
79 self.config.address
80 }
81
82 pub fn state(&self) -> &AppState {
84 &self.state
85 }
86}
87
88#[cfg(test)]
93mod tests {
94 use super::*;
95 use axum::body::Body;
96 use axum::http::{Request, StatusCode};
97 use tower::ServiceExt;
98
99 fn test_server() -> TrackingServer {
100 TrackingServer::new(ServerConfig::default())
101 }
102
103 #[tokio::test]
104 async fn test_tracking_server_new() {
105 let server = test_server();
106 assert_eq!(server.address().port(), 5000);
107 }
108
109 #[tokio::test]
110 async fn test_health_endpoint() {
111 let server = test_server();
112 let app = server.router();
113
114 let response = app
115 .oneshot(
116 Request::builder()
117 .uri("/health")
118 .body(Body::empty())
119 .expect("operation should succeed"),
120 )
121 .await
122 .expect("operation should succeed");
123
124 assert_eq!(response.status(), StatusCode::OK);
125 }
126
127 #[tokio::test]
128 async fn test_create_experiment_endpoint() {
129 let server = test_server();
130 let app = server.router();
131
132 let body = r#"{"name": "test-experiment"}"#;
133 let response = app
134 .oneshot(
135 Request::builder()
136 .method("POST")
137 .uri("/api/v1/experiments")
138 .header("Content-Type", "application/json")
139 .body(Body::from(body))
140 .expect("operation should succeed"),
141 )
142 .await
143 .expect("operation should succeed");
144
145 assert_eq!(response.status(), StatusCode::CREATED);
146 }
147
148 #[tokio::test]
149 async fn test_list_experiments_endpoint() {
150 let server = test_server();
151 let app = server.router();
152
153 let response = app
154 .oneshot(
155 Request::builder()
156 .uri("/api/v1/experiments")
157 .body(Body::empty())
158 .expect("operation should succeed"),
159 )
160 .await
161 .expect("operation should succeed");
162
163 assert_eq!(response.status(), StatusCode::OK);
164 }
165
166 #[tokio::test]
167 async fn test_get_experiment_not_found() {
168 let server = test_server();
169 let app = server.router();
170
171 let response = app
172 .oneshot(
173 Request::builder()
174 .uri("/api/v1/experiments/nonexistent")
175 .body(Body::empty())
176 .expect("operation should succeed"),
177 )
178 .await
179 .expect("operation should succeed");
180
181 assert_eq!(response.status(), StatusCode::NOT_FOUND);
182 }
183
184 #[tokio::test]
185 async fn test_create_run_endpoint() {
186 let server = test_server();
187
188 let exp = server
190 .state
191 .storage
192 .create_experiment("test", None, None)
193 .expect("operation should succeed");
194
195 let app = server.router();
196 let body = format!(r#"{{"experiment_id": "{}"}}"#, exp.id);
197
198 let response = app
199 .oneshot(
200 Request::builder()
201 .method("POST")
202 .uri("/api/v1/runs")
203 .header("Content-Type", "application/json")
204 .body(Body::from(body))
205 .expect("operation should succeed"),
206 )
207 .await
208 .expect("operation should succeed");
209
210 assert_eq!(response.status(), StatusCode::CREATED);
211 }
212
213 #[tokio::test]
214 async fn test_log_params_endpoint() {
215 let server = test_server();
216
217 let exp = server
219 .state
220 .storage
221 .create_experiment("test", None, None)
222 .expect("operation should succeed");
223 let run =
224 server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
225
226 let app = server.router();
227 let body = r#"{"params": {"lr": 0.001, "batch_size": 32}}"#;
228
229 let response = app
230 .oneshot(
231 Request::builder()
232 .method("POST")
233 .uri(format!("/api/v1/runs/{}/params", run.id))
234 .header("Content-Type", "application/json")
235 .body(Body::from(body))
236 .expect("operation should succeed"),
237 )
238 .await
239 .expect("operation should succeed");
240
241 assert_eq!(response.status(), StatusCode::OK);
242 }
243
244 #[tokio::test]
245 async fn test_log_metrics_endpoint() {
246 let server = test_server();
247
248 let exp = server
250 .state
251 .storage
252 .create_experiment("test", None, None)
253 .expect("operation should succeed");
254 let run =
255 server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
256
257 let app = server.router();
258 let body = r#"{"metrics": {"loss": 0.5, "accuracy": 0.9}, "step": 100}"#;
259
260 let response = app
261 .oneshot(
262 Request::builder()
263 .method("POST")
264 .uri(format!("/api/v1/runs/{}/metrics", run.id))
265 .header("Content-Type", "application/json")
266 .body(Body::from(body))
267 .expect("operation should succeed"),
268 )
269 .await
270 .expect("operation should succeed");
271
272 assert_eq!(response.status(), StatusCode::OK);
273 }
274
275 #[tokio::test]
276 async fn test_update_run_endpoint() {
277 let server = test_server();
278
279 let exp = server
281 .state
282 .storage
283 .create_experiment("test", None, None)
284 .expect("operation should succeed");
285 let run =
286 server.state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
287
288 let app = server.router();
289 let body = r#"{"status": "completed"}"#;
290
291 let response = app
292 .oneshot(
293 Request::builder()
294 .method("PATCH")
295 .uri(format!("/api/v1/runs/{}", run.id))
296 .header("Content-Type", "application/json")
297 .body(Body::from(body))
298 .expect("operation should succeed"),
299 )
300 .await
301 .expect("operation should succeed");
302
303 assert_eq!(response.status(), StatusCode::OK);
304 }
305
306 #[tokio::test]
307 async fn test_cors_enabled() {
308 let config = ServerConfig::default();
309 assert!(config.cors_enabled);
310
311 let server = TrackingServer::new(config);
312 let _app = server.router();
313 }
315
316 #[tokio::test]
317 async fn test_cors_disabled() {
318 let config = ServerConfig::default().without_cors();
319 assert!(!config.cors_enabled);
320
321 let server = TrackingServer::new(config);
322 let _app = server.router();
323 }
325}