Skip to main content

deep_delta_learning/
eval.rs

1use std::fmt::{Display, Formatter};
2use std::time::Instant;
3
4use burn::prelude::*;
5use burn::tensor::Int;
6use serde::{Deserialize, Serialize};
7
8use crate::config::DdlConfig;
9use crate::data::{TokenBatch, TokenDataset, TokenDatasetSummary};
10use crate::lm::{
11    CausalLmLossSummary, CausalLmMetrics, aggregate_causal_lm_summaries,
12    causal_language_model_metrics, causal_language_model_summary_with_lengths,
13};
14use crate::spectral::{SpectralCollector, SpectralDiagnostics};
15use crate::variant::{DiagnosticLevel, ModelVariant};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum EvalError {
19    EmptyVariantSet,
20    EmptyDataset,
21}
22
23impl Display for EvalError {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::EmptyVariantSet => write!(f, "at least one model variant is required"),
27            Self::EmptyDataset => write!(f, "evaluation dataset does not contain any batches"),
28        }
29    }
30}
31
32impl std::error::Error for EvalError {}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct VariantEvaluation {
36    pub variant: ModelVariant,
37    pub config: DdlConfig,
38    pub num_params: usize,
39    pub logits_shape: [usize; 3],
40    pub metrics: CausalLmMetrics,
41    pub beta_per_layer: Vec<f32>,
42    pub spectral: Option<SpectralDiagnostics>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ComparisonReport {
47    pub input_shape: [usize; 2],
48    pub evaluations: Vec<VariantEvaluation>,
49    pub loss_ranking: Vec<ModelVariant>,
50    pub best_loss_variant: ModelVariant,
51}
52
53impl ComparisonReport {
54    pub fn best_loss(&self) -> Option<&VariantEvaluation> {
55        self.evaluations
56            .iter()
57            .min_by(|left, right| left.metrics.loss.total_cmp(&right.metrics.loss))
58    }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
62pub struct DatasetTiming {
63    pub total_latency_ms: f64,
64    pub avg_batch_latency_ms: f64,
65    pub throughput_predictions_per_second: f64,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatasetVariantEvaluation {
70    pub variant: ModelVariant,
71    pub config: DdlConfig,
72    pub num_params: usize,
73    pub num_batches: usize,
74    pub num_sequences: usize,
75    pub num_predictions: usize,
76    pub num_padded_tokens: usize,
77    pub metrics: CausalLmMetrics,
78    pub timing: DatasetTiming,
79    pub beta_per_layer: Vec<f32>,
80    pub spectral: Option<SpectralDiagnostics>,
81    pub spectral_history: Option<SpectralCollector>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DatasetComparisonReport {
86    pub dataset: TokenDatasetSummary,
87    pub evaluations: Vec<DatasetVariantEvaluation>,
88    pub loss_ranking: Vec<ModelVariant>,
89    pub best_loss_variant: ModelVariant,
90}
91
92impl DatasetComparisonReport {
93    pub fn best_loss(&self) -> Option<&DatasetVariantEvaluation> {
94        self.evaluations
95            .iter()
96            .min_by(|left, right| left.metrics.loss.total_cmp(&right.metrics.loss))
97    }
98}
99
100pub fn evaluate_variant<B: Backend>(
101    base_config: &DdlConfig,
102    variant: ModelVariant,
103    device: &B::Device,
104    input_ids: &Tensor<B, 2, Int>,
105    mask: Option<&Tensor<B, 3>>,
106    pad_tokens: Option<&[usize]>,
107) -> VariantEvaluation {
108    evaluate_variant_with_diagnostics(
109        base_config,
110        variant,
111        device,
112        input_ids,
113        mask,
114        pad_tokens,
115        DiagnosticLevel::default(),
116    )
117}
118
119pub fn evaluate_variant_with_diagnostics<B: Backend>(
120    base_config: &DdlConfig,
121    variant: ModelVariant,
122    device: &B::Device,
123    input_ids: &Tensor<B, 2, Int>,
124    mask: Option<&Tensor<B, 3>>,
125    pad_tokens: Option<&[usize]>,
126    diagnostic_level: DiagnosticLevel,
127) -> VariantEvaluation {
128    let (config, model) = variant.build::<B>(base_config, device);
129    let output = model.forward_with_diagnostics(input_ids.clone(), mask, diagnostic_level);
130    let logits_shape = output.logits.dims();
131    let metrics = causal_language_model_metrics(
132        &output.logits,
133        input_ids,
134        pad_tokens.map(|tokens| tokens.to_vec()),
135    );
136    let beta_per_layer = output
137        .diagnostics
138        .as_ref()
139        .map_or_else(Vec::new, |diagnostics| diagnostics.beta_per_layer());
140
141    VariantEvaluation {
142        variant,
143        config,
144        num_params: model.num_params(),
145        logits_shape,
146        metrics,
147        beta_per_layer,
148        spectral: output.spectral,
149    }
150}
151
152pub fn compare_variants<B: Backend>(
153    base_config: &DdlConfig,
154    variants: &[ModelVariant],
155    device: &B::Device,
156    input_ids: &Tensor<B, 2, Int>,
157    mask: Option<&Tensor<B, 3>>,
158    pad_tokens: Option<&[usize]>,
159) -> Result<ComparisonReport, EvalError> {
160    compare_variants_with_diagnostics(
161        base_config,
162        variants,
163        device,
164        input_ids,
165        mask,
166        pad_tokens,
167        DiagnosticLevel::default(),
168    )
169}
170
171pub fn compare_variants_with_diagnostics<B: Backend>(
172    base_config: &DdlConfig,
173    variants: &[ModelVariant],
174    device: &B::Device,
175    input_ids: &Tensor<B, 2, Int>,
176    mask: Option<&Tensor<B, 3>>,
177    pad_tokens: Option<&[usize]>,
178    diagnostic_level: DiagnosticLevel,
179) -> Result<ComparisonReport, EvalError> {
180    if variants.is_empty() {
181        return Err(EvalError::EmptyVariantSet);
182    }
183
184    let evaluations = variants
185        .iter()
186        .copied()
187        .map(|variant| {
188            evaluate_variant_with_diagnostics(
189                base_config,
190                variant,
191                device,
192                input_ids,
193                mask,
194                pad_tokens,
195                diagnostic_level,
196            )
197        })
198        .collect::<Vec<_>>();
199    let mut loss_ranking = evaluations
200        .iter()
201        .map(|evaluation| (evaluation.variant, evaluation.metrics.loss))
202        .collect::<Vec<_>>();
203    loss_ranking.sort_by(|left, right| left.1.total_cmp(&right.1));
204
205    Ok(ComparisonReport {
206        input_shape: input_ids.dims(),
207        best_loss_variant: loss_ranking[0].0,
208        loss_ranking: loss_ranking
209            .into_iter()
210            .map(|(variant, _)| variant)
211            .collect(),
212        evaluations,
213    })
214}
215
216pub fn evaluate_variant_on_dataset<B: Backend>(
217    base_config: &DdlConfig,
218    variant: ModelVariant,
219    device: &B::Device,
220    dataset: &TokenDataset,
221) -> Result<DatasetVariantEvaluation, EvalError> {
222    evaluate_variant_on_dataset_with_diagnostics::<B>(
223        base_config,
224        variant,
225        device,
226        dataset,
227        DiagnosticLevel::default(),
228    )
229}
230
231pub fn evaluate_variant_on_dataset_with_diagnostics<B: Backend>(
232    base_config: &DdlConfig,
233    variant: ModelVariant,
234    device: &B::Device,
235    dataset: &TokenDataset,
236    diagnostic_level: DiagnosticLevel,
237) -> Result<DatasetVariantEvaluation, EvalError> {
238    if dataset.batches().is_empty() {
239        return Err(EvalError::EmptyDataset);
240    }
241
242    let (config, model) = variant.build::<B>(base_config, device);
243    let summary = dataset.summary();
244    let mut batch_summaries = Vec::with_capacity(summary.num_batches);
245    let mut total_latency_ms = 0.0;
246    let mut beta_weighted_sums = Vec::<f64>::new();
247    let mut beta_weight = 0.0;
248    let mut latest_spectral = None;
249    let mut spectral_history = (variant.uses_ddl() && diagnostic_level.wants_spectral())
250        .then(|| SpectralCollector::new(summary.num_batches.max(1)));
251
252    for batch in dataset.batches() {
253        let batch_result = evaluate_dataset_batch(&model, batch, device, diagnostic_level);
254        total_latency_ms += batch_result.latency_ms;
255        batch_summaries.push(batch_result.summary);
256
257        if let Some(diagnostics) = batch_result.diagnostics {
258            let betas = diagnostics.beta_per_layer();
259            if beta_weighted_sums.is_empty() {
260                beta_weighted_sums.resize(betas.len(), 0.0);
261            }
262
263            let weight = batch_result.summary.prediction_count.max(1) as f64;
264            for (sum, beta) in beta_weighted_sums.iter_mut().zip(betas.iter()) {
265                *sum += f64::from(*beta) * weight;
266            }
267            beta_weight += weight;
268        }
269
270        if let Some(spectral) = batch_result.spectral {
271            latest_spectral = Some(spectral.clone());
272            if let Some(history) = spectral_history.as_mut() {
273                history.record(spectral);
274            }
275        }
276    }
277
278    let metrics: CausalLmMetrics = aggregate_causal_lm_summaries(&batch_summaries).into();
279    let beta_per_layer = if beta_weight == 0.0 {
280        Vec::new()
281    } else {
282        beta_weighted_sums
283            .into_iter()
284            .map(|sum| (sum / beta_weight) as f32)
285            .collect()
286    };
287    let timing = DatasetTiming {
288        total_latency_ms,
289        avg_batch_latency_ms: if summary.num_batches == 0 {
290            0.0
291        } else {
292            total_latency_ms / summary.num_batches as f64
293        },
294        throughput_predictions_per_second: if total_latency_ms == 0.0 {
295            0.0
296        } else {
297            summary.num_predictions as f64 / (total_latency_ms / 1_000.0)
298        },
299    };
300
301    Ok(DatasetVariantEvaluation {
302        variant,
303        config,
304        num_params: model.num_params(),
305        num_batches: summary.num_batches,
306        num_sequences: summary.num_sequences,
307        num_predictions: summary.num_predictions,
308        num_padded_tokens: summary.num_padded_tokens,
309        metrics,
310        timing,
311        beta_per_layer,
312        spectral: latest_spectral,
313        spectral_history,
314    })
315}
316
317pub fn compare_variants_on_dataset<B: Backend>(
318    base_config: &DdlConfig,
319    variants: &[ModelVariant],
320    device: &B::Device,
321    dataset: &TokenDataset,
322) -> Result<DatasetComparisonReport, EvalError> {
323    compare_variants_on_dataset_with_diagnostics::<B>(
324        base_config,
325        variants,
326        device,
327        dataset,
328        DiagnosticLevel::default(),
329    )
330}
331
332pub fn compare_variants_on_dataset_with_diagnostics<B: Backend>(
333    base_config: &DdlConfig,
334    variants: &[ModelVariant],
335    device: &B::Device,
336    dataset: &TokenDataset,
337    diagnostic_level: DiagnosticLevel,
338) -> Result<DatasetComparisonReport, EvalError> {
339    if variants.is_empty() {
340        return Err(EvalError::EmptyVariantSet);
341    }
342    if dataset.batches().is_empty() {
343        return Err(EvalError::EmptyDataset);
344    }
345
346    let evaluations = variants
347        .iter()
348        .copied()
349        .map(|variant| {
350            evaluate_variant_on_dataset_with_diagnostics::<B>(
351                base_config,
352                variant,
353                device,
354                dataset,
355                diagnostic_level,
356            )
357        })
358        .collect::<Result<Vec<_>, _>>()?;
359    let mut loss_ranking = evaluations
360        .iter()
361        .map(|evaluation| (evaluation.variant, evaluation.metrics.loss))
362        .collect::<Vec<_>>();
363    loss_ranking.sort_by(|left, right| left.1.total_cmp(&right.1));
364
365    Ok(DatasetComparisonReport {
366        dataset: dataset.summary(),
367        best_loss_variant: loss_ranking[0].0,
368        loss_ranking: loss_ranking
369            .into_iter()
370            .map(|(variant, _)| variant)
371            .collect(),
372        evaluations,
373    })
374}
375
376struct DatasetBatchResult {
377    summary: CausalLmLossSummary,
378    latency_ms: f64,
379    diagnostics: Option<crate::spectral::ModelDiagnostics>,
380    spectral: Option<SpectralDiagnostics>,
381}
382
383fn evaluate_dataset_batch<B: Backend>(
384    model: &crate::variant::ModelInstance<B>,
385    batch: &TokenBatch,
386    device: &B::Device,
387    diagnostic_level: DiagnosticLevel,
388) -> DatasetBatchResult {
389    let input_ids = batch.to_tensor(device);
390    let start = Instant::now();
391    let output = model.forward_with_diagnostics(input_ids.clone(), None, diagnostic_level);
392    let latency_ms = start.elapsed().as_secs_f64() * 1_000.0;
393    let summary = causal_language_model_summary_with_lengths(
394        &output.logits,
395        &input_ids,
396        batch.sequence_lengths(),
397    );
398
399    DatasetBatchResult {
400        summary,
401        latency_ms,
402        diagnostics: output.diagnostics,
403        spectral: output.spectral,
404    }
405}