Skip to main content

multiscreen_rs/
training.rs

1//! High-level training API with a builder pattern.
2//!
3//! The [`Trainer`] struct wraps model construction, token-sequence training,
4//! and checkpoint management behind a single ergonomic interface with sensible
5//! defaults.
6//!
7//! Users encode their own text into token IDs using any tokenizer they choose,
8//! then pass the `Vec<Vec<u32>>` sequences to [`Trainer::train_on_token_sequences`].
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use multiscreen_rs::prelude::*;
14//!
15//! fn main() -> multiscreen_rs::Result<()> {
16//!     let mut trainer = Trainer::builder()
17//!         .vocab_size(1000)
18//!         .budget(ParameterBudget::Params10M)
19//!         .batch_size(4)
20//!         .seq_len(64)
21//!         .steps(100)
22//!         .device(cpu()?)
23//!         .build()?;
24//!
25//!     // Token sequences from YOUR tokenizer
26//!     let sequences = vec![
27//!         vec![1, 2, 3, 4, 5],
28//!         vec![1, 2, 6, 7, 5],
29//!     ];
30//!
31//!     let report = trainer.train_on_token_sequences(&sequences)?;
32//!     println!("trained {} steps, final loss {:.4}", report.steps, report.final_loss);
33//!     Ok(())
34//! }
35//! ```
36
37use crate::error::{Error, Result};
38use crate::model::{
39    DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::Path;
44
45/// Re-export of the parameter budget enum for convenience.
46pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48// ---------------------------------------------------------------------------
49// TrainingReport
50// ---------------------------------------------------------------------------
51
52/// Summary returned after training via the high-level [`Trainer`].
53#[derive(Clone, Debug)]
54pub struct TrainingReport {
55    /// Number of training steps completed.
56    pub steps: usize,
57    /// Final training loss.
58    pub final_loss: f32,
59    /// Total number of parameters in the model.
60    pub parameter_count: usize,
61    /// Path the checkpoint was saved to, if any.
62    pub checkpoint_path: Option<String>,
63}
64
65impl TrainingReport {
66    fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
67        Self {
68            steps: report.steps,
69            final_loss: report.final_loss,
70            parameter_count: report.parameter_count,
71            checkpoint_path,
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Trainer
78// ---------------------------------------------------------------------------
79
80/// High-level trainer that bundles a model and training config.
81///
82/// Construct via [`Trainer::builder()`] and the [`TrainerBuilder`] struct.
83/// Users provide token sequences (`Vec<Vec<u32>>`) from their own tokenizer.
84pub struct Trainer {
85    model: DefaultMultiscreenModel,
86    training_config: ModelTrainingConfig,
87    checkpoint_dir: Option<String>,
88    #[allow(dead_code)]
89    checkpoint_interval: usize,
90    #[allow(dead_code)]
91    run_dir: Option<String>,
92}
93
94impl std::fmt::Debug for Trainer {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("Trainer")
97            .field("training_config", &self.training_config)
98            .field("checkpoint_dir", &self.checkpoint_dir)
99            .field("checkpoint_interval", &self.checkpoint_interval)
100            .field("run_dir", &self.run_dir)
101            .finish_non_exhaustive()
102    }
103}
104
105impl Trainer {
106    /// Returns a new [`TrainerBuilder`] with sensible defaults.
107    pub fn builder() -> TrainerBuilder {
108        TrainerBuilder::new()
109    }
110
111    /// Trains the model on the provided token sequences.
112    ///
113    /// Each inner `Vec<u32>` is a tokenized text sample. Use your own
114    /// tokenizer to produce these sequences before calling this method.
115    ///
116    /// The `on_step` callback is invoked after each optimizer step with
117    /// `(step_index, loss_value)`.
118    pub fn train_on_token_sequences_with_callback(
119        &mut self,
120        sequences: &[Vec<u32>],
121        on_step: impl FnMut(usize, f32),
122    ) -> Result<TrainingReport> {
123        if sequences.is_empty() {
124            return Err(Error::Training("no training sequences provided".into()));
125        }
126
127        let device = self.model_device();
128        let report =
129            self.model
130                .train_token_sequences(sequences, &self.training_config, &device, on_step)?;
131
132        let checkpoint_path = match &self.checkpoint_dir {
133            Some(dir) => {
134                let dir_path = Path::new(dir);
135                fs::create_dir_all(dir_path).map_err(|e| {
136                    Error::Io(format!(
137                        "failed to create checkpoint directory {:?}: {}",
138                        dir, e
139                    ))
140                })?;
141                let path = dir_path.join("checkpoint.mpk");
142                self.model.save_parameters(&path)?;
143                Some(path.to_string_lossy().into_owned())
144            }
145            None => None,
146        };
147
148        Ok(TrainingReport::from_model_report(&report, checkpoint_path))
149    }
150
151    /// Convenience wrapper that trains without a per-step callback.
152    pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
153        self.train_on_token_sequences_with_callback(sequences, |_, _| {})
154    }
155
156    /// Saves a model checkpoint to the given path.
157    pub fn save_checkpoint(&self, path: &str) -> Result<()> {
158        if let Some(parent) = Path::new(path).parent() {
159            fs::create_dir_all(parent).map_err(|e| {
160                Error::Io(format!(
161                    "failed to create checkpoint directory {:?}: {}",
162                    parent, e
163                ))
164            })?;
165        }
166        self.model.save_parameters(path)
167    }
168
169    /// Returns a reference to the underlying model.
170    pub fn model(&self) -> &DefaultMultiscreenModel {
171        &self.model
172    }
173
174    /// Returns a mutable reference to the underlying model.
175    pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
176        &mut self.model
177    }
178
179    /// Returns a reference to the training configuration.
180    pub fn training_config(&self) -> &ModelTrainingConfig {
181        &self.training_config
182    }
183
184    /// Obtains the device the model is currently on.
185    fn model_device(&self) -> Device {
186        Device::default()
187    }
188}
189
190// ---------------------------------------------------------------------------
191// TrainerBuilder
192// ---------------------------------------------------------------------------
193
194/// Builder for constructing a [`Trainer`] with sensible defaults.
195pub struct TrainerBuilder {
196    vocab_size: Option<usize>,
197    budget: ParameterBudget,
198    device: Option<Device>,
199    batch_size: usize,
200    seq_len: usize,
201    steps: usize,
202    learning_rate: f64,
203    weight_decay: f64,
204    grad_clip_norm: Option<f64>,
205    checkpoint_dir: Option<String>,
206    checkpoint_interval: usize,
207    run_dir: Option<String>,
208}
209
210impl TrainerBuilder {
211    /// Creates a new builder with default values.
212    fn new() -> Self {
213        Self {
214            vocab_size: None,
215            budget: ParameterBudget::Params10M,
216            device: None,
217            batch_size: 4,
218            seq_len: 128,
219            steps: 1000,
220            learning_rate: 2e-4,
221            weight_decay: 0.01,
222            grad_clip_norm: Some(1.0),
223            checkpoint_dir: None,
224            checkpoint_interval: 1000,
225            run_dir: None,
226        }
227    }
228
229    /// Sets the vocabulary size (required).
230    ///
231    /// This is the number of distinct token IDs your tokenizer produces.
232    pub fn vocab_size(mut self, size: usize) -> Self {
233        self.vocab_size = Some(size);
234        self
235    }
236
237    /// Sets the parameter budget. Defaults to [`ParameterBudget::Params10M`].
238    pub fn budget(mut self, budget: ParameterBudget) -> Self {
239        self.budget = budget;
240        self
241    }
242
243    /// Sets the compute device. Defaults to [`default_device()`].
244    pub fn device(mut self, device: Device) -> Self {
245        self.device = Some(device);
246        self
247    }
248
249    /// Sets the batch size. Defaults to 4.
250    pub fn batch_size(mut self, size: usize) -> Self {
251        self.batch_size = size;
252        self
253    }
254
255    /// Sets the sequence length. Defaults to 128.
256    pub fn seq_len(mut self, len: usize) -> Self {
257        self.seq_len = len;
258        self
259    }
260
261    /// Sets the number of training steps. Defaults to 1000.
262    pub fn steps(mut self, steps: usize) -> Self {
263        self.steps = steps;
264        self
265    }
266
267    /// Sets the learning rate. Defaults to `2e-4`.
268    pub fn learning_rate(mut self, lr: f64) -> Self {
269        self.learning_rate = lr;
270        self
271    }
272
273    /// Sets the weight decay. Defaults to `0.01`.
274    pub fn weight_decay(mut self, wd: f64) -> Self {
275        self.weight_decay = wd;
276        self
277    }
278
279    /// Sets gradient clip norm. Defaults to `Some(1.0)`.
280    pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
281        self.grad_clip_norm = norm;
282        self
283    }
284
285    /// Sets the directory where checkpoints are saved.
286    pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
287        self.checkpoint_dir = Some(dir.into());
288        self
289    }
290
291    /// Sets how often (in steps) to save checkpoints. Defaults to 1000.
292    pub fn checkpoint_interval(mut self, steps: usize) -> Self {
293        self.checkpoint_interval = steps;
294        self
295    }
296
297    /// Overrides the run directory.
298    pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
299        self.run_dir = Some(dir.into());
300        self
301    }
302
303    /// Builds the [`Trainer`].
304    ///
305    /// This constructs the model config from the parameter budget + vocab size,
306    /// and creates the model.
307    pub fn build(self) -> Result<Trainer> {
308        let vocab_size = self.vocab_size.ok_or_else(|| {
309            Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
310        })?;
311
312        let device = match self.device {
313            Some(d) => d,
314            None => default_device()?,
315        };
316
317        let config =
318            MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
319        let model = DefaultMultiscreenModel::new(config, &device)?;
320
321        let training_config = ModelTrainingConfig {
322            steps: self.steps,
323            batch_size: self.batch_size,
324            learning_rate: self.learning_rate,
325            weight_decay: self.weight_decay,
326            grad_clip_norm: self.grad_clip_norm,
327            pad_token_id: 0,
328        };
329
330        let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
331
332        Ok(Trainer {
333            model,
334            training_config,
335            checkpoint_dir: self.checkpoint_dir,
336            checkpoint_interval: self.checkpoint_interval,
337            run_dir,
338        })
339    }
340}
341
342impl Default for TrainerBuilder {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn builder_requires_vocab_size() {
358        let result = Trainer::builder().build();
359        assert!(result.is_err(), "build should fail without vocab_size");
360        let msg = format!("{}", result.unwrap_err());
361        assert!(
362            msg.contains("vocab_size"),
363            "error should mention vocab_size: {}",
364            msg
365        );
366    }
367
368    #[test]
369    fn training_report_from_model_report() {
370        let model_report = ModelTrainingReport {
371            steps: 500,
372            final_loss: 0.123,
373            training_window_count: 100,
374            parameter_count: 10_000_000,
375        };
376        let report =
377            TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
378        assert_eq!(report.steps, 500);
379        assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
380        assert_eq!(report.parameter_count, 10_000_000);
381        assert_eq!(
382            report.checkpoint_path.as_deref(),
383            Some("runs/checkpoint.mpk")
384        );
385    }
386
387    #[test]
388    fn builder_defaults() {
389        let builder = TrainerBuilder::new();
390        assert!(builder.vocab_size.is_none());
391        assert!(matches!(builder.budget, ParameterBudget::Params10M));
392        assert!(builder.device.is_none());
393        assert_eq!(builder.batch_size, 4);
394        assert_eq!(builder.seq_len, 128);
395        assert_eq!(builder.steps, 1000);
396        assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
397        assert_eq!(builder.checkpoint_interval, 1000);
398    }
399}