Skip to main content

batuta/serve/banco/
handlers_train.rs

1//! Training run endpoint handlers — start, stop, list, metrics SSE, export.
2
3use axum::extract::State;
4use axum::http::StatusCode;
5use axum::response::Json;
6use serde::Deserialize;
7
8use super::state::BancoState;
9use super::training::{
10    ExportFormat, ExportRequest, ExportResult, TrainingConfig, TrainingMethod, TrainingPreset,
11    TrainingRun, TrainingStatus,
12};
13use super::types::ErrorResponse;
14
15/// POST /api/v1/train/start — start a training run (with optional preset).
16pub async fn start_training_handler(
17    State(state): State<BancoState>,
18    Json(request): Json<StartTrainingRequest>,
19) -> Json<TrainingRun> {
20    // Expand preset if provided, otherwise use explicit method + config
21    let (method, config) = if let Some(preset) = &request.preset {
22        preset.expand()
23    } else {
24        (
25            request.method.clone().unwrap_or(TrainingMethod::Lora),
26            request.config.clone().unwrap_or_default(),
27        )
28    };
29
30    let mut run = state.training.start(&request.dataset_id, method.clone(), config.clone());
31
32    // Run training (real with ml feature, simulated without)
33    state.training.set_status(&run.id, TrainingStatus::Running);
34    state.events.emit(&super::events::BancoEvent::TrainingStarted {
35        run_id: run.id.clone(),
36        method: format!("{method:?}").to_lowercase(),
37    });
38
39    // Try real loss computation via model forward pass
40    #[cfg(feature = "realizar")]
41    let real_loss = {
42        let dataset = state.recipes.get_dataset(&request.dataset_id);
43        let text = dataset
44            .as_ref()
45            .map(|d| d.records.iter().map(|r| r.text.as_str()).collect::<Vec<_>>().join(" "))
46            .unwrap_or_else(|| "The quick brown fox jumps over the lazy dog.".to_string());
47
48        let token_ids = state.model.encode_text(&text);
49        state
50            .model
51            .quantized_model()
52            .and_then(|m| super::training_engine::compute_training_loss(&m, &token_ids, 128))
53    };
54
55    // Use real dataset from recipe output if available, else placeholder
56    let dataset = state.recipes.get_dataset(&request.dataset_id);
57    let data_size = dataset.as_ref().map(|d| d.record_count).unwrap_or(100);
58    let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; data_size.max(1)];
59
60    let vocab_size = state.model.info().and_then(|i| i.vocab_size).unwrap_or(32000);
61    let mut metrics = super::training::run_lora_training(&config, &data, vocab_size);
62
63    // If we got real loss from model forward pass, replace first metric with it
64    #[cfg(feature = "realizar")]
65    if let Some((real_loss_val, tokens_eval)) = real_loss {
66        if let Some(first) = metrics.first_mut() {
67            first.loss = real_loss_val;
68            first.tokens_per_sec = Some(tokens_eval as u64);
69        }
70        run.simulated = false; // At least one metric is real
71    }
72
73    for m in &metrics {
74        state.training.push_metric(&run.id, m.clone());
75    }
76
77    state.training.set_status(&run.id, TrainingStatus::Complete);
78    state.events.emit(&super::events::BancoEvent::TrainingComplete { run_id: run.id.clone() });
79    run.status = TrainingStatus::Complete;
80    run.metrics = metrics;
81
82    Json(run)
83}
84
85/// GET /api/v1/train/runs — list training runs.
86pub async fn list_training_runs_handler(
87    State(state): State<BancoState>,
88) -> Json<TrainingRunsResponse> {
89    Json(TrainingRunsResponse { runs: state.training.list() })
90}
91
92/// GET /api/v1/train/runs/:id — get run status.
93pub async fn get_training_run_handler(
94    State(state): State<BancoState>,
95    axum::extract::Path(id): axum::extract::Path<String>,
96) -> Result<Json<TrainingRun>, (StatusCode, Json<ErrorResponse>)> {
97    state.training.get(&id).map(Json).ok_or((
98        StatusCode::NOT_FOUND,
99        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
100    ))
101}
102
103/// POST /api/v1/train/runs/:id/stop — stop a running training.
104pub async fn stop_training_handler(
105    State(state): State<BancoState>,
106    axum::extract::Path(id): axum::extract::Path<String>,
107) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
108    state.training.stop(&id).map(|()| StatusCode::OK).map_err(|_| {
109        (
110            StatusCode::NOT_FOUND,
111            Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
112        )
113    })
114}
115
116/// DELETE /api/v1/train/runs/:id — delete a run.
117pub async fn delete_training_run_handler(
118    State(state): State<BancoState>,
119    axum::extract::Path(id): axum::extract::Path<String>,
120) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
121    state.training.delete(&id).map(|()| StatusCode::NO_CONTENT).map_err(|_| {
122        (
123            StatusCode::NOT_FOUND,
124            Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
125        )
126    })
127}
128
129/// GET /api/v1/train/runs/:id/metrics — stream metrics via SSE.
130pub async fn training_metrics_handler(
131    State(state): State<BancoState>,
132    axum::extract::Path(id): axum::extract::Path<String>,
133) -> Result<
134    axum::response::sse::Sse<
135        impl futures_util::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
136    >,
137    (StatusCode, Json<ErrorResponse>),
138> {
139    let run = state.training.get(&id).ok_or((
140        StatusCode::NOT_FOUND,
141        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
142    ))?;
143
144    let stream = async_stream::stream! {
145        for metric in &run.metrics {
146            let data = serde_json::to_string(metric).unwrap_or_default();
147            yield Ok(axum::response::sse::Event::default().data(data));
148        }
149        yield Ok(axum::response::sse::Event::default().data("[DONE]"));
150    };
151
152    Ok(axum::response::sse::Sse::new(stream))
153}
154
155/// POST /api/v1/train/runs/:id/export — export adapter or merged model.
156pub async fn export_training_handler(
157    State(state): State<BancoState>,
158    axum::extract::Path(id): axum::extract::Path<String>,
159    Json(request): Json<ExportRequest>,
160) -> Result<Json<ExportResult>, (StatusCode, Json<ErrorResponse>)> {
161    let run = state.training.get(&id).ok_or((
162        StatusCode::NOT_FOUND,
163        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
164    ))?;
165
166    if run.status != TrainingStatus::Complete {
167        return Err((
168            StatusCode::BAD_REQUEST,
169            Json(ErrorResponse::new(
170                format!("Run {} is {:?}, not complete", id, run.status),
171                "invalid_status",
172                400,
173            )),
174        ));
175    }
176
177    let ext = match &request.format {
178        ExportFormat::Safetensors => "safetensors",
179        ExportFormat::Gguf => "gguf",
180        ExportFormat::Apr => "apr",
181    };
182    let filename =
183        if request.merge { format!("{id}-merged.{ext}") } else { format!("{id}-adapter.{ext}") };
184    let path = format!("~/.banco/exports/{filename}");
185
186    state.training.set_export_path(&id, &path);
187
188    Ok(Json(ExportResult {
189        run_id: id,
190        format: request.format,
191        merged: request.merge,
192        path,
193        size_bytes: 0, // populated when real export happens
194    }))
195}
196
197/// GET /api/v1/train/presets — list available training presets.
198pub async fn list_presets_handler() -> Json<PresetsResponse> {
199    let presets: Vec<PresetInfo> = TrainingPreset::all()
200        .into_iter()
201        .map(|p| {
202            let (method, config) = p.expand();
203            PresetInfo {
204                name: format!("{p:?}").to_lowercase().replace("lora", "-lora"),
205                method: format!("{method:?}").to_lowercase(),
206                lora_r: config.lora_r,
207                epochs: config.epochs,
208                learning_rate: config.learning_rate,
209            }
210        })
211        .collect();
212    Json(PresetsResponse { presets })
213}
214
215#[derive(Debug, Deserialize)]
216pub struct StartTrainingRequest {
217    pub dataset_id: String,
218    #[serde(default)]
219    pub method: Option<TrainingMethod>,
220    #[serde(default)]
221    pub config: Option<TrainingConfig>,
222    #[serde(default)]
223    pub preset: Option<TrainingPreset>,
224}
225
226#[derive(Debug, serde::Serialize)]
227pub struct TrainingRunsResponse {
228    pub runs: Vec<TrainingRun>,
229}
230
231#[derive(Debug, serde::Serialize)]
232pub struct PresetsResponse {
233    pub presets: Vec<PresetInfo>,
234}
235
236#[derive(Debug, serde::Serialize)]
237pub struct PresetInfo {
238    pub name: String,
239    pub method: String,
240    pub lora_r: u32,
241    pub epochs: u32,
242    pub learning_rate: f64,
243}