1use crate::error::{Error, Result};
38use crate::model::{
39 DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{Device, default_device};
42use std::fs;
43use std::path::Path;
44
45pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48#[derive(Clone, Debug)]
54pub struct TrainingReport {
55 pub steps: usize,
57 pub final_loss: f32,
59 pub best_loss: f32,
61 pub best_loss_step: usize,
63 pub parameter_count: usize,
65 pub checkpoint_path: Option<String>,
67}
68
69impl TrainingReport {
70 fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
71 Self {
72 steps: report.steps,
73 final_loss: report.final_loss,
74 best_loss: report.best_loss,
75 best_loss_step: report.best_loss_step,
76 parameter_count: report.parameter_count,
77 checkpoint_path,
78 }
79 }
80}
81
82pub struct Trainer {
91 model: DefaultMultiscreenModel,
92 training_config: ModelTrainingConfig,
93 checkpoint_dir: Option<String>,
94 checkpoint_interval: usize,
95 #[allow(dead_code)]
96 run_dir: Option<String>,
97}
98
99impl std::fmt::Debug for Trainer {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Trainer")
102 .field("training_config", &self.training_config)
103 .field("checkpoint_dir", &self.checkpoint_dir)
104 .field("checkpoint_interval", &self.checkpoint_interval)
105 .field("run_dir", &self.run_dir)
106 .finish_non_exhaustive()
107 }
108}
109
110impl Trainer {
111 pub fn builder() -> TrainerBuilder {
113 TrainerBuilder::new()
114 }
115
116 pub fn train_on_token_sequences_with_callback(
124 &mut self,
125 sequences: &[Vec<u32>],
126 on_step: impl FnMut(usize, f32),
127 ) -> Result<TrainingReport> {
128 if sequences.is_empty() {
129 return Err(Error::Training("no training sequences provided".into()));
130 }
131
132 let mut config = self.training_config.clone();
133 config.checkpoint_dir = self.checkpoint_dir.clone();
134 config.checkpoint_interval = self.checkpoint_interval;
135
136 let device = self.model_device();
137 let report = self
138 .model
139 .train_token_sequences(sequences, &config, &device, on_step)?;
140
141 let checkpoint_path = match &self.checkpoint_dir {
142 Some(dir) => {
143 let dir_path = Path::new(dir);
144 fs::create_dir_all(dir_path).map_err(|e| {
145 Error::Io(format!(
146 "failed to create checkpoint directory {:?}: {}",
147 dir, e
148 ))
149 })?;
150 let path = dir_path.join("checkpoint.mpk");
151 self.model.save_parameters(&path)?;
152 Some(path.to_string_lossy().into_owned())
153 }
154 None => None,
155 };
156
157 Ok(TrainingReport::from_model_report(&report, checkpoint_path))
158 }
159
160 pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
162 self.train_on_token_sequences_with_callback(sequences, |_, _| {})
163 }
164
165 pub fn train_on_chat_sequences_with_callback(
174 &mut self,
175 chat_pairs: &[(Vec<u32>, Vec<u32>)],
176 on_step: impl FnMut(usize, f32),
177 ) -> Result<TrainingReport> {
178 if chat_pairs.is_empty() {
179 return Err(Error::Training("no training chat pairs provided".into()));
180 }
181
182 let mut config = self.training_config.clone();
183 config.checkpoint_dir = self.checkpoint_dir.clone();
184 config.checkpoint_interval = self.checkpoint_interval;
185
186 let device = self.model_device();
187 let report = self
188 .model
189 .train_chat_sequences(chat_pairs, &config, &device, on_step)?;
190
191 let checkpoint_path = match &self.checkpoint_dir {
192 Some(dir) => {
193 let dir_path = Path::new(dir);
194 fs::create_dir_all(dir_path).map_err(|e| {
195 Error::Io(format!(
196 "failed to create checkpoint directory {:?}: {}",
197 dir, e
198 ))
199 })?;
200 let path = dir_path.join("checkpoint.mpk");
201 self.model.save_parameters(&path)?;
202 Some(path.to_string_lossy().into_owned())
203 }
204 None => None,
205 };
206
207 Ok(TrainingReport::from_model_report(&report, checkpoint_path))
208 }
209
210 pub fn train_on_chat_sequences(
212 &mut self,
213 chat_pairs: &[(Vec<u32>, Vec<u32>)],
214 ) -> Result<TrainingReport> {
215 self.train_on_chat_sequences_with_callback(chat_pairs, |_, _| {})
216 }
217
218 pub fn save_checkpoint(&self, path: &str) -> Result<()> {
220 if let Some(parent) = Path::new(path).parent() {
221 fs::create_dir_all(parent).map_err(|e| {
222 Error::Io(format!(
223 "failed to create checkpoint directory {:?}: {}",
224 parent, e
225 ))
226 })?;
227 }
228 self.model.save_parameters(path)
229 }
230
231 pub fn model(&self) -> &DefaultMultiscreenModel {
233 &self.model
234 }
235
236 pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
238 &mut self.model
239 }
240
241 pub fn training_config(&self) -> &ModelTrainingConfig {
243 &self.training_config
244 }
245
246 fn model_device(&self) -> Device {
248 Device::default()
249 }
250}
251
252pub struct TrainerBuilder {
258 vocab_size: Option<usize>,
259 budget: ParameterBudget,
260 device: Option<Device>,
261 batch_size: usize,
262 seq_len: usize,
263 steps: usize,
264 learning_rate: f64,
265 weight_decay: f64,
266 grad_clip_norm: Option<f64>,
267 checkpoint_dir: Option<String>,
268 checkpoint_interval: usize,
269 run_dir: Option<String>,
270}
271
272impl TrainerBuilder {
273 fn new() -> Self {
275 Self {
276 vocab_size: None,
277 budget: ParameterBudget::Params10M,
278 device: None,
279 batch_size: 4,
280 seq_len: 128,
281 steps: 1000,
282 learning_rate: 2e-4,
283 weight_decay: 0.01,
284 grad_clip_norm: Some(1.0),
285 checkpoint_dir: None,
286 checkpoint_interval: 1000,
287 run_dir: None,
288 }
289 }
290
291 pub fn vocab_size(mut self, size: usize) -> Self {
295 self.vocab_size = Some(size);
296 self
297 }
298
299 pub fn budget(mut self, budget: ParameterBudget) -> Self {
301 self.budget = budget;
302 self
303 }
304
305 pub fn device(mut self, device: Device) -> Self {
307 self.device = Some(device);
308 self
309 }
310
311 pub fn batch_size(mut self, size: usize) -> Self {
313 self.batch_size = size;
314 self
315 }
316
317 pub fn seq_len(mut self, len: usize) -> Self {
319 self.seq_len = len;
320 self
321 }
322
323 pub fn steps(mut self, steps: usize) -> Self {
325 self.steps = steps;
326 self
327 }
328
329 pub fn learning_rate(mut self, lr: f64) -> Self {
331 self.learning_rate = lr;
332 self
333 }
334
335 pub fn weight_decay(mut self, wd: f64) -> Self {
337 self.weight_decay = wd;
338 self
339 }
340
341 pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
343 self.grad_clip_norm = norm;
344 self
345 }
346
347 pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
349 self.checkpoint_dir = Some(dir.into());
350 self
351 }
352
353 pub fn checkpoint_interval(mut self, steps: usize) -> Self {
355 self.checkpoint_interval = steps;
356 self
357 }
358
359 pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
361 self.run_dir = Some(dir.into());
362 self
363 }
364
365 pub fn build(self) -> Result<Trainer> {
370 let vocab_size = self.vocab_size.ok_or_else(|| {
371 Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
372 })?;
373
374 let device = match self.device {
375 Some(d) => d,
376 None => default_device()?,
377 };
378
379 let config =
380 MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
381 let model = DefaultMultiscreenModel::new(config, &device)?;
382
383 let training_config = ModelTrainingConfig {
384 steps: self.steps,
385 batch_size: self.batch_size,
386 learning_rate: self.learning_rate,
387 weight_decay: self.weight_decay,
388 grad_clip_norm: self.grad_clip_norm,
389 pad_token_id: 0,
390 checkpoint_dir: None, checkpoint_interval: 0,
392 };
393
394 let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
395
396 Ok(Trainer {
397 model,
398 training_config,
399 checkpoint_dir: self.checkpoint_dir,
400 checkpoint_interval: self.checkpoint_interval,
401 run_dir,
402 })
403 }
404}
405
406impl Default for TrainerBuilder {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn builder_requires_vocab_size() {
422 let result = Trainer::builder().build();
423 assert!(result.is_err(), "build should fail without vocab_size");
424 let msg = format!("{}", result.unwrap_err());
425 assert!(
426 msg.contains("vocab_size"),
427 "error should mention vocab_size: {}",
428 msg
429 );
430 }
431
432 #[test]
433 fn training_report_from_model_report() {
434 let model_report = ModelTrainingReport {
435 steps: 500,
436 final_loss: 0.123,
437 best_loss: 0.100,
438 best_loss_step: 420,
439 training_window_count: 100,
440 parameter_count: 10_000_000,
441 };
442 let report =
443 TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
444 assert_eq!(report.steps, 500);
445 assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
446 assert!((report.best_loss - 0.100).abs() < f32::EPSILON);
447 assert_eq!(report.best_loss_step, 420);
448 assert_eq!(report.parameter_count, 10_000_000);
449 assert_eq!(
450 report.checkpoint_path.as_deref(),
451 Some("runs/checkpoint.mpk")
452 );
453 }
454
455 #[test]
456 fn builder_defaults() {
457 let builder = TrainerBuilder::new();
458 assert!(builder.vocab_size.is_none());
459 assert!(matches!(builder.budget, ParameterBudget::Params10M));
460 assert!(builder.device.is_none());
461 assert_eq!(builder.batch_size, 4);
462 assert_eq!(builder.seq_len, 128);
463 assert_eq!(builder.steps, 1000);
464 assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
465 assert_eq!(builder.checkpoint_interval, 1000);
466 }
467}