batuta 0.7.3

Sovereign AI orchestration: autonomous agents, ML serving, code analysis, and transpilation pipelines
Documentation
//! Training run endpoint handlers — start, stop, list, metrics SSE, export.

use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Json;
use serde::Deserialize;

use super::state::BancoState;
use super::training::{
    ExportFormat, ExportRequest, ExportResult, TrainingConfig, TrainingMethod, TrainingPreset,
    TrainingRun, TrainingStatus,
};
use super::types::ErrorResponse;

/// POST /api/v1/train/start — start a training run (with optional preset).
pub async fn start_training_handler(
    State(state): State<BancoState>,
    Json(request): Json<StartTrainingRequest>,
) -> Json<TrainingRun> {
    // Expand preset if provided, otherwise use explicit method + config
    let (method, config) = if let Some(preset) = &request.preset {
        preset.expand()
    } else {
        (
            request.method.clone().unwrap_or(TrainingMethod::Lora),
            request.config.clone().unwrap_or_default(),
        )
    };

    let mut run = state.training.start(&request.dataset_id, method.clone(), config.clone());

    // Run training (real with ml feature, simulated without)
    state.training.set_status(&run.id, TrainingStatus::Running);
    state.events.emit(&super::events::BancoEvent::TrainingStarted {
        run_id: run.id.clone(),
        method: format!("{method:?}").to_lowercase(),
    });

    // Try real loss computation via model forward pass
    #[cfg(feature = "realizar")]
    let real_loss = {
        let dataset = state.recipes.get_dataset(&request.dataset_id);
        let text = dataset
            .as_ref()
            .map(|d| d.records.iter().map(|r| r.text.as_str()).collect::<Vec<_>>().join(" "))
            .unwrap_or_else(|| "The quick brown fox jumps over the lazy dog.".to_string());

        let token_ids = state.model.encode_text(&text);
        state
            .model
            .quantized_model()
            .and_then(|m| super::training_engine::compute_training_loss(&m, &token_ids, 128))
    };

    // Use real dataset from recipe output if available, else placeholder
    let dataset = state.recipes.get_dataset(&request.dataset_id);
    let data_size = dataset.as_ref().map(|d| d.record_count).unwrap_or(100);
    let data: Vec<Vec<f32>> = vec![vec![0.0; 64]; data_size.max(1)];

    let vocab_size = state.model.info().and_then(|i| i.vocab_size).unwrap_or(32000);
    let mut metrics = super::training::run_lora_training(&config, &data, vocab_size);

    // If we got real loss from model forward pass, replace first metric with it
    #[cfg(feature = "realizar")]
    if let Some((real_loss_val, tokens_eval)) = real_loss {
        if let Some(first) = metrics.first_mut() {
            first.loss = real_loss_val;
            first.tokens_per_sec = Some(tokens_eval as u64);
        }
        run.simulated = false; // At least one metric is real
    }

    for m in &metrics {
        state.training.push_metric(&run.id, m.clone());
    }

    state.training.set_status(&run.id, TrainingStatus::Complete);
    state.events.emit(&super::events::BancoEvent::TrainingComplete { run_id: run.id.clone() });
    run.status = TrainingStatus::Complete;
    run.metrics = metrics;

    Json(run)
}

/// GET /api/v1/train/runs — list training runs.
pub async fn list_training_runs_handler(
    State(state): State<BancoState>,
) -> Json<TrainingRunsResponse> {
    Json(TrainingRunsResponse { runs: state.training.list() })
}

/// GET /api/v1/train/runs/:id — get run status.
pub async fn get_training_run_handler(
    State(state): State<BancoState>,
    axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<TrainingRun>, (StatusCode, Json<ErrorResponse>)> {
    state.training.get(&id).map(Json).ok_or((
        StatusCode::NOT_FOUND,
        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
    ))
}

/// POST /api/v1/train/runs/:id/stop — stop a running training.
pub async fn stop_training_handler(
    State(state): State<BancoState>,
    axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
    state.training.stop(&id).map(|()| StatusCode::OK).map_err(|_| {
        (
            StatusCode::NOT_FOUND,
            Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
        )
    })
}

/// DELETE /api/v1/train/runs/:id — delete a run.
pub async fn delete_training_run_handler(
    State(state): State<BancoState>,
    axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
    state.training.delete(&id).map(|()| StatusCode::NO_CONTENT).map_err(|_| {
        (
            StatusCode::NOT_FOUND,
            Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
        )
    })
}

/// GET /api/v1/train/runs/:id/metrics — stream metrics via SSE.
pub async fn training_metrics_handler(
    State(state): State<BancoState>,
    axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<
    axum::response::sse::Sse<
        impl futures_util::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
    >,
    (StatusCode, Json<ErrorResponse>),
> {
    let run = state.training.get(&id).ok_or((
        StatusCode::NOT_FOUND,
        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
    ))?;

    let stream = async_stream::stream! {
        for metric in &run.metrics {
            let data = serde_json::to_string(metric).unwrap_or_default();
            yield Ok(axum::response::sse::Event::default().data(data));
        }
        yield Ok(axum::response::sse::Event::default().data("[DONE]"));
    };

    Ok(axum::response::sse::Sse::new(stream))
}

/// POST /api/v1/train/runs/:id/export — export adapter or merged model.
pub async fn export_training_handler(
    State(state): State<BancoState>,
    axum::extract::Path(id): axum::extract::Path<String>,
    Json(request): Json<ExportRequest>,
) -> Result<Json<ExportResult>, (StatusCode, Json<ErrorResponse>)> {
    let run = state.training.get(&id).ok_or((
        StatusCode::NOT_FOUND,
        Json(ErrorResponse::new(format!("Run {id} not found"), "not_found", 404)),
    ))?;

    if run.status != TrainingStatus::Complete {
        return Err((
            StatusCode::BAD_REQUEST,
            Json(ErrorResponse::new(
                format!("Run {} is {:?}, not complete", id, run.status),
                "invalid_status",
                400,
            )),
        ));
    }

    let ext = match &request.format {
        ExportFormat::Safetensors => "safetensors",
        ExportFormat::Gguf => "gguf",
        ExportFormat::Apr => "apr",
    };
    let filename =
        if request.merge { format!("{id}-merged.{ext}") } else { format!("{id}-adapter.{ext}") };
    let path = format!("~/.banco/exports/{filename}");

    state.training.set_export_path(&id, &path);

    Ok(Json(ExportResult {
        run_id: id,
        format: request.format,
        merged: request.merge,
        path,
        size_bytes: 0, // populated when real export happens
    }))
}

/// GET /api/v1/train/presets — list available training presets.
pub async fn list_presets_handler() -> Json<PresetsResponse> {
    let presets: Vec<PresetInfo> = TrainingPreset::all()
        .into_iter()
        .map(|p| {
            let (method, config) = p.expand();
            PresetInfo {
                name: format!("{p:?}").to_lowercase().replace("lora", "-lora"),
                method: format!("{method:?}").to_lowercase(),
                lora_r: config.lora_r,
                epochs: config.epochs,
                learning_rate: config.learning_rate,
            }
        })
        .collect();
    Json(PresetsResponse { presets })
}

#[derive(Debug, Deserialize)]
pub struct StartTrainingRequest {
    pub dataset_id: String,
    #[serde(default)]
    pub method: Option<TrainingMethod>,
    #[serde(default)]
    pub config: Option<TrainingConfig>,
    #[serde(default)]
    pub preset: Option<TrainingPreset>,
}

#[derive(Debug, serde::Serialize)]
pub struct TrainingRunsResponse {
    pub runs: Vec<TrainingRun>,
}

#[derive(Debug, serde::Serialize)]
pub struct PresetsResponse {
    pub presets: Vec<PresetInfo>,
}

#[derive(Debug, serde::Serialize)]
pub struct PresetInfo {
    pub name: String,
    pub method: String,
    pub lora_r: u32,
    pub epochs: u32,
    pub learning_rate: f64,
}