Skip to main content

batuta/serve/banco/
training.rs

1//! Training run management — start, stop, list, metrics, export.
2//!
3//! Types and store for training runs. Presets and the training engine
4//! live in `training_engine.rs`.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10// Re-export engine types so callers use `training::TrainingPreset`
11pub use super::training_engine::{run_lora_training, TrainingPreset};
12
13// ============================================================================
14// Training types
15// ============================================================================
16
17/// Training run metadata.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TrainingRun {
20    pub id: String,
21    pub dataset_id: String,
22    pub method: TrainingMethod,
23    pub config: TrainingConfig,
24    pub status: TrainingStatus,
25    pub created_at: u64,
26    pub metrics: Vec<TrainingMetric>,
27    /// True when metrics are from simulated cosine schedule, not real gradients.
28    /// Honest labeling per Jidoka — stop-the-line on false claims.
29    #[serde(default)]
30    pub simulated: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub export_path: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub error: Option<String>,
35}
36
37/// Training method.
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum TrainingMethod {
41    Lora,
42    Qlora,
43    FullFinetune,
44}
45
46/// Optimizer type.
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48#[serde(rename_all = "snake_case")]
49pub enum OptimizerType {
50    Adam,
51    AdamW,
52    Sgd,
53}
54
55/// Learning rate scheduler type.
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
57#[serde(rename_all = "snake_case")]
58pub enum SchedulerType {
59    Constant,
60    Cosine,
61    Linear,
62    StepDecay,
63}
64
65/// Training configuration.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct TrainingConfig {
68    #[serde(default = "default_lora_r")]
69    pub lora_r: u32,
70    #[serde(default = "default_lora_alpha")]
71    pub lora_alpha: u32,
72    #[serde(default = "default_learning_rate")]
73    pub learning_rate: f64,
74    #[serde(default = "default_epochs")]
75    pub epochs: u32,
76    #[serde(default = "default_batch_size")]
77    pub batch_size: u32,
78    #[serde(default = "default_max_seq_length")]
79    pub max_seq_length: u32,
80    #[serde(default)]
81    pub target_modules: Vec<String>,
82    #[serde(default = "default_optimizer")]
83    pub optimizer: OptimizerType,
84    #[serde(default = "default_scheduler")]
85    pub scheduler: SchedulerType,
86    #[serde(default = "default_warmup_steps")]
87    pub warmup_steps: u32,
88    #[serde(default = "default_grad_accum")]
89    pub gradient_accumulation_steps: u32,
90    #[serde(default = "default_max_grad_norm")]
91    pub max_grad_norm: f64,
92}
93
94fn default_lora_r() -> u32 {
95    16
96}
97fn default_lora_alpha() -> u32 {
98    32
99}
100fn default_learning_rate() -> f64 {
101    2e-4
102}
103fn default_epochs() -> u32 {
104    3
105}
106fn default_batch_size() -> u32 {
107    4
108}
109fn default_max_seq_length() -> u32 {
110    2048
111}
112fn default_optimizer() -> OptimizerType {
113    OptimizerType::AdamW
114}
115fn default_scheduler() -> SchedulerType {
116    SchedulerType::Cosine
117}
118fn default_warmup_steps() -> u32 {
119    100
120}
121fn default_grad_accum() -> u32 {
122    4
123}
124fn default_max_grad_norm() -> f64 {
125    1.0
126}
127
128impl Default for TrainingConfig {
129    fn default() -> Self {
130        Self {
131            lora_r: default_lora_r(),
132            lora_alpha: default_lora_alpha(),
133            learning_rate: default_learning_rate(),
134            epochs: default_epochs(),
135            batch_size: default_batch_size(),
136            max_seq_length: default_max_seq_length(),
137            target_modules: vec![
138                "q_proj".into(),
139                "k_proj".into(),
140                "v_proj".into(),
141                "o_proj".into(),
142            ],
143            optimizer: default_optimizer(),
144            scheduler: default_scheduler(),
145            warmup_steps: default_warmup_steps(),
146            gradient_accumulation_steps: default_grad_accum(),
147            max_grad_norm: default_max_grad_norm(),
148        }
149    }
150}
151
152/// Training run status.
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154#[serde(rename_all = "snake_case")]
155pub enum TrainingStatus {
156    Queued,
157    Running,
158    Complete,
159    Failed,
160    Stopped,
161}
162
163/// A single training metric snapshot.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TrainingMetric {
166    pub step: u64,
167    pub loss: f32,
168    pub learning_rate: f64,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub grad_norm: Option<f32>,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub tokens_per_sec: Option<u64>,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub eta_secs: Option<u64>,
175}
176
177/// Export format for trained adapters/models.
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
179#[serde(rename_all = "snake_case")]
180pub enum ExportFormat {
181    Safetensors,
182    Gguf,
183    Apr,
184}
185
186/// Export request configuration.
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct ExportRequest {
189    #[serde(default = "default_export_format")]
190    pub format: ExportFormat,
191    #[serde(default)]
192    pub merge: bool,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub quantization: Option<String>,
195}
196
197fn default_export_format() -> ExportFormat {
198    ExportFormat::Safetensors
199}
200
201/// Export result.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct ExportResult {
204    pub run_id: String,
205    pub format: ExportFormat,
206    pub merged: bool,
207    pub path: String,
208    pub size_bytes: u64,
209}
210
211// ============================================================================
212// Training store
213// ============================================================================
214
215/// Training run store.
216pub struct TrainingStore {
217    runs: RwLock<HashMap<String, TrainingRun>>,
218    counter: std::sync::atomic::AtomicU64,
219}
220
221impl TrainingStore {
222    #[must_use]
223    pub fn new() -> Arc<Self> {
224        Arc::new(Self {
225            runs: RwLock::new(HashMap::new()),
226            counter: std::sync::atomic::AtomicU64::new(0),
227        })
228    }
229
230    /// Start a training run.
231    pub fn start(
232        &self,
233        dataset_id: &str,
234        method: TrainingMethod,
235        config: TrainingConfig,
236    ) -> TrainingRun {
237        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
238        let run = TrainingRun {
239            id: format!("run-{}-{seq}", epoch_secs()),
240            dataset_id: dataset_id.to_string(),
241            method,
242            config,
243            status: TrainingStatus::Queued,
244            created_at: epoch_secs(),
245            metrics: Vec::new(),
246            simulated: true, // No real gradient-based training yet
247            export_path: None,
248            error: None,
249        };
250        if let Ok(mut store) = self.runs.write() {
251            store.insert(run.id.clone(), run.clone());
252        }
253        run
254    }
255
256    /// Push a metric snapshot for a run.
257    pub fn push_metric(&self, run_id: &str, metric: TrainingMetric) {
258        if let Ok(mut store) = self.runs.write() {
259            if let Some(run) = store.get_mut(run_id) {
260                run.metrics.push(metric);
261            }
262        }
263    }
264
265    /// Update run status.
266    pub fn set_status(&self, run_id: &str, status: TrainingStatus) {
267        if let Ok(mut store) = self.runs.write() {
268            if let Some(run) = store.get_mut(run_id) {
269                run.status = status;
270            }
271        }
272    }
273
274    /// Mark run as failed with error message.
275    pub fn fail(&self, run_id: &str, error: &str) {
276        if let Ok(mut store) = self.runs.write() {
277            if let Some(run) = store.get_mut(run_id) {
278                run.status = TrainingStatus::Failed;
279                run.error = Some(error.to_string());
280            }
281        }
282    }
283
284    /// Set export path for a completed run.
285    pub fn set_export_path(&self, run_id: &str, path: &str) {
286        if let Ok(mut store) = self.runs.write() {
287            if let Some(run) = store.get_mut(run_id) {
288                run.export_path = Some(path.to_string());
289            }
290        }
291    }
292
293    /// List all runs.
294    #[must_use]
295    pub fn list(&self) -> Vec<TrainingRun> {
296        let store = self.runs.read().unwrap_or_else(|e| e.into_inner());
297        let mut runs: Vec<TrainingRun> = store.values().cloned().collect();
298        runs.sort_by(|a, b| b.created_at.cmp(&a.created_at));
299        runs
300    }
301
302    /// Get a run by ID.
303    #[must_use]
304    pub fn get(&self, id: &str) -> Option<TrainingRun> {
305        self.runs.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
306    }
307
308    /// Stop a run.
309    pub fn stop(&self, id: &str) -> Result<(), TrainingError> {
310        let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
311        let run = store.get_mut(id).ok_or(TrainingError::NotFound(id.to_string()))?;
312        run.status = TrainingStatus::Stopped;
313        Ok(())
314    }
315
316    /// Delete a run.
317    pub fn delete(&self, id: &str) -> Result<(), TrainingError> {
318        let mut store = self.runs.write().map_err(|_| TrainingError::LockPoisoned)?;
319        store.remove(id).ok_or(TrainingError::NotFound(id.to_string()))?;
320        Ok(())
321    }
322}
323
324// ============================================================================
325// Errors
326// ============================================================================
327
328/// Training errors.
329#[derive(Debug, Clone, PartialEq, Eq)]
330pub enum TrainingError {
331    NotFound(String),
332    NoModel,
333    NoDataset(String),
334    LockPoisoned,
335}
336
337impl std::fmt::Display for TrainingError {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        match self {
340            Self::NotFound(id) => write!(f, "Training run not found: {id}"),
341            Self::NoModel => write!(f, "No model loaded — load a model first"),
342            Self::NoDataset(id) => write!(f, "Dataset not found: {id}"),
343            Self::LockPoisoned => write!(f, "Internal lock error"),
344        }
345    }
346}
347
348impl std::error::Error for TrainingError {}
349
350fn epoch_secs() -> u64 {
351    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
352}