Skip to main content

batuta/serve/banco/
handlers_experiment.rs

1//! Experiment endpoint handlers.
2
3use axum::{extract::State, http::StatusCode, response::Json};
4use serde::Deserialize;
5
6use super::experiment::{Experiment, RunComparison};
7use super::state::BancoState;
8use super::types::ErrorResponse;
9
10/// POST /api/v1/experiments — create experiment.
11pub async fn create_experiment_handler(
12    State(state): State<BancoState>,
13    Json(request): Json<CreateExperimentRequest>,
14) -> Json<Experiment> {
15    let exp = state.experiments.create(&request.name, &request.description.unwrap_or_default());
16    Json(exp)
17}
18
19/// GET /api/v1/experiments — list experiments.
20pub async fn list_experiments_handler(
21    State(state): State<BancoState>,
22) -> Json<ExperimentsResponse> {
23    Json(ExperimentsResponse { experiments: state.experiments.list() })
24}
25
26/// POST /api/v1/experiments/:id/runs — add a run to an experiment.
27pub async fn add_run_to_experiment_handler(
28    State(state): State<BancoState>,
29    axum::extract::Path(id): axum::extract::Path<String>,
30    Json(request): Json<AddRunRequest>,
31) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
32    state.experiments.add_run(&id, &request.run_id).map(|()| StatusCode::OK).map_err(|_| {
33        (
34            StatusCode::NOT_FOUND,
35            Json(ErrorResponse::new(format!("Experiment {id} not found"), "not_found", 404)),
36        )
37    })
38}
39
40/// GET /api/v1/experiments/:id/compare — compare runs in experiment.
41pub async fn compare_experiment_handler(
42    State(state): State<BancoState>,
43    axum::extract::Path(id): axum::extract::Path<String>,
44) -> Result<Json<RunComparison>, (StatusCode, Json<ErrorResponse>)> {
45    state.experiments.compare(&id, &state.training).map(Json).map_err(|_| {
46        (
47            StatusCode::NOT_FOUND,
48            Json(ErrorResponse::new(format!("Experiment {id} not found"), "not_found", 404)),
49        )
50    })
51}
52
53#[derive(Debug, Deserialize)]
54pub struct CreateExperimentRequest {
55    pub name: String,
56    #[serde(default)]
57    pub description: Option<String>,
58}
59
60#[derive(Debug, Deserialize)]
61pub struct AddRunRequest {
62    pub run_id: String,
63}
64
65#[derive(Debug, serde::Serialize)]
66pub struct ExperimentsResponse {
67    pub experiments: Vec<Experiment>,
68}