Skip to main content

multiscreen_rs/
training.rs

1//! High-level training API with a builder pattern.
2//!
3//! The [`Trainer`] struct wraps model construction, token-sequence training,
4//! and checkpoint management behind a single ergonomic interface with sensible
5//! defaults.
6//!
7//! Users encode their own text into token IDs using any tokenizer they choose,
8//! then pass the `Vec<Vec<u32>>` sequences to [`Trainer::train_on_token_sequences`].
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use multiscreen_rs::prelude::*;
14//!
15//! fn main() -> multiscreen_rs::Result<()> {
16//!     let mut trainer = Trainer::builder()
17//!         .vocab_size(1000)
18//!         .budget(ParameterBudget::Params10M)
19//!         .batch_size(4)
20//!         .seq_len(64)
21//!         .steps(100)
22//!         .device(auto_device()?)
23//!         .build()?;
24//!
25//!     // Token sequences from YOUR tokenizer
26//!     let sequences = vec![
27//!         vec![1, 2, 3, 4, 5],
28//!         vec![1, 2, 6, 7, 5],
29//!     ];
30//!
31//!     let report = trainer.train_on_token_sequences(&sequences)?;
32//!     println!("trained {} steps, final loss {:.4}", report.steps, report.final_loss);
33//!     Ok(())
34//! }
35//! ```
36
37use crate::error::{Error, Result};
38use crate::model::{
39    DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::Path;
44
45/// Re-export of the parameter budget enum for convenience.
46pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48// ---------------------------------------------------------------------------
49// TrainingReport
50// ---------------------------------------------------------------------------
51
52/// Summary returned after training via the high-level [`Trainer`].
53#[derive(Clone, Debug)]
54pub struct TrainingReport {
55    /// Number of training steps completed.
56    pub steps: usize,
57    /// Final training loss.
58    pub final_loss: f32,
59    /// Total number of parameters in the model.
60    pub parameter_count: usize,
61    /// Path the checkpoint was saved to, if any.
62    pub checkpoint_path: Option<String>,
63}
64
65impl TrainingReport {
66    fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
67        Self {
68            steps: report.steps,
69            final_loss: report.final_loss,
70            parameter_count: report.parameter_count,
71            checkpoint_path,
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Trainer
78// ---------------------------------------------------------------------------
79
80/// High-level trainer that bundles a model and training config.
81///
82/// Construct via [`Trainer::builder()`] and the [`TrainerBuilder`] struct.
83/// Users provide token sequences (`Vec<Vec<u32>>`) from their own tokenizer.
84pub struct Trainer {
85    model: DefaultMultiscreenModel,
86    training_config: ModelTrainingConfig,
87    checkpoint_dir: Option<String>,
88    #[allow(dead_code)]
89    checkpoint_interval: usize,
90    #[allow(dead_code)]
91    run_dir: Option<String>,
92}
93
94impl std::fmt::Debug for Trainer {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("Trainer")
97            .field("training_config", &self.training_config)
98            .field("checkpoint_dir", &self.checkpoint_dir)
99            .field("checkpoint_interval", &self.checkpoint_interval)
100            .field("run_dir", &self.run_dir)
101            .finish_non_exhaustive()
102    }
103}
104
105impl Trainer {
106    /// Returns a new [`TrainerBuilder`] with sensible defaults.
107    pub fn builder() -> TrainerBuilder {
108        TrainerBuilder::new()
109    }
110
111    /// Trains the model on the provided token sequences.
112    ///
113    /// Each inner `Vec<u32>` is a tokenized text sample. Use your own
114    /// tokenizer to produce these sequences before calling this method.
115    ///
116    /// The `on_step` callback is invoked after each optimizer step with
117    /// `(step_index, loss_value)`.
118    pub fn train_on_token_sequences_with_callback(
119        &mut self,
120        sequences: &[Vec<u32>],
121        on_step: impl FnMut(usize, f32),
122    ) -> Result<TrainingReport> {
123        if sequences.is_empty() {
124            return Err(Error::Training("no training sequences provided".into()));
125        }
126
127        let device = self.model_device();
128        let report =
129            self.model
130                .train_token_sequences(sequences, &self.training_config, &device, on_step)?;
131
132        let checkpoint_path = match &self.checkpoint_dir {
133            Some(dir) => {
134                let dir_path = Path::new(dir);
135                fs::create_dir_all(dir_path).map_err(|e| {
136                    Error::Io(format!(
137                        "failed to create checkpoint directory {:?}: {}",
138                        dir, e
139                    ))
140                })?;
141                let path = dir_path.join("checkpoint.mpk");
142                self.model.save_parameters(&path)?;
143                Some(path.to_string_lossy().into_owned())
144            }
145            None => None,
146        };
147
148        Ok(TrainingReport::from_model_report(&report, checkpoint_path))
149    }
150
151    /// Convenience wrapper that trains without a per-step callback.
152    pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
153        self.train_on_token_sequences_with_callback(sequences, |_, _| {})
154    }
155
156    /// Trains the model on chat-style (prompt, response) token-ID pairs.
157    ///
158    /// Each element of `chat_pairs` is `(prompt_token_ids, response_token_ids)`.
159    /// The model sees the full context but loss is computed only on response
160    /// tokens, preventing the model from learning to generate role labels.
161    ///
162    /// The `on_step` callback is invoked after each optimizer step with
163    /// `(step_index, loss_value)`.
164    pub fn train_on_chat_sequences_with_callback(
165        &mut self,
166        chat_pairs: &[(Vec<u32>, Vec<u32>)],
167        on_step: impl FnMut(usize, f32),
168    ) -> Result<TrainingReport> {
169        if chat_pairs.is_empty() {
170            return Err(Error::Training("no training chat pairs provided".into()));
171        }
172
173        let device = self.model_device();
174        let report =
175            self.model
176                .train_chat_sequences(chat_pairs, &self.training_config, &device, on_step)?;
177
178        let checkpoint_path = match &self.checkpoint_dir {
179            Some(dir) => {
180                let dir_path = Path::new(dir);
181                fs::create_dir_all(dir_path).map_err(|e| {
182                    Error::Io(format!(
183                        "failed to create checkpoint directory {:?}: {}",
184                        dir, e
185                    ))
186                })?;
187                let path = dir_path.join("checkpoint.mpk");
188                self.model.save_parameters(&path)?;
189                Some(path.to_string_lossy().into_owned())
190            }
191            None => None,
192        };
193
194        Ok(TrainingReport::from_model_report(&report, checkpoint_path))
195    }
196
197    /// Convenience wrapper that trains on chat pairs without a per-step callback.
198    pub fn train_on_chat_sequences(
199        &mut self,
200        chat_pairs: &[(Vec<u32>, Vec<u32>)],
201    ) -> Result<TrainingReport> {
202        self.train_on_chat_sequences_with_callback(chat_pairs, |_, _| {})
203    }
204
205    /// Saves a model checkpoint to the given path.
206    pub fn save_checkpoint(&self, path: &str) -> Result<()> {
207        if let Some(parent) = Path::new(path).parent() {
208            fs::create_dir_all(parent).map_err(|e| {
209                Error::Io(format!(
210                    "failed to create checkpoint directory {:?}: {}",
211                    parent, e
212                ))
213            })?;
214        }
215        self.model.save_parameters(path)
216    }
217
218    /// Returns a reference to the underlying model.
219    pub fn model(&self) -> &DefaultMultiscreenModel {
220        &self.model
221    }
222
223    /// Returns a mutable reference to the underlying model.
224    pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
225        &mut self.model
226    }
227
228    /// Returns a reference to the training configuration.
229    pub fn training_config(&self) -> &ModelTrainingConfig {
230        &self.training_config
231    }
232
233    /// Obtains the device the model is currently on.
234    fn model_device(&self) -> Device {
235        Device::default()
236    }
237}
238
239// ---------------------------------------------------------------------------
240// TrainerBuilder
241// ---------------------------------------------------------------------------
242
243/// Builder for constructing a [`Trainer`] with sensible defaults.
244pub struct TrainerBuilder {
245    vocab_size: Option<usize>,
246    budget: ParameterBudget,
247    device: Option<Device>,
248    batch_size: usize,
249    seq_len: usize,
250    steps: usize,
251    learning_rate: f64,
252    weight_decay: f64,
253    grad_clip_norm: Option<f64>,
254    checkpoint_dir: Option<String>,
255    checkpoint_interval: usize,
256    run_dir: Option<String>,
257}
258
259impl TrainerBuilder {
260    /// Creates a new builder with default values.
261    fn new() -> Self {
262        Self {
263            vocab_size: None,
264            budget: ParameterBudget::Params10M,
265            device: None,
266            batch_size: 4,
267            seq_len: 128,
268            steps: 1000,
269            learning_rate: 2e-4,
270            weight_decay: 0.01,
271            grad_clip_norm: Some(1.0),
272            checkpoint_dir: None,
273            checkpoint_interval: 1000,
274            run_dir: None,
275        }
276    }
277
278    /// Sets the vocabulary size (required).
279    ///
280    /// This is the number of distinct token IDs your tokenizer produces.
281    pub fn vocab_size(mut self, size: usize) -> Self {
282        self.vocab_size = Some(size);
283        self
284    }
285
286    /// Sets the parameter budget. Defaults to [`ParameterBudget::Params10M`].
287    pub fn budget(mut self, budget: ParameterBudget) -> Self {
288        self.budget = budget;
289        self
290    }
291
292    /// Sets the compute device. Defaults to [`default_device()`].
293    pub fn device(mut self, device: Device) -> Self {
294        self.device = Some(device);
295        self
296    }
297
298    /// Sets the batch size. Defaults to 4.
299    pub fn batch_size(mut self, size: usize) -> Self {
300        self.batch_size = size;
301        self
302    }
303
304    /// Sets the sequence length. Defaults to 128.
305    pub fn seq_len(mut self, len: usize) -> Self {
306        self.seq_len = len;
307        self
308    }
309
310    /// Sets the number of training steps. Defaults to 1000.
311    pub fn steps(mut self, steps: usize) -> Self {
312        self.steps = steps;
313        self
314    }
315
316    /// Sets the learning rate. Defaults to `2e-4`.
317    pub fn learning_rate(mut self, lr: f64) -> Self {
318        self.learning_rate = lr;
319        self
320    }
321
322    /// Sets the weight decay. Defaults to `0.01`.
323    pub fn weight_decay(mut self, wd: f64) -> Self {
324        self.weight_decay = wd;
325        self
326    }
327
328    /// Sets gradient clip norm. Defaults to `Some(1.0)`.
329    pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
330        self.grad_clip_norm = norm;
331        self
332    }
333
334    /// Sets the directory where checkpoints are saved.
335    pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
336        self.checkpoint_dir = Some(dir.into());
337        self
338    }
339
340    /// Sets how often (in steps) to save checkpoints. Defaults to 1000.
341    pub fn checkpoint_interval(mut self, steps: usize) -> Self {
342        self.checkpoint_interval = steps;
343        self
344    }
345
346    /// Overrides the run directory.
347    pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
348        self.run_dir = Some(dir.into());
349        self
350    }
351
352    /// Builds the [`Trainer`].
353    ///
354    /// This constructs the model config from the parameter budget + vocab size,
355    /// and creates the model.
356    pub fn build(self) -> Result<Trainer> {
357        let vocab_size = self.vocab_size.ok_or_else(|| {
358            Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
359        })?;
360
361        let device = match self.device {
362            Some(d) => d,
363            None => default_device()?,
364        };
365
366        let config =
367            MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
368        let model = DefaultMultiscreenModel::new(config, &device)?;
369
370        let training_config = ModelTrainingConfig {
371            steps: self.steps,
372            batch_size: self.batch_size,
373            learning_rate: self.learning_rate,
374            weight_decay: self.weight_decay,
375            grad_clip_norm: self.grad_clip_norm,
376            pad_token_id: 0,
377        };
378
379        let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
380
381        Ok(Trainer {
382            model,
383            training_config,
384            checkpoint_dir: self.checkpoint_dir,
385            checkpoint_interval: self.checkpoint_interval,
386            run_dir,
387        })
388    }
389}
390
391impl Default for TrainerBuilder {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397// ---------------------------------------------------------------------------
398// Tests
399// ---------------------------------------------------------------------------
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn builder_requires_vocab_size() {
407        let result = Trainer::builder().build();
408        assert!(result.is_err(), "build should fail without vocab_size");
409        let msg = format!("{}", result.unwrap_err());
410        assert!(
411            msg.contains("vocab_size"),
412            "error should mention vocab_size: {}",
413            msg
414        );
415    }
416
417    #[test]
418    fn training_report_from_model_report() {
419        let model_report = ModelTrainingReport {
420            steps: 500,
421            final_loss: 0.123,
422            training_window_count: 100,
423            parameter_count: 10_000_000,
424        };
425        let report =
426            TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
427        assert_eq!(report.steps, 500);
428        assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
429        assert_eq!(report.parameter_count, 10_000_000);
430        assert_eq!(
431            report.checkpoint_path.as_deref(),
432            Some("runs/checkpoint.mpk")
433        );
434    }
435
436    #[test]
437    fn builder_defaults() {
438        let builder = TrainerBuilder::new();
439        assert!(builder.vocab_size.is_none());
440        assert!(matches!(builder.budget, ParameterBudget::Params10M));
441        assert!(builder.device.is_none());
442        assert_eq!(builder.batch_size, 4);
443        assert_eq!(builder.seq_len, 128);
444        assert_eq!(builder.steps, 1000);
445        assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
446        assert_eq!(builder.checkpoint_interval, 1000);
447    }
448}