Skip to main content

multiscreen_rs/
training.rs

1//! High-level training API with a builder pattern.
2//!
3//! The [`Trainer`] struct wraps model construction, token-sequence training,
4//! and checkpoint management behind a single ergonomic interface with sensible
5//! defaults.
6//!
7//! Users encode their own text into token IDs using any tokenizer they choose,
8//! then pass the `Vec<Vec<u32>>` sequences to [`Trainer::train_on_token_sequences`].
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use multiscreen_rs::prelude::*;
14//!
15//! fn main() -> multiscreen_rs::Result<()> {
16//!     let mut trainer = Trainer::builder()
17//!         .vocab_size(1000)
18//!         .budget(ParameterBudget::Params10M)
19//!         .batch_size(4)
20//!         .seq_len(64)
21//!         .steps(100)
22//!         .device(auto_device()?)
23//!         .build()?;
24//!
25//!     // Token sequences from YOUR tokenizer
26//!     let sequences = vec![
27//!         vec![1, 2, 3, 4, 5],
28//!         vec![1, 2, 6, 7, 5],
29//!     ];
30//!
31//!     let report = trainer.train_on_token_sequences(&sequences)?;
32//!     println!("trained {} steps, final loss {:.4}", report.steps, report.final_loss);
33//!     Ok(())
34//! }
35//! ```
36
37use crate::error::{Error, Result};
38use crate::model::{
39    DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{Device, default_device};
42use std::fs;
43use std::path::Path;
44
45/// Re-export of the parameter budget enum for convenience.
46pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48// ---------------------------------------------------------------------------
49// TrainingReport
50// ---------------------------------------------------------------------------
51
52/// Summary returned after training via the high-level [`Trainer`].
53#[derive(Clone, Debug)]
54pub struct TrainingReport {
55    /// Number of training steps completed.
56    pub steps: usize,
57    /// Final training loss.
58    pub final_loss: f32,
59    /// The lowest loss observed across all training steps.
60    pub best_loss: f32,
61    /// The step at which `best_loss` was recorded.
62    pub best_loss_step: usize,
63    /// Total number of parameters in the model.
64    pub parameter_count: usize,
65    /// Path the checkpoint was saved to, if any.
66    pub checkpoint_path: Option<String>,
67}
68
69impl TrainingReport {
70    fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
71        Self {
72            steps: report.steps,
73            final_loss: report.final_loss,
74            best_loss: report.best_loss,
75            best_loss_step: report.best_loss_step,
76            parameter_count: report.parameter_count,
77            checkpoint_path,
78        }
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Trainer
84// ---------------------------------------------------------------------------
85
86/// High-level trainer that bundles a model and training config.
87///
88/// Construct via [`Trainer::builder()`] and the [`TrainerBuilder`] struct.
89/// Users provide token sequences (`Vec<Vec<u32>>`) from their own tokenizer.
90pub struct Trainer {
91    model: DefaultMultiscreenModel,
92    training_config: ModelTrainingConfig,
93    checkpoint_dir: Option<String>,
94    checkpoint_interval: usize,
95    #[allow(dead_code)]
96    run_dir: Option<String>,
97}
98
99impl std::fmt::Debug for Trainer {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Trainer")
102            .field("training_config", &self.training_config)
103            .field("checkpoint_dir", &self.checkpoint_dir)
104            .field("checkpoint_interval", &self.checkpoint_interval)
105            .field("run_dir", &self.run_dir)
106            .finish_non_exhaustive()
107    }
108}
109
110impl Trainer {
111    /// Returns a new [`TrainerBuilder`] with sensible defaults.
112    pub fn builder() -> TrainerBuilder {
113        TrainerBuilder::new()
114    }
115
116    /// Trains the model on the provided token sequences.
117    ///
118    /// Each inner `Vec<u32>` is a tokenized text sample. Use your own
119    /// tokenizer to produce these sequences before calling this method.
120    ///
121    /// The `on_step` callback is invoked after each optimizer step with
122    /// `(step_index, loss_value)`.
123    pub fn train_on_token_sequences_with_callback(
124        &mut self,
125        sequences: &[Vec<u32>],
126        on_step: impl FnMut(usize, f32),
127    ) -> Result<TrainingReport> {
128        if sequences.is_empty() {
129            return Err(Error::Training("no training sequences provided".into()));
130        }
131
132        let mut config = self.training_config.clone();
133        config.checkpoint_dir = self.checkpoint_dir.clone();
134        config.checkpoint_interval = self.checkpoint_interval;
135
136        let device = self.model_device();
137        let report = self
138            .model
139            .train_token_sequences(sequences, &config, &device, on_step)?;
140
141        let checkpoint_path = match &self.checkpoint_dir {
142            Some(dir) => {
143                let dir_path = Path::new(dir);
144                fs::create_dir_all(dir_path).map_err(|e| {
145                    Error::Io(format!(
146                        "failed to create checkpoint directory {:?}: {}",
147                        dir, e
148                    ))
149                })?;
150                let path = dir_path.join("checkpoint.mpk");
151                self.model.save_parameters(&path)?;
152                Some(path.to_string_lossy().into_owned())
153            }
154            None => None,
155        };
156
157        Ok(TrainingReport::from_model_report(&report, checkpoint_path))
158    }
159
160    /// Convenience wrapper that trains without a per-step callback.
161    pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
162        self.train_on_token_sequences_with_callback(sequences, |_, _| {})
163    }
164
165    /// Trains the model on chat-style (prompt, response) token-ID pairs.
166    ///
167    /// Each element of `chat_pairs` is `(prompt_token_ids, response_token_ids)`.
168    /// The model sees the full context but loss is computed only on response
169    /// tokens, preventing the model from learning to generate role labels.
170    ///
171    /// The `on_step` callback is invoked after each optimizer step with
172    /// `(step_index, loss_value)`.
173    pub fn train_on_chat_sequences_with_callback(
174        &mut self,
175        chat_pairs: &[(Vec<u32>, Vec<u32>)],
176        on_step: impl FnMut(usize, f32),
177    ) -> Result<TrainingReport> {
178        if chat_pairs.is_empty() {
179            return Err(Error::Training("no training chat pairs provided".into()));
180        }
181
182        let mut config = self.training_config.clone();
183        config.checkpoint_dir = self.checkpoint_dir.clone();
184        config.checkpoint_interval = self.checkpoint_interval;
185
186        let device = self.model_device();
187        let report = self
188            .model
189            .train_chat_sequences(chat_pairs, &config, &device, on_step)?;
190
191        let checkpoint_path = match &self.checkpoint_dir {
192            Some(dir) => {
193                let dir_path = Path::new(dir);
194                fs::create_dir_all(dir_path).map_err(|e| {
195                    Error::Io(format!(
196                        "failed to create checkpoint directory {:?}: {}",
197                        dir, e
198                    ))
199                })?;
200                let path = dir_path.join("checkpoint.mpk");
201                self.model.save_parameters(&path)?;
202                Some(path.to_string_lossy().into_owned())
203            }
204            None => None,
205        };
206
207        Ok(TrainingReport::from_model_report(&report, checkpoint_path))
208    }
209
210    /// Convenience wrapper that trains on chat pairs without a per-step callback.
211    pub fn train_on_chat_sequences(
212        &mut self,
213        chat_pairs: &[(Vec<u32>, Vec<u32>)],
214    ) -> Result<TrainingReport> {
215        self.train_on_chat_sequences_with_callback(chat_pairs, |_, _| {})
216    }
217
218    /// Saves a model checkpoint to the given path.
219    pub fn save_checkpoint(&self, path: &str) -> Result<()> {
220        if let Some(parent) = Path::new(path).parent() {
221            fs::create_dir_all(parent).map_err(|e| {
222                Error::Io(format!(
223                    "failed to create checkpoint directory {:?}: {}",
224                    parent, e
225                ))
226            })?;
227        }
228        self.model.save_parameters(path)
229    }
230
231    /// Returns a reference to the underlying model.
232    pub fn model(&self) -> &DefaultMultiscreenModel {
233        &self.model
234    }
235
236    /// Returns a mutable reference to the underlying model.
237    pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
238        &mut self.model
239    }
240
241    /// Returns a reference to the training configuration.
242    pub fn training_config(&self) -> &ModelTrainingConfig {
243        &self.training_config
244    }
245
246    /// Obtains the device the model is currently on.
247    fn model_device(&self) -> Device {
248        Device::default()
249    }
250}
251
252// ---------------------------------------------------------------------------
253// TrainerBuilder
254// ---------------------------------------------------------------------------
255
256/// Builder for constructing a [`Trainer`] with sensible defaults.
257pub struct TrainerBuilder {
258    vocab_size: Option<usize>,
259    budget: ParameterBudget,
260    device: Option<Device>,
261    batch_size: usize,
262    seq_len: usize,
263    steps: usize,
264    learning_rate: f64,
265    weight_decay: f64,
266    grad_clip_norm: Option<f64>,
267    checkpoint_dir: Option<String>,
268    checkpoint_interval: usize,
269    run_dir: Option<String>,
270}
271
272impl TrainerBuilder {
273    /// Creates a new builder with default values.
274    fn new() -> Self {
275        Self {
276            vocab_size: None,
277            budget: ParameterBudget::Params10M,
278            device: None,
279            batch_size: 4,
280            seq_len: 128,
281            steps: 1000,
282            learning_rate: 2e-4,
283            weight_decay: 0.01,
284            grad_clip_norm: Some(1.0),
285            checkpoint_dir: None,
286            checkpoint_interval: 1000,
287            run_dir: None,
288        }
289    }
290
291    /// Sets the vocabulary size (required).
292    ///
293    /// This is the number of distinct token IDs your tokenizer produces.
294    pub fn vocab_size(mut self, size: usize) -> Self {
295        self.vocab_size = Some(size);
296        self
297    }
298
299    /// Sets the parameter budget. Defaults to [`ParameterBudget::Params10M`].
300    pub fn budget(mut self, budget: ParameterBudget) -> Self {
301        self.budget = budget;
302        self
303    }
304
305    /// Sets the compute device. Defaults to [`default_device()`].
306    pub fn device(mut self, device: Device) -> Self {
307        self.device = Some(device);
308        self
309    }
310
311    /// Sets the batch size. Defaults to 4.
312    pub fn batch_size(mut self, size: usize) -> Self {
313        self.batch_size = size;
314        self
315    }
316
317    /// Sets the sequence length. Defaults to 128.
318    pub fn seq_len(mut self, len: usize) -> Self {
319        self.seq_len = len;
320        self
321    }
322
323    /// Sets the number of training steps. Defaults to 1000.
324    pub fn steps(mut self, steps: usize) -> Self {
325        self.steps = steps;
326        self
327    }
328
329    /// Sets the learning rate. Defaults to `2e-4`.
330    pub fn learning_rate(mut self, lr: f64) -> Self {
331        self.learning_rate = lr;
332        self
333    }
334
335    /// Sets the weight decay. Defaults to `0.01`.
336    pub fn weight_decay(mut self, wd: f64) -> Self {
337        self.weight_decay = wd;
338        self
339    }
340
341    /// Sets gradient clip norm. Defaults to `Some(1.0)`.
342    pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
343        self.grad_clip_norm = norm;
344        self
345    }
346
347    /// Sets the directory where checkpoints are saved.
348    pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
349        self.checkpoint_dir = Some(dir.into());
350        self
351    }
352
353    /// Sets how often (in steps) to save checkpoints. Defaults to 1000.
354    pub fn checkpoint_interval(mut self, steps: usize) -> Self {
355        self.checkpoint_interval = steps;
356        self
357    }
358
359    /// Overrides the run directory.
360    pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
361        self.run_dir = Some(dir.into());
362        self
363    }
364
365    /// Builds the [`Trainer`].
366    ///
367    /// This constructs the model config from the parameter budget + vocab size,
368    /// and creates the model.
369    pub fn build(self) -> Result<Trainer> {
370        let vocab_size = self.vocab_size.ok_or_else(|| {
371            Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
372        })?;
373
374        let device = match self.device {
375            Some(d) => d,
376            None => default_device()?,
377        };
378
379        let config =
380            MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
381        let model = DefaultMultiscreenModel::new(config, &device)?;
382
383        let training_config = ModelTrainingConfig {
384            steps: self.steps,
385            batch_size: self.batch_size,
386            learning_rate: self.learning_rate,
387            weight_decay: self.weight_decay,
388            grad_clip_norm: self.grad_clip_norm,
389            pad_token_id: 0,
390            checkpoint_dir: None, // injected by Trainer during training
391            checkpoint_interval: 0,
392        };
393
394        let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
395
396        Ok(Trainer {
397            model,
398            training_config,
399            checkpoint_dir: self.checkpoint_dir,
400            checkpoint_interval: self.checkpoint_interval,
401            run_dir,
402        })
403    }
404}
405
406impl Default for TrainerBuilder {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412// ---------------------------------------------------------------------------
413// Tests
414// ---------------------------------------------------------------------------
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn builder_requires_vocab_size() {
422        let result = Trainer::builder().build();
423        assert!(result.is_err(), "build should fail without vocab_size");
424        let msg = format!("{}", result.unwrap_err());
425        assert!(
426            msg.contains("vocab_size"),
427            "error should mention vocab_size: {}",
428            msg
429        );
430    }
431
432    #[test]
433    fn training_report_from_model_report() {
434        let model_report = ModelTrainingReport {
435            steps: 500,
436            final_loss: 0.123,
437            best_loss: 0.100,
438            best_loss_step: 420,
439            training_window_count: 100,
440            parameter_count: 10_000_000,
441        };
442        let report =
443            TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
444        assert_eq!(report.steps, 500);
445        assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
446        assert!((report.best_loss - 0.100).abs() < f32::EPSILON);
447        assert_eq!(report.best_loss_step, 420);
448        assert_eq!(report.parameter_count, 10_000_000);
449        assert_eq!(
450            report.checkpoint_path.as_deref(),
451            Some("runs/checkpoint.mpk")
452        );
453    }
454
455    #[test]
456    fn builder_defaults() {
457        let builder = TrainerBuilder::new();
458        assert!(builder.vocab_size.is_none());
459        assert!(matches!(builder.budget, ParameterBudget::Params10M));
460        assert!(builder.device.is_none());
461        assert_eq!(builder.batch_size, 4);
462        assert_eq!(builder.seq_len, 128);
463        assert_eq!(builder.steps, 1000);
464        assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
465        assert_eq!(builder.checkpoint_interval, 1000);
466    }
467}