batuta 0.7.3

Sovereign AI orchestration: autonomous agents, ML serving, code analysis, and transpilation pipelines
Documentation
//! Training engine — presets, cosine schedule, entrenar LoRA wiring, and real loss computation.
//!
//! When a model is loaded, `compute_training_loss()` evaluates actual cross-entropy
//! loss via the model's forward pass. The first training metric uses this real loss.
//! Remaining steps use simulated cosine decay (no weight updates yet — #59).

use super::training::{
    OptimizerType, SchedulerType, TrainingConfig, TrainingMethod, TrainingMetric,
};
use serde::{Deserialize, Serialize};

// ============================================================================
// Training presets
// ============================================================================

/// Named training preset — expands to a full TrainingConfig.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum TrainingPreset {
    QuickLora,
    StandardLora,
    DeepLora,
    QloraLowVram,
    FullFinetune,
}

impl TrainingPreset {
    /// Expand preset into (method, config).
    #[must_use]
    pub fn expand(&self) -> (TrainingMethod, TrainingConfig) {
        match self {
            Self::QuickLora => (
                TrainingMethod::Lora,
                TrainingConfig {
                    lora_r: 8,
                    lora_alpha: 16,
                    learning_rate: 2e-4,
                    epochs: 1,
                    batch_size: 4,
                    max_seq_length: 2048,
                    target_modules: vec!["q_proj".into(), "v_proj".into()],
                    optimizer: OptimizerType::AdamW,
                    scheduler: SchedulerType::Cosine,
                    warmup_steps: 50,
                    gradient_accumulation_steps: 1,
                    max_grad_norm: 1.0,
                },
            ),
            Self::StandardLora => (
                TrainingMethod::Lora,
                TrainingConfig {
                    lora_r: 16,
                    lora_alpha: 32,
                    learning_rate: 2e-4,
                    epochs: 3,
                    batch_size: 4,
                    max_seq_length: 2048,
                    target_modules: vec![
                        "q_proj".into(),
                        "k_proj".into(),
                        "v_proj".into(),
                        "o_proj".into(),
                    ],
                    optimizer: OptimizerType::AdamW,
                    scheduler: SchedulerType::Cosine,
                    warmup_steps: 100,
                    gradient_accumulation_steps: 4,
                    max_grad_norm: 1.0,
                },
            ),
            Self::DeepLora => (
                TrainingMethod::Lora,
                TrainingConfig {
                    lora_r: 32,
                    lora_alpha: 64,
                    learning_rate: 1e-4,
                    epochs: 5,
                    batch_size: 4,
                    max_seq_length: 2048,
                    target_modules: vec!["all_linear".into()],
                    optimizer: OptimizerType::AdamW,
                    scheduler: SchedulerType::Cosine,
                    warmup_steps: 200,
                    gradient_accumulation_steps: 8,
                    max_grad_norm: 1.0,
                },
            ),
            Self::QloraLowVram => (
                TrainingMethod::Qlora,
                TrainingConfig {
                    lora_r: 16,
                    lora_alpha: 32,
                    learning_rate: 2e-4,
                    epochs: 3,
                    batch_size: 2,
                    max_seq_length: 2048,
                    target_modules: vec![
                        "q_proj".into(),
                        "k_proj".into(),
                        "v_proj".into(),
                        "o_proj".into(),
                    ],
                    optimizer: OptimizerType::AdamW,
                    scheduler: SchedulerType::Cosine,
                    warmup_steps: 100,
                    gradient_accumulation_steps: 8,
                    max_grad_norm: 1.0,
                },
            ),
            Self::FullFinetune => (
                TrainingMethod::FullFinetune,
                TrainingConfig {
                    lora_r: 0,
                    lora_alpha: 0,
                    learning_rate: 5e-5,
                    epochs: 3,
                    batch_size: 4,
                    max_seq_length: 2048,
                    target_modules: Vec::new(),
                    optimizer: OptimizerType::AdamW,
                    scheduler: SchedulerType::Cosine,
                    warmup_steps: 100,
                    gradient_accumulation_steps: 4,
                    max_grad_norm: 1.0,
                },
            ),
        }
    }

    /// List all available presets.
    #[must_use]
    pub fn all() -> Vec<Self> {
        vec![
            Self::QuickLora,
            Self::StandardLora,
            Self::DeepLora,
            Self::QloraLowVram,
            Self::FullFinetune,
        ]
    }
}

// ============================================================================
// entrenar integration (behind ml feature)
// ============================================================================

/// Run a LoRA training loop using entrenar. Returns metrics per step.
///
/// With `ml` feature: creates LoRA config and optimizer via entrenar,
/// validates config, then produces step-by-step metrics with cosine schedule.
///
/// Without `ml` feature: produces simulated metrics for API testing.
#[cfg(feature = "entrenar")]
pub fn run_lora_training(
    config: &TrainingConfig,
    data: &[Vec<f32>],
    vocab_size: usize,
) -> Vec<TrainingMetric> {
    use entrenar::lora::LoRAConfig;
    use entrenar::optim::Adam;

    let lora_config = LoRAConfig::new(config.lora_r as usize, config.lora_alpha as f32);
    let _optimizer = Adam::default_params(config.learning_rate as f32);

    // Validate config via entrenar types
    let _target_count = lora_config.num_target_modules();

    let total_steps =
        (data.len().max(1) / config.batch_size.max(1) as usize).max(1) * config.epochs as usize;

    let mut metrics = Vec::with_capacity(total_steps);
    let mut loss = 2.5_f32;
    let decay = 0.97_f32;

    for step in 0..total_steps {
        loss *= decay;
        let lr_scale = cosine_schedule(step, total_steps, config.warmup_steps as usize);
        metrics.push(TrainingMetric {
            step: step as u64,
            loss,
            learning_rate: config.learning_rate * lr_scale as f64,
            grad_norm: Some(1.0 / (1.0 + step as f32 * 0.01)),
            tokens_per_sec: Some(((vocab_size as u64) * config.batch_size as u64) / 10),
            eta_secs: Some(((total_steps - step) as u64) * 2),
        });
    }
    metrics
}

/// Simulated training (no ml feature) — produces realistic metric progression.
#[cfg(not(feature = "entrenar"))]
pub fn run_lora_training(
    config: &TrainingConfig,
    data: &[Vec<f32>],
    _vocab_size: usize,
) -> Vec<TrainingMetric> {
    let total_steps =
        (data.len().max(1) / config.batch_size.max(1) as usize).max(1) * config.epochs as usize;

    let mut metrics = Vec::with_capacity(total_steps);
    let mut loss = 2.5_f32;
    let decay = 0.97_f32;

    for step in 0..total_steps {
        loss *= decay;
        let lr_scale = cosine_schedule(step, total_steps, config.warmup_steps as usize);
        metrics.push(TrainingMetric {
            step: step as u64,
            loss,
            learning_rate: config.learning_rate * lr_scale as f64,
            grad_norm: Some(1.0 / (1.0 + step as f32 * 0.01)),
            tokens_per_sec: None,
            eta_secs: Some(((total_steps - step) as u64) * 2),
        });
    }
    metrics
}

/// Compute real loss on training data via model forward pass.
///
/// Uses the loaded quantized model to evaluate cross-entropy loss on token sequences.
/// This is NOT training (no weight updates) — it's evaluation of training data quality.
/// Returns (loss, tokens_evaluated) or None if no model loaded.
#[cfg(feature = "realizar")]
pub fn compute_training_loss(
    model: &std::sync::Arc<realizar::gguf::OwnedQuantizedModel>,
    token_ids: &[u32],
    max_tokens: usize,
) -> Option<(f32, usize)> {
    // Reuse the perplexity computation — it IS cross-entropy loss
    super::eval::compute_perplexity(model, token_ids, max_tokens)
        .map(|(ppl, count)| (ppl.ln() as f32, count)) // PPL = exp(loss), so loss = ln(PPL)
}

/// Cosine learning rate schedule with warmup.
fn cosine_schedule(step: usize, total: usize, warmup: usize) -> f32 {
    if step < warmup {
        return step as f32 / warmup.max(1) as f32;
    }
    let progress = (step - warmup) as f32 / (total - warmup).max(1) as f32;
    0.5 * (1.0 + (std::f32::consts::PI * progress).cos())
}