batuta/serve/banco/
handlers_experiment.rs1use axum::{extract::State, http::StatusCode, response::Json};
4use serde::Deserialize;
5
6use super::experiment::{Experiment, RunComparison};
7use super::state::BancoState;
8use super::types::ErrorResponse;
9
10pub async fn create_experiment_handler(
12 State(state): State<BancoState>,
13 Json(request): Json<CreateExperimentRequest>,
14) -> Json<Experiment> {
15 let exp = state.experiments.create(&request.name, &request.description.unwrap_or_default());
16 Json(exp)
17}
18
19pub async fn list_experiments_handler(
21 State(state): State<BancoState>,
22) -> Json<ExperimentsResponse> {
23 Json(ExperimentsResponse { experiments: state.experiments.list() })
24}
25
26pub async fn add_run_to_experiment_handler(
28 State(state): State<BancoState>,
29 axum::extract::Path(id): axum::extract::Path<String>,
30 Json(request): Json<AddRunRequest>,
31) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
32 state.experiments.add_run(&id, &request.run_id).map(|()| StatusCode::OK).map_err(|_| {
33 (
34 StatusCode::NOT_FOUND,
35 Json(ErrorResponse::new(format!("Experiment {id} not found"), "not_found", 404)),
36 )
37 })
38}
39
40pub async fn compare_experiment_handler(
42 State(state): State<BancoState>,
43 axum::extract::Path(id): axum::extract::Path<String>,
44) -> Result<Json<RunComparison>, (StatusCode, Json<ErrorResponse>)> {
45 state.experiments.compare(&id, &state.training).map(Json).map_err(|_| {
46 (
47 StatusCode::NOT_FOUND,
48 Json(ErrorResponse::new(format!("Experiment {id} not found"), "not_found", 404)),
49 )
50 })
51}
52
53#[derive(Debug, Deserialize)]
54pub struct CreateExperimentRequest {
55 pub name: String,
56 #[serde(default)]
57 pub description: Option<String>,
58}
59
60#[derive(Debug, Deserialize)]
61pub struct AddRunRequest {
62 pub run_id: String,
63}
64
65#[derive(Debug, serde::Serialize)]
66pub struct ExperimentsResponse {
67 pub experiments: Vec<Experiment>,
68}