Skip to main content

entrenar/server/
handlers.rs

1//! HTTP request handlers
2//!
3//! Axum handlers for the tracking server API.
4
5use crate::server::{
6    state::{AppState, RunStatus},
7    ApiResponse, CreateExperimentRequest, CreateRunRequest, ExperimentResponse, HealthResponse,
8    LogMetricsRequest, LogParamsRequest, RunResponse, UpdateRunRequest,
9};
10use axum::{
11    extract::{Path, State},
12    http::StatusCode,
13    Json,
14};
15
16/// Generate a request ID
17fn request_id() -> String {
18    format!("req-{:016x}", rand::random::<u64>())
19}
20
21/// Health check handler
22pub async fn health_check(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
23    let health = HealthResponse {
24        status: "healthy".to_string(),
25        version: env!("CARGO_PKG_VERSION").to_string(),
26        uptime_secs: state.uptime_secs(),
27        experiments_count: state.storage.experiments_count(),
28        runs_count: state.storage.runs_count(),
29    };
30
31    (StatusCode::OK, Json(health))
32}
33
34/// Create a new experiment
35pub async fn create_experiment(
36    State(state): State<AppState>,
37    Json(payload): Json<CreateExperimentRequest>,
38) -> (StatusCode, Json<ApiResponse<ExperimentResponse>>) {
39    let req_id = request_id();
40
41    match state.storage.create_experiment(&payload.name, payload.description, payload.tags) {
42        Ok(exp) => {
43            let response: ExperimentResponse = exp.into();
44            (StatusCode::CREATED, Json(ApiResponse::success(response, &req_id)))
45        }
46        Err(e) => {
47            (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error(&e.to_string(), &req_id)))
48        }
49    }
50}
51
52/// Get an experiment by ID
53pub async fn get_experiment(
54    State(state): State<AppState>,
55    Path(id): Path<String>,
56) -> (StatusCode, Json<ApiResponse<ExperimentResponse>>) {
57    let req_id = request_id();
58
59    match state.storage.get_experiment(&id) {
60        Ok(exp) => {
61            let response: ExperimentResponse = exp.into();
62            (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
63        }
64        Err(e) => (StatusCode::NOT_FOUND, Json(ApiResponse::error(&e.to_string(), &req_id))),
65    }
66}
67
68/// List all experiments
69pub async fn list_experiments(
70    State(state): State<AppState>,
71) -> (StatusCode, Json<ApiResponse<Vec<ExperimentResponse>>>) {
72    let req_id = request_id();
73
74    match state.storage.list_experiments() {
75        Ok(exps) => {
76            let responses: Vec<ExperimentResponse> = exps.into_iter().map(Into::into).collect();
77            (StatusCode::OK, Json(ApiResponse::success(responses, &req_id)))
78        }
79        Err(e) => {
80            (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error(&e.to_string(), &req_id)))
81        }
82    }
83}
84
85/// Create a new run
86pub async fn create_run(
87    State(state): State<AppState>,
88    Json(payload): Json<CreateRunRequest>,
89) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
90    let req_id = request_id();
91
92    match state.storage.create_run(&payload.experiment_id, payload.name, payload.tags) {
93        Ok(run) => {
94            let response: RunResponse = run.into();
95            (StatusCode::CREATED, Json(ApiResponse::success(response, &req_id)))
96        }
97        Err(e) => {
98            let status = if e.to_string().contains("not found") {
99                StatusCode::NOT_FOUND
100            } else {
101                StatusCode::INTERNAL_SERVER_ERROR
102            };
103            (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
104        }
105    }
106}
107
108/// Get a run by ID
109pub async fn get_run(
110    State(state): State<AppState>,
111    Path(id): Path<String>,
112) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
113    let req_id = request_id();
114
115    match state.storage.get_run(&id) {
116        Ok(run) => {
117            let response: RunResponse = run.into();
118            (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
119        }
120        Err(e) => (StatusCode::NOT_FOUND, Json(ApiResponse::error(&e.to_string(), &req_id))),
121    }
122}
123
124/// Update a run
125pub async fn update_run(
126    State(state): State<AppState>,
127    Path(id): Path<String>,
128    Json(payload): Json<UpdateRunRequest>,
129) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
130    let req_id = request_id();
131
132    let status = payload.status.as_ref().and_then(|s| s.parse::<RunStatus>().ok());
133
134    let end_time = payload.end_time.as_ref().and_then(|t| {
135        chrono::DateTime::parse_from_rfc3339(t).ok().map(|dt| dt.with_timezone(&chrono::Utc))
136    });
137
138    match state.storage.update_run(&id, status, end_time) {
139        Ok(run) => {
140            let response: RunResponse = run.into();
141            (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
142        }
143        Err(e) => {
144            let status_code = if e.to_string().contains("not found") {
145                StatusCode::NOT_FOUND
146            } else {
147                StatusCode::INTERNAL_SERVER_ERROR
148            };
149            (status_code, Json(ApiResponse::error(&e.to_string(), &req_id)))
150        }
151    }
152}
153
154/// Log parameters for a run
155pub async fn log_params(
156    State(state): State<AppState>,
157    Path(id): Path<String>,
158    Json(payload): Json<LogParamsRequest>,
159) -> (StatusCode, Json<ApiResponse<&'static str>>) {
160    let req_id = request_id();
161
162    match state.storage.log_params(&id, payload.params) {
163        Ok(()) => (StatusCode::OK, Json(ApiResponse::success("Parameters logged", &req_id))),
164        Err(e) => {
165            let status = if e.to_string().contains("not found") {
166                StatusCode::NOT_FOUND
167            } else {
168                StatusCode::INTERNAL_SERVER_ERROR
169            };
170            (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
171        }
172    }
173}
174
175/// Log metrics for a run
176pub async fn log_metrics(
177    State(state): State<AppState>,
178    Path(id): Path<String>,
179    Json(payload): Json<LogMetricsRequest>,
180) -> (StatusCode, Json<ApiResponse<&'static str>>) {
181    let req_id = request_id();
182
183    match state.storage.log_metrics(&id, payload.metrics) {
184        Ok(()) => (StatusCode::OK, Json(ApiResponse::success("Metrics logged", &req_id))),
185        Err(e) => {
186            let status = if e.to_string().contains("not found") {
187                StatusCode::NOT_FOUND
188            } else {
189                StatusCode::INTERNAL_SERVER_ERROR
190            };
191            (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
192        }
193    }
194}
195
196// =============================================================================
197// Tests
198// =============================================================================
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::server::ServerConfig;
204
205    fn test_state() -> AppState {
206        AppState::new(ServerConfig::default())
207    }
208
209    #[tokio::test]
210    async fn test_health_check() {
211        let state = test_state();
212        let (status, Json(body)) = health_check(State(state)).await;
213        assert_eq!(status, StatusCode::OK);
214        assert_eq!(body.status, "healthy");
215    }
216
217    #[tokio::test]
218    async fn test_create_experiment() {
219        let state = test_state();
220        let req =
221            CreateExperimentRequest { name: "test".to_string(), description: None, tags: None };
222
223        let (status, _) = create_experiment(State(state), Json(req)).await;
224        assert_eq!(status, StatusCode::CREATED);
225    }
226
227    #[tokio::test]
228    async fn test_get_experiment() {
229        let state = test_state();
230        let exp =
231            state.storage.create_experiment("test", None, None).expect("operation should succeed");
232
233        let (status, _) = get_experiment(State(state), Path(exp.id)).await;
234        assert_eq!(status, StatusCode::OK);
235    }
236
237    #[tokio::test]
238    async fn test_get_experiment_not_found() {
239        let state = test_state();
240
241        let (status, _) = get_experiment(State(state), Path("nonexistent".to_string())).await;
242        assert_eq!(status, StatusCode::NOT_FOUND);
243    }
244
245    #[tokio::test]
246    async fn test_list_experiments() {
247        let state = test_state();
248        state.storage.create_experiment("exp1", None, None).expect("operation should succeed");
249        state.storage.create_experiment("exp2", None, None).expect("operation should succeed");
250
251        let (status, _) = list_experiments(State(state)).await;
252        assert_eq!(status, StatusCode::OK);
253    }
254
255    #[tokio::test]
256    async fn test_create_run() {
257        let state = test_state();
258        let exp =
259            state.storage.create_experiment("test", None, None).expect("operation should succeed");
260
261        let req =
262            CreateRunRequest { experiment_id: exp.id, name: Some("run-1".to_string()), tags: None };
263
264        let (status, _) = create_run(State(state), Json(req)).await;
265        assert_eq!(status, StatusCode::CREATED);
266    }
267
268    #[tokio::test]
269    async fn test_create_run_invalid_experiment() {
270        let state = test_state();
271
272        let req =
273            CreateRunRequest { experiment_id: "nonexistent".to_string(), name: None, tags: None };
274
275        let (status, _) = create_run(State(state), Json(req)).await;
276        assert_eq!(status, StatusCode::NOT_FOUND);
277    }
278
279    #[tokio::test]
280    async fn test_get_run() {
281        let state = test_state();
282        let exp =
283            state.storage.create_experiment("test", None, None).expect("operation should succeed");
284        let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
285
286        let (status, _) = get_run(State(state), Path(run.id)).await;
287        assert_eq!(status, StatusCode::OK);
288    }
289
290    #[tokio::test]
291    async fn test_update_run() {
292        let state = test_state();
293        let exp =
294            state.storage.create_experiment("test", None, None).expect("operation should succeed");
295        let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
296
297        let req = UpdateRunRequest { status: Some("completed".to_string()), end_time: None };
298
299        let (status, _) = update_run(State(state), Path(run.id), Json(req)).await;
300        assert_eq!(status, StatusCode::OK);
301    }
302
303    #[tokio::test]
304    async fn test_log_params() {
305        let state = test_state();
306        let exp =
307            state.storage.create_experiment("test", None, None).expect("operation should succeed");
308        let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
309
310        let mut params = std::collections::HashMap::new();
311        params.insert("lr".to_string(), serde_json::json!(0.001));
312        let req = LogParamsRequest { params };
313
314        let (status, _) = log_params(State(state), Path(run.id), Json(req)).await;
315        assert_eq!(status, StatusCode::OK);
316    }
317
318    #[tokio::test]
319    async fn test_log_metrics() {
320        let state = test_state();
321        let exp =
322            state.storage.create_experiment("test", None, None).expect("operation should succeed");
323        let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
324
325        let mut metrics = std::collections::HashMap::new();
326        metrics.insert("loss".to_string(), 0.5);
327        let req = LogMetricsRequest { metrics, step: None };
328
329        let (status, _) = log_metrics(State(state), Path(run.id), Json(req)).await;
330        assert_eq!(status, StatusCode::OK);
331    }
332
333    #[tokio::test]
334    async fn test_log_params_not_found() {
335        let state = test_state();
336
337        let req = LogParamsRequest { params: std::collections::HashMap::new() };
338
339        let (status, _) =
340            log_params(State(state), Path("nonexistent".to_string()), Json(req)).await;
341        assert_eq!(status, StatusCode::NOT_FOUND);
342    }
343}