Skip to main content

deep_delta_learning/
training.rs

1use std::fmt::{Display, Formatter};
2
3use burn::config::Config;
4use burn::module::{AutodiffModule, Module};
5use burn::nn::loss::CrossEntropyLossConfig;
6use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
7use burn::prelude::*;
8use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
9use burn::tensor::activation::log_softmax;
10use burn::tensor::backend::AutodiffBackend;
11use serde::{Deserialize, Serialize};
12
13use crate::baseline::BaselineTransformer;
14use crate::checkpoint::LoadedTrainingArtifact;
15use crate::config::DdlConfig;
16use crate::data::{TokenDataset, TokenDatasetSummary};
17use crate::lm::{
18    CausalLmMetrics, aggregate_causal_lm_summaries, causal_language_model_summary_with_lengths,
19};
20use crate::spectral::{BetaHistogram, DeltaRegime, SpectralDiagnostics};
21use crate::transformer::DdlTransformer;
22use crate::variant::{ModelInstance, ModelVariant};
23
24#[derive(Config, Debug, PartialEq)]
25pub struct TrainingConfig {
26    #[config(default = 32)]
27    pub max_steps: usize,
28    #[config(default = 4)]
29    pub eval_interval: usize,
30    #[config(default = 1e-3)]
31    pub learning_rate: f64,
32    #[config(default = 0.0)]
33    pub min_learning_rate: f64,
34    #[config(default = 0)]
35    pub warmup_steps: usize,
36    #[config(default = 0.1)]
37    pub weight_decay: f64,
38    #[config(default = 0.9)]
39    pub beta1: f64,
40    #[config(default = 0.95)]
41    pub beta2: f64,
42    #[config(default = 1e-5)]
43    pub epsilon: f64,
44    #[config(default = 1.0)]
45    pub grad_clip: f64,
46}
47
48impl TrainingConfig {
49    pub fn validate(&self) -> Result<(), TrainingError> {
50        if self.max_steps == 0 {
51            return Err(TrainingError::InvalidConfig(
52                "max_steps must be greater than zero",
53            ));
54        }
55        if self.eval_interval == 0 {
56            return Err(TrainingError::InvalidConfig(
57                "eval_interval must be greater than zero",
58            ));
59        }
60        if !(self.learning_rate.is_finite() && self.learning_rate > 0.0) {
61            return Err(TrainingError::InvalidConfig(
62                "learning_rate must be finite and positive",
63            ));
64        }
65        if !(self.min_learning_rate.is_finite() && self.min_learning_rate >= 0.0) {
66            return Err(TrainingError::InvalidConfig(
67                "min_learning_rate must be finite and non-negative",
68            ));
69        }
70        if self.min_learning_rate > self.learning_rate {
71            return Err(TrainingError::InvalidConfig(
72                "min_learning_rate cannot exceed learning_rate",
73            ));
74        }
75        if self.warmup_steps > self.max_steps {
76            return Err(TrainingError::InvalidConfig(
77                "warmup_steps cannot exceed max_steps",
78            ));
79        }
80        if !(self.weight_decay.is_finite() && self.weight_decay >= 0.0) {
81            return Err(TrainingError::InvalidConfig(
82                "weight_decay must be finite and non-negative",
83            ));
84        }
85        if !(self.beta1.is_finite() && (0.0..1.0).contains(&self.beta1)) {
86            return Err(TrainingError::InvalidConfig("beta1 must be in [0, 1)"));
87        }
88        if !(self.beta2.is_finite() && (0.0..1.0).contains(&self.beta2)) {
89            return Err(TrainingError::InvalidConfig("beta2 must be in [0, 1)"));
90        }
91        if !(self.epsilon.is_finite() && self.epsilon > 0.0) {
92            return Err(TrainingError::InvalidConfig(
93                "epsilon must be finite and positive",
94            ));
95        }
96        if !(self.grad_clip.is_finite() && self.grad_clip > 0.0) {
97            return Err(TrainingError::InvalidConfig(
98                "grad_clip must be finite and positive",
99            ));
100        }
101
102        Ok(())
103    }
104
105    fn optimizer<M, B>(&self) -> impl Optimizer<M, B>
106    where
107        B: AutodiffBackend,
108        M: AutodiffModule<B>,
109    {
110        AdamWConfig::new()
111            .with_beta_1(self.beta1 as f32)
112            .with_beta_2(self.beta2 as f32)
113            .with_epsilon(self.epsilon as f32)
114            .with_weight_decay(self.weight_decay as f32)
115            .with_grad_clipping(Some(burn::grad_clipping::GradientClippingConfig::Norm(
116                self.grad_clip as f32,
117            )))
118            .init::<B, M>()
119    }
120}
121
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct TrainingStepMetrics {
124    pub step: usize,
125    pub learning_rate: f64,
126    pub train: CausalLmMetrics,
127    pub validation: Option<CausalLmMetrics>,
128    pub train_spectral: Option<TrainingSpectralSnapshot>,
129    pub validation_spectral: Option<TrainingSpectralSnapshot>,
130}
131
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct TrainingSpectralSnapshot {
134    pub beta_per_layer: Vec<f32>,
135    pub k_eigenvalue_per_layer: Vec<f32>,
136    pub spatial_determinant_per_layer: Vec<f32>,
137    pub lifted_determinant_per_layer: Vec<f32>,
138    pub regime_per_layer: Vec<DeltaRegime>,
139    pub beta_histogram: BetaHistogram,
140    pub k_coherence_per_layer: Vec<f32>,
141    pub correction_norm_per_layer: Vec<f32>,
142}
143
144impl From<&SpectralDiagnostics> for TrainingSpectralSnapshot {
145    fn from(spectral: &SpectralDiagnostics) -> Self {
146        Self {
147            beta_per_layer: spectral.beta_per_layer.clone(),
148            k_eigenvalue_per_layer: spectral.k_eigenvalue_per_layer.clone(),
149            spatial_determinant_per_layer: spectral.spatial_determinant_per_layer.clone(),
150            lifted_determinant_per_layer: spectral.lifted_determinant_per_layer.clone(),
151            regime_per_layer: spectral.regime_per_layer.clone(),
152            beta_histogram: spectral.beta_histogram,
153            k_coherence_per_layer: spectral.k_coherence_per_layer.clone(),
154            correction_norm_per_layer: spectral.correction_norm_per_layer.clone(),
155        }
156    }
157}
158
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct TrainingReport {
161    pub variant: ModelVariant,
162    pub config: DdlConfig,
163    pub training: TrainingConfig,
164    pub num_params: usize,
165    pub train_dataset: TokenDatasetSummary,
166    pub validation_dataset: Option<TokenDatasetSummary>,
167    pub steps_completed: usize,
168    pub initial_train: CausalLmMetrics,
169    pub initial_validation: Option<CausalLmMetrics>,
170    pub final_train: CausalLmMetrics,
171    pub final_validation: Option<CausalLmMetrics>,
172    pub best_validation: Option<CausalLmMetrics>,
173    pub initial_train_spectral: Option<TrainingSpectralSnapshot>,
174    pub initial_validation_spectral: Option<TrainingSpectralSnapshot>,
175    pub final_train_spectral: Option<TrainingSpectralSnapshot>,
176    pub final_validation_spectral: Option<TrainingSpectralSnapshot>,
177    pub best_validation_spectral: Option<TrainingSpectralSnapshot>,
178    pub best_validation_step: Option<usize>,
179    pub history: Vec<TrainingStepMetrics>,
180}
181
182#[derive(Debug)]
183pub struct TrainingOutcome<B: Backend> {
184    pub report: TrainingReport,
185    pub model: ModelInstance<B>,
186    pub best_validation_model: Option<ModelInstance<B>>,
187    pub optimizer_state: Vec<u8>,
188}
189
190/// Ranked summary for a multi-variant RFC-007 smoke run.
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct TrainingComparisonReport {
193    pub train_dataset: TokenDatasetSummary,
194    pub validation_dataset: Option<TokenDatasetSummary>,
195    pub training: TrainingConfig,
196    pub reports: Vec<TrainingReport>,
197    pub final_train_loss_ranking: Vec<ModelVariant>,
198    pub best_final_train_variant: ModelVariant,
199    pub final_validation_loss_ranking: Option<Vec<ModelVariant>>,
200    pub best_final_validation_variant: Option<ModelVariant>,
201}
202
203impl TrainingComparisonReport {
204    pub fn report(&self, variant: ModelVariant) -> Option<&TrainingReport> {
205        self.reports.iter().find(|report| report.variant == variant)
206    }
207
208    pub fn best_final_train(&self) -> Option<&TrainingReport> {
209        self.report(self.best_final_train_variant)
210    }
211
212    pub fn best_final_validation(&self) -> Option<&TrainingReport> {
213        self.best_final_validation_variant
214            .and_then(|variant| self.report(variant))
215    }
216}
217
218#[derive(Debug)]
219pub struct TrainingSweepOutcome<B: Backend> {
220    pub report: TrainingComparisonReport,
221    pub outcomes: Vec<TrainingOutcome<B>>,
222}
223
224impl<B: Backend> TrainingSweepOutcome<B> {
225    pub fn outcome(&self, variant: ModelVariant) -> Option<&TrainingOutcome<B>> {
226        self.outcomes
227            .iter()
228            .find(|outcome| outcome.report.variant == variant)
229    }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq)]
233pub struct CosineWarmupSchedule {
234    base_learning_rate: f64,
235    min_learning_rate: f64,
236    warmup_steps: usize,
237    total_steps: usize,
238}
239
240impl CosineWarmupSchedule {
241    pub fn new(
242        base_learning_rate: f64,
243        min_learning_rate: f64,
244        warmup_steps: usize,
245        total_steps: usize,
246    ) -> Result<Self, TrainingError> {
247        if total_steps == 0 {
248            return Err(TrainingError::InvalidConfig(
249                "total_steps must be greater than zero",
250            ));
251        }
252        if warmup_steps > total_steps {
253            return Err(TrainingError::InvalidConfig(
254                "warmup_steps cannot exceed total_steps",
255            ));
256        }
257        if min_learning_rate > base_learning_rate {
258            return Err(TrainingError::InvalidConfig(
259                "min_learning_rate cannot exceed base_learning_rate",
260            ));
261        }
262
263        Ok(Self {
264            base_learning_rate,
265            min_learning_rate,
266            warmup_steps,
267            total_steps,
268        })
269    }
270
271    pub fn learning_rate(&self, step: usize) -> f64 {
272        let step = step.min(self.total_steps.saturating_sub(1));
273
274        if self.warmup_steps > 0 && step < self.warmup_steps {
275            let progress = (step + 1) as f64 / self.warmup_steps as f64;
276            return self.base_learning_rate * progress;
277        }
278
279        let cosine_steps = self.total_steps.saturating_sub(self.warmup_steps);
280        if cosine_steps <= 1 {
281            return self.base_learning_rate;
282        }
283
284        let cosine_step = step.saturating_sub(self.warmup_steps);
285        let progress = cosine_step as f64 / (cosine_steps - 1) as f64;
286        let cosine = (std::f64::consts::PI * progress).cos();
287        self.min_learning_rate
288            + (self.base_learning_rate - self.min_learning_rate) * 0.5 * (1.0 + cosine)
289    }
290}
291
292#[derive(Debug, Clone)]
293struct DatasetTelemetry {
294    metrics: CausalLmMetrics,
295    spectral: Option<TrainingSpectralSnapshot>,
296}
297
298#[derive(Debug)]
299struct ResumeState<M> {
300    report: TrainingReport,
301    best_validation_model: Option<M>,
302    optimizer_state: Vec<u8>,
303}
304
305#[derive(Debug)]
306struct ForwardPassOutput<B: Backend> {
307    logits: Tensor<B, 3>,
308    spectral: Option<TrainingSpectralSnapshot>,
309}
310
311#[derive(Debug, Default)]
312struct SpectralAccumulator {
313    beta_per_layer: Vec<f64>,
314    k_eigenvalue_per_layer: Vec<f64>,
315    spatial_determinant_per_layer: Vec<f64>,
316    lifted_determinant_per_layer: Vec<f64>,
317    k_coherence_per_layer: Vec<f64>,
318    correction_norm_per_layer: Vec<f64>,
319    total_weight: f64,
320}
321
322impl SpectralAccumulator {
323    fn observe(&mut self, spectral: &TrainingSpectralSnapshot, weight: f64) {
324        if self.beta_per_layer.is_empty() {
325            self.beta_per_layer = vec![0.0; spectral.beta_per_layer.len()];
326            self.k_eigenvalue_per_layer = vec![0.0; spectral.k_eigenvalue_per_layer.len()];
327            self.spatial_determinant_per_layer =
328                vec![0.0; spectral.spatial_determinant_per_layer.len()];
329            self.lifted_determinant_per_layer =
330                vec![0.0; spectral.lifted_determinant_per_layer.len()];
331            self.k_coherence_per_layer = vec![0.0; spectral.k_coherence_per_layer.len()];
332            self.correction_norm_per_layer = vec![0.0; spectral.correction_norm_per_layer.len()];
333        }
334
335        accumulate_weighted(
336            &mut self.beta_per_layer,
337            &spectral.beta_per_layer,
338            weight,
339            "beta_per_layer",
340        );
341        accumulate_weighted(
342            &mut self.k_eigenvalue_per_layer,
343            &spectral.k_eigenvalue_per_layer,
344            weight,
345            "k_eigenvalue_per_layer",
346        );
347        accumulate_weighted(
348            &mut self.spatial_determinant_per_layer,
349            &spectral.spatial_determinant_per_layer,
350            weight,
351            "spatial_determinant_per_layer",
352        );
353        accumulate_weighted(
354            &mut self.lifted_determinant_per_layer,
355            &spectral.lifted_determinant_per_layer,
356            weight,
357            "lifted_determinant_per_layer",
358        );
359        accumulate_weighted(
360            &mut self.k_coherence_per_layer,
361            &spectral.k_coherence_per_layer,
362            weight,
363            "k_coherence_per_layer",
364        );
365        accumulate_weighted(
366            &mut self.correction_norm_per_layer,
367            &spectral.correction_norm_per_layer,
368            weight,
369            "correction_norm_per_layer",
370        );
371
372        self.total_weight += weight;
373    }
374
375    fn finish(self) -> Option<TrainingSpectralSnapshot> {
376        if self.total_weight == 0.0 {
377            return None;
378        }
379
380        let beta_per_layer = average_weighted(self.beta_per_layer, self.total_weight);
381        let k_eigenvalue_per_layer =
382            average_weighted(self.k_eigenvalue_per_layer, self.total_weight);
383        let spatial_determinant_per_layer =
384            average_weighted(self.spatial_determinant_per_layer, self.total_weight);
385        let lifted_determinant_per_layer =
386            average_weighted(self.lifted_determinant_per_layer, self.total_weight);
387        let k_coherence_per_layer = average_weighted(self.k_coherence_per_layer, self.total_weight);
388        let correction_norm_per_layer =
389            average_weighted(self.correction_norm_per_layer, self.total_weight);
390        let regime_per_layer = beta_per_layer
391            .iter()
392            .map(|beta| DeltaRegime::from_beta(*beta))
393            .collect::<Vec<_>>();
394        let beta_histogram = BetaHistogram::from_betas(&beta_per_layer);
395
396        Some(TrainingSpectralSnapshot {
397            beta_per_layer,
398            k_eigenvalue_per_layer,
399            spatial_determinant_per_layer,
400            lifted_determinant_per_layer,
401            regime_per_layer,
402            beta_histogram,
403            k_coherence_per_layer,
404            correction_norm_per_layer,
405        })
406    }
407}
408
409#[derive(Debug)]
410pub enum TrainingError {
411    EmptyTrainingDataset,
412    EmptyValidationDataset,
413    EmptyVariantSet,
414    InvalidConfig(&'static str),
415    OptimizerState(String),
416    ResumeStateMismatch(String),
417}
418
419impl Display for TrainingError {
420    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
421        match self {
422            Self::EmptyTrainingDataset => {
423                write!(f, "training dataset does not contain any batches")
424            }
425            Self::EmptyValidationDataset => {
426                write!(f, "validation dataset does not contain any batches")
427            }
428            Self::EmptyVariantSet => write!(f, "at least one model variant is required"),
429            Self::InvalidConfig(message) => write!(f, "invalid training configuration: {message}"),
430            Self::OptimizerState(message) => {
431                write!(f, "optimizer state serialization failed: {message}")
432            }
433            Self::ResumeStateMismatch(message) => {
434                write!(f, "invalid resume state: {message}")
435            }
436        }
437    }
438}
439
440impl std::error::Error for TrainingError {}
441
442pub fn train_variant<B>(
443    base_config: &DdlConfig,
444    variant: ModelVariant,
445    device: &B::Device,
446    train_dataset: &TokenDataset,
447    validation_dataset: Option<&TokenDataset>,
448    training_config: &TrainingConfig,
449) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
450where
451    B: AutodiffBackend,
452{
453    training_config.validate()?;
454    if train_dataset.batches().is_empty() {
455        return Err(TrainingError::EmptyTrainingDataset);
456    }
457    if validation_dataset.is_some_and(|dataset| dataset.batches().is_empty()) {
458        return Err(TrainingError::EmptyValidationDataset);
459    }
460
461    let resolved_config = variant.resolve_config(base_config);
462    let datasets = DatasetPair {
463        train: train_dataset,
464        validation: validation_dataset,
465    };
466    match variant {
467        ModelVariant::Baseline => train_model(
468            variant,
469            resolved_config.clone(),
470            BaselineTransformer::<B>::new(&resolved_config, device),
471            |model| ModelInstance::Baseline(Box::new(model)),
472            device,
473            datasets,
474            training_config,
475            None,
476        ),
477        _ => train_model(
478            variant,
479            resolved_config.clone(),
480            resolved_config.init::<B>(device),
481            |model| ModelInstance::Ddl(Box::new(model)),
482            device,
483            datasets,
484            training_config,
485            None,
486        ),
487    }
488}
489
490pub fn resume_training<B>(
491    artifact: LoadedTrainingArtifact<B>,
492    device: &B::Device,
493    train_dataset: &TokenDataset,
494    validation_dataset: Option<&TokenDataset>,
495    training_config: &TrainingConfig,
496) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
497where
498    B: AutodiffBackend,
499{
500    training_config.validate()?;
501    if train_dataset.batches().is_empty() {
502        return Err(TrainingError::EmptyTrainingDataset);
503    }
504    if validation_dataset.is_some_and(|dataset| dataset.batches().is_empty()) {
505        return Err(TrainingError::EmptyValidationDataset);
506    }
507
508    let train_summary = train_dataset.summary();
509    if artifact.report.train_dataset != train_summary {
510        return Err(TrainingError::ResumeStateMismatch(format!(
511            "training dataset summary {:?} does not match saved artifact {:?}",
512            train_summary, artifact.report.train_dataset
513        )));
514    }
515
516    let validation_summary = validation_dataset.map(TokenDataset::summary);
517    if artifact.report.validation_dataset != validation_summary {
518        return Err(TrainingError::ResumeStateMismatch(format!(
519            "validation dataset summary {:?} does not match saved artifact {:?}",
520            validation_summary, artifact.report.validation_dataset
521        )));
522    }
523
524    if artifact.report.steps_completed != artifact.report.history.len() {
525        return Err(TrainingError::ResumeStateMismatch(format!(
526            "saved report recorded {} completed steps but {} history entries",
527            artifact.report.steps_completed,
528            artifact.report.history.len()
529        )));
530    }
531
532    let datasets = DatasetPair {
533        train: train_dataset,
534        validation: validation_dataset,
535    };
536    let optimizer_state = if artifact.report.steps_completed == 0 {
537        Vec::new()
538    } else {
539        artifact.optimizer_state.ok_or_else(|| {
540            TrainingError::ResumeStateMismatch(
541                "saved training artifact is missing optimizer state".to_string(),
542            )
543        })?
544    };
545    match artifact.model {
546        ModelInstance::Baseline(model) => {
547            if artifact.report.variant.uses_ddl() {
548                return Err(TrainingError::ResumeStateMismatch(format!(
549                    "saved report expects variant {} but checkpoint contains a baseline model",
550                    artifact.report.variant.slug()
551                )));
552            }
553            let best_validation_model = match artifact.best_validation_model {
554                Some(ModelInstance::Baseline(model)) => Some(model.valid()),
555                Some(ModelInstance::Ddl(_)) => {
556                    return Err(TrainingError::ResumeStateMismatch(
557                        "best-validation checkpoint kind does not match the baseline artifact"
558                            .to_string(),
559                    ));
560                }
561                None => None,
562            };
563            train_model(
564                artifact.report.variant,
565                artifact.report.config.clone(),
566                *model,
567                |model| ModelInstance::Baseline(Box::new(model)),
568                device,
569                datasets,
570                training_config,
571                Some(ResumeState {
572                    report: artifact.report,
573                    best_validation_model,
574                    optimizer_state,
575                }),
576            )
577        }
578        ModelInstance::Ddl(model) => {
579            if !artifact.report.variant.uses_ddl() {
580                return Err(TrainingError::ResumeStateMismatch(format!(
581                    "saved report expects variant {} but checkpoint contains a DDL model",
582                    artifact.report.variant.slug()
583                )));
584            }
585            let best_validation_model = match artifact.best_validation_model {
586                Some(ModelInstance::Ddl(model)) => Some(model.valid()),
587                Some(ModelInstance::Baseline(_)) => {
588                    return Err(TrainingError::ResumeStateMismatch(
589                        "best-validation checkpoint kind does not match the DDL artifact"
590                            .to_string(),
591                    ));
592                }
593                None => None,
594            };
595            train_model(
596                artifact.report.variant,
597                artifact.report.config.clone(),
598                *model,
599                |model| ModelInstance::Ddl(Box::new(model)),
600                device,
601                datasets,
602                training_config,
603                Some(ResumeState {
604                    report: artifact.report,
605                    best_validation_model,
606                    optimizer_state,
607                }),
608            )
609        }
610    }
611}
612
613/// Train a set of model variants under the same data/configuration and rank them by loss.
614pub fn train_variants<B>(
615    base_config: &DdlConfig,
616    variants: &[ModelVariant],
617    device: &B::Device,
618    train_dataset: &TokenDataset,
619    validation_dataset: Option<&TokenDataset>,
620    training_config: &TrainingConfig,
621) -> Result<TrainingSweepOutcome<B::InnerBackend>, TrainingError>
622where
623    B: AutodiffBackend,
624{
625    if variants.is_empty() {
626        return Err(TrainingError::EmptyVariantSet);
627    }
628
629    let mut outcomes = Vec::with_capacity(variants.len());
630    for variant in variants {
631        outcomes.push(train_variant::<B>(
632            base_config,
633            *variant,
634            device,
635            train_dataset,
636            validation_dataset,
637            training_config,
638        )?);
639    }
640
641    let reports = outcomes
642        .iter()
643        .map(|outcome| outcome.report.clone())
644        .collect::<Vec<_>>();
645    let final_train_loss_ranking = rank_variants_by_loss(
646        reports
647            .iter()
648            .map(|report| (report.variant, report.final_train.loss)),
649    );
650    let final_validation_loss_ranking = reports
651        .iter()
652        .map(|report| {
653            report
654                .final_validation
655                .map(|metrics| (report.variant, metrics.loss))
656        })
657        .collect::<Option<Vec<_>>>()
658        .map(rank_variants_by_loss);
659
660    Ok(TrainingSweepOutcome {
661        report: TrainingComparisonReport {
662            train_dataset: train_dataset.summary(),
663            validation_dataset: validation_dataset.map(TokenDataset::summary),
664            training: training_config.clone(),
665            reports,
666            best_final_train_variant: final_train_loss_ranking[0],
667            final_train_loss_ranking,
668            best_final_validation_variant: final_validation_loss_ranking
669                .as_ref()
670                .and_then(|ranking| ranking.first().copied()),
671            final_validation_loss_ranking,
672        },
673        outcomes,
674    })
675}
676
677#[derive(Clone, Copy)]
678struct DatasetPair<'a> {
679    train: &'a TokenDataset,
680    validation: Option<&'a TokenDataset>,
681}
682
683#[allow(clippy::too_many_arguments)]
684fn train_model<B, M, F>(
685    variant: ModelVariant,
686    resolved_config: DdlConfig,
687    mut model: M,
688    into_model_instance: F,
689    device: &B::Device,
690    datasets: DatasetPair<'_>,
691    training_config: &TrainingConfig,
692    resume: Option<ResumeState<M::InnerModule>>,
693) -> Result<TrainingOutcome<B::InnerBackend>, TrainingError>
694where
695    B: AutodiffBackend,
696    M: CausalLmModel<B> + AutodiffModule<B> + Module<B>,
697    M::InnerModule: CausalLmModel<B::InnerBackend> + Clone,
698    F: Fn(M::InnerModule) -> ModelInstance<B::InnerBackend> + Copy,
699{
700    let mut optimizer = training_config.optimizer::<M, B>();
701    let (
702        initial_train,
703        initial_validation,
704        initial_train_spectral,
705        initial_validation_spectral,
706        mut best_validation,
707        mut best_validation_spectral,
708        mut best_validation_step,
709        mut best_validation_model,
710        mut history,
711        step_offset,
712        resume_optimizer_state,
713    ) = match resume {
714        Some(resume) => (
715            resume.report.initial_train,
716            resume.report.initial_validation,
717            resume.report.initial_train_spectral,
718            resume.report.initial_validation_spectral,
719            resume.report.best_validation,
720            resume.report.best_validation_spectral,
721            resume.report.best_validation_step,
722            resume.best_validation_model,
723            resume.report.history,
724            resume.report.steps_completed,
725            Some(resume.optimizer_state),
726        ),
727        None => {
728            let initial_model = model.valid();
729            let initial_train_eval =
730                evaluate_dataset_with_telemetry(&initial_model, datasets.train, device);
731            let initial_validation_eval = datasets
732                .validation
733                .map(|dataset| evaluate_dataset_with_telemetry(&initial_model, dataset, device));
734            let initial_train = initial_train_eval.metrics;
735            let initial_validation = initial_validation_eval
736                .as_ref()
737                .map(|evaluation| evaluation.metrics);
738
739            (
740                initial_train,
741                initial_validation,
742                initial_train_eval.spectral,
743                initial_validation_eval
744                    .clone()
745                    .and_then(|evaluation| evaluation.spectral),
746                initial_validation,
747                initial_validation_eval
748                    .as_ref()
749                    .and_then(|evaluation| evaluation.spectral.clone()),
750                initial_validation.map(|_| 0),
751                initial_validation.map(|_| initial_model.clone()),
752                Vec::with_capacity(training_config.max_steps),
753                0,
754                None,
755            )
756        }
757    };
758    if let Some(optimizer_state) = resume_optimizer_state {
759        optimizer = load_optimizer_state::<_, M, B>(optimizer, &optimizer_state, device)?;
760    }
761    let schedule = schedule_for_phase(
762        training_config,
763        history.last().map(|step| step.learning_rate),
764    )?;
765
766    for step in 0..training_config.max_steps {
767        let record_spectral = (step + 1) % training_config.eval_interval == 0
768            || step + 1 == training_config.max_steps;
769        let batch = &datasets.train.batches()[step % datasets.train.num_batches()];
770        let input_ids = batch.to_tensor(device);
771        let forward = if record_spectral {
772            model.forward_with_spectral_snapshot(input_ids.clone(), None)
773        } else {
774            ForwardPassOutput {
775                logits: model.forward_logits(input_ids.clone(), None),
776                spectral: None,
777            }
778        };
779        let logits = forward.logits;
780        let loss = causal_language_model_training_loss(logits, input_ids, batch.sequence_lengths());
781        let train = CausalLmMetrics {
782            loss: scalar_from_autodiff_tensor(loss.clone()),
783            perplexity: scalar_from_autodiff_tensor(loss.clone()).exp(),
784        };
785        let grads = GradientsParams::from_grads(loss.backward(), &model);
786        let lr = schedule.learning_rate(step);
787        model = optimizer.step(lr, model, grads);
788
789        let mut validation_model = None;
790        let validation_eval = if record_spectral {
791            let valid_model = model.valid();
792            let evaluation = datasets
793                .validation
794                .map(|dataset| evaluate_dataset_with_telemetry(&valid_model, dataset, device));
795            validation_model = Some(valid_model);
796            evaluation
797        } else {
798            None
799        };
800        let validation = validation_eval
801            .as_ref()
802            .map(|evaluation| evaluation.metrics);
803
804        if let Some((metrics, spectral)) = validation_eval
805            .as_ref()
806            .map(|evaluation| (evaluation.metrics, evaluation.spectral.clone()))
807        {
808            let should_update = match best_validation {
809                Some(best) => metrics.loss < best.loss,
810                None => true,
811            };
812            if should_update {
813                best_validation = Some(metrics);
814                best_validation_spectral = spectral;
815                best_validation_step = Some(step_offset + step + 1);
816                best_validation_model = validation_model.clone();
817            }
818        }
819
820        history.push(TrainingStepMetrics {
821            step: step_offset + step + 1,
822            learning_rate: lr,
823            train,
824            validation,
825            train_spectral: forward.spectral,
826            validation_spectral: validation_eval.and_then(|evaluation| evaluation.spectral),
827        });
828    }
829
830    let model = model.valid();
831    let final_train_eval = evaluate_dataset_with_telemetry(&model, datasets.train, device);
832    let final_validation_eval = datasets
833        .validation
834        .map(|dataset| evaluate_dataset_with_telemetry(&model, dataset, device));
835    let final_train = final_train_eval.metrics;
836    let final_validation = final_validation_eval
837        .as_ref()
838        .map(|evaluation| evaluation.metrics);
839    if let Some((metrics, spectral)) = final_validation_eval
840        .as_ref()
841        .map(|evaluation| (evaluation.metrics, evaluation.spectral.clone()))
842    {
843        let should_update = match best_validation {
844            Some(best) => metrics.loss < best.loss,
845            None => true,
846        };
847        if should_update {
848            best_validation = Some(metrics);
849            best_validation_spectral = spectral;
850            best_validation_step = Some(step_offset + training_config.max_steps);
851            best_validation_model = Some(model.clone());
852        }
853    }
854    let num_params = model.num_params();
855
856    Ok(TrainingOutcome {
857        report: TrainingReport {
858            variant,
859            config: resolved_config,
860            training: training_config.clone(),
861            num_params,
862            train_dataset: datasets.train.summary(),
863            validation_dataset: datasets.validation.map(TokenDataset::summary),
864            steps_completed: history.len(),
865            initial_train,
866            initial_validation,
867            final_train,
868            final_validation,
869            best_validation,
870            initial_train_spectral,
871            initial_validation_spectral,
872            final_train_spectral: final_train_eval.spectral,
873            final_validation_spectral: final_validation_eval
874                .and_then(|evaluation| evaluation.spectral),
875            best_validation_spectral,
876            best_validation_step,
877            history,
878        },
879        model: into_model_instance(model),
880        best_validation_model: best_validation_model.map(into_model_instance),
881        optimizer_state: save_optimizer_state::<_, M, B>(&optimizer)?,
882    })
883}
884
885trait CausalLmModel<B: Backend> {
886    fn forward_logits(
887        &self,
888        input_ids: Tensor<B, 2, Int>,
889        mask: Option<&Tensor<B, 3>>,
890    ) -> Tensor<B, 3>;
891
892    fn forward_with_spectral_snapshot(
893        &self,
894        input_ids: Tensor<B, 2, Int>,
895        mask: Option<&Tensor<B, 3>>,
896    ) -> ForwardPassOutput<B> {
897        ForwardPassOutput {
898            logits: self.forward_logits(input_ids, mask),
899            spectral: None,
900        }
901    }
902}
903
904impl<B: Backend> CausalLmModel<B> for BaselineTransformer<B> {
905    fn forward_logits(
906        &self,
907        input_ids: Tensor<B, 2, Int>,
908        mask: Option<&Tensor<B, 3>>,
909    ) -> Tensor<B, 3> {
910        self.forward_logits(input_ids, mask)
911    }
912}
913
914impl<B: Backend> CausalLmModel<B> for DdlTransformer<B> {
915    fn forward_logits(
916        &self,
917        input_ids: Tensor<B, 2, Int>,
918        mask: Option<&Tensor<B, 3>>,
919    ) -> Tensor<B, 3> {
920        self.forward_logits(input_ids, mask)
921    }
922
923    #[cfg(feature = "spectral")]
924    fn forward_with_spectral_snapshot(
925        &self,
926        input_ids: Tensor<B, 2, Int>,
927        mask: Option<&Tensor<B, 3>>,
928    ) -> ForwardPassOutput<B> {
929        let (logits, _, spectral) = self.forward_with_spectral_diagnostics(input_ids, mask);
930        ForwardPassOutput {
931            logits,
932            spectral: Some(TrainingSpectralSnapshot::from(&spectral)),
933        }
934    }
935}
936
937fn evaluate_dataset_with_telemetry<B, M>(
938    model: &M,
939    dataset: &TokenDataset,
940    device: &B::Device,
941) -> DatasetTelemetry
942where
943    B: Backend,
944    M: CausalLmModel<B>,
945{
946    let mut summaries = Vec::with_capacity(dataset.num_batches());
947    let mut spectral = SpectralAccumulator::default();
948    for batch in dataset.batches() {
949        let input_ids = batch.to_tensor(device);
950        let output = model.forward_with_spectral_snapshot(input_ids.clone(), None);
951        let summary = causal_language_model_summary_with_lengths(
952            &output.logits,
953            &input_ids,
954            batch.sequence_lengths(),
955        );
956        if let Some(snapshot) = output.spectral.as_ref() {
957            spectral.observe(snapshot, summary.prediction_count.max(1) as f64);
958        }
959        summaries.push(summary);
960    }
961
962    DatasetTelemetry {
963        metrics: aggregate_causal_lm_summaries(&summaries).into(),
964        spectral: spectral.finish(),
965    }
966}
967
968fn accumulate_weighted(target: &mut [f64], values: &[f32], weight: f64, label: &str) {
969    assert_eq!(
970        target.len(),
971        values.len(),
972        "{label} length must remain stable across spectral snapshots"
973    );
974
975    for (slot, value) in target.iter_mut().zip(values.iter()) {
976        *slot += f64::from(*value) * weight;
977    }
978}
979
980fn average_weighted(values: Vec<f64>, total_weight: f64) -> Vec<f32> {
981    values
982        .into_iter()
983        .map(|value| (value / total_weight) as f32)
984        .collect()
985}
986
987fn causal_language_model_training_loss<B: Backend>(
988    logits: Tensor<B, 3>,
989    input_ids: Tensor<B, 2, Int>,
990    sequence_lengths: &[usize],
991) -> Tensor<B, 1> {
992    let [batch_size, seq_len, vocab_size] = logits.dims();
993    if seq_len < 2 {
994        return Tensor::<B, 1>::zeros([1], &logits.device());
995    }
996    assert_eq!(
997        sequence_lengths.len(),
998        batch_size,
999        "sequence_lengths must match the batch size"
1000    );
1001
1002    let positions_per_row = seq_len - 1;
1003    let num_predictions = batch_size * positions_per_row;
1004    let valid_predictions = sequence_lengths
1005        .iter()
1006        .map(|length| length.saturating_sub(1).min(positions_per_row))
1007        .sum::<usize>();
1008    if valid_predictions == 0 {
1009        return Tensor::<B, 1>::zeros([1], &logits.device());
1010    }
1011
1012    let shifted_logits = logits
1013        .slice([0..batch_size, 0..positions_per_row, 0..vocab_size])
1014        .reshape([num_predictions, vocab_size]);
1015    let shifted_targets = input_ids
1016        .slice([0..batch_size, 1..seq_len])
1017        .reshape([num_predictions]);
1018    let losses = CrossEntropyLossConfig::new()
1019        .init(&shifted_logits.device())
1020        .forward(shifted_logits.clone(), shifted_targets.clone());
1021    if valid_predictions == num_predictions {
1022        return losses;
1023    }
1024
1025    let gathered_losses = log_softmax(shifted_logits, 1)
1026        .gather(1, shifted_targets.reshape([num_predictions, 1]))
1027        .reshape([num_predictions])
1028        .neg();
1029    let mask = prediction_mask(
1030        sequence_lengths,
1031        positions_per_row,
1032        &gathered_losses.device(),
1033    );
1034    gathered_losses.mul(mask).sum() / valid_predictions as f32
1035}
1036
1037fn prediction_mask<B: Backend>(
1038    sequence_lengths: &[usize],
1039    positions_per_row: usize,
1040    device: &B::Device,
1041) -> Tensor<B, 1> {
1042    let values = sequence_lengths
1043        .iter()
1044        .flat_map(|length| {
1045            let valid_predictions = length.saturating_sub(1).min(positions_per_row);
1046            (0..positions_per_row).map(move |position| {
1047                if position < valid_predictions {
1048                    1.0
1049                } else {
1050                    0.0
1051                }
1052            })
1053        })
1054        .collect::<Vec<_>>();
1055    Tensor::<B, 1>::from_floats(values.as_slice(), device)
1056}
1057
1058fn scalar_from_autodiff_tensor<B: AutodiffBackend>(tensor: Tensor<B, 1>) -> f32 {
1059    let values = tensor
1060        .inner()
1061        .into_data()
1062        .to_vec::<f32>()
1063        .expect("training loss tensor should convert to f32");
1064    assert_eq!(
1065        values.len(),
1066        1,
1067        "training loss tensor should contain exactly one scalar"
1068    );
1069    values[0]
1070}
1071
1072fn schedule_for_phase(
1073    training_config: &TrainingConfig,
1074    previous_learning_rate: Option<f64>,
1075) -> Result<CosineWarmupSchedule, TrainingError> {
1076    let (base_learning_rate, warmup_steps) = match previous_learning_rate {
1077        Some(previous_learning_rate) => (
1078            previous_learning_rate.clamp(
1079                training_config.min_learning_rate,
1080                training_config.learning_rate,
1081            ),
1082            0,
1083        ),
1084        None => (training_config.learning_rate, training_config.warmup_steps),
1085    };
1086
1087    CosineWarmupSchedule::new(
1088        base_learning_rate,
1089        training_config.min_learning_rate,
1090        warmup_steps,
1091        training_config.max_steps,
1092    )
1093}
1094
1095fn save_optimizer_state<O, M, B>(optimizer: &O) -> Result<Vec<u8>, TrainingError>
1096where
1097    O: Optimizer<M, B>,
1098    M: AutodiffModule<B>,
1099    B: AutodiffBackend,
1100{
1101    BinBytesRecorder::<FullPrecisionSettings>::default()
1102        .record(optimizer.to_record(), ())
1103        .map_err(|error| TrainingError::OptimizerState(error.to_string()))
1104}
1105
1106fn load_optimizer_state<O, M, B>(
1107    optimizer: O,
1108    bytes: &[u8],
1109    device: &B::Device,
1110) -> Result<O, TrainingError>
1111where
1112    O: Optimizer<M, B>,
1113    M: AutodiffModule<B>,
1114    B: AutodiffBackend,
1115{
1116    let record = BinBytesRecorder::<FullPrecisionSettings>::default()
1117        .load::<O::Record>(bytes.to_vec(), device)
1118        .map_err(|error| TrainingError::OptimizerState(error.to_string()))?;
1119    Ok(optimizer.load_record(record))
1120}
1121
1122fn rank_variants_by_loss<I>(entries: I) -> Vec<ModelVariant>
1123where
1124    I: IntoIterator<Item = (ModelVariant, f32)>,
1125{
1126    let mut ranked = entries.into_iter().collect::<Vec<_>>();
1127    ranked.sort_by(|left, right| left.1.total_cmp(&right.1));
1128    ranked.into_iter().map(|(variant, _)| variant).collect()
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134
1135    #[test]
1136    fn cosine_schedule_warms_up_then_decays_to_floor() {
1137        let schedule = CosineWarmupSchedule::new(1.0, 0.1, 2, 6).unwrap();
1138        let lrs = (0..6)
1139            .map(|step| schedule.learning_rate(step))
1140            .collect::<Vec<_>>();
1141
1142        assert_eq!(lrs[0], 0.5);
1143        assert_eq!(lrs[1], 1.0);
1144        assert!(lrs[2] <= 1.0);
1145        assert!(lrs[3] < lrs[2]);
1146        assert_eq!(lrs[5], 0.1);
1147    }
1148}