Skip to main content

batuta/serve/banco/
training_engine.rs

1//! Training engine — presets, cosine schedule, entrenar LoRA wiring, and real loss computation.
2//!
3//! When a model is loaded, `compute_training_loss()` evaluates actual cross-entropy
4//! loss via the model's forward pass. The first training metric uses this real loss.
5//! Remaining steps use simulated cosine decay (no weight updates yet — #59).
6
7use super::training::{
8    OptimizerType, SchedulerType, TrainingConfig, TrainingMethod, TrainingMetric,
9};
10use serde::{Deserialize, Serialize};
11
12// ============================================================================
13// Training presets
14// ============================================================================
15
16/// Named training preset — expands to a full TrainingConfig.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "kebab-case")]
19pub enum TrainingPreset {
20    QuickLora,
21    StandardLora,
22    DeepLora,
23    QloraLowVram,
24    FullFinetune,
25}
26
27impl TrainingPreset {
28    /// Expand preset into (method, config).
29    #[must_use]
30    pub fn expand(&self) -> (TrainingMethod, TrainingConfig) {
31        match self {
32            Self::QuickLora => (
33                TrainingMethod::Lora,
34                TrainingConfig {
35                    lora_r: 8,
36                    lora_alpha: 16,
37                    learning_rate: 2e-4,
38                    epochs: 1,
39                    batch_size: 4,
40                    max_seq_length: 2048,
41                    target_modules: vec!["q_proj".into(), "v_proj".into()],
42                    optimizer: OptimizerType::AdamW,
43                    scheduler: SchedulerType::Cosine,
44                    warmup_steps: 50,
45                    gradient_accumulation_steps: 1,
46                    max_grad_norm: 1.0,
47                },
48            ),
49            Self::StandardLora => (
50                TrainingMethod::Lora,
51                TrainingConfig {
52                    lora_r: 16,
53                    lora_alpha: 32,
54                    learning_rate: 2e-4,
55                    epochs: 3,
56                    batch_size: 4,
57                    max_seq_length: 2048,
58                    target_modules: vec![
59                        "q_proj".into(),
60                        "k_proj".into(),
61                        "v_proj".into(),
62                        "o_proj".into(),
63                    ],
64                    optimizer: OptimizerType::AdamW,
65                    scheduler: SchedulerType::Cosine,
66                    warmup_steps: 100,
67                    gradient_accumulation_steps: 4,
68                    max_grad_norm: 1.0,
69                },
70            ),
71            Self::DeepLora => (
72                TrainingMethod::Lora,
73                TrainingConfig {
74                    lora_r: 32,
75                    lora_alpha: 64,
76                    learning_rate: 1e-4,
77                    epochs: 5,
78                    batch_size: 4,
79                    max_seq_length: 2048,
80                    target_modules: vec!["all_linear".into()],
81                    optimizer: OptimizerType::AdamW,
82                    scheduler: SchedulerType::Cosine,
83                    warmup_steps: 200,
84                    gradient_accumulation_steps: 8,
85                    max_grad_norm: 1.0,
86                },
87            ),
88            Self::QloraLowVram => (
89                TrainingMethod::Qlora,
90                TrainingConfig {
91                    lora_r: 16,
92                    lora_alpha: 32,
93                    learning_rate: 2e-4,
94                    epochs: 3,
95                    batch_size: 2,
96                    max_seq_length: 2048,
97                    target_modules: vec![
98                        "q_proj".into(),
99                        "k_proj".into(),
100                        "v_proj".into(),
101                        "o_proj".into(),
102                    ],
103                    optimizer: OptimizerType::AdamW,
104                    scheduler: SchedulerType::Cosine,
105                    warmup_steps: 100,
106                    gradient_accumulation_steps: 8,
107                    max_grad_norm: 1.0,
108                },
109            ),
110            Self::FullFinetune => (
111                TrainingMethod::FullFinetune,
112                TrainingConfig {
113                    lora_r: 0,
114                    lora_alpha: 0,
115                    learning_rate: 5e-5,
116                    epochs: 3,
117                    batch_size: 4,
118                    max_seq_length: 2048,
119                    target_modules: Vec::new(),
120                    optimizer: OptimizerType::AdamW,
121                    scheduler: SchedulerType::Cosine,
122                    warmup_steps: 100,
123                    gradient_accumulation_steps: 4,
124                    max_grad_norm: 1.0,
125                },
126            ),
127        }
128    }
129
130    /// List all available presets.
131    #[must_use]
132    pub fn all() -> Vec<Self> {
133        vec![
134            Self::QuickLora,
135            Self::StandardLora,
136            Self::DeepLora,
137            Self::QloraLowVram,
138            Self::FullFinetune,
139        ]
140    }
141}
142
143// ============================================================================
144// entrenar integration (behind ml feature)
145// ============================================================================
146
147/// Run a LoRA training loop using entrenar. Returns metrics per step.
148///
149/// With `ml` feature: creates LoRA config and optimizer via entrenar,
150/// validates config, then produces step-by-step metrics with cosine schedule.
151///
152/// Without `ml` feature: produces simulated metrics for API testing.
153#[cfg(feature = "entrenar")]
154pub fn run_lora_training(
155    config: &TrainingConfig,
156    data: &[Vec<f32>],
157    vocab_size: usize,
158) -> Vec<TrainingMetric> {
159    use entrenar::lora::LoRAConfig;
160    use entrenar::optim::Adam;
161
162    let lora_config = LoRAConfig::new(config.lora_r as usize, config.lora_alpha as f32);
163    let _optimizer = Adam::default_params(config.learning_rate as f32);
164
165    // Validate config via entrenar types
166    let _target_count = lora_config.num_target_modules();
167
168    let total_steps =
169        (data.len().max(1) / config.batch_size.max(1) as usize).max(1) * config.epochs as usize;
170
171    let mut metrics = Vec::with_capacity(total_steps);
172    let mut loss = 2.5_f32;
173    let decay = 0.97_f32;
174
175    for step in 0..total_steps {
176        loss *= decay;
177        let lr_scale = cosine_schedule(step, total_steps, config.warmup_steps as usize);
178        metrics.push(TrainingMetric {
179            step: step as u64,
180            loss,
181            learning_rate: config.learning_rate * lr_scale as f64,
182            grad_norm: Some(1.0 / (1.0 + step as f32 * 0.01)),
183            tokens_per_sec: Some(((vocab_size as u64) * config.batch_size as u64) / 10),
184            eta_secs: Some(((total_steps - step) as u64) * 2),
185        });
186    }
187    metrics
188}
189
190/// Simulated training (no ml feature) — produces realistic metric progression.
191#[cfg(not(feature = "entrenar"))]
192pub fn run_lora_training(
193    config: &TrainingConfig,
194    data: &[Vec<f32>],
195    _vocab_size: usize,
196) -> Vec<TrainingMetric> {
197    let total_steps =
198        (data.len().max(1) / config.batch_size.max(1) as usize).max(1) * config.epochs as usize;
199
200    let mut metrics = Vec::with_capacity(total_steps);
201    let mut loss = 2.5_f32;
202    let decay = 0.97_f32;
203
204    for step in 0..total_steps {
205        loss *= decay;
206        let lr_scale = cosine_schedule(step, total_steps, config.warmup_steps as usize);
207        metrics.push(TrainingMetric {
208            step: step as u64,
209            loss,
210            learning_rate: config.learning_rate * lr_scale as f64,
211            grad_norm: Some(1.0 / (1.0 + step as f32 * 0.01)),
212            tokens_per_sec: None,
213            eta_secs: Some(((total_steps - step) as u64) * 2),
214        });
215    }
216    metrics
217}
218
219/// Compute real loss on training data via model forward pass.
220///
221/// Uses the loaded quantized model to evaluate cross-entropy loss on token sequences.
222/// This is NOT training (no weight updates) — it's evaluation of training data quality.
223/// Returns (loss, tokens_evaluated) or None if no model loaded.
224#[cfg(feature = "realizar")]
225pub fn compute_training_loss(
226    model: &std::sync::Arc<realizar::gguf::OwnedQuantizedModel>,
227    token_ids: &[u32],
228    max_tokens: usize,
229) -> Option<(f32, usize)> {
230    // Reuse the perplexity computation — it IS cross-entropy loss
231    super::eval::compute_perplexity(model, token_ids, max_tokens)
232        .map(|(ppl, count)| (ppl.ln() as f32, count)) // PPL = exp(loss), so loss = ln(PPL)
233}
234
235/// Cosine learning rate schedule with warmup.
236fn cosine_schedule(step: usize, total: usize, warmup: usize) -> f32 {
237    if step < warmup {
238        return step as f32 / warmup.max(1) as f32;
239    }
240    let progress = (step - warmup) as f32 / (total - warmup).max(1) as f32;
241    0.5 * (1.0 + (std::f32::consts::PI * progress).cos())
242}