1use axum::extract::State;
4use axum::http::StatusCode;
5use axum::response::Json;
6use serde::Deserialize;
7
8use super::state::BancoState;
9use super::training::{
10 ExportFormat, ExportRequest, ExportResult, TrainingConfig, TrainingMethod, TrainingPreset,
11 TrainingRun, TrainingStatus,
12};
13use super::types::ErrorResponse;
14
15pub async fn start_training_handler(
17 State(state): State<BancoState>,
18 Json(request): Json<StartTrainingRequest>,
19) -> Json<TrainingRun> {
20 let (method, config) = if let Some(preset) = &request.preset {
22 preset.expand()
23 } else {
24 (
25 request.method.clone().unwrap_or(TrainingMethod::Lora),
26 request.config.clone().unwrap_or_default(),
27 )
28 };
29
30 let mut run = state.training.start(&request.dataset_id, method.clone(), config.clone());
31
32 state.training.set_status(&run.id, TrainingStatus::Running);
34 state.events.emit(&super::events::BancoEvent::TrainingStarted {
35 run_id: run.id.clone(),
36 method: format!("{method:?}").to_lowercase(),
37 });
38
39 #[cfg(feature = "realizar")]
41 let real_loss = {
42 let dataset = state.recipes.get_dataset(&request.dataset_id);
43 let text = dataset
44 .as_ref()
45 .map(|d| d.records.iter().map(|r| r.text.as_str()).collect::<Vec<_>>().join(" "))
46 .unwrap_or_else(|| "The quick brown fox jumps over the lazy dog.".to_string());
47
48 let token_ids = state.model.encode_text(&text);
49 state
50 .model
51 .quantized_model()
52 .and_then(|m| super::training_engine::compute_training_loss(&m, &token_ids, 128))
53 };
54
55 let dataset = state.recipes.get_dataset(&request.dataset_id);
57 let data_size = dataset.as_ref().map(|d| d.record_count).unwrap_or(100);
58 let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; data_size.max(1)];
59
60 let vocab_size = state.model.info().and_then(|i| i.vocab_size).unwrap_or(32000);
61 let mut metrics = super::training::run_lora_training(&config, &data, vocab_size);
62
63 #[cfg(feature = "realizar")]
65 if let Some((real_loss_val, tokens_eval)) = real_loss {
66 if let Some(first) = metrics.first_mut() {
67 first.loss = real_loss_val;
68 first.tokens_per_sec = Some(tokens_eval as u64);
69 }
70 run.simulated = false; }
72
73 for m in &metrics {
74 state.training.push_metric(&run.id, m.clone());
75 }
76
77 state.training.set_status(&run.id, TrainingStatus::Complete);
78 state.events.emit(&super::events::BancoEvent::TrainingComplete { run_id: run.id.clone() });
79 run.status = TrainingStatus::Complete;
80 run.metrics = metrics;
81
82 Json(run)
83}
84
85pub async fn list_training_runs_handler(
87 State(state): State<BancoState>,
88) -> Json<TrainingRunsResponse> {
89 Json(TrainingRunsResponse { runs: state.training.list() })
90}
91
92pub async fn get_training_run_handler(
94 State(state): State<BancoState>,
95 axum::extract::Path(id): axum::extract::Path<String>,
96) -> Result<Json<TrainingRun>, (StatusCode, Json<ErrorResponse>)> {
97 state.training.get(&id).map(Json).ok_or((
98 StatusCode::NOT_FOUND,
99 Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
100 ))
101}
102
103pub async fn stop_training_handler(
105 State(state): State<BancoState>,
106 axum::extract::Path(id): axum::extract::Path<String>,
107) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
108 state.training.stop(&id).map(|()| StatusCode::OK).map_err(|_| {
109 (
110 StatusCode::NOT_FOUND,
111 Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
112 )
113 })
114}
115
116pub async fn delete_training_run_handler(
118 State(state): State<BancoState>,
119 axum::extract::Path(id): axum::extract::Path<String>,
120) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
121 state.training.delete(&id).map(|()| StatusCode::NO_CONTENT).map_err(|_| {
122 (
123 StatusCode::NOT_FOUND,
124 Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
125 )
126 })
127}
128
129pub async fn training_metrics_handler(
131 State(state): State<BancoState>,
132 axum::extract::Path(id): axum::extract::Path<String>,
133) -> Result<
134 axum::response::sse::Sse<
135 impl futures_util::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
136 >,
137 (StatusCode, Json<ErrorResponse>),
138> {
139 let run = state.training.get(&id).ok_or((
140 StatusCode::NOT_FOUND,
141 Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
142 ))?;
143
144 let stream = async_stream::stream! {
145 for metric in &run.metrics {
146 let data = serde_json::to_string(metric).unwrap_or_default();
147 yield Ok(axum::response::sse::Event::default().data(data));
148 }
149 yield Ok(axum::response::sse::Event::default().data("[DONE]"));
150 };
151
152 Ok(axum::response::sse::Sse::new(stream))
153}
154
155pub async fn export_training_handler(
157 State(state): State<BancoState>,
158 axum::extract::Path(id): axum::extract::Path<String>,
159 Json(request): Json<ExportRequest>,
160) -> Result<Json<ExportResult>, (StatusCode, Json<ErrorResponse>)> {
161 let run = state.training.get(&id).ok_or((
162 StatusCode::NOT_FOUND,
163 Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
164 ))?;
165
166 if run.status != TrainingStatus::Complete {
167 return Err((
168 StatusCode::BAD_REQUEST,
169 Json(ErrorResponse::new(
170 format!("Run {} is {:?}, not complete", id, run.status),
171 "invalid_status",
172 400,
173 )),
174 ));
175 }
176
177 let ext = match &request.format {
178 ExportFormat::Safetensors => "safetensors",
179 ExportFormat::Gguf => "gguf",
180 ExportFormat::Apr => "apr",
181 };
182 let filename =
183 if request.merge { format!("{id}-merged.{ext}") } else { format!("{id}-adapter.{ext}") };
184 let path = format!("~/.banco/exports/{filename}");
185
186 state.training.set_export_path(&id, &path);
187
188 Ok(Json(ExportResult {
189 run_id: id,
190 format: request.format,
191 merged: request.merge,
192 path,
193 size_bytes: 0, }))
195}
196
197pub async fn list_presets_handler() -> Json<PresetsResponse> {
199 let presets: Vec<PresetInfo> = TrainingPreset::all()
200 .into_iter()
201 .map(|p| {
202 let (method, config) = p.expand();
203 PresetInfo {
204 name: format!("{p:?}").to_lowercase().replace("lora", "-lora"),
205 method: format!("{method:?}").to_lowercase(),
206 lora_r: config.lora_r,
207 epochs: config.epochs,
208 learning_rate: config.learning_rate,
209 }
210 })
211 .collect();
212 Json(PresetsResponse { presets })
213}
214
215#[derive(Debug, Deserialize)]
216pub struct StartTrainingRequest {
217 pub dataset_id: String,
218 #[serde(default)]
219 pub method: Option<TrainingMethod>,
220 #[serde(default)]
221 pub config: Option<TrainingConfig>,
222 #[serde(default)]
223 pub preset: Option<TrainingPreset>,
224}
225
226#[derive(Debug, serde::Serialize)]
227pub struct TrainingRunsResponse {
228 pub runs: Vec<TrainingRun>,
229}
230
231#[derive(Debug, serde::Serialize)]
232pub struct PresetsResponse {
233 pub presets: Vec<PresetInfo>,
234}
235
236#[derive(Debug, serde::Serialize)]
237pub struct PresetInfo {
238 pub name: String,
239 pub method: String,
240 pub lora_r: u32,
241 pub epochs: u32,
242 pub learning_rate: f64,
243}