1use std::fmt::{Display, Formatter};
2
3use burn::config::Config;
4use burn::module::{AutodiffModule, Module};
5use burn::nn::loss::CrossEntropyLossConfig;
6use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
7use burn::prelude::*;
8use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
9use burn::tensor::activation::log_softmax;
10use burn::tensor::backend::AutodiffBackend;
11use serde::{Deserialize, Serialize};
12
13use crate::baseline::BaselineTransformer;
14use crate::checkpoint::LoadedTrainingArtifact;
15use crate::config::DdlConfig;
16use crate::data::{TokenDataset, TokenDatasetSummary};
17use crate::lm::{
18 CausalLmMetrics, aggregate_causal_lm_summaries, causal_language_model_summary_with_lengths,
19};
20use crate::spectral::{BetaHistogram, DeltaRegime, SpectralDiagnostics};
21use crate::transformer::DdlTransformer;
22use crate::variant::{ModelInstance, ModelVariant};
23
24#[derive(Config, Debug, PartialEq)]
25pub struct TrainingConfig {
26 #[config(default = 32)]
27 pub max_steps: usize,
28 #[config(default = 4)]
29 pub eval_interval: usize,
30 #[config(default = 1e-3)]
31 pub learning_rate: f64,
32 #[config(default = 0.0)]
33 pub min_learning_rate: f64,
34 #[config(default = 0)]
35 pub warmup_steps: usize,
36 #[config(default = 0.1)]
37 pub weight_decay: f64,
38 #[config(default = 0.9)]
39 pub beta1: f64,
40 #[config(default = 0.95)]
41 pub beta2: f64,
42 #[config(default = 1e-5)]
43 pub epsilon: f64,
44 #[config(default = 1.0)]
45 pub grad_clip: f64,
46}
47
48impl TrainingConfig {
49 pub fn validate(&self) -> Result<(), TrainingError> {
50 if self.max_steps == 0 {
51 return Err(TrainingError::InvalidConfig(
52 "max_steps must be greater than zero",
53 ));
54 }
55 if self.eval_interval == 0 {
56 return Err(TrainingError::InvalidConfig(
57 "eval_interval must be greater than zero",
58 ));
59 }
60 if !(self.learning_rate.is_finite() && self.learning_rate > 0.0) {
61 return Err(TrainingError::InvalidConfig(
62 "learning_rate must be finite and positive",
63 ));
64 }
65 if !(self.min_learning_rate.is_finite() && self.min_learning_rate >= 0.0) {
66 return Err(TrainingError::InvalidConfig(
67 "min_learning_rate must be finite and non-negative",
68 ));
69 }
70 if self.min_learning_rate > self.learning_rate {
71 return Err(TrainingError::InvalidConfig(
72 "min_learning_rate cannot exceed learning_rate",
73 ));
74 }
75 if self.warmup_steps > self.max_steps {
76 return Err(TrainingError::InvalidConfig(
77 "warmup_steps cannot exceed max_steps",
78 ));
79 }
80 if !(self.weight_decay.is_finite() && self.weight_decay >= 0.0) {
81 return Err(TrainingError::InvalidConfig(
82 "weight_decay must be finite and non-negative",
83 ));
84 }
85 if !(self.beta1.is_finite() && (0.0..1.0).contains(&self.beta1)) {
86 return Err(TrainingError::InvalidConfig("beta1 must be in [0, 1)"));
87 }
88 if !(self.beta2.is_finite() && (0.0..1.0).contains(&self.beta2)) {
89 return Err(TrainingError::InvalidConfig("beta2 must be in [0, 1)"));
90 }
91 if !(self.epsilon.is_finite() && self.epsilon > 0.0) {
92 return Err(TrainingError::InvalidConfig(
93 "epsilon must be finite and positive",
94 ));
95 }
96 if !(self.grad_clip.is_finite() && self.grad_clip > 0.0) {
97 return Err(TrainingError::InvalidConfig(
98 "grad_clip must be finite and positive",
99 ));
100 }
101
102 Ok(())
103 }
104
105 fn optimizer<M, B>(&self) -> impl Optimizer<M, B>
106 where
107 B: AutodiffBackend,
108 M: AutodiffModule<B>,
109 {
110 AdamWConfig::new()
111 .with_beta_1(self.beta1 as f32)
112 .with_beta_2(self.beta2 as f32)
113 .with_epsilon(self.epsilon as f32)
114 .with_weight_decay(self.weight_decay as f32)
115 .with_grad_clipping(Some(burn::grad_clipping::GradientClippingConfig::Norm(
116 self.grad_clip as f32,
117 )))
118 .init::<B, M>()
119 }
120}
121
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct TrainingStepMetrics {
124 pub step: usize,
125 pub learning_rate: f64,
126 pub train: CausalLmMetrics,
127 pub validation: Option<CausalLmMetrics>,
128 pub train_spectral: Option<TrainingSpectralSnapshot>,
129 pub validation_spectral: Option<TrainingSpectralSnapshot>,
130}
131
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct TrainingSpectralSnapshot {
134 pub beta_per_layer: Vec<f32>,
135 pub k_eigenvalue_per_layer: Vec<f32>,
136 pub spatial_determinant_per_layer: Vec<f32>,
137 pub lifted_determinant_per_layer: Vec<f32>,
138 pub regime_per_layer: Vec<DeltaRegime>,
139 pub beta_histogram: BetaHistogram,
140 pub k_coherence_per_layer: Vec<f32>,
141 pub correction_norm_per_layer: Vec<f32>,
142}
143
144impl From<&SpectralDiagnostics> for TrainingSpectralSnapshot {
145 fn from(spectral: &SpectralDiagnostics) -> Self {
146 Self {
147 beta_per_layer: spectral.beta_per_layer.clone(),
148 k_eigenvalue_per_layer: spectral.k_eigenvalue_per_layer.clone(),
149 spatial_determinant_per_layer: spectral.spatial_determinant_per_layer.clone(),
150 lifted_determinant_per_layer: spectral.lifted_determinant_per_layer.clone(),
151 regime_per_layer: spectral.regime_per_layer.clone(),
152 beta_histogram: spectral.beta_histogram,
153 k_coherence_per_layer: spectral.k_coherence_per_layer.clone(),
154 correction_norm_per_layer: spectral.correction_norm_per_layer.clone(),
155 }
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct TrainingReport {
161 pub variant: ModelVariant,
162 pub config: DdlConfig,
163 pub training: TrainingConfig,
164 pub num_params: usize,
165 pub train_dataset: TokenDatasetSummary,
166 pub validation_dataset: Option<TokenDatasetSummary>,
167 pub steps_completed: usize,
168 pub initial_train: CausalLmMetrics,
169 pub initial_validation: Option<CausalLmMetrics>,
170 pub final_train: CausalLmMetrics,
171 pub final_validation: Option<CausalLmMetrics>,
172 pub best_validation: Option<CausalLmMetrics>,
173 pub initial_train_spectral: Option<TrainingSpectralSnapshot>,
174 pub initial_validation_spectral: Option<TrainingSpectralSnapshot>,
175 pub final_train_spectral: Option<TrainingSpectralSnapshot>,
176 pub final_validation_spectral: Option<TrainingSpectralSnapshot>,
177 pub best_validation_spectral: Option<TrainingSpectralSnapshot>,
178 pub best_validation_step: Option<usize>,
179 pub history: Vec<TrainingStepMetrics>,
180}
181
182#[derive(Debug)]
183pub struct TrainingOutcome<B: Backend> {
184 pub report: TrainingReport,
185 pub model: ModelInstance<B>,
186 pub best_validation_model: Option<ModelInstance<B>>,
187 pub optimizer_state: Vec<u8>,
188}
189
190#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct TrainingComparisonReport {
193 pub train_dataset: TokenDatasetSummary,
194 pub validation_dataset: Option<TokenDatasetSummary>,
195 pub training: TrainingConfig,
196 pub reports: Vec<TrainingReport>,
197 pub final_train_loss_ranking: Vec<ModelVariant>,
198 pub best_final_train_variant: ModelVariant,
199 pub final_validation_loss_ranking: Option<Vec<ModelVariant>>,
200 pub best_final_validation_variant: Option<ModelVariant>,
201}
202
203impl TrainingComparisonReport {
204 pub fn report(&self, variant: ModelVariant) -> Option<&TrainingReport> {
205 self.reports.iter().find(|report| report.variant == variant)
206 }
207
208 pub fn best_final_train(&self) -> Option<&TrainingReport> {
209 self.report(self.best_final_train_variant)
210 }
211
212 pub fn best_final_validation(&self) -> Option<&TrainingReport> {
213 self.best_final_validation_variant
214 .and_then(|variant| self.report(variant))
215 }
216}
217
218#[derive(Debug)]
219pub struct TrainingSweepOutcome<B: Backend> {
220 pub report: TrainingComparisonReport,
221 pub outcomes: Vec<TrainingOutcome<B>>,
222}
223
224impl<B: Backend> TrainingSweepOutcome<B> {
225 pub fn outcome(&self, variant: ModelVariant) -> Option<&TrainingOutcome<B>> {
226 self.outcomes
227 .iter()
228 .find(|outcome| outcome.report.variant == variant)
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq)]
233pub struct CosineWarmupSchedule {
234 base_learning_rate: f64,
235 min_learning_rate: f64,
236 warmup_steps: usize,
237 total_steps: usize,
238}
239
240impl CosineWarmupSchedule {
241 pub fn new(
242 base_learning_rate: f64,
243 min_learning_rate: f64,
244 warmup_steps: usize,
245 total_steps: usize,
246 ) -> Result<Self, TrainingError> {
247 if total_steps == 0 {
248 return Err(TrainingError::InvalidConfig(
249 "total_steps must be greater than zero",
250 ));
251 }
252 if warmup_steps > total_steps {
253 return Err(TrainingError::InvalidConfig(
254 "warmup_steps cannot exceed total_steps",
255 ));
256 }
257 if min_learning_rate > base_learning_rate {
258 return Err(TrainingError::InvalidConfig(
259 "min_learning_rate cannot exceed base_learning_rate",
260 ));
261 }
262
263 Ok(Self {
264 base_learning_rate,
265 min_learning_rate,
266 warmup_steps,
267 total_steps,
268 })
269 }
270
271 pub fn learning_rate(&self, step: usize) -> f64 {
272 let step = step.min(self.total_steps.saturating_sub(1));
273
274 if self.warmup_steps > 0 && step < self.warmup_steps {
275 let progress = (step + 1) as f64 / self.warmup_steps as f64;
276 return self.base_learning_rate * progress;
277 }
278
279 let cosine_steps = self.total_steps.saturating_sub(self.warmup_steps);
280 if cosine_steps <= 1 {
281 return self.base_learning_rate;
282 }
283
284 let cosine_step = step.saturating_sub(self.warmup_steps);
285 let progress = cosine_step as f64 / (cosine_steps - 1) as f64;
286 let cosine = (std::f64::consts::PI * progress).cos();
287 self.min_learning_rate
288 + (self.base_learning_rate - self.min_learning_rate) * 0.5 * (1.0 + cosine)
289 }
290}
291
292#[derive(Debug, Clone)]
293struct DatasetTelemetry {
294 metrics: CausalLmMetrics,
295 spectral: Option<TrainingSpectralSnapshot>,
296}
297
298#[derive(Debug)]
299struct ResumeState<M> {
300 report: TrainingReport,
301 best_validation_model: Option<M>,
302 optimizer_state: Vec<u8>,
303}
304
305#[derive(Debug)]
306struct ForwardPassOutput<B: Backend> {
307 logits: Tensor<B, 3>,
308 spectral: Option<TrainingSpectralSnapshot>,
309}
310
311#[derive(Debug, Default)]
312struct SpectralAccumulator {
313 beta_per_layer: Vec<f64>,
314 k_eigenvalue_per_layer: Vec<f64>,
315 spatial_determinant_per_layer: Vec<f64>,
316 lifted_determinant_per_layer: Vec<f64>,
317 k_coherence_per_layer: Vec<f64>,
318 correction_norm_per_layer: Vec<f64>,
319 total_weight: f64,
320}
321
322impl SpectralAccumulator {
323 fn observe(&mut self, spectral: &TrainingSpectralSnapshot, weight: f64) {
324 if self.beta_per_layer.is_empty() {
325 self.beta_per_layer = vec![0.0; spectral.beta_per_layer.len()];
326 self.k_eigenvalue_per_layer = vec![0.0; spectral.k_eigenvalue_per_layer.len()];
327 self.spatial_determinant_per_layer =
328 vec![0.0; spectral.spatial_determinant_per_layer.len()];
329 self.lifted_determinant_per_layer =
330 vec![0.0; spectral.lifted_determinant_per_layer.len()];
331 self.k_coherence_per_layer = vec![0.0; spectral.k_coherence_per_layer.len()];
332 self.correction_norm_per_layer = vec![0.0; spectral.correction_norm_per_layer.len()];
333 }
334
335 accumulate_weighted(
336 &mut self.beta_per_layer,
337 &spectral.beta_per_layer,
338 weight,
339 "beta_per_layer",
340 );
341 accumulate_weighted(
342 &mut self.k_eigenvalue_per_layer,
343 &spectral.k_eigenvalue_per_layer,
344 weight,
345 "k_eigenvalue_per_layer",
346 );
347 accumulate_weighted(
348 &mut self.spatial_determinant_per_layer,
349 &spectral.spatial_determinant_per_layer,
350 weight,
351 "spatial_determinant_per_layer",
352 );
353 accumulate_weighted(
354 &mut self.lifted_determinant_per_layer,
355 &spectral.lifted_determinant_per_layer,
356 weight,
357 "lifted_determinant_per_layer",
358 );
359 accumulate_weighted(
360 &mut self.k_coherence_per_layer,
361 &spectral.k_coherence_per_layer,
362 weight,
363 "k_coherence_per_layer",
364 );
365 accumulate_weighted(
366 &mut self.correction_norm_per_layer,
367 &spectral.correction_norm_per_layer,
368 weight,
369 "correction_norm_per_layer",
370 );
371
372 self.total_weight += weight;
373 }
374
375 fn finish(self) -> Option<TrainingSpectralSnapshot> {
376 if self.total_weight == 0.0 {
377 return None;
378 }
379
380 let beta_per_layer = average_weighted(self.beta_per_layer, self.total_weight);
381 let k_eigenvalue_per_layer =
382 average_weighted(self.k_eigenvalue_per_layer, self.total_weight);
383 let spatial_determinant_per_layer =
384 average_weighted(self.spatial_determinant_per_layer, self.total_weight);
385 let lifted_determinant_per_layer =
386 average_weighted(self.lifted_determinant_per_layer, self.total_weight);
387 let k_coherence_per_layer = average_weighted(self.k_coherence_per_layer, self.total_weight);
388 let correction_norm_per_layer =
389 average_weighted(self.correction_norm_per_layer, self.total_weight);
390 let regime_per_layer = beta_per_layer
391 .iter()
392 .map(|beta| DeltaRegime::from_beta(*beta))
393 .collect::<Vec<_>>();
394 let beta_histogram = BetaHistogram::from_betas(&beta_per_layer);
395
396 Some(TrainingSpectralSnapshot {
397 beta_per_layer,
398 k_eigenvalue_per_layer,
399 spatial_determinant_per_layer,
400 lifted_determinant_per_layer,
401 regime_per_layer,
402 beta_histogram,
403 k_coherence_per_layer,
404 correction_norm_per_layer,
405 })
406 }
407}
408
409#[derive(Debug)]
410pub enum TrainingError {
411 EmptyTrainingDataset,
412 EmptyValidationDataset,
413 EmptyVariantSet,
414 InvalidConfig(&'static str),
415 OptimizerState(String),
416 ResumeStateMismatch(String),
417}
418
419impl Display for TrainingError {
420 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
421 match self {
422 Self::EmptyTrainingDataset => {
423 write!(f, "training dataset does not contain any batches")
424 }
425 Self::EmptyValidationDataset => {
426 write!(f, "validation dataset does not contain any batches")
427 }
428 Self::EmptyVariantSet => write!(f, "at least one model variant is required"),
429 Self::InvalidConfig(message) => write!(f, "invalid training configuration: {message}"),
430 Self::OptimizerState(message) => {
431 write!(f, "optimizer state serialization failed: {message}")
432 }
433 Self::ResumeStateMismatch(message) => {
434 write!(f, "invalid resume state: {message}")
435 }
436 }
437 }
438}
439
440impl std::error::Error for TrainingError {}
441
442pub fn train_variant<B>(
443 base_config: &DdlConfig,
444 variant: ModelVariant,
445 device: &B::Device,
446 train_dataset: &TokenDataset,
447 validation_dataset: Option<&TokenDataset>,
448 training_config: &TrainingConfig,
449) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
450where
451 B: AutodiffBackend,
452{
453 training_config.validate()?;
454 if train_dataset.batches().is_empty() {
455 return Err(TrainingError::EmptyTrainingDataset);
456 }
457 if validation_dataset.is_some_and(|dataset| dataset.batches().is_empty()) {
458 return Err(TrainingError::EmptyValidationDataset);
459 }
460
461 let resolved_config = variant.resolve_config(base_config);
462 let datasets = DatasetPair {
463 train: train_dataset,
464 validation: validation_dataset,
465 };
466 match variant {
467 ModelVariant::Baseline => train_model(
468 variant,
469 resolved_config.clone(),
470 BaselineTransformer::<B>::new(&resolved_config, device),
471 |model| ModelInstance::Baseline(Box::new(model)),
472 device,
473 datasets,
474 training_config,
475 None,
476 ),
477 _ => train_model(
478 variant,
479 resolved_config.clone(),
480 resolved_config.init::<B>(device),
481 |model| ModelInstance::Ddl(Box::new(model)),
482 device,
483 datasets,
484 training_config,
485 None,
486 ),
487 }
488}
489
490pub fn resume_training<B>(
491 artifact: LoadedTrainingArtifact<B>,
492 device: &B::Device,
493 train_dataset: &TokenDataset,
494 validation_dataset: Option<&TokenDataset>,
495 training_config: &TrainingConfig,
496) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
497where
498 B: AutodiffBackend,
499{
500 training_config.validate()?;
501 if train_dataset.batches().is_empty() {
502 return Err(TrainingError::EmptyTrainingDataset);
503 }
504 if validation_dataset.is_some_and(|dataset| dataset.batches().is_empty()) {
505 return Err(TrainingError::EmptyValidationDataset);
506 }
507
508 let train_summary = train_dataset.summary();
509 if artifact.report.train_dataset != train_summary {
510 return Err(TrainingError::ResumeStateMismatch(format!(
511 "training dataset summary {:?} does not match saved artifact {:?}",
512 train_summary, artifact.report.train_dataset
513 )));
514 }
515
516 let validation_summary = validation_dataset.map(TokenDataset::summary);
517 if artifact.report.validation_dataset != validation_summary {
518 return Err(TrainingError::ResumeStateMismatch(format!(
519 "validation dataset summary {:?} does not match saved artifact {:?}",
520 validation_summary, artifact.report.validation_dataset
521 )));
522 }
523
524 if artifact.report.steps_completed != artifact.report.history.len() {
525 return Err(TrainingError::ResumeStateMismatch(format!(
526 "saved report recorded {} completed steps but {} history entries",
527 artifact.report.steps_completed,
528 artifact.report.history.len()
529 )));
530 }
531
532 let datasets = DatasetPair {
533 train: train_dataset,
534 validation: validation_dataset,
535 };
536 let optimizer_state = if artifact.report.steps_completed == 0 {
537 Vec::new()
538 } else {
539 artifact.optimizer_state.ok_or_else(|| {
540 TrainingError::ResumeStateMismatch(
541 "saved training artifact is missing optimizer state".to_string(),
542 )
543 })?
544 };
545 match artifact.model {
546 ModelInstance::Baseline(model) => {
547 if artifact.report.variant.uses_ddl() {
548 return Err(TrainingError::ResumeStateMismatch(format!(
549 "saved report expects variant {} but checkpoint contains a baseline model",
550 artifact.report.variant.slug()
551 )));
552 }
553 let best_validation_model = match artifact.best_validation_model {
554 Some(ModelInstance::Baseline(model)) => Some(model.valid()),
555 Some(ModelInstance::Ddl(_)) => {
556 return Err(TrainingError::ResumeStateMismatch(
557 "best-validation checkpoint kind does not match the baseline artifact"
558 .to_string(),
559 ));
560 }
561 None => None,
562 };
563 train_model(
564 artifact.report.variant,
565 artifact.report.config.clone(),
566 *model,
567 |model| ModelInstance::Baseline(Box::new(model)),
568 device,
569 datasets,
570 training_config,
571 Some(ResumeState {
572 report: artifact.report,
573 best_validation_model,
574 optimizer_state,
575 }),
576 )
577 }
578 ModelInstance::Ddl(model) => {
579 if !artifact.report.variant.uses_ddl() {
580 return Err(TrainingError::ResumeStateMismatch(format!(
581 "saved report expects variant {} but checkpoint contains a DDL model",
582 artifact.report.variant.slug()
583 )));
584 }
585 let best_validation_model = match artifact.best_validation_model {
586 Some(ModelInstance::Ddl(model)) => Some(model.valid()),
587 Some(ModelInstance::Baseline(_)) => {
588 return Err(TrainingError::ResumeStateMismatch(
589 "best-validation checkpoint kind does not match the DDL artifact"
590 .to_string(),
591 ));
592 }
593 None => None,
594 };
595 train_model(
596 artifact.report.variant,
597 artifact.report.config.clone(),
598 *model,
599 |model| ModelInstance::Ddl(Box::new(model)),
600 device,
601 datasets,
602 training_config,
603 Some(ResumeState {
604 report: artifact.report,
605 best_validation_model,
606 optimizer_state,
607 }),
608 )
609 }
610 }
611}
612
613pub fn train_variants<B>(
615 base_config: &DdlConfig,
616 variants: &[ModelVariant],
617 device: &B::Device,
618 train_dataset: &TokenDataset,
619 validation_dataset: Option<&TokenDataset>,
620 training_config: &TrainingConfig,
621) -> Result<TrainingSweepOutcome<B::InnerBackend>, TrainingError>
622where
623 B: AutodiffBackend,
624{
625 if variants.is_empty() {
626 return Err(TrainingError::EmptyVariantSet);
627 }
628
629 let mut outcomes = Vec::with_capacity(variants.len());
630 for variant in variants {
631 outcomes.push(train_variant::<B>(
632 base_config,
633 *variant,
634 device,
635 train_dataset,
636 validation_dataset,
637 training_config,
638 )?);
639 }
640
641 let reports = outcomes
642 .iter()
643 .map(|outcome| outcome.report.clone())
644 .collect::<Vec<_>>();
645 let final_train_loss_ranking = rank_variants_by_loss(
646 reports
647 .iter()
648 .map(|report| (report.variant, report.final_train.loss)),
649 );
650 let final_validation_loss_ranking = reports
651 .iter()
652 .map(|report| {
653 report
654 .final_validation
655 .map(|metrics| (report.variant, metrics.loss))
656 })
657 .collect::<Option<Vec<_>>>()
658 .map(rank_variants_by_loss);
659
660 Ok(TrainingSweepOutcome {
661 report: TrainingComparisonReport {
662 train_dataset: train_dataset.summary(),
663 validation_dataset: validation_dataset.map(TokenDataset::summary),
664 training: training_config.clone(),
665 reports,
666 best_final_train_variant: final_train_loss_ranking[0],
667 final_train_loss_ranking,
668 best_final_validation_variant: final_validation_loss_ranking
669 .as_ref()
670 .and_then(|ranking| ranking.first().copied()),
671 final_validation_loss_ranking,
672 },
673 outcomes,
674 })
675}
676
677#[derive(Clone, Copy)]
678struct DatasetPair<'a> {
679 train: &'a TokenDataset,
680 validation: Option<&'a TokenDataset>,
681}
682
683#[allow(clippy::too_many_arguments)]
684fn train_model<B, M, F>(
685 variant: ModelVariant,
686 resolved_config: DdlConfig,
687 mut model: M,
688 into_model_instance: F,
689 device: &B::Device,
690 datasets: DatasetPair<'_>,
691 training_config: &TrainingConfig,
692 resume: Option<ResumeState<M::InnerModule>>,
693) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
694where
695 B: AutodiffBackend,
696 M: CausalLmModel<B> + AutodiffModule<B> + Module<B>,
697 M::InnerModule: CausalLmModel<B::InnerBackend> + Clone,
698 F: Fn(M::InnerModule) -> ModelInstance<B::InnerBackend> + Copy,
699{
700 let mut optimizer = training_config.optimizer::<M, B>();
701 let (
702 initial_train,
703 initial_validation,
704 initial_train_spectral,
705 initial_validation_spectral,
706 mut best_validation,
707 mut best_validation_spectral,
708 mut best_validation_step,
709 mut best_validation_model,
710 mut history,
711 step_offset,
712 resume_optimizer_state,
713 ) = match resume {
714 Some(resume) => (
715 resume.report.initial_train,
716 resume.report.initial_validation,
717 resume.report.initial_train_spectral,
718 resume.report.initial_validation_spectral,
719 resume.report.best_validation,
720 resume.report.best_validation_spectral,
721 resume.report.best_validation_step,
722 resume.best_validation_model,
723 resume.report.history,
724 resume.report.steps_completed,
725 Some(resume.optimizer_state),
726 ),
727 None => {
728 let initial_model = model.valid();
729 let initial_train_eval =
730 evaluate_dataset_with_telemetry(&initial_model, datasets.train, device);
731 let initial_validation_eval = datasets
732 .validation
733 .map(|dataset| evaluate_dataset_with_telemetry(&initial_model, dataset, device));
734 let initial_train = initial_train_eval.metrics;
735 let initial_validation = initial_validation_eval
736 .as_ref()
737 .map(|evaluation| evaluation.metrics);
738
739 (
740 initial_train,
741 initial_validation,
742 initial_train_eval.spectral,
743 initial_validation_eval
744 .clone()
745 .and_then(|evaluation| evaluation.spectral),
746 initial_validation,
747 initial_validation_eval
748 .as_ref()
749 .and_then(|evaluation| evaluation.spectral.clone()),
750 initial_validation.map(|_| 0),
751 initial_validation.map(|_| initial_model.clone()),
752 Vec::with_capacity(training_config.max_steps),
753 0,
754 None,
755 )
756 }
757 };
758 if let Some(optimizer_state) = resume_optimizer_state {
759 optimizer = load_optimizer_state::<_, M, B>(optimizer, &optimizer_state, device)?;
760 }
761 let schedule = schedule_for_phase(
762 training_config,
763 history.last().map(|step| step.learning_rate),
764 )?;
765
766 for step in 0..training_config.max_steps {
767 let record_spectral = (step + 1) % training_config.eval_interval == 0
768 || step + 1 == training_config.max_steps;
769 let batch = &datasets.train.batches()[step % datasets.train.num_batches()];
770 let input_ids = batch.to_tensor(device);
771 let forward = if record_spectral {
772 model.forward_with_spectral_snapshot(input_ids.clone(), None)
773 } else {
774 ForwardPassOutput {
775 logits: model.forward_logits(input_ids.clone(), None),
776 spectral: None,
777 }
778 };
779 let logits = forward.logits;
780 let loss = causal_language_model_training_loss(logits, input_ids, batch.sequence_lengths());
781 let train = CausalLmMetrics {
782 loss: scalar_from_autodiff_tensor(loss.clone()),
783 perplexity: scalar_from_autodiff_tensor(loss.clone()).exp(),
784 };
785 let grads = GradientsParams::from_grads(loss.backward(), &model);
786 let lr = schedule.learning_rate(step);
787 model = optimizer.step(lr, model, grads);
788
789 let mut validation_model = None;
790 let validation_eval = if record_spectral {
791 let valid_model = model.valid();
792 let evaluation = datasets
793 .validation
794 .map(|dataset| evaluate_dataset_with_telemetry(&valid_model, dataset, device));
795 validation_model = Some(valid_model);
796 evaluation
797 } else {
798 None
799 };
800 let validation = validation_eval
801 .as_ref()
802 .map(|evaluation| evaluation.metrics);
803
804 if let Some((metrics, spectral)) = validation_eval
805 .as_ref()
806 .map(|evaluation| (evaluation.metrics, evaluation.spectral.clone()))
807 {
808 let should_update = match best_validation {
809 Some(best) => metrics.loss < best.loss,
810 None => true,
811 };
812 if should_update {
813 best_validation = Some(metrics);
814 best_validation_spectral = spectral;
815 best_validation_step = Some(step_offset + step + 1);
816 best_validation_model = validation_model.clone();
817 }
818 }
819
820 history.push(TrainingStepMetrics {
821 step: step_offset + step + 1,
822 learning_rate: lr,
823 train,
824 validation,
825 train_spectral: forward.spectral,
826 validation_spectral: validation_eval.and_then(|evaluation| evaluation.spectral),
827 });
828 }
829
830 let model = model.valid();
831 let final_train_eval = evaluate_dataset_with_telemetry(&model, datasets.train, device);
832 let final_validation_eval = datasets
833 .validation
834 .map(|dataset| evaluate_dataset_with_telemetry(&model, dataset, device));
835 let final_train = final_train_eval.metrics;
836 let final_validation = final_validation_eval
837 .as_ref()
838 .map(|evaluation| evaluation.metrics);
839 if let Some((metrics, spectral)) = final_validation_eval
840 .as_ref()
841 .map(|evaluation| (evaluation.metrics, evaluation.spectral.clone()))
842 {
843 let should_update = match best_validation {
844 Some(best) => metrics.loss < best.loss,
845 None => true,
846 };
847 if should_update {
848 best_validation = Some(metrics);
849 best_validation_spectral = spectral;
850 best_validation_step = Some(step_offset + training_config.max_steps);
851 best_validation_model = Some(model.clone());
852 }
853 }
854 let num_params = model.num_params();
855
856 Ok(TrainingOutcome {
857 report: TrainingReport {
858 variant,
859 config: resolved_config,
860 training: training_config.clone(),
861 num_params,
862 train_dataset: datasets.train.summary(),
863 validation_dataset: datasets.validation.map(TokenDataset::summary),
864 steps_completed: history.len(),
865 initial_train,
866 initial_validation,
867 final_train,
868 final_validation,
869 best_validation,
870 initial_train_spectral,
871 initial_validation_spectral,
872 final_train_spectral: final_train_eval.spectral,
873 final_validation_spectral: final_validation_eval
874 .and_then(|evaluation| evaluation.spectral),
875 best_validation_spectral,
876 best_validation_step,
877 history,
878 },
879 model: into_model_instance(model),
880 best_validation_model: best_validation_model.map(into_model_instance),
881 optimizer_state: save_optimizer_state::<_, M, B>(&optimizer)?,
882 })
883}
884
885trait CausalLmModel<B: Backend> {
886 fn forward_logits(
887 &self,
888 input_ids: Tensor<B, 2, Int>,
889 mask: Option<&Tensor<B, 3>>,
890 ) -> Tensor<B, 3>;
891
892 fn forward_with_spectral_snapshot(
893 &self,
894 input_ids: Tensor<B, 2, Int>,
895 mask: Option<&Tensor<B, 3>>,
896 ) -> ForwardPassOutput<B> {
897 ForwardPassOutput {
898 logits: self.forward_logits(input_ids, mask),
899 spectral: None,
900 }
901 }
902}
903
904impl<B: Backend> CausalLmModel<B> for BaselineTransformer<B> {
905 fn forward_logits(
906 &self,
907 input_ids: Tensor<B, 2, Int>,
908 mask: Option<&Tensor<B, 3>>,
909 ) -> Tensor<B, 3> {
910 self.forward_logits(input_ids, mask)
911 }
912}
913
914impl<B: Backend> CausalLmModel<B> for DdlTransformer<B> {
915 fn forward_logits(
916 &self,
917 input_ids: Tensor<B, 2, Int>,
918 mask: Option<&Tensor<B, 3>>,
919 ) -> Tensor<B, 3> {
920 self.forward_logits(input_ids, mask)
921 }
922
923 #[cfg(feature = "spectral")]
924 fn forward_with_spectral_snapshot(
925 &self,
926 input_ids: Tensor<B, 2, Int>,
927 mask: Option<&Tensor<B, 3>>,
928 ) -> ForwardPassOutput<B> {
929 let (logits, _, spectral) = self.forward_with_spectral_diagnostics(input_ids, mask);
930 ForwardPassOutput {
931 logits,
932 spectral: Some(TrainingSpectralSnapshot::from(&spectral)),
933 }
934 }
935}
936
937fn evaluate_dataset_with_telemetry<B, M>(
938 model: &M,
939 dataset: &TokenDataset,
940 device: &B::Device,
941) -> DatasetTelemetry
942where
943 B: Backend,
944 M: CausalLmModel<B>,
945{
946 let mut summaries = Vec::with_capacity(dataset.num_batches());
947 let mut spectral = SpectralAccumulator::default();
948 for batch in dataset.batches() {
949 let input_ids = batch.to_tensor(device);
950 let output = model.forward_with_spectral_snapshot(input_ids.clone(), None);
951 let summary = causal_language_model_summary_with_lengths(
952 &output.logits,
953 &input_ids,
954 batch.sequence_lengths(),
955 );
956 if let Some(snapshot) = output.spectral.as_ref() {
957 spectral.observe(snapshot, summary.prediction_count.max(1) as f64);
958 }
959 summaries.push(summary);
960 }
961
962 DatasetTelemetry {
963 metrics: aggregate_causal_lm_summaries(&summaries).into(),
964 spectral: spectral.finish(),
965 }
966}
967
968fn accumulate_weighted(target: &mut [f64], values: &[f32], weight: f64, label: &str) {
969 assert_eq!(
970 target.len(),
971 values.len(),
972 "{label} length must remain stable across spectral snapshots"
973 );
974
975 for (slot, value) in target.iter_mut().zip(values.iter()) {
976 *slot += f64::from(*value) * weight;
977 }
978}
979
980fn average_weighted(values: Vec<f64>, total_weight: f64) -> Vec<f32> {
981 values
982 .into_iter()
983 .map(|value| (value / total_weight) as f32)
984 .collect()
985}
986
987fn causal_language_model_training_loss<B: Backend>(
988 logits: Tensor<B, 3>,
989 input_ids: Tensor<B, 2, Int>,
990 sequence_lengths: &[usize],
991) -> Tensor<B, 1> {
992 let [batch_size, seq_len, vocab_size] = logits.dims();
993 if seq_len < 2 {
994 return Tensor::<B, 1>::zeros([1], &logits.device());
995 }
996 assert_eq!(
997 sequence_lengths.len(),
998 batch_size,
999 "sequence_lengths must match the batch size"
1000 );
1001
1002 let positions_per_row = seq_len - 1;
1003 let num_predictions = batch_size * positions_per_row;
1004 let valid_predictions = sequence_lengths
1005 .iter()
1006 .map(|length| length.saturating_sub(1).min(positions_per_row))
1007 .sum::<usize>();
1008 if valid_predictions == 0 {
1009 return Tensor::<B, 1>::zeros([1], &logits.device());
1010 }
1011
1012 let shifted_logits = logits
1013 .slice([0..batch_size, 0..positions_per_row, 0..vocab_size])
1014 .reshape([num_predictions, vocab_size]);
1015 let shifted_targets = input_ids
1016 .slice([0..batch_size, 1..seq_len])
1017 .reshape([num_predictions]);
1018 let losses = CrossEntropyLossConfig::new()
1019 .init(&shifted_logits.device())
1020 .forward(shifted_logits.clone(), shifted_targets.clone());
1021 if valid_predictions == num_predictions {
1022 return losses;
1023 }
1024
1025 let gathered_losses = log_softmax(shifted_logits, 1)
1026 .gather(1, shifted_targets.reshape([num_predictions, 1]))
1027 .reshape([num_predictions])
1028 .neg();
1029 let mask = prediction_mask(
1030 sequence_lengths,
1031 positions_per_row,
1032 &gathered_losses.device(),
1033 );
1034 gathered_losses.mul(mask).sum() / valid_predictions as f32
1035}
1036
1037fn prediction_mask<B: Backend>(
1038 sequence_lengths: &[usize],
1039 positions_per_row: usize,
1040 device: &B::Device,
1041) -> Tensor<B, 1> {
1042 let values = sequence_lengths
1043 .iter()
1044 .flat_map(|length| {
1045 let valid_predictions = length.saturating_sub(1).min(positions_per_row);
1046 (0..positions_per_row).map(move |position| {
1047 if position < valid_predictions {
1048 1.0
1049 } else {
1050 0.0
1051 }
1052 })
1053 })
1054 .collect::<Vec<_>>();
1055 Tensor::<B, 1>::from_floats(values.as_slice(), device)
1056}
1057
1058fn scalar_from_autodiff_tensor<B: AutodiffBackend>(tensor: Tensor<B, 1>) -> f32 {
1059 let values = tensor
1060 .inner()
1061 .into_data()
1062 .to_vec::<f32>()
1063 .expect("training loss tensor should convert to f32");
1064 assert_eq!(
1065 values.len(),
1066 1,
1067 "training loss tensor should contain exactly one scalar"
1068 );
1069 values[0]
1070}
1071
1072fn schedule_for_phase(
1073 training_config: &TrainingConfig,
1074 previous_learning_rate: Option<f64>,
1075) -> Result<CosineWarmupSchedule, TrainingError> {
1076 let (base_learning_rate, warmup_steps) = match previous_learning_rate {
1077 Some(previous_learning_rate) => (
1078 previous_learning_rate.clamp(
1079 training_config.min_learning_rate,
1080 training_config.learning_rate,
1081 ),
1082 0,
1083 ),
1084 None => (training_config.learning_rate, training_config.warmup_steps),
1085 };
1086
1087 CosineWarmupSchedule::new(
1088 base_learning_rate,
1089 training_config.min_learning_rate,
1090 warmup_steps,
1091 training_config.max_steps,
1092 )
1093}
1094
1095fn save_optimizer_state<O, M, B>(optimizer: &O) -> Result<Vec<u8>, TrainingError>
1096where
1097 O: Optimizer<M, B>,
1098 M: AutodiffModule<B>,
1099 B: AutodiffBackend,
1100{
1101 BinBytesRecorder::<FullPrecisionSettings>::default()
1102 .record(optimizer.to_record(), ())
1103 .map_err(|error| TrainingError::OptimizerState(error.to_string()))
1104}
1105
1106fn load_optimizer_state<O, M, B>(
1107 optimizer: O,
1108 bytes: &[u8],
1109 device: &B::Device,
1110) -> Result<O, TrainingError>
1111where
1112 O: Optimizer<M, B>,
1113 M: AutodiffModule<B>,
1114 B: AutodiffBackend,
1115{
1116 let record = BinBytesRecorder::<FullPrecisionSettings>::default()
1117 .load::<O::Record>(bytes.to_vec(), device)
1118 .map_err(|error| TrainingError::OptimizerState(error.to_string()))?;
1119 Ok(optimizer.load_record(record))
1120}
1121
1122fn rank_variants_by_loss<I>(entries: I) -> Vec<ModelVariant>
1123where
1124 I: IntoIterator<Item = (ModelVariant, f32)>,
1125{
1126 let mut ranked = entries.into_iter().collect::<Vec<_>>();
1127 ranked.sort_by(|left, right| left.1.total_cmp(&right.1));
1128 ranked.into_iter().map(|(variant, _)| variant).collect()
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134
1135 #[test]
1136 fn cosine_schedule_warms_up_then_decays_to_floor() {
1137 let schedule = CosineWarmupSchedule::new(1.0, 0.1, 2, 6).unwrap();
1138 let lrs = (0..6)
1139 .map(|step| schedule.learning_rate(step))
1140 .collect::<Vec<_>>();
1141
1142 assert_eq!(lrs[0], 0.5);
1143 assert_eq!(lrs[1], 1.0);
1144 assert!(lrs[2] <= 1.0);
1145 assert!(lrs[3] < lrs[2]);
1146 assert_eq!(lrs[5], 0.1);
1147 }
1148}