1use crate::error::{Error, Result};
38use crate::model::{
39 DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
40};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::Path;
44
45pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
47
48#[derive(Clone, Debug)]
54pub struct TrainingReport {
55 pub steps: usize,
57 pub final_loss: f32,
59 pub parameter_count: usize,
61 pub checkpoint_path: Option<String>,
63}
64
65impl TrainingReport {
66 fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
67 Self {
68 steps: report.steps,
69 final_loss: report.final_loss,
70 parameter_count: report.parameter_count,
71 checkpoint_path,
72 }
73 }
74}
75
76pub struct Trainer {
85 model: DefaultMultiscreenModel,
86 training_config: ModelTrainingConfig,
87 checkpoint_dir: Option<String>,
88 #[allow(dead_code)]
89 checkpoint_interval: usize,
90 #[allow(dead_code)]
91 run_dir: Option<String>,
92}
93
94impl std::fmt::Debug for Trainer {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("Trainer")
97 .field("training_config", &self.training_config)
98 .field("checkpoint_dir", &self.checkpoint_dir)
99 .field("checkpoint_interval", &self.checkpoint_interval)
100 .field("run_dir", &self.run_dir)
101 .finish_non_exhaustive()
102 }
103}
104
105impl Trainer {
106 pub fn builder() -> TrainerBuilder {
108 TrainerBuilder::new()
109 }
110
111 pub fn train_on_token_sequences_with_callback(
119 &mut self,
120 sequences: &[Vec<u32>],
121 on_step: impl FnMut(usize, f32),
122 ) -> Result<TrainingReport> {
123 if sequences.is_empty() {
124 return Err(Error::Training("no training sequences provided".into()));
125 }
126
127 let device = self.model_device();
128 let report =
129 self.model
130 .train_token_sequences(sequences, &self.training_config, &device, on_step)?;
131
132 let checkpoint_path = match &self.checkpoint_dir {
133 Some(dir) => {
134 let dir_path = Path::new(dir);
135 fs::create_dir_all(dir_path).map_err(|e| {
136 Error::Io(format!(
137 "failed to create checkpoint directory {:?}: {}",
138 dir, e
139 ))
140 })?;
141 let path = dir_path.join("checkpoint.mpk");
142 self.model.save_parameters(&path)?;
143 Some(path.to_string_lossy().into_owned())
144 }
145 None => None,
146 };
147
148 Ok(TrainingReport::from_model_report(&report, checkpoint_path))
149 }
150
151 pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
153 self.train_on_token_sequences_with_callback(sequences, |_, _| {})
154 }
155
156 pub fn train_on_chat_sequences_with_callback(
165 &mut self,
166 chat_pairs: &[(Vec<u32>, Vec<u32>)],
167 on_step: impl FnMut(usize, f32),
168 ) -> Result<TrainingReport> {
169 if chat_pairs.is_empty() {
170 return Err(Error::Training("no training chat pairs provided".into()));
171 }
172
173 let device = self.model_device();
174 let report =
175 self.model
176 .train_chat_sequences(chat_pairs, &self.training_config, &device, on_step)?;
177
178 let checkpoint_path = match &self.checkpoint_dir {
179 Some(dir) => {
180 let dir_path = Path::new(dir);
181 fs::create_dir_all(dir_path).map_err(|e| {
182 Error::Io(format!(
183 "failed to create checkpoint directory {:?}: {}",
184 dir, e
185 ))
186 })?;
187 let path = dir_path.join("checkpoint.mpk");
188 self.model.save_parameters(&path)?;
189 Some(path.to_string_lossy().into_owned())
190 }
191 None => None,
192 };
193
194 Ok(TrainingReport::from_model_report(&report, checkpoint_path))
195 }
196
197 pub fn train_on_chat_sequences(
199 &mut self,
200 chat_pairs: &[(Vec<u32>, Vec<u32>)],
201 ) -> Result<TrainingReport> {
202 self.train_on_chat_sequences_with_callback(chat_pairs, |_, _| {})
203 }
204
205 pub fn save_checkpoint(&self, path: &str) -> Result<()> {
207 if let Some(parent) = Path::new(path).parent() {
208 fs::create_dir_all(parent).map_err(|e| {
209 Error::Io(format!(
210 "failed to create checkpoint directory {:?}: {}",
211 parent, e
212 ))
213 })?;
214 }
215 self.model.save_parameters(path)
216 }
217
218 pub fn model(&self) -> &DefaultMultiscreenModel {
220 &self.model
221 }
222
223 pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
225 &mut self.model
226 }
227
228 pub fn training_config(&self) -> &ModelTrainingConfig {
230 &self.training_config
231 }
232
233 fn model_device(&self) -> Device {
235 Device::default()
236 }
237}
238
239pub struct TrainerBuilder {
245 vocab_size: Option<usize>,
246 budget: ParameterBudget,
247 device: Option<Device>,
248 batch_size: usize,
249 seq_len: usize,
250 steps: usize,
251 learning_rate: f64,
252 weight_decay: f64,
253 grad_clip_norm: Option<f64>,
254 checkpoint_dir: Option<String>,
255 checkpoint_interval: usize,
256 run_dir: Option<String>,
257}
258
259impl TrainerBuilder {
260 fn new() -> Self {
262 Self {
263 vocab_size: None,
264 budget: ParameterBudget::Params10M,
265 device: None,
266 batch_size: 4,
267 seq_len: 128,
268 steps: 1000,
269 learning_rate: 2e-4,
270 weight_decay: 0.01,
271 grad_clip_norm: Some(1.0),
272 checkpoint_dir: None,
273 checkpoint_interval: 1000,
274 run_dir: None,
275 }
276 }
277
278 pub fn vocab_size(mut self, size: usize) -> Self {
282 self.vocab_size = Some(size);
283 self
284 }
285
286 pub fn budget(mut self, budget: ParameterBudget) -> Self {
288 self.budget = budget;
289 self
290 }
291
292 pub fn device(mut self, device: Device) -> Self {
294 self.device = Some(device);
295 self
296 }
297
298 pub fn batch_size(mut self, size: usize) -> Self {
300 self.batch_size = size;
301 self
302 }
303
304 pub fn seq_len(mut self, len: usize) -> Self {
306 self.seq_len = len;
307 self
308 }
309
310 pub fn steps(mut self, steps: usize) -> Self {
312 self.steps = steps;
313 self
314 }
315
316 pub fn learning_rate(mut self, lr: f64) -> Self {
318 self.learning_rate = lr;
319 self
320 }
321
322 pub fn weight_decay(mut self, wd: f64) -> Self {
324 self.weight_decay = wd;
325 self
326 }
327
328 pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
330 self.grad_clip_norm = norm;
331 self
332 }
333
334 pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
336 self.checkpoint_dir = Some(dir.into());
337 self
338 }
339
340 pub fn checkpoint_interval(mut self, steps: usize) -> Self {
342 self.checkpoint_interval = steps;
343 self
344 }
345
346 pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
348 self.run_dir = Some(dir.into());
349 self
350 }
351
352 pub fn build(self) -> Result<Trainer> {
357 let vocab_size = self.vocab_size.ok_or_else(|| {
358 Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
359 })?;
360
361 let device = match self.device {
362 Some(d) => d,
363 None => default_device()?,
364 };
365
366 let config =
367 MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
368 let model = DefaultMultiscreenModel::new(config, &device)?;
369
370 let training_config = ModelTrainingConfig {
371 steps: self.steps,
372 batch_size: self.batch_size,
373 learning_rate: self.learning_rate,
374 weight_decay: self.weight_decay,
375 grad_clip_norm: self.grad_clip_norm,
376 pad_token_id: 0,
377 };
378
379 let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
380
381 Ok(Trainer {
382 model,
383 training_config,
384 checkpoint_dir: self.checkpoint_dir,
385 checkpoint_interval: self.checkpoint_interval,
386 run_dir,
387 })
388 }
389}
390
391impl Default for TrainerBuilder {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn builder_requires_vocab_size() {
407 let result = Trainer::builder().build();
408 assert!(result.is_err(), "build should fail without vocab_size");
409 let msg = format!("{}", result.unwrap_err());
410 assert!(
411 msg.contains("vocab_size"),
412 "error should mention vocab_size: {}",
413 msg
414 );
415 }
416
417 #[test]
418 fn training_report_from_model_report() {
419 let model_report = ModelTrainingReport {
420 steps: 500,
421 final_loss: 0.123,
422 training_window_count: 100,
423 parameter_count: 10_000_000,
424 };
425 let report =
426 TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
427 assert_eq!(report.steps, 500);
428 assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
429 assert_eq!(report.parameter_count, 10_000_000);
430 assert_eq!(
431 report.checkpoint_path.as_deref(),
432 Some("runs/checkpoint.mpk")
433 );
434 }
435
436 #[test]
437 fn builder_defaults() {
438 let builder = TrainerBuilder::new();
439 assert!(builder.vocab_size.is_none());
440 assert!(matches!(builder.budget, ParameterBudget::Params10M));
441 assert!(builder.device.is_none());
442 assert_eq!(builder.batch_size, 4);
443 assert_eq!(builder.seq_len, 128);
444 assert_eq!(builder.steps, 1000);
445 assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
446 assert_eq!(builder.checkpoint_interval, 1000);
447 }
448}