1use std::fmt::{Display, Formatter};
2use std::time::Instant;
3
4use burn::prelude::*;
5use burn::tensor::Int;
6use serde::{Deserialize, Serialize};
7
8use crate::config::DdlConfig;
9use crate::data::{TokenBatch, TokenDataset, TokenDatasetSummary};
10use crate::lm::{
11 CausalLmLossSummary, CausalLmMetrics, aggregate_causal_lm_summaries,
12 causal_language_model_metrics, causal_language_model_summary_with_lengths,
13};
14use crate::spectral::{SpectralCollector, SpectralDiagnostics};
15use crate::variant::{DiagnosticLevel, ModelVariant};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum EvalError {
19 EmptyVariantSet,
20 EmptyDataset,
21}
22
23impl Display for EvalError {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::EmptyVariantSet => write!(f, "at least one model variant is required"),
27 Self::EmptyDataset => write!(f, "evaluation dataset does not contain any batches"),
28 }
29 }
30}
31
32impl std::error::Error for EvalError {}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct VariantEvaluation {
36 pub variant: ModelVariant,
37 pub config: DdlConfig,
38 pub num_params: usize,
39 pub logits_shape: [usize; 3],
40 pub metrics: CausalLmMetrics,
41 pub beta_per_layer: Vec<f32>,
42 pub spectral: Option<SpectralDiagnostics>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ComparisonReport {
47 pub input_shape: [usize; 2],
48 pub evaluations: Vec<VariantEvaluation>,
49 pub loss_ranking: Vec<ModelVariant>,
50 pub best_loss_variant: ModelVariant,
51}
52
53impl ComparisonReport {
54 pub fn best_loss(&self) -> Option<&VariantEvaluation> {
55 self.evaluations
56 .iter()
57 .min_by(|left, right| left.metrics.loss.total_cmp(&right.metrics.loss))
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
62pub struct DatasetTiming {
63 pub total_latency_ms: f64,
64 pub avg_batch_latency_ms: f64,
65 pub throughput_predictions_per_second: f64,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatasetVariantEvaluation {
70 pub variant: ModelVariant,
71 pub config: DdlConfig,
72 pub num_params: usize,
73 pub num_batches: usize,
74 pub num_sequences: usize,
75 pub num_predictions: usize,
76 pub num_padded_tokens: usize,
77 pub metrics: CausalLmMetrics,
78 pub timing: DatasetTiming,
79 pub beta_per_layer: Vec<f32>,
80 pub spectral: Option<SpectralDiagnostics>,
81 pub spectral_history: Option<SpectralCollector>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DatasetComparisonReport {
86 pub dataset: TokenDatasetSummary,
87 pub evaluations: Vec<DatasetVariantEvaluation>,
88 pub loss_ranking: Vec<ModelVariant>,
89 pub best_loss_variant: ModelVariant,
90}
91
92impl DatasetComparisonReport {
93 pub fn best_loss(&self) -> Option<&DatasetVariantEvaluation> {
94 self.evaluations
95 .iter()
96 .min_by(|left, right| left.metrics.loss.total_cmp(&right.metrics.loss))
97 }
98}
99
100pub fn evaluate_variant<B: Backend>(
101 base_config: &DdlConfig,
102 variant: ModelVariant,
103 device: &B::Device,
104 input_ids: &Tensor<B, 2, Int>,
105 mask: Option<&Tensor<B, 3>>,
106 pad_tokens: Option<&[usize]>,
107) -> VariantEvaluation {
108 evaluate_variant_with_diagnostics(
109 base_config,
110 variant,
111 device,
112 input_ids,
113 mask,
114 pad_tokens,
115 DiagnosticLevel::default(),
116 )
117}
118
119pub fn evaluate_variant_with_diagnostics<B: Backend>(
120 base_config: &DdlConfig,
121 variant: ModelVariant,
122 device: &B::Device,
123 input_ids: &Tensor<B, 2, Int>,
124 mask: Option<&Tensor<B, 3>>,
125 pad_tokens: Option<&[usize]>,
126 diagnostic_level: DiagnosticLevel,
127) -> VariantEvaluation {
128 let (config, model) = variant.build::<B>(base_config, device);
129 let output = model.forward_with_diagnostics(input_ids.clone(), mask, diagnostic_level);
130 let logits_shape = output.logits.dims();
131 let metrics = causal_language_model_metrics(
132 &output.logits,
133 input_ids,
134 pad_tokens.map(|tokens| tokens.to_vec()),
135 );
136 let beta_per_layer = output
137 .diagnostics
138 .as_ref()
139 .map_or_else(Vec::new, |diagnostics| diagnostics.beta_per_layer());
140
141 VariantEvaluation {
142 variant,
143 config,
144 num_params: model.num_params(),
145 logits_shape,
146 metrics,
147 beta_per_layer,
148 spectral: output.spectral,
149 }
150}
151
152pub fn compare_variants<B: Backend>(
153 base_config: &DdlConfig,
154 variants: &[ModelVariant],
155 device: &B::Device,
156 input_ids: &Tensor<B, 2, Int>,
157 mask: Option<&Tensor<B, 3>>,
158 pad_tokens: Option<&[usize]>,
159) -> Result<ComparisonReport, EvalError> {
160 compare_variants_with_diagnostics(
161 base_config,
162 variants,
163 device,
164 input_ids,
165 mask,
166 pad_tokens,
167 DiagnosticLevel::default(),
168 )
169}
170
171pub fn compare_variants_with_diagnostics<B: Backend>(
172 base_config: &DdlConfig,
173 variants: &[ModelVariant],
174 device: &B::Device,
175 input_ids: &Tensor<B, 2, Int>,
176 mask: Option<&Tensor<B, 3>>,
177 pad_tokens: Option<&[usize]>,
178 diagnostic_level: DiagnosticLevel,
179) -> Result<ComparisonReport, EvalError> {
180 if variants.is_empty() {
181 return Err(EvalError::EmptyVariantSet);
182 }
183
184 let evaluations = variants
185 .iter()
186 .copied()
187 .map(|variant| {
188 evaluate_variant_with_diagnostics(
189 base_config,
190 variant,
191 device,
192 input_ids,
193 mask,
194 pad_tokens,
195 diagnostic_level,
196 )
197 })
198 .collect::<Vec<_>>();
199 let mut loss_ranking = evaluations
200 .iter()
201 .map(|evaluation| (evaluation.variant, evaluation.metrics.loss))
202 .collect::<Vec<_>>();
203 loss_ranking.sort_by(|left, right| left.1.total_cmp(&right.1));
204
205 Ok(ComparisonReport {
206 input_shape: input_ids.dims(),
207 best_loss_variant: loss_ranking[0].0,
208 loss_ranking: loss_ranking
209 .into_iter()
210 .map(|(variant, _)| variant)
211 .collect(),
212 evaluations,
213 })
214}
215
216pub fn evaluate_variant_on_dataset<B: Backend>(
217 base_config: &DdlConfig,
218 variant: ModelVariant,
219 device: &B::Device,
220 dataset: &TokenDataset,
221) -> Result<DatasetVariantEvaluation, EvalError> {
222 evaluate_variant_on_dataset_with_diagnostics::<B>(
223 base_config,
224 variant,
225 device,
226 dataset,
227 DiagnosticLevel::default(),
228 )
229}
230
231pub fn evaluate_variant_on_dataset_with_diagnostics<B: Backend>(
232 base_config: &DdlConfig,
233 variant: ModelVariant,
234 device: &B::Device,
235 dataset: &TokenDataset,
236 diagnostic_level: DiagnosticLevel,
237) -> Result<DatasetVariantEvaluation, EvalError> {
238 if dataset.batches().is_empty() {
239 return Err(EvalError::EmptyDataset);
240 }
241
242 let (config, model) = variant.build::<B>(base_config, device);
243 let summary = dataset.summary();
244 let mut batch_summaries = Vec::with_capacity(summary.num_batches);
245 let mut total_latency_ms = 0.0;
246 let mut beta_weighted_sums = Vec::<f64>::new();
247 let mut beta_weight = 0.0;
248 let mut latest_spectral = None;
249 let mut spectral_history = (variant.uses_ddl() && diagnostic_level.wants_spectral())
250 .then(|| SpectralCollector::new(summary.num_batches.max(1)));
251
252 for batch in dataset.batches() {
253 let batch_result = evaluate_dataset_batch(&model, batch, device, diagnostic_level);
254 total_latency_ms += batch_result.latency_ms;
255 batch_summaries.push(batch_result.summary);
256
257 if let Some(diagnostics) = batch_result.diagnostics {
258 let betas = diagnostics.beta_per_layer();
259 if beta_weighted_sums.is_empty() {
260 beta_weighted_sums.resize(betas.len(), 0.0);
261 }
262
263 let weight = batch_result.summary.prediction_count.max(1) as f64;
264 for (sum, beta) in beta_weighted_sums.iter_mut().zip(betas.iter()) {
265 *sum += f64::from(*beta) * weight;
266 }
267 beta_weight += weight;
268 }
269
270 if let Some(spectral) = batch_result.spectral {
271 latest_spectral = Some(spectral.clone());
272 if let Some(history) = spectral_history.as_mut() {
273 history.record(spectral);
274 }
275 }
276 }
277
278 let metrics: CausalLmMetrics = aggregate_causal_lm_summaries(&batch_summaries).into();
279 let beta_per_layer = if beta_weight == 0.0 {
280 Vec::new()
281 } else {
282 beta_weighted_sums
283 .into_iter()
284 .map(|sum| (sum / beta_weight) as f32)
285 .collect()
286 };
287 let timing = DatasetTiming {
288 total_latency_ms,
289 avg_batch_latency_ms: if summary.num_batches == 0 {
290 0.0
291 } else {
292 total_latency_ms / summary.num_batches as f64
293 },
294 throughput_predictions_per_second: if total_latency_ms == 0.0 {
295 0.0
296 } else {
297 summary.num_predictions as f64 / (total_latency_ms / 1_000.0)
298 },
299 };
300
301 Ok(DatasetVariantEvaluation {
302 variant,
303 config,
304 num_params: model.num_params(),
305 num_batches: summary.num_batches,
306 num_sequences: summary.num_sequences,
307 num_predictions: summary.num_predictions,
308 num_padded_tokens: summary.num_padded_tokens,
309 metrics,
310 timing,
311 beta_per_layer,
312 spectral: latest_spectral,
313 spectral_history,
314 })
315}
316
317pub fn compare_variants_on_dataset<B: Backend>(
318 base_config: &DdlConfig,
319 variants: &[ModelVariant],
320 device: &B::Device,
321 dataset: &TokenDataset,
322) -> Result<DatasetComparisonReport, EvalError> {
323 compare_variants_on_dataset_with_diagnostics::<B>(
324 base_config,
325 variants,
326 device,
327 dataset,
328 DiagnosticLevel::default(),
329 )
330}
331
332pub fn compare_variants_on_dataset_with_diagnostics<B: Backend>(
333 base_config: &DdlConfig,
334 variants: &[ModelVariant],
335 device: &B::Device,
336 dataset: &TokenDataset,
337 diagnostic_level: DiagnosticLevel,
338) -> Result<DatasetComparisonReport, EvalError> {
339 if variants.is_empty() {
340 return Err(EvalError::EmptyVariantSet);
341 }
342 if dataset.batches().is_empty() {
343 return Err(EvalError::EmptyDataset);
344 }
345
346 let evaluations = variants
347 .iter()
348 .copied()
349 .map(|variant| {
350 evaluate_variant_on_dataset_with_diagnostics::<B>(
351 base_config,
352 variant,
353 device,
354 dataset,
355 diagnostic_level,
356 )
357 })
358 .collect::<Result<Vec<_>, _>>()?;
359 let mut loss_ranking = evaluations
360 .iter()
361 .map(|evaluation| (evaluation.variant, evaluation.metrics.loss))
362 .collect::<Vec<_>>();
363 loss_ranking.sort_by(|left, right| left.1.total_cmp(&right.1));
364
365 Ok(DatasetComparisonReport {
366 dataset: dataset.summary(),
367 best_loss_variant: loss_ranking[0].0,
368 loss_ranking: loss_ranking
369 .into_iter()
370 .map(|(variant, _)| variant)
371 .collect(),
372 evaluations,
373 })
374}
375
376struct DatasetBatchResult {
377 summary: CausalLmLossSummary,
378 latency_ms: f64,
379 diagnostics: Option<crate::spectral::ModelDiagnostics>,
380 spectral: Option<SpectralDiagnostics>,
381}
382
383fn evaluate_dataset_batch<B: Backend>(
384 model: &crate::variant::ModelInstance<B>,
385 batch: &TokenBatch,
386 device: &B::Device,
387 diagnostic_level: DiagnosticLevel,
388) -> DatasetBatchResult {
389 let input_ids = batch.to_tensor(device);
390 let start = Instant::now();
391 let output = model.forward_with_diagnostics(input_ids.clone(), None, diagnostic_level);
392 let latency_ms = start.elapsed().as_secs_f64() * 1_000.0;
393 let summary = causal_language_model_summary_with_lengths(
394 &output.logits,
395 &input_ids,
396 batch.sequence_lengths(),
397 );
398
399 DatasetBatchResult {
400 summary,
401 latency_ms,
402 diagnostics: output.diagnostics,
403 spectral: output.spectral,
404 }
405}