Skip to main content

deep_delta_learning/
benchmark.rs

1use std::fmt::{Display, Formatter};
2use std::hint::black_box;
3use std::mem::size_of;
4use std::time::Instant;
5
6use burn::prelude::*;
7use burn::tensor::{Int, Shape, TensorData};
8use serde::{Deserialize, Serialize};
9
10use crate::compressor::{ChannelConvCompressor, TokenConvCompressor};
11use crate::config::{CompressionVariant, DdlConfig};
12use crate::delta_operator::DeltaOperator;
13use crate::delta_res_block::DeltaResBlock;
14use crate::rms_norm::RmsNormConfig;
15use crate::variant::ModelVariant;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum BenchmarkSuite {
19    DeltaOperator,
20    DeltaResBlock,
21    Compressor,
22    Normalization,
23    Model,
24}
25
26impl BenchmarkSuite {
27    pub const ALL: [Self; 5] = [
28        Self::DeltaOperator,
29        Self::DeltaResBlock,
30        Self::Compressor,
31        Self::Normalization,
32        Self::Model,
33    ];
34
35    pub fn all() -> &'static [Self] {
36        &Self::ALL
37    }
38
39    pub fn slug(&self) -> &'static str {
40        match self {
41            Self::DeltaOperator => "operator",
42            Self::DeltaResBlock => "block",
43            Self::Compressor => "compressor",
44            Self::Normalization => "normalization",
45            Self::Model => "model",
46        }
47    }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub struct BenchmarkConfig {
52    pub iterations: usize,
53    pub warmup_iterations: usize,
54    pub batch_size: usize,
55    pub seq_len: usize,
56}
57
58impl BenchmarkConfig {
59    pub fn new(iterations: usize) -> Self {
60        Self {
61            iterations,
62            ..Default::default()
63        }
64    }
65
66    pub fn with_warmup_iterations(mut self, warmup_iterations: usize) -> Self {
67        self.warmup_iterations = warmup_iterations;
68        self
69    }
70
71    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
72        self.batch_size = batch_size;
73        self
74    }
75
76    pub fn with_seq_len(mut self, seq_len: usize) -> Self {
77        self.seq_len = seq_len;
78        self
79    }
80
81    pub fn validate(&self) -> Result<(), BenchmarkError> {
82        if self.iterations == 0 {
83            return Err(BenchmarkError::InvalidConfig(
84                "iterations must be greater than zero".to_string(),
85            ));
86        }
87        if self.batch_size == 0 {
88            return Err(BenchmarkError::InvalidConfig(
89                "batch_size must be greater than zero".to_string(),
90            ));
91        }
92        if self.seq_len == 0 {
93            return Err(BenchmarkError::InvalidConfig(
94                "seq_len must be greater than zero".to_string(),
95            ));
96        }
97
98        Ok(())
99    }
100}
101
102impl Default for BenchmarkConfig {
103    fn default() -> Self {
104        Self {
105            iterations: 10,
106            warmup_iterations: 2,
107            batch_size: 2,
108            seq_len: 16,
109        }
110    }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum BenchmarkError {
115    InvalidConfig(String),
116    ComparisonMismatch(String),
117}
118
119impl Display for BenchmarkError {
120    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Self::InvalidConfig(message) => f.write_str(message),
123            Self::ComparisonMismatch(message) => f.write_str(message),
124        }
125    }
126}
127
128impl std::error::Error for BenchmarkError {}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
131pub struct BenchmarkRegressionThresholds {
132    pub min_throughput_ratio: Option<f64>,
133    pub max_avg_latency_ratio: Option<f64>,
134    pub max_p95_latency_ratio: Option<f64>,
135    pub max_peak_memory_ratio: Option<f64>,
136}
137
138impl BenchmarkRegressionThresholds {
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    pub fn with_min_throughput_ratio(mut self, min_throughput_ratio: f64) -> Self {
144        self.min_throughput_ratio = Some(min_throughput_ratio);
145        self
146    }
147
148    pub fn with_max_avg_latency_ratio(mut self, max_avg_latency_ratio: f64) -> Self {
149        self.max_avg_latency_ratio = Some(max_avg_latency_ratio);
150        self
151    }
152
153    pub fn with_max_p95_latency_ratio(mut self, max_p95_latency_ratio: f64) -> Self {
154        self.max_p95_latency_ratio = Some(max_p95_latency_ratio);
155        self
156    }
157
158    pub fn with_max_peak_memory_ratio(mut self, max_peak_memory_ratio: f64) -> Self {
159        self.max_peak_memory_ratio = Some(max_peak_memory_ratio);
160        self
161    }
162
163    pub fn is_configured(&self) -> bool {
164        self.min_throughput_ratio.is_some()
165            || self.max_avg_latency_ratio.is_some()
166            || self.max_p95_latency_ratio.is_some()
167            || self.max_peak_memory_ratio.is_some()
168    }
169
170    pub fn validate(&self) -> Result<(), BenchmarkError> {
171        validate_positive_threshold(self.min_throughput_ratio, "min_throughput_ratio")?;
172        validate_positive_threshold(self.max_avg_latency_ratio, "max_avg_latency_ratio")?;
173        validate_positive_threshold(self.max_p95_latency_ratio, "max_p95_latency_ratio")?;
174        validate_positive_threshold(self.max_peak_memory_ratio, "max_peak_memory_ratio")?;
175        Ok(())
176    }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
180pub enum BenchmarkTarget {
181    Case {
182        suite: BenchmarkSuite,
183        label: String,
184    },
185    Model {
186        variant: ModelVariant,
187    },
188}
189
190impl BenchmarkTarget {
191    pub fn label(&self) -> String {
192        match self {
193            Self::Case { suite, label } => format!("{}:{label}", suite.slug()),
194            Self::Model { variant } => format!("model:{}", variant.slug()),
195        }
196    }
197}
198
199#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
200pub struct BenchmarkDelta {
201    pub target: BenchmarkTarget,
202    pub input_shape: Vec<usize>,
203    pub output_shape: Vec<usize>,
204    pub avg_latency_ratio: f64,
205    pub p95_latency_ratio: f64,
206    pub throughput_ratio: f64,
207    pub peak_memory_ratio: f64,
208    pub parameter_ratio: Option<f64>,
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
212pub struct BenchmarkGateOutcome {
213    pub passed: bool,
214    pub failures: Vec<String>,
215}
216
217impl BenchmarkGateOutcome {
218    fn from_failures(failures: Vec<String>) -> Self {
219        Self {
220            passed: failures.is_empty(),
221            failures,
222        }
223    }
224}
225
226#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
227pub struct BenchmarkComparisonReport {
228    pub current_benchmark: BenchmarkConfig,
229    pub baseline_benchmark: BenchmarkConfig,
230    pub current_suites: Vec<BenchmarkSuite>,
231    pub baseline_suites: Vec<BenchmarkSuite>,
232    pub thresholds: BenchmarkRegressionThresholds,
233    pub deltas: Vec<BenchmarkDelta>,
234    pub gate: BenchmarkGateOutcome,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct BenchmarkTiming {
239    pub iterations: usize,
240    pub warmup_iterations: usize,
241    pub total_duration_ms: f64,
242    pub avg_duration_ms: f64,
243    pub min_duration_ms: f64,
244    pub median_duration_ms: f64,
245    pub p95_duration_ms: f64,
246    pub max_duration_ms: f64,
247    pub iterations_per_second: f64,
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub struct BenchmarkMemory {
252    pub input_bytes: usize,
253    pub output_bytes: usize,
254    pub parameter_bytes: usize,
255    pub working_set_bytes: usize,
256    pub peak_live_bytes: usize,
257}
258
259impl BenchmarkMemory {
260    fn new(
261        input_bytes: usize,
262        output_bytes: usize,
263        parameter_bytes: usize,
264        working_set_bytes: usize,
265    ) -> Self {
266        Self {
267            input_bytes,
268            output_bytes,
269            parameter_bytes,
270            working_set_bytes,
271            peak_live_bytes: input_bytes
272                .saturating_add(output_bytes)
273                .saturating_add(parameter_bytes)
274                .saturating_add(working_set_bytes),
275        }
276    }
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct BenchmarkCase {
281    pub label: String,
282    pub input_shape: Vec<usize>,
283    pub output_shape: Vec<usize>,
284    pub timing: BenchmarkTiming,
285    pub memory: BenchmarkMemory,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
289pub struct ModelMemoryBreakdown {
290    pub input_ids_bytes: usize,
291    pub parameter_bytes: usize,
292    pub embedding_bytes: usize,
293    pub state_bytes: usize,
294    pub residual_stream_bytes: usize,
295    pub mask_bytes: usize,
296    pub attention_scores_bytes: usize,
297    pub compressor_workspace_bytes: usize,
298    pub delta_branch_bytes: usize,
299    pub mlp_activation_bytes: usize,
300    pub logits_bytes: usize,
301    pub peak_activation_bytes: usize,
302    pub peak_total_bytes: usize,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct DeltaOperatorBenchmarks {
307    pub vector: BenchmarkCase,
308    pub matrix: BenchmarkCase,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct DeltaResBlockBenchmarks {
313    pub vector: BenchmarkCase,
314    pub matrix: BenchmarkCase,
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct CompressorBenchmarks {
319    pub token_conv: BenchmarkCase,
320    pub channel_conv: BenchmarkCase,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct NormalizationBenchmarks {
325    pub precision_friendly: BenchmarkCase,
326    pub explicit_l2: BenchmarkCase,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct ModelBenchmark {
331    pub variant: ModelVariant,
332    pub resolved_config: DdlConfig,
333    pub num_params: usize,
334    pub input_shape: [usize; 2],
335    pub output_shape: [usize; 3],
336    pub timing: BenchmarkTiming,
337    pub tokens_per_second: f64,
338    pub estimated_hidden_state_bytes: usize,
339    pub estimated_logit_bytes: usize,
340    pub estimated_parameter_bytes: usize,
341    pub memory: ModelMemoryBreakdown,
342}
343
344#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
345pub struct ModelBenchmarkComparison {
346    pub variant: ModelVariant,
347    pub tokens_per_second: f64,
348    pub throughput_ratio_vs_baseline: Option<f64>,
349    pub parameter_ratio_vs_baseline: Option<f64>,
350    pub peak_activation_ratio_vs_baseline: Option<f64>,
351    pub peak_total_ratio_vs_baseline: Option<f64>,
352}
353
354#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
355pub struct ModelBenchmarkSummary {
356    pub fastest_variant: ModelVariant,
357    pub lowest_peak_total_variant: ModelVariant,
358    pub comparisons: Vec<ModelBenchmarkComparison>,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct BenchmarkReport {
363    pub benchmark: BenchmarkConfig,
364    pub base_config: DdlConfig,
365    pub suites: Vec<BenchmarkSuite>,
366    pub delta_operator: Option<DeltaOperatorBenchmarks>,
367    pub delta_res_block: Option<DeltaResBlockBenchmarks>,
368    pub compressors: Option<CompressorBenchmarks>,
369    pub normalization: Option<NormalizationBenchmarks>,
370    pub models: Vec<ModelBenchmark>,
371    pub model_summary: Option<ModelBenchmarkSummary>,
372}
373
374pub fn run_benchmarks<B: Backend>(
375    base_config: &DdlConfig,
376    benchmark: &BenchmarkConfig,
377    suites: &[BenchmarkSuite],
378    variants: &[ModelVariant],
379    device: &B::Device,
380) -> Result<BenchmarkReport, BenchmarkError> {
381    benchmark.validate()?;
382    if benchmark.seq_len > base_config.max_seq_len {
383        return Err(BenchmarkError::InvalidConfig(format!(
384            "benchmark seq_len {} exceeds model max_seq_len {}",
385            benchmark.seq_len, base_config.max_seq_len
386        )));
387    }
388
389    let suites = normalize_suites(suites);
390    let variants = normalize_variants(variants);
391    let delta_operator = suites
392        .contains(&BenchmarkSuite::DeltaOperator)
393        .then(|| benchmark_delta_operator::<B>(base_config, benchmark, device));
394    let delta_res_block = suites
395        .contains(&BenchmarkSuite::DeltaResBlock)
396        .then(|| benchmark_delta_res_block::<B>(base_config, benchmark, device));
397    let compressors = suites
398        .contains(&BenchmarkSuite::Compressor)
399        .then(|| benchmark_compressors::<B>(base_config, benchmark, device));
400    let normalization = suites
401        .contains(&BenchmarkSuite::Normalization)
402        .then(|| benchmark_normalization::<B>(base_config, benchmark, device));
403    let models = if suites.contains(&BenchmarkSuite::Model) {
404        benchmark_models::<B>(base_config, benchmark, &variants, device)
405    } else {
406        Vec::new()
407    };
408    let model_summary = (!models.is_empty()).then(|| summarize_model_benchmarks(&models));
409
410    Ok(BenchmarkReport {
411        benchmark: *benchmark,
412        base_config: base_config.clone(),
413        suites,
414        delta_operator,
415        delta_res_block,
416        compressors,
417        normalization,
418        models,
419        model_summary,
420    })
421}
422
423pub fn compare_benchmark_reports(
424    current: &BenchmarkReport,
425    baseline: &BenchmarkReport,
426    thresholds: &BenchmarkRegressionThresholds,
427) -> Result<BenchmarkComparisonReport, BenchmarkError> {
428    thresholds.validate()?;
429
430    let mut deltas = Vec::new();
431    for suite in &current.suites {
432        match suite {
433            BenchmarkSuite::DeltaOperator => {
434                let current_suite = current.delta_operator.as_ref().ok_or_else(|| {
435                    BenchmarkError::ComparisonMismatch(
436                        "current report is missing the delta operator suite".to_string(),
437                    )
438                })?;
439                let baseline_suite = baseline.delta_operator.as_ref().ok_or_else(|| {
440                    BenchmarkError::ComparisonMismatch(
441                        "baseline report is missing the delta operator suite".to_string(),
442                    )
443                })?;
444                deltas.push(compare_case(
445                    *suite,
446                    &current_suite.vector,
447                    &baseline_suite.vector,
448                )?);
449                deltas.push(compare_case(
450                    *suite,
451                    &current_suite.matrix,
452                    &baseline_suite.matrix,
453                )?);
454            }
455            BenchmarkSuite::DeltaResBlock => {
456                let current_suite = current.delta_res_block.as_ref().ok_or_else(|| {
457                    BenchmarkError::ComparisonMismatch(
458                        "current report is missing the delta-res block suite".to_string(),
459                    )
460                })?;
461                let baseline_suite = baseline.delta_res_block.as_ref().ok_or_else(|| {
462                    BenchmarkError::ComparisonMismatch(
463                        "baseline report is missing the delta-res block suite".to_string(),
464                    )
465                })?;
466                deltas.push(compare_case(
467                    *suite,
468                    &current_suite.vector,
469                    &baseline_suite.vector,
470                )?);
471                deltas.push(compare_case(
472                    *suite,
473                    &current_suite.matrix,
474                    &baseline_suite.matrix,
475                )?);
476            }
477            BenchmarkSuite::Compressor => {
478                let current_suite = current.compressors.as_ref().ok_or_else(|| {
479                    BenchmarkError::ComparisonMismatch(
480                        "current report is missing the compressor suite".to_string(),
481                    )
482                })?;
483                let baseline_suite = baseline.compressors.as_ref().ok_or_else(|| {
484                    BenchmarkError::ComparisonMismatch(
485                        "baseline report is missing the compressor suite".to_string(),
486                    )
487                })?;
488                deltas.push(compare_case(
489                    *suite,
490                    &current_suite.token_conv,
491                    &baseline_suite.token_conv,
492                )?);
493                deltas.push(compare_case(
494                    *suite,
495                    &current_suite.channel_conv,
496                    &baseline_suite.channel_conv,
497                )?);
498            }
499            BenchmarkSuite::Normalization => {
500                let current_suite = current.normalization.as_ref().ok_or_else(|| {
501                    BenchmarkError::ComparisonMismatch(
502                        "current report is missing the normalization suite".to_string(),
503                    )
504                })?;
505                let baseline_suite = baseline.normalization.as_ref().ok_or_else(|| {
506                    BenchmarkError::ComparisonMismatch(
507                        "baseline report is missing the normalization suite".to_string(),
508                    )
509                })?;
510                deltas.push(compare_case(
511                    *suite,
512                    &current_suite.precision_friendly,
513                    &baseline_suite.precision_friendly,
514                )?);
515                deltas.push(compare_case(
516                    *suite,
517                    &current_suite.explicit_l2,
518                    &baseline_suite.explicit_l2,
519                )?);
520            }
521            BenchmarkSuite::Model => {
522                if current.models.is_empty() {
523                    return Err(BenchmarkError::ComparisonMismatch(
524                        "current report is missing the model suite".to_string(),
525                    ));
526                }
527                if baseline.models.is_empty() {
528                    return Err(BenchmarkError::ComparisonMismatch(
529                        "baseline report is missing the model suite".to_string(),
530                    ));
531                }
532                for current_model in &current.models {
533                    let baseline_model = baseline
534                        .models
535                        .iter()
536                        .find(|model| model.variant == current_model.variant)
537                        .ok_or_else(|| {
538                            BenchmarkError::ComparisonMismatch(format!(
539                                "baseline report is missing model variant {}",
540                                current_model.variant.slug()
541                            ))
542                        })?;
543                    deltas.push(compare_model(current_model, baseline_model)?);
544                }
545            }
546        }
547    }
548
549    let failures = evaluate_gate(&deltas, thresholds);
550    Ok(BenchmarkComparisonReport {
551        current_benchmark: current.benchmark,
552        baseline_benchmark: baseline.benchmark,
553        current_suites: current.suites.clone(),
554        baseline_suites: baseline.suites.clone(),
555        thresholds: thresholds.clone(),
556        deltas,
557        gate: BenchmarkGateOutcome::from_failures(failures),
558    })
559}
560
561fn benchmark_delta_operator<B: Backend>(
562    base_config: &DdlConfig,
563    benchmark: &BenchmarkConfig,
564    device: &B::Device,
565) -> DeltaOperatorBenchmarks {
566    let batch_tokens = benchmark.batch_size * benchmark.seq_len;
567    let d_model = base_config.d_model;
568    let d_value = matrix_d_value(base_config);
569    let k = direction_tensor::<B>(batch_tokens, d_model, device);
570    let beta = beta_tensor::<B>(batch_tokens, 0.75, device);
571    let x_vector = float_tensor::<B, 2>([batch_tokens, d_model], 0.125, device);
572    let x_matrix = float_tensor::<B, 3>([batch_tokens, d_model, d_value], 0.375, device);
573
574    let vector_operator = DeltaOperator::new(k.clone(), beta.clone());
575    let vector_output = vector_operator.apply_vector(&x_vector);
576    let vector = BenchmarkCase {
577        label: "delta-operator-vector".to_string(),
578        input_shape: vec![batch_tokens, d_model],
579        output_shape: vector_output.dims().to_vec(),
580        timing: measure(benchmark, || {
581            let output = vector_operator.apply_vector(&x_vector);
582            black_box(output);
583        }),
584        memory: BenchmarkMemory::new(
585            tensor_bytes(&[batch_tokens, d_model]),
586            tensor_bytes(&[batch_tokens, d_model]),
587            0,
588            tensor_bytes(&[batch_tokens, d_model])
589                + tensor_bytes(&[batch_tokens])
590                + tensor_bytes(&[batch_tokens])
591                + tensor_bytes(&[batch_tokens, d_model]),
592        ),
593    };
594
595    let matrix_operator = DeltaOperator::new(k, beta);
596    let matrix_output = matrix_operator.apply(&x_matrix);
597    let matrix = BenchmarkCase {
598        label: "delta-operator-matrix".to_string(),
599        input_shape: vec![batch_tokens, d_model, d_value],
600        output_shape: matrix_output.dims().to_vec(),
601        timing: measure(benchmark, || {
602            let output = matrix_operator.apply(&x_matrix);
603            black_box(output);
604        }),
605        memory: BenchmarkMemory::new(
606            tensor_bytes(&[batch_tokens, d_model, d_value]),
607            tensor_bytes(&[batch_tokens, d_model, d_value]),
608            0,
609            tensor_bytes(&[batch_tokens, d_model])
610                + tensor_bytes(&[batch_tokens])
611                + tensor_bytes(&[batch_tokens, d_value])
612                + tensor_bytes(&[batch_tokens, d_model, d_value]),
613        ),
614    };
615
616    DeltaOperatorBenchmarks { vector, matrix }
617}
618
619fn benchmark_delta_res_block<B: Backend>(
620    base_config: &DdlConfig,
621    benchmark: &BenchmarkConfig,
622    device: &B::Device,
623) -> DeltaResBlockBenchmarks {
624    let vector_config = base_config.clone().with_d_value(1);
625    let matrix_config = base_config
626        .clone()
627        .with_d_value(matrix_d_value(base_config));
628    let [batch_size, seq_len, d_model] =
629        [benchmark.batch_size, benchmark.seq_len, base_config.d_model];
630
631    let vector_block = DeltaResBlock::<B>::new(&vector_config, device);
632    let x_vector = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.125, device);
633    let vector_backbone = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.25, device);
634    let vector_ctx = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.5, device);
635    let vector_residual = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.75, device);
636    let vector_output = vector_block.forward_vector_hidden(
637        x_vector.clone(),
638        vector_backbone.clone(),
639        vector_ctx.clone(),
640        vector_residual.clone(),
641    );
642    let vector = BenchmarkCase {
643        label: "delta-res-block-vector".to_string(),
644        input_shape: vec![batch_size, seq_len, d_model],
645        output_shape: vector_output.dims().to_vec(),
646        timing: measure(benchmark, || {
647            let output = vector_block.forward_vector_hidden(
648                x_vector.clone(),
649                vector_backbone.clone(),
650                vector_ctx.clone(),
651                vector_residual.clone(),
652            );
653            black_box(output);
654        }),
655        memory: BenchmarkMemory::new(
656            tensor_bytes(&[batch_size, seq_len, d_model]) * 4,
657            tensor_bytes(&[batch_size, seq_len, d_model]),
658            f32_bytes(vector_block.num_params()),
659            tensor_bytes(&[batch_size, seq_len, d_model])
660                + tensor_bytes(&[batch_size, seq_len]) * 4
661                + tensor_bytes(&[batch_size, seq_len, 1]),
662        ),
663    };
664
665    let matrix_block = DeltaResBlock::<B>::new(&matrix_config, device);
666    let d_value = matrix_config.d_value;
667    let x_state = float_tensor::<B, 4>([batch_size, seq_len, d_model, d_value], 0.125, device);
668    let matrix_backbone = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.25, device);
669    let matrix_ctx = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.5, device);
670    let matrix_residual = float_tensor::<B, 3>([batch_size, seq_len, d_model], 0.75, device);
671    let matrix_output = matrix_block.forward_matrix_hidden(
672        x_state.clone(),
673        matrix_backbone.clone(),
674        matrix_ctx.clone(),
675        matrix_residual.clone(),
676    );
677    let matrix = BenchmarkCase {
678        label: "delta-res-block-matrix".to_string(),
679        input_shape: vec![batch_size, seq_len, d_model, d_value],
680        output_shape: matrix_output.dims().to_vec(),
681        timing: measure(benchmark, || {
682            let output = matrix_block.forward_matrix_hidden(
683                x_state.clone(),
684                matrix_backbone.clone(),
685                matrix_ctx.clone(),
686                matrix_residual.clone(),
687            );
688            black_box(output);
689        }),
690        memory: BenchmarkMemory::new(
691            tensor_bytes(&[batch_size, seq_len, d_model, d_value])
692                + tensor_bytes(&[batch_size, seq_len, d_model]) * 3,
693            tensor_bytes(&[batch_size, seq_len, d_model, d_value]),
694            f32_bytes(matrix_block.num_params()),
695            tensor_bytes(&[batch_size, seq_len, d_model])
696                + tensor_bytes(&[batch_size, seq_len, d_value]) * 3
697                + tensor_bytes(&[batch_size, seq_len])
698                + tensor_bytes(&[batch_size, seq_len, d_model, d_value]),
699        ),
700    };
701
702    DeltaResBlockBenchmarks { vector, matrix }
703}
704
705fn benchmark_compressors<B: Backend>(
706    base_config: &DdlConfig,
707    benchmark: &BenchmarkConfig,
708    device: &B::Device,
709) -> CompressorBenchmarks {
710    let d_model = base_config.d_model;
711    let d_value = matrix_d_value(base_config);
712    let batch_size = benchmark.batch_size;
713    let seq_len = benchmark.seq_len;
714    let state = float_tensor::<B, 4>([batch_size, seq_len, d_model, d_value], 0.5, device);
715
716    let token =
717        TokenConvCompressor::<B>::new(d_model, d_value, base_config.shortconv_kernel_size, device);
718    let token_output = token.forward(state.clone());
719    let token_conv = BenchmarkCase {
720        label: "token-conv-compressor".to_string(),
721        input_shape: vec![batch_size, seq_len, d_model, d_value],
722        output_shape: token_output.dims().to_vec(),
723        timing: measure(benchmark, || {
724            let output = token.forward(state.clone());
725            black_box(output);
726        }),
727        memory: BenchmarkMemory::new(
728            tensor_bytes(&[batch_size, seq_len, d_model, d_value]),
729            tensor_bytes(&[batch_size, seq_len, d_model]),
730            f32_bytes(token.num_params()),
731            tensor_bytes(&[batch_size, seq_len, d_model, d_value]),
732        ),
733    };
734
735    let channel = ChannelConvCompressor::<B>::new(d_model, d_value, device);
736    let channel_output = channel.forward(state.clone());
737    let channel_conv = BenchmarkCase {
738        label: "channel-conv-compressor".to_string(),
739        input_shape: vec![batch_size, seq_len, d_model, d_value],
740        output_shape: channel_output.dims().to_vec(),
741        timing: measure(benchmark, || {
742            let output = channel.forward(state.clone());
743            black_box(output);
744        }),
745        memory: BenchmarkMemory::new(
746            tensor_bytes(&[batch_size, seq_len, d_model, d_value]),
747            tensor_bytes(&[batch_size, seq_len, d_model]),
748            f32_bytes(channel.num_params()),
749            tensor_bytes(&[batch_size * seq_len, d_model, 1]),
750        ),
751    };
752
753    CompressorBenchmarks {
754        token_conv,
755        channel_conv,
756    }
757}
758
759fn benchmark_normalization<B: Backend>(
760    base_config: &DdlConfig,
761    benchmark: &BenchmarkConfig,
762    device: &B::Device,
763) -> NormalizationBenchmarks {
764    let d_model = base_config.d_model;
765    let input = float_tensor::<B, 3>(
766        [benchmark.batch_size, benchmark.seq_len, d_model],
767        0.875,
768        device,
769    );
770    let epsilon = base_config.k_eps * base_config.k_eps / d_model as f64;
771    let scale = 1.0 / (d_model as f32).sqrt();
772    let precision_friendly_norm = RmsNormConfig::new(d_model)
773        .with_epsilon(epsilon)
774        .init(device);
775    let precision_output = precision_friendly_norm
776        .forward(input.clone())
777        .mul_scalar(scale);
778    let precision_friendly = BenchmarkCase {
779        label: "k-normalization-rms".to_string(),
780        input_shape: vec![benchmark.batch_size, benchmark.seq_len, d_model],
781        output_shape: precision_output.dims().to_vec(),
782        timing: measure(benchmark, || {
783            let output = precision_friendly_norm
784                .forward(input.clone())
785                .mul_scalar(scale);
786            black_box(output);
787        }),
788        memory: BenchmarkMemory::new(
789            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len, d_model]),
790            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len, d_model]),
791            f32_bytes(precision_friendly_norm.num_params()),
792            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len, d_model]),
793        ),
794    };
795
796    let explicit_output = explicit_l2_normalize(input.clone());
797    let explicit_l2 = BenchmarkCase {
798        label: "k-normalization-l2".to_string(),
799        input_shape: vec![benchmark.batch_size, benchmark.seq_len, d_model],
800        output_shape: explicit_output.dims().to_vec(),
801        timing: measure(benchmark, || {
802            let output = explicit_l2_normalize(input.clone());
803            black_box(output);
804        }),
805        memory: BenchmarkMemory::new(
806            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len, d_model]),
807            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len, d_model]),
808            0,
809            tensor_bytes(&[benchmark.batch_size, benchmark.seq_len]),
810        ),
811    };
812
813    NormalizationBenchmarks {
814        precision_friendly,
815        explicit_l2,
816    }
817}
818
819fn benchmark_models<B: Backend>(
820    base_config: &DdlConfig,
821    benchmark: &BenchmarkConfig,
822    variants: &[ModelVariant],
823    device: &B::Device,
824) -> Vec<ModelBenchmark> {
825    let input_shape = [benchmark.batch_size, benchmark.seq_len];
826    let input_ids = int_tensor::<B>(input_shape, base_config.vocab_size, device);
827
828    variants
829        .iter()
830        .copied()
831        .map(|variant| {
832            let (resolved_config, model) = variant.build::<B>(base_config, device);
833            let output = model.forward_logits(input_ids.clone(), None);
834            let output_shape = output.dims();
835            let timing = measure(benchmark, || {
836                let logits = model.forward_logits(input_ids.clone(), None);
837                black_box(logits);
838            });
839            let total_tokens = benchmark.batch_size * benchmark.seq_len * benchmark.iterations;
840            let estimated_hidden_state_bytes = f32_bytes(
841                benchmark.batch_size
842                    * benchmark.seq_len
843                    * resolved_config.d_model
844                    * resolved_config.d_value,
845            );
846            let estimated_logit_bytes =
847                f32_bytes(benchmark.batch_size * benchmark.seq_len * resolved_config.vocab_size);
848            let memory =
849                estimate_model_memory(variant, &resolved_config, model.num_params(), benchmark);
850
851            ModelBenchmark {
852                variant,
853                resolved_config,
854                num_params: model.num_params(),
855                input_shape,
856                output_shape,
857                tokens_per_second: if timing.total_duration_ms == 0.0 {
858                    0.0
859                } else {
860                    total_tokens as f64 / (timing.total_duration_ms / 1_000.0)
861                },
862                estimated_hidden_state_bytes,
863                estimated_logit_bytes,
864                estimated_parameter_bytes: f32_bytes(model.num_params()),
865                memory,
866                timing,
867            }
868        })
869        .collect()
870}
871
872fn estimate_model_memory(
873    variant: ModelVariant,
874    config: &DdlConfig,
875    num_params: usize,
876    benchmark: &BenchmarkConfig,
877) -> ModelMemoryBreakdown {
878    let batch_size = benchmark.batch_size;
879    let seq_len = benchmark.seq_len;
880    let d_model = config.d_model;
881    let d_value = config.d_value;
882    let hidden_bytes = tensor_bytes(&[batch_size, seq_len, d_model]);
883    let state_bytes = if config.uses_matrix_state() {
884        tensor_bytes(&[batch_size, seq_len, d_model, d_value])
885    } else {
886        hidden_bytes
887    };
888    let input_ids_bytes = int_bytes(&[batch_size, seq_len]);
889    let parameter_bytes = f32_bytes(num_params);
890    let mask_bytes = tensor_bytes(&[batch_size, seq_len, seq_len]);
891    let attention_scores_bytes = tensor_bytes(&[batch_size, config.num_heads, seq_len, seq_len]);
892    let compressor_workspace_bytes = compressor_workspace_bytes(config, batch_size, seq_len);
893    let delta_branch_bytes = if variant.uses_ddl() {
894        ddl_branch_bytes(config, batch_size, seq_len)
895    } else {
896        0
897    };
898    let mlp_activation_bytes =
899        hidden_bytes + tensor_bytes(&[batch_size, seq_len, config.effective_mlp_hidden()]) * 2;
900    let logits_bytes = tensor_bytes(&[batch_size, seq_len, config.vocab_size]);
901    let qkv_bytes = hidden_bytes * 3;
902    let attn_stage_peak = state_bytes
903        + compressor_workspace_bytes
904        + hidden_bytes
905        + qkv_bytes
906        + attention_scores_bytes
907        + delta_branch_bytes;
908    let mlp_stage_peak = state_bytes
909        + compressor_workspace_bytes
910        + hidden_bytes
911        + mlp_activation_bytes
912        + delta_branch_bytes;
913    let embedding_stage_peak = input_ids_bytes
914        + hidden_bytes
915        + state_bytes
916        + if config.uses_matrix_state() && config.embed_conv {
917            state_bytes
918        } else {
919            0
920        };
921    let final_stage_peak = state_bytes + compressor_workspace_bytes + hidden_bytes + logits_bytes;
922    let peak_activation_bytes = embedding_stage_peak
923        .max(attn_stage_peak.max(mlp_stage_peak))
924        .max(final_stage_peak);
925
926    ModelMemoryBreakdown {
927        input_ids_bytes,
928        parameter_bytes,
929        embedding_bytes: hidden_bytes,
930        state_bytes,
931        residual_stream_bytes: hidden_bytes,
932        mask_bytes,
933        attention_scores_bytes,
934        compressor_workspace_bytes,
935        delta_branch_bytes,
936        mlp_activation_bytes,
937        logits_bytes,
938        peak_activation_bytes,
939        peak_total_bytes: parameter_bytes
940            .saturating_add(mask_bytes)
941            .saturating_add(peak_activation_bytes),
942    }
943}
944
945fn summarize_model_benchmarks(models: &[ModelBenchmark]) -> ModelBenchmarkSummary {
946    let baseline = models
947        .iter()
948        .find(|model| model.variant == ModelVariant::Baseline);
949    let fastest_variant = models
950        .iter()
951        .max_by(|left, right| left.tokens_per_second.total_cmp(&right.tokens_per_second))
952        .map(|model| model.variant)
953        .unwrap_or(ModelVariant::Baseline);
954    let lowest_peak_total_variant = models
955        .iter()
956        .min_by_key(|model| model.memory.peak_total_bytes)
957        .map(|model| model.variant)
958        .unwrap_or(ModelVariant::Baseline);
959    let comparisons = models
960        .iter()
961        .map(|model| ModelBenchmarkComparison {
962            variant: model.variant,
963            tokens_per_second: model.tokens_per_second,
964            throughput_ratio_vs_baseline: baseline.and_then(|baseline| {
965                ratio_f64(model.tokens_per_second, baseline.tokens_per_second)
966            }),
967            parameter_ratio_vs_baseline: baseline
968                .map(|baseline| ratio_usize(model.num_params, baseline.num_params)),
969            peak_activation_ratio_vs_baseline: baseline.map(|baseline| {
970                ratio_usize(
971                    model.memory.peak_activation_bytes,
972                    baseline.memory.peak_activation_bytes,
973                )
974            }),
975            peak_total_ratio_vs_baseline: baseline.map(|baseline| {
976                ratio_usize(
977                    model.memory.peak_total_bytes,
978                    baseline.memory.peak_total_bytes,
979                )
980            }),
981        })
982        .collect();
983
984    ModelBenchmarkSummary {
985        fastest_variant,
986        lowest_peak_total_variant,
987        comparisons,
988    }
989}
990
991fn explicit_l2_normalize<B: Backend>(input: Tensor<B, 3>) -> Tensor<B, 3> {
992    let norms = (input.clone() * input.clone()).sum_dim(2).sqrt();
993    input / norms
994}
995
996fn measure<F>(benchmark: &BenchmarkConfig, mut bench_fn: F) -> BenchmarkTiming
997where
998    F: FnMut(),
999{
1000    for _ in 0..benchmark.warmup_iterations {
1001        bench_fn();
1002    }
1003
1004    let mut samples_ms = Vec::with_capacity(benchmark.iterations);
1005    let total_start = Instant::now();
1006    for _ in 0..benchmark.iterations {
1007        let sample_start = Instant::now();
1008        bench_fn();
1009        samples_ms.push(sample_start.elapsed().as_secs_f64() * 1_000.0);
1010    }
1011    let total_duration_ms = total_start.elapsed().as_secs_f64() * 1_000.0;
1012    let avg_duration_ms = samples_ms.iter().copied().sum::<f64>() / samples_ms.len() as f64;
1013    let min_duration_ms = samples_ms.iter().copied().fold(f64::INFINITY, f64::min);
1014    let max_duration_ms = samples_ms.iter().copied().fold(0.0, f64::max);
1015
1016    BenchmarkTiming {
1017        iterations: benchmark.iterations,
1018        warmup_iterations: benchmark.warmup_iterations,
1019        total_duration_ms,
1020        avg_duration_ms,
1021        min_duration_ms,
1022        median_duration_ms: percentile(&samples_ms, 0.5),
1023        p95_duration_ms: percentile(&samples_ms, 0.95),
1024        max_duration_ms,
1025        iterations_per_second: if total_duration_ms == 0.0 {
1026            0.0
1027        } else {
1028            benchmark.iterations as f64 / (total_duration_ms / 1_000.0)
1029        },
1030    }
1031}
1032
1033fn percentile(samples: &[f64], percentile: f64) -> f64 {
1034    let mut ordered = samples.to_vec();
1035    ordered.sort_by(|left, right| left.total_cmp(right));
1036    let index = ((ordered.len() - 1) as f64 * percentile).round() as usize;
1037    ordered[index]
1038}
1039
1040fn normalize_suites(suites: &[BenchmarkSuite]) -> Vec<BenchmarkSuite> {
1041    let suites = if suites.is_empty() {
1042        BenchmarkSuite::all().to_vec()
1043    } else {
1044        suites.to_vec()
1045    };
1046
1047    dedup_preserve_order(suites)
1048}
1049
1050fn normalize_variants(variants: &[ModelVariant]) -> Vec<ModelVariant> {
1051    let variants = if variants.is_empty() {
1052        ModelVariant::all().to_vec()
1053    } else {
1054        variants.to_vec()
1055    };
1056
1057    dedup_preserve_order(variants)
1058}
1059
1060fn dedup_preserve_order<T>(items: Vec<T>) -> Vec<T>
1061where
1062    T: Copy + PartialEq,
1063{
1064    let mut deduped = Vec::with_capacity(items.len());
1065    for item in items {
1066        if !deduped.contains(&item) {
1067            deduped.push(item);
1068        }
1069    }
1070    deduped
1071}
1072
1073fn matrix_d_value(base_config: &DdlConfig) -> usize {
1074    base_config.d_value.max(4)
1075}
1076
1077fn direction_tensor<B: Backend>(rows: usize, d_model: usize, device: &B::Device) -> Tensor<B, 2> {
1078    let row = normalized_row(d_model);
1079    let mut data = Vec::with_capacity(rows * d_model);
1080    for _ in 0..rows {
1081        data.extend(row.iter().copied());
1082    }
1083    Tensor::<B, 2>::from_data(TensorData::new(data, Shape::new([rows, d_model])), device)
1084}
1085
1086fn beta_tensor<B: Backend>(rows: usize, value: f32, device: &B::Device) -> Tensor<B, 1> {
1087    let data = vec![value; rows];
1088    Tensor::<B, 1>::from_data(TensorData::new(data, Shape::new([rows])), device)
1089}
1090
1091fn float_tensor<B: Backend, const D: usize>(
1092    shape: [usize; D],
1093    offset: f32,
1094    device: &B::Device,
1095) -> Tensor<B, D> {
1096    let numel = shape.iter().product::<usize>();
1097    let data = (0..numel)
1098        .map(|idx| ((idx % 97) as f32 + 1.0) / 97.0 + offset)
1099        .collect::<Vec<_>>();
1100    Tensor::<B, D>::from_data(TensorData::new(data, Shape::new(shape)), device)
1101}
1102
1103fn int_tensor<B: Backend>(
1104    shape: [usize; 2],
1105    vocab_size: usize,
1106    device: &B::Device,
1107) -> Tensor<B, 2, Int> {
1108    let numel = shape.iter().product::<usize>();
1109    let data = (0..numel)
1110        .map(|idx| ((idx * 7 + 3) % vocab_size.max(2)) as i64)
1111        .collect::<Vec<_>>();
1112    Tensor::<B, 2, Int>::from_data(TensorData::new(data, Shape::new(shape)), device)
1113}
1114
1115fn normalized_row(d_model: usize) -> Vec<f32> {
1116    let mut row = (0..d_model).map(|idx| idx as f32 + 1.0).collect::<Vec<_>>();
1117    let norm = row.iter().map(|value| value * value).sum::<f32>().sqrt();
1118    for value in &mut row {
1119        *value /= norm;
1120    }
1121    row
1122}
1123
1124fn f32_bytes(elements: usize) -> usize {
1125    elements * size_of::<f32>()
1126}
1127
1128fn int_bytes(shape: &[usize]) -> usize {
1129    shape.iter().product::<usize>() * size_of::<i64>()
1130}
1131
1132fn tensor_bytes(shape: &[usize]) -> usize {
1133    f32_bytes(shape.iter().product())
1134}
1135
1136fn compressor_workspace_bytes(config: &DdlConfig, batch_size: usize, seq_len: usize) -> usize {
1137    if !config.uses_matrix_state() {
1138        return 0;
1139    }
1140
1141    match config.compression {
1142        CompressionVariant::TokenConv => {
1143            tensor_bytes(&[batch_size, seq_len, config.d_model, config.d_value])
1144        }
1145        CompressionVariant::ChannelConv => tensor_bytes(&[batch_size * seq_len, config.d_model, 1]),
1146    }
1147}
1148
1149fn ddl_branch_bytes(config: &DdlConfig, batch_size: usize, seq_len: usize) -> usize {
1150    let direction_bytes = tensor_bytes(&[batch_size, seq_len, config.d_model]);
1151    let beta_bytes = tensor_bytes(&[batch_size, seq_len]);
1152
1153    if config.d_value == 1 {
1154        let scalar_bytes = tensor_bytes(&[batch_size, seq_len]);
1155        direction_bytes + beta_bytes + scalar_bytes * 4 + tensor_bytes(&[batch_size, seq_len, 1])
1156    } else {
1157        let value_bytes = tensor_bytes(&[batch_size, seq_len, config.d_value]);
1158        direction_bytes
1159            + beta_bytes
1160            + value_bytes * 3
1161            + tensor_bytes(&[batch_size, seq_len, config.d_model, config.d_value])
1162    }
1163}
1164
1165fn ratio_usize(value: usize, baseline: usize) -> f64 {
1166    if baseline == 0 {
1167        0.0
1168    } else {
1169        value as f64 / baseline as f64
1170    }
1171}
1172
1173fn ratio_f64(value: f64, baseline: f64) -> Option<f64> {
1174    (baseline > 0.0).then_some(value / baseline)
1175}
1176
1177fn validate_positive_threshold(value: Option<f64>, label: &str) -> Result<(), BenchmarkError> {
1178    if let Some(value) = value
1179        && !(value.is_finite() && value > 0.0)
1180    {
1181        return Err(BenchmarkError::InvalidConfig(format!(
1182            "{label} must be finite and positive"
1183        )));
1184    }
1185
1186    Ok(())
1187}
1188
1189fn compare_case(
1190    suite: BenchmarkSuite,
1191    current: &BenchmarkCase,
1192    baseline: &BenchmarkCase,
1193) -> Result<BenchmarkDelta, BenchmarkError> {
1194    if current.label != baseline.label {
1195        return Err(BenchmarkError::ComparisonMismatch(format!(
1196            "suite {} case label mismatch: current `{}` vs baseline `{}`",
1197            suite.slug(),
1198            current.label,
1199            baseline.label
1200        )));
1201    }
1202    ensure_shapes_match(
1203        &BenchmarkTarget::Case {
1204            suite,
1205            label: current.label.clone(),
1206        },
1207        &current.input_shape,
1208        &baseline.input_shape,
1209        &current.output_shape,
1210        &baseline.output_shape,
1211    )?;
1212
1213    Ok(BenchmarkDelta {
1214        target: BenchmarkTarget::Case {
1215            suite,
1216            label: current.label.clone(),
1217        },
1218        input_shape: current.input_shape.clone(),
1219        output_shape: current.output_shape.clone(),
1220        avg_latency_ratio: compare_ratio(
1221            current.timing.avg_duration_ms,
1222            baseline.timing.avg_duration_ms,
1223            "avg latency",
1224            &BenchmarkTarget::Case {
1225                suite,
1226                label: current.label.clone(),
1227            },
1228        )?,
1229        p95_latency_ratio: compare_ratio(
1230            current.timing.p95_duration_ms,
1231            baseline.timing.p95_duration_ms,
1232            "p95 latency",
1233            &BenchmarkTarget::Case {
1234                suite,
1235                label: current.label.clone(),
1236            },
1237        )?,
1238        throughput_ratio: compare_ratio(
1239            current.timing.iterations_per_second,
1240            baseline.timing.iterations_per_second,
1241            "iterations per second",
1242            &BenchmarkTarget::Case {
1243                suite,
1244                label: current.label.clone(),
1245            },
1246        )?,
1247        peak_memory_ratio: compare_usize_ratio(
1248            current.memory.peak_live_bytes,
1249            baseline.memory.peak_live_bytes,
1250            "peak live bytes",
1251            &BenchmarkTarget::Case {
1252                suite,
1253                label: current.label.clone(),
1254            },
1255        )?,
1256        parameter_ratio: compare_optional_usize_ratio(
1257            current.memory.parameter_bytes,
1258            baseline.memory.parameter_bytes,
1259        ),
1260    })
1261}
1262
1263fn compare_model(
1264    current: &ModelBenchmark,
1265    baseline: &ModelBenchmark,
1266) -> Result<BenchmarkDelta, BenchmarkError> {
1267    let target = BenchmarkTarget::Model {
1268        variant: current.variant,
1269    };
1270    if current.variant != baseline.variant {
1271        return Err(BenchmarkError::ComparisonMismatch(format!(
1272            "model variant mismatch: current {} vs baseline {}",
1273            current.variant.slug(),
1274            baseline.variant.slug()
1275        )));
1276    }
1277    if current.resolved_config != baseline.resolved_config {
1278        return Err(BenchmarkError::ComparisonMismatch(format!(
1279            "model variant {} uses a different resolved config in the baseline report",
1280            current.variant.slug()
1281        )));
1282    }
1283
1284    ensure_shapes_match(
1285        &target,
1286        &current.input_shape,
1287        &baseline.input_shape,
1288        &current.output_shape,
1289        &baseline.output_shape,
1290    )?;
1291
1292    Ok(BenchmarkDelta {
1293        target,
1294        input_shape: current.input_shape.to_vec(),
1295        output_shape: current.output_shape.to_vec(),
1296        avg_latency_ratio: compare_ratio(
1297            current.timing.avg_duration_ms,
1298            baseline.timing.avg_duration_ms,
1299            "avg latency",
1300            &BenchmarkTarget::Model {
1301                variant: current.variant,
1302            },
1303        )?,
1304        p95_latency_ratio: compare_ratio(
1305            current.timing.p95_duration_ms,
1306            baseline.timing.p95_duration_ms,
1307            "p95 latency",
1308            &BenchmarkTarget::Model {
1309                variant: current.variant,
1310            },
1311        )?,
1312        throughput_ratio: compare_ratio(
1313            current.tokens_per_second,
1314            baseline.tokens_per_second,
1315            "tokens per second",
1316            &BenchmarkTarget::Model {
1317                variant: current.variant,
1318            },
1319        )?,
1320        peak_memory_ratio: compare_usize_ratio(
1321            current.memory.peak_total_bytes,
1322            baseline.memory.peak_total_bytes,
1323            "peak total bytes",
1324            &BenchmarkTarget::Model {
1325                variant: current.variant,
1326            },
1327        )?,
1328        parameter_ratio: compare_optional_usize_ratio(current.num_params, baseline.num_params),
1329    })
1330}
1331
1332fn ensure_shapes_match(
1333    target: &BenchmarkTarget,
1334    current_input_shape: &[usize],
1335    baseline_input_shape: &[usize],
1336    current_output_shape: &[usize],
1337    baseline_output_shape: &[usize],
1338) -> Result<(), BenchmarkError> {
1339    if current_input_shape != baseline_input_shape {
1340        return Err(BenchmarkError::ComparisonMismatch(format!(
1341            "{} input shape mismatch: current {:?} vs baseline {:?}",
1342            target.label(),
1343            current_input_shape,
1344            baseline_input_shape
1345        )));
1346    }
1347    if current_output_shape != baseline_output_shape {
1348        return Err(BenchmarkError::ComparisonMismatch(format!(
1349            "{} output shape mismatch: current {:?} vs baseline {:?}",
1350            target.label(),
1351            current_output_shape,
1352            baseline_output_shape
1353        )));
1354    }
1355    Ok(())
1356}
1357
1358fn compare_ratio(
1359    value: f64,
1360    baseline: f64,
1361    metric: &str,
1362    target: &BenchmarkTarget,
1363) -> Result<f64, BenchmarkError> {
1364    if !(baseline.is_finite() && baseline > 0.0) {
1365        return Err(BenchmarkError::ComparisonMismatch(format!(
1366            "{} has a non-positive baseline {} value: {baseline}",
1367            target.label(),
1368            metric
1369        )));
1370    }
1371    if !value.is_finite() {
1372        return Err(BenchmarkError::ComparisonMismatch(format!(
1373            "{} has a non-finite current {} value: {value}",
1374            target.label(),
1375            metric
1376        )));
1377    }
1378
1379    Ok(value / baseline)
1380}
1381
1382fn compare_usize_ratio(
1383    value: usize,
1384    baseline: usize,
1385    metric: &str,
1386    target: &BenchmarkTarget,
1387) -> Result<f64, BenchmarkError> {
1388    if baseline == 0 {
1389        return Err(BenchmarkError::ComparisonMismatch(format!(
1390            "{} has a zero baseline {metric} value",
1391            target.label()
1392        )));
1393    }
1394
1395    Ok(value as f64 / baseline as f64)
1396}
1397
1398fn compare_optional_usize_ratio(value: usize, baseline: usize) -> Option<f64> {
1399    match (value, baseline) {
1400        (0, 0) => Some(1.0),
1401        (_, 0) => None,
1402        _ => Some(value as f64 / baseline as f64),
1403    }
1404}
1405
1406fn evaluate_gate(
1407    deltas: &[BenchmarkDelta],
1408    thresholds: &BenchmarkRegressionThresholds,
1409) -> Vec<String> {
1410    let mut failures = Vec::new();
1411    for delta in deltas {
1412        let target = delta.target.label();
1413        if let Some(min_throughput_ratio) = thresholds.min_throughput_ratio
1414            && delta.throughput_ratio < min_throughput_ratio
1415        {
1416            failures.push(format!(
1417                "{target} throughput ratio {:.3} fell below {:.3}",
1418                delta.throughput_ratio, min_throughput_ratio
1419            ));
1420        }
1421        if let Some(max_avg_latency_ratio) = thresholds.max_avg_latency_ratio
1422            && delta.avg_latency_ratio > max_avg_latency_ratio
1423        {
1424            failures.push(format!(
1425                "{target} avg latency ratio {:.3} exceeded {:.3}",
1426                delta.avg_latency_ratio, max_avg_latency_ratio
1427            ));
1428        }
1429        if let Some(max_p95_latency_ratio) = thresholds.max_p95_latency_ratio
1430            && delta.p95_latency_ratio > max_p95_latency_ratio
1431        {
1432            failures.push(format!(
1433                "{target} p95 latency ratio {:.3} exceeded {:.3}",
1434                delta.p95_latency_ratio, max_p95_latency_ratio
1435            ));
1436        }
1437        if let Some(max_peak_memory_ratio) = thresholds.max_peak_memory_ratio
1438            && delta.peak_memory_ratio > max_peak_memory_ratio
1439        {
1440            failures.push(format!(
1441                "{target} peak memory ratio {:.3} exceeded {:.3}",
1442                delta.peak_memory_ratio, max_peak_memory_ratio
1443            ));
1444        }
1445    }
1446
1447    failures
1448}