1use crate::error::{Error, Result};
38use crate::model::{
39 DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::Path;
44
45pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48#[derive(Clone, Debug)]
54pub struct TrainingReport {
55 pub steps: usize,
57 pub final_loss: f32,
59 pub parameter_count: usize,
61 pub checkpoint_path: Option<String>,
63}
64
65impl TrainingReport {
66 fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
67 Self {
68 steps: report.steps,
69 final_loss: report.final_loss,
70 parameter_count: report.parameter_count,
71 checkpoint_path,
72 }
73 }
74}
75
76pub struct Trainer {
85 model: DefaultMultiscreenModel,
86 training_config: ModelTrainingConfig,
87 checkpoint_dir: Option<String>,
88 #[allow(dead_code)]
89 checkpoint_interval: usize,
90 #[allow(dead_code)]
91 run_dir: Option<String>,
92}
93
94impl std::fmt::Debug for Trainer {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("Trainer")
97 .field("training_config", &self.training_config)
98 .field("checkpoint_dir", &self.checkpoint_dir)
99 .field("checkpoint_interval", &self.checkpoint_interval)
100 .field("run_dir", &self.run_dir)
101 .finish_non_exhaustive()
102 }
103}
104
105impl Trainer {
106 pub fn builder() -> TrainerBuilder {
108 TrainerBuilder::new()
109 }
110
111 pub fn train_on_token_sequences_with_callback(
119 &mut self,
120 sequences: &[Vec<u32>],
121 on_step: impl FnMut(usize, f32),
122 ) -> Result<TrainingReport> {
123 if sequences.is_empty() {
124 return Err(Error::Training("no training sequences provided".into()));
125 }
126
127 let device = self.model_device();
128 let report =
129 self.model
130 .train_token_sequences(sequences, &self.training_config, &device, on_step)?;
131
132 let checkpoint_path = match &self.checkpoint_dir {
133 Some(dir) => {
134 let dir_path = Path::new(dir);
135 fs::create_dir_all(dir_path).map_err(|e| {
136 Error::Io(format!(
137 "failed to create checkpoint directory {:?}: {}",
138 dir, e
139 ))
140 })?;
141 let path = dir_path.join("checkpoint.mpk");
142 self.model.save_parameters(&path)?;
143 Some(path.to_string_lossy().into_owned())
144 }
145 None => None,
146 };
147
148 Ok(TrainingReport::from_model_report(&report, checkpoint_path))
149 }
150
151 pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
153 self.train_on_token_sequences_with_callback(sequences, |_, _| {})
154 }
155
156 pub fn save_checkpoint(&self, path: &str) -> Result<()> {
158 if let Some(parent) = Path::new(path).parent() {
159 fs::create_dir_all(parent).map_err(|e| {
160 Error::Io(format!(
161 "failed to create checkpoint directory {:?}: {}",
162 parent, e
163 ))
164 })?;
165 }
166 self.model.save_parameters(path)
167 }
168
169 pub fn model(&self) -> &DefaultMultiscreenModel {
171 &self.model
172 }
173
174 pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
176 &mut self.model
177 }
178
179 pub fn training_config(&self) -> &ModelTrainingConfig {
181 &self.training_config
182 }
183
184 fn model_device(&self) -> Device {
186 Device::default()
187 }
188}
189
190pub struct TrainerBuilder {
196 vocab_size: Option<usize>,
197 budget: ParameterBudget,
198 device: Option<Device>,
199 batch_size: usize,
200 seq_len: usize,
201 steps: usize,
202 learning_rate: f64,
203 weight_decay: f64,
204 grad_clip_norm: Option<f64>,
205 checkpoint_dir: Option<String>,
206 checkpoint_interval: usize,
207 run_dir: Option<String>,
208}
209
210impl TrainerBuilder {
211 fn new() -> Self {
213 Self {
214 vocab_size: None,
215 budget: ParameterBudget::Params10M,
216 device: None,
217 batch_size: 4,
218 seq_len: 128,
219 steps: 1000,
220 learning_rate: 2e-4,
221 weight_decay: 0.01,
222 grad_clip_norm: Some(1.0),
223 checkpoint_dir: None,
224 checkpoint_interval: 1000,
225 run_dir: None,
226 }
227 }
228
229 pub fn vocab_size(mut self, size: usize) -> Self {
233 self.vocab_size = Some(size);
234 self
235 }
236
237 pub fn budget(mut self, budget: ParameterBudget) -> Self {
239 self.budget = budget;
240 self
241 }
242
243 pub fn device(mut self, device: Device) -> Self {
245 self.device = Some(device);
246 self
247 }
248
249 pub fn batch_size(mut self, size: usize) -> Self {
251 self.batch_size = size;
252 self
253 }
254
255 pub fn seq_len(mut self, len: usize) -> Self {
257 self.seq_len = len;
258 self
259 }
260
261 pub fn steps(mut self, steps: usize) -> Self {
263 self.steps = steps;
264 self
265 }
266
267 pub fn learning_rate(mut self, lr: f64) -> Self {
269 self.learning_rate = lr;
270 self
271 }
272
273 pub fn weight_decay(mut self, wd: f64) -> Self {
275 self.weight_decay = wd;
276 self
277 }
278
279 pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
281 self.grad_clip_norm = norm;
282 self
283 }
284
285 pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
287 self.checkpoint_dir = Some(dir.into());
288 self
289 }
290
291 pub fn checkpoint_interval(mut self, steps: usize) -> Self {
293 self.checkpoint_interval = steps;
294 self
295 }
296
297 pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
299 self.run_dir = Some(dir.into());
300 self
301 }
302
303 pub fn build(self) -> Result<Trainer> {
308 let vocab_size = self.vocab_size.ok_or_else(|| {
309 Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
310 })?;
311
312 let device = match self.device {
313 Some(d) => d,
314 None => default_device()?,
315 };
316
317 let config =
318 MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
319 let model = DefaultMultiscreenModel::new(config, &device)?;
320
321 let training_config = ModelTrainingConfig {
322 steps: self.steps,
323 batch_size: self.batch_size,
324 learning_rate: self.learning_rate,
325 weight_decay: self.weight_decay,
326 grad_clip_norm: self.grad_clip_norm,
327 pad_token_id: 0,
328 };
329
330 let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
331
332 Ok(Trainer {
333 model,
334 training_config,
335 checkpoint_dir: self.checkpoint_dir,
336 checkpoint_interval: self.checkpoint_interval,
337 run_dir,
338 })
339 }
340}
341
342impl Default for TrainerBuilder {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn builder_requires_vocab_size() {
358 let result = Trainer::builder().build();
359 assert!(result.is_err(), "build should fail without vocab_size");
360 let msg = format!("{}", result.unwrap_err());
361 assert!(
362 msg.contains("vocab_size"),
363 "error should mention vocab_size: {}",
364 msg
365 );
366 }
367
368 #[test]
369 fn training_report_from_model_report() {
370 let model_report = ModelTrainingReport {
371 steps: 500,
372 final_loss: 0.123,
373 training_window_count: 100,
374 parameter_count: 10_000_000,
375 };
376 let report =
377 TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
378 assert_eq!(report.steps, 500);
379 assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
380 assert_eq!(report.parameter_count, 10_000_000);
381 assert_eq!(
382 report.checkpoint_path.as_deref(),
383 Some("runs/checkpoint.mpk")
384 );
385 }
386
387 #[test]
388 fn builder_defaults() {
389 let builder = TrainerBuilder::new();
390 assert!(builder.vocab_size.is_none());
391 assert!(matches!(builder.budget, ParameterBudget::Params10M));
392 assert!(builder.device.is_none());
393 assert_eq!(builder.batch_size, 4);
394 assert_eq!(builder.seq_len, 128);
395 assert_eq!(builder.steps, 1000);
396 assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
397 assert_eq!(builder.checkpoint_interval, 1000);
398 }
399}