use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use serde::Deserialize;
use super::state::BancoState;
use super::training::{
ExportFormat, ExportRequest, ExportResult, TrainingConfig, TrainingMethod, TrainingPreset,
TrainingRun, TrainingStatus,
};
use super::types::ErrorResponse;
pub async fn start_training_handler(
State(state): State<BancoState>,
Json(request): Json<StartTrainingRequest>,
) -> Json<TrainingRun> {
let (method, config) = if let Some(preset) = &request.preset {
preset.expand()
} else {
(
request.method.clone().unwrap_or(TrainingMethod::Lora),
request.config.clone().unwrap_or_default(),
)
};
let mut run = state.training.start(&request.dataset_id, method.clone(), config.clone());
state.training.set_status(&run.id, TrainingStatus::Running);
state.events.emit(&super::events::BancoEvent::TrainingStarted {
run_id: run.id.clone(),
method: format!("{method:?}").to_lowercase(),
});
#[cfg(feature = "realizar")]
let real_loss = {
let dataset = state.recipes.get_dataset(&request.dataset_id);
let text = dataset
.as_ref()
.map(|d| d.records.iter().map(|r| r.text.as_str()).collect::<Vec<_>>().join(" "))
.unwrap_or_else(|| "The quick brown fox jumps over the lazy dog.".to_string());
let token_ids = state.model.encode_text(&text);
state
.model
.quantized_model()
.and_then(|m| super::training_engine::compute_training_loss(&m, &token_ids, 128))
};
let dataset = state.recipes.get_dataset(&request.dataset_id);
let data_size = dataset.as_ref().map(|d| d.record_count).unwrap_or(100);
let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; data_size.max(1)];
let vocab_size = state.model.info().and_then(|i| i.vocab_size).unwrap_or(32000);
let mut metrics = super::training::run_lora_training(&config, &data, vocab_size);
#[cfg(feature = "realizar")]
if let Some((real_loss_val, tokens_eval)) = real_loss {
if let Some(first) = metrics.first_mut() {
first.loss = real_loss_val;
first.tokens_per_sec = Some(tokens_eval as u64);
}
run.simulated = false; }
for m in &metrics {
state.training.push_metric(&run.id, m.clone());
}
state.training.set_status(&run.id, TrainingStatus::Complete);
state.events.emit(&super::events::BancoEvent::TrainingComplete { run_id: run.id.clone() });
run.status = TrainingStatus::Complete;
run.metrics = metrics;
Json(run)
}
pub async fn list_training_runs_handler(
State(state): State<BancoState>,
) -> Json<TrainingRunsResponse> {
Json(TrainingRunsResponse { runs: state.training.list() })
}
pub async fn get_training_run_handler(
State(state): State<BancoState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<TrainingRun>, (StatusCode, Json<ErrorResponse>)> {
state.training.get(&id).map(Json).ok_or((
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
))
}
pub async fn stop_training_handler(
State(state): State<BancoState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
state.training.stop(&id).map(|()| StatusCode::OK).map_err(|_| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
)
})
}
pub async fn delete_training_run_handler(
State(state): State<BancoState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
state.training.delete(&id).map(|()| StatusCode::NO_CONTENT).map_err(|_| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
)
})
}
pub async fn training_metrics_handler(
State(state): State<BancoState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<
axum::response::sse::Sse<
impl futures_util::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
>,
(StatusCode, Json<ErrorResponse>),
> {
let run = state.training.get(&id).ok_or((
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
))?;
let stream = async_stream::stream! {
for metric in &run.metrics {
let data = serde_json::to_string(metric).unwrap_or_default();
yield Ok(axum::response::sse::Event::default().data(data));
}
yield Ok(axum::response::sse::Event::default().data("[DONE]"));
};
Ok(axum::response::sse::Sse::new(stream))
}
pub async fn export_training_handler(
State(state): State<BancoState>,
axum::extract::Path(id): axum::extract::Path<String>,
Json(request): Json<ExportRequest>,
) -> Result<Json<ExportResult>, (StatusCode, Json<ErrorResponse>)> {
let run = state.training.get(&id).ok_or((
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
))?;
if run.status != TrainingStatus::Complete {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new(
format!("Run {} is {:?}, not complete", id, run.status),
"invalid_status",
400,
)),
));
}
let ext = match &request.format {
ExportFormat::Safetensors => "safetensors",
ExportFormat::Gguf => "gguf",
ExportFormat::Apr => "apr",
};
let filename =
if request.merge { format!("{id}-merged.{ext}") } else { format!("{id}-adapter.{ext}") };
let path = format!("~/.banco/exports/{filename}");
state.training.set_export_path(&id, &path);
Ok(Json(ExportResult {
run_id: id,
format: request.format,
merged: request.merge,
path,
size_bytes: 0, }))
}
pub async fn list_presets_handler() -> Json<PresetsResponse> {
let presets: Vec<PresetInfo> = TrainingPreset::all()
.into_iter()
.map(|p| {
let (method, config) = p.expand();
PresetInfo {
name: format!("{p:?}").to_lowercase().replace("lora", "-lora"),
method: format!("{method:?}").to_lowercase(),
lora_r: config.lora_r,
epochs: config.epochs,
learning_rate: config.learning_rate,
}
})
.collect();
Json(PresetsResponse { presets })
}
#[derive(Debug, Deserialize)]
pub struct StartTrainingRequest {
pub dataset_id: String,
#[serde(default)]
pub method: Option<TrainingMethod>,
#[serde(default)]
pub config: Option<TrainingConfig>,
#[serde(default)]
pub preset: Option<TrainingPreset>,
}
#[derive(Debug, serde::Serialize)]
pub struct TrainingRunsResponse {
pub runs: Vec<TrainingRun>,
}
#[derive(Debug, serde::Serialize)]
pub struct PresetsResponse {
pub presets: Vec<PresetInfo>,
}
#[derive(Debug, serde::Serialize)]
pub struct PresetInfo {
pub name: String,
pub method: String,
pub lora_r: u32,
pub epochs: u32,
pub learning_rate: f64,
}