1use crate::server::{
6 state::{AppState, RunStatus},
7 ApiResponse, CreateExperimentRequest, CreateRunRequest, ExperimentResponse, HealthResponse,
8 LogMetricsRequest, LogParamsRequest, RunResponse, UpdateRunRequest,
9};
10use axum::{
11 extract::{Path, State},
12 http::StatusCode,
13 Json,
14};
15
16fn request_id() -> String {
18 format!("req-{:016x}", rand::random::<u64>())
19}
20
21pub async fn health_check(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
23 let health = HealthResponse {
24 status: "healthy".to_string(),
25 version: env!("CARGO_PKG_VERSION").to_string(),
26 uptime_secs: state.uptime_secs(),
27 experiments_count: state.storage.experiments_count(),
28 runs_count: state.storage.runs_count(),
29 };
30
31 (StatusCode::OK, Json(health))
32}
33
34pub async fn create_experiment(
36 State(state): State<AppState>,
37 Json(payload): Json<CreateExperimentRequest>,
38) -> (StatusCode, Json<ApiResponse<ExperimentResponse>>) {
39 let req_id = request_id();
40
41 match state.storage.create_experiment(&payload.name, payload.description, payload.tags) {
42 Ok(exp) => {
43 let response: ExperimentResponse = exp.into();
44 (StatusCode::CREATED, Json(ApiResponse::success(response, &req_id)))
45 }
46 Err(e) => {
47 (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error(&e.to_string(), &req_id)))
48 }
49 }
50}
51
52pub async fn get_experiment(
54 State(state): State<AppState>,
55 Path(id): Path<String>,
56) -> (StatusCode, Json<ApiResponse<ExperimentResponse>>) {
57 let req_id = request_id();
58
59 match state.storage.get_experiment(&id) {
60 Ok(exp) => {
61 let response: ExperimentResponse = exp.into();
62 (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
63 }
64 Err(e) => (StatusCode::NOT_FOUND, Json(ApiResponse::error(&e.to_string(), &req_id))),
65 }
66}
67
68pub async fn list_experiments(
70 State(state): State<AppState>,
71) -> (StatusCode, Json<ApiResponse<Vec<ExperimentResponse>>>) {
72 let req_id = request_id();
73
74 match state.storage.list_experiments() {
75 Ok(exps) => {
76 let responses: Vec<ExperimentResponse> = exps.into_iter().map(Into::into).collect();
77 (StatusCode::OK, Json(ApiResponse::success(responses, &req_id)))
78 }
79 Err(e) => {
80 (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error(&e.to_string(), &req_id)))
81 }
82 }
83}
84
85pub async fn create_run(
87 State(state): State<AppState>,
88 Json(payload): Json<CreateRunRequest>,
89) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
90 let req_id = request_id();
91
92 match state.storage.create_run(&payload.experiment_id, payload.name, payload.tags) {
93 Ok(run) => {
94 let response: RunResponse = run.into();
95 (StatusCode::CREATED, Json(ApiResponse::success(response, &req_id)))
96 }
97 Err(e) => {
98 let status = if e.to_string().contains("not found") {
99 StatusCode::NOT_FOUND
100 } else {
101 StatusCode::INTERNAL_SERVER_ERROR
102 };
103 (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
104 }
105 }
106}
107
108pub async fn get_run(
110 State(state): State<AppState>,
111 Path(id): Path<String>,
112) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
113 let req_id = request_id();
114
115 match state.storage.get_run(&id) {
116 Ok(run) => {
117 let response: RunResponse = run.into();
118 (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
119 }
120 Err(e) => (StatusCode::NOT_FOUND, Json(ApiResponse::error(&e.to_string(), &req_id))),
121 }
122}
123
124pub async fn update_run(
126 State(state): State<AppState>,
127 Path(id): Path<String>,
128 Json(payload): Json<UpdateRunRequest>,
129) -> (StatusCode, Json<ApiResponse<RunResponse>>) {
130 let req_id = request_id();
131
132 let status = payload.status.as_ref().and_then(|s| s.parse::<RunStatus>().ok());
133
134 let end_time = payload.end_time.as_ref().and_then(|t| {
135 chrono::DateTime::parse_from_rfc3339(t).ok().map(|dt| dt.with_timezone(&chrono::Utc))
136 });
137
138 match state.storage.update_run(&id, status, end_time) {
139 Ok(run) => {
140 let response: RunResponse = run.into();
141 (StatusCode::OK, Json(ApiResponse::success(response, &req_id)))
142 }
143 Err(e) => {
144 let status_code = if e.to_string().contains("not found") {
145 StatusCode::NOT_FOUND
146 } else {
147 StatusCode::INTERNAL_SERVER_ERROR
148 };
149 (status_code, Json(ApiResponse::error(&e.to_string(), &req_id)))
150 }
151 }
152}
153
154pub async fn log_params(
156 State(state): State<AppState>,
157 Path(id): Path<String>,
158 Json(payload): Json<LogParamsRequest>,
159) -> (StatusCode, Json<ApiResponse<&'static str>>) {
160 let req_id = request_id();
161
162 match state.storage.log_params(&id, payload.params) {
163 Ok(()) => (StatusCode::OK, Json(ApiResponse::success("Parameters logged", &req_id))),
164 Err(e) => {
165 let status = if e.to_string().contains("not found") {
166 StatusCode::NOT_FOUND
167 } else {
168 StatusCode::INTERNAL_SERVER_ERROR
169 };
170 (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
171 }
172 }
173}
174
175pub async fn log_metrics(
177 State(state): State<AppState>,
178 Path(id): Path<String>,
179 Json(payload): Json<LogMetricsRequest>,
180) -> (StatusCode, Json<ApiResponse<&'static str>>) {
181 let req_id = request_id();
182
183 match state.storage.log_metrics(&id, payload.metrics) {
184 Ok(()) => (StatusCode::OK, Json(ApiResponse::success("Metrics logged", &req_id))),
185 Err(e) => {
186 let status = if e.to_string().contains("not found") {
187 StatusCode::NOT_FOUND
188 } else {
189 StatusCode::INTERNAL_SERVER_ERROR
190 };
191 (status, Json(ApiResponse::error(&e.to_string(), &req_id)))
192 }
193 }
194}
195
196#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::server::ServerConfig;
204
205 fn test_state() -> AppState {
206 AppState::new(ServerConfig::default())
207 }
208
209 #[tokio::test]
210 async fn test_health_check() {
211 let state = test_state();
212 let (status, Json(body)) = health_check(State(state)).await;
213 assert_eq!(status, StatusCode::OK);
214 assert_eq!(body.status, "healthy");
215 }
216
217 #[tokio::test]
218 async fn test_create_experiment() {
219 let state = test_state();
220 let req =
221 CreateExperimentRequest { name: "test".to_string(), description: None, tags: None };
222
223 let (status, _) = create_experiment(State(state), Json(req)).await;
224 assert_eq!(status, StatusCode::CREATED);
225 }
226
227 #[tokio::test]
228 async fn test_get_experiment() {
229 let state = test_state();
230 let exp =
231 state.storage.create_experiment("test", None, None).expect("operation should succeed");
232
233 let (status, _) = get_experiment(State(state), Path(exp.id)).await;
234 assert_eq!(status, StatusCode::OK);
235 }
236
237 #[tokio::test]
238 async fn test_get_experiment_not_found() {
239 let state = test_state();
240
241 let (status, _) = get_experiment(State(state), Path("nonexistent".to_string())).await;
242 assert_eq!(status, StatusCode::NOT_FOUND);
243 }
244
245 #[tokio::test]
246 async fn test_list_experiments() {
247 let state = test_state();
248 state.storage.create_experiment("exp1", None, None).expect("operation should succeed");
249 state.storage.create_experiment("exp2", None, None).expect("operation should succeed");
250
251 let (status, _) = list_experiments(State(state)).await;
252 assert_eq!(status, StatusCode::OK);
253 }
254
255 #[tokio::test]
256 async fn test_create_run() {
257 let state = test_state();
258 let exp =
259 state.storage.create_experiment("test", None, None).expect("operation should succeed");
260
261 let req =
262 CreateRunRequest { experiment_id: exp.id, name: Some("run-1".to_string()), tags: None };
263
264 let (status, _) = create_run(State(state), Json(req)).await;
265 assert_eq!(status, StatusCode::CREATED);
266 }
267
268 #[tokio::test]
269 async fn test_create_run_invalid_experiment() {
270 let state = test_state();
271
272 let req =
273 CreateRunRequest { experiment_id: "nonexistent".to_string(), name: None, tags: None };
274
275 let (status, _) = create_run(State(state), Json(req)).await;
276 assert_eq!(status, StatusCode::NOT_FOUND);
277 }
278
279 #[tokio::test]
280 async fn test_get_run() {
281 let state = test_state();
282 let exp =
283 state.storage.create_experiment("test", None, None).expect("operation should succeed");
284 let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
285
286 let (status, _) = get_run(State(state), Path(run.id)).await;
287 assert_eq!(status, StatusCode::OK);
288 }
289
290 #[tokio::test]
291 async fn test_update_run() {
292 let state = test_state();
293 let exp =
294 state.storage.create_experiment("test", None, None).expect("operation should succeed");
295 let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
296
297 let req = UpdateRunRequest { status: Some("completed".to_string()), end_time: None };
298
299 let (status, _) = update_run(State(state), Path(run.id), Json(req)).await;
300 assert_eq!(status, StatusCode::OK);
301 }
302
303 #[tokio::test]
304 async fn test_log_params() {
305 let state = test_state();
306 let exp =
307 state.storage.create_experiment("test", None, None).expect("operation should succeed");
308 let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
309
310 let mut params = std::collections::HashMap::new();
311 params.insert("lr".to_string(), serde_json::json!(0.001));
312 let req = LogParamsRequest { params };
313
314 let (status, _) = log_params(State(state), Path(run.id), Json(req)).await;
315 assert_eq!(status, StatusCode::OK);
316 }
317
318 #[tokio::test]
319 async fn test_log_metrics() {
320 let state = test_state();
321 let exp =
322 state.storage.create_experiment("test", None, None).expect("operation should succeed");
323 let run = state.storage.create_run(&exp.id, None, None).expect("operation should succeed");
324
325 let mut metrics = std::collections::HashMap::new();
326 metrics.insert("loss".to_string(), 0.5);
327 let req = LogMetricsRequest { metrics, step: None };
328
329 let (status, _) = log_metrics(State(state), Path(run.id), Json(req)).await;
330 assert_eq!(status, StatusCode::OK);
331 }
332
333 #[tokio::test]
334 async fn test_log_params_not_found() {
335 let state = test_state();
336
337 let req = LogParamsRequest { params: std::collections::HashMap::new() };
338
339 let (status, _) =
340 log_params(State(state), Path("nonexistent".to_string()), Json(req)).await;
341 assert_eq!(status, StatusCode::NOT_FOUND);
342 }
343}