Skip to main content

diskann_benchmark_simd/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! SIMD distance kernel benchmarks with regression detection.
7
8use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15    distr::{Distribution, StandardUniform},
16    rngs::StdRng,
17    SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23    benchmark::{PassFail, Regression},
24    describeln,
25    dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
26    utils::{
27        datatype::{self, DataType},
28        num::{relative_change, NonNegativeFinite},
29        percentiles, MicroSeconds,
30    },
31    Any, Benchmark, CheckDeserialization, Checker, Input,
32};
33
34////////////////
35// Public API //
36////////////////
37
38pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
39    register_benchmarks_impl(dispatcher)
40}
41
42///////////
43// Utils //
44///////////
45
46#[derive(Debug, Clone, Copy)]
47struct DisplayWrapper<'a, T: ?Sized>(&'a T);
48
49impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
50    type Target = T;
51    fn deref(&self) -> &T {
52        self.0
53    }
54}
55
56////////////
57// Inputs //
58////////////
59
60#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62pub enum SimilarityMeasure {
63    SquaredL2,
64    InnerProduct,
65    Cosine,
66}
67
68impl std::fmt::Display for SimilarityMeasure {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        let st = match self {
71            Self::SquaredL2 => "squared_l2",
72            Self::InnerProduct => "inner_product",
73            Self::Cosine => "cosine",
74        };
75        write!(f, "{}", st)
76    }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
80#[serde(rename_all = "kebab-case")]
81enum Arch {
82    #[serde(rename = "x86-64-v4")]
83    #[allow(non_camel_case_types)]
84    X86_64_V4,
85    #[serde(rename = "x86-64-v3")]
86    #[allow(non_camel_case_types)]
87    X86_64_V3,
88    Neon,
89    Scalar,
90    Reference,
91}
92
93impl std::fmt::Display for Arch {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        let st = match self {
96            Self::X86_64_V4 => "x86-64-v4",
97            Self::X86_64_V3 => "x86-64-v3",
98            Self::Neon => "neon",
99            Self::Scalar => "scalar",
100            Self::Reference => "reference",
101        };
102        write!(f, "{}", st)
103    }
104}
105
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
107struct Run {
108    distance: SimilarityMeasure,
109    dim: NonZeroUsize,
110    num_points: NonZeroUsize,
111    loops_per_measurement: NonZeroUsize,
112    num_measurements: NonZeroUsize,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct SimdOp {
117    query_type: DataType,
118    data_type: DataType,
119    arch: Arch,
120    runs: Vec<Run>,
121}
122
123impl CheckDeserialization for SimdOp {
124    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
125        Ok(())
126    }
127}
128
129macro_rules! write_field {
130    ($f:ident, $field:tt, $($expr:tt)*) => {
131        writeln!($f, "{:>18}: {}", $field, $($expr)*)
132    }
133}
134
135impl SimdOp {
136    fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write_field!(f, "query type", self.query_type)?;
138        write_field!(f, "data type", self.data_type)?;
139        write_field!(f, "arch", self.arch)?;
140        write_field!(f, "number of runs", self.runs.len())?;
141        Ok(())
142    }
143}
144
145impl std::fmt::Display for SimdOp {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        writeln!(f, "SIMD Operation\n")?;
148        write_field!(f, "tag", Self::tag())?;
149        self.summarize_fields(f)
150    }
151}
152
153impl Input for SimdOp {
154    fn tag() -> &'static str {
155        "simd-op"
156    }
157
158    fn try_deserialize(
159        serialized: &serde_json::Value,
160        checker: &mut Checker,
161    ) -> anyhow::Result<Any> {
162        checker.any(Self::deserialize(serialized)?)
163    }
164
165    fn example() -> anyhow::Result<serde_json::Value> {
166        const DIM: [NonZeroUsize; 2] = [
167            NonZeroUsize::new(128).unwrap(),
168            NonZeroUsize::new(150).unwrap(),
169        ];
170
171        const NUM_POINTS: [NonZeroUsize; 2] = [
172            NonZeroUsize::new(1000).unwrap(),
173            NonZeroUsize::new(800).unwrap(),
174        ];
175
176        const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
177        const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
178
179        let runs = vec![
180            Run {
181                distance: SimilarityMeasure::SquaredL2,
182                dim: DIM[0],
183                num_points: NUM_POINTS[0],
184                loops_per_measurement: LOOPS_PER_MEASUREMENT,
185                num_measurements: NUM_MEASUREMENTS,
186            },
187            Run {
188                distance: SimilarityMeasure::InnerProduct,
189                dim: DIM[1],
190                num_points: NUM_POINTS[1],
191                loops_per_measurement: LOOPS_PER_MEASUREMENT,
192                num_measurements: NUM_MEASUREMENTS,
193            },
194        ];
195
196        Ok(serde_json::to_value(&Self {
197            query_type: DataType::Float32,
198            data_type: DataType::Float32,
199            arch: Arch::X86_64_V3,
200            runs,
201        })?)
202    }
203}
204
205//////////////////////
206// Regression Check //
207//////////////////////
208
209/// Tolerance thresholds for SIMD benchmark regression detection.
210///
211/// Each field specifies the maximum allowed relative increase in the corresponding metric.
212/// For example, a value of `0.10` means a 10% increase is tolerated.
213#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
214struct SimdTolerance {
215    min_time_regression: NonNegativeFinite,
216}
217
218impl CheckDeserialization for SimdTolerance {
219    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
220        Ok(())
221    }
222}
223
224impl Input for SimdTolerance {
225    fn tag() -> &'static str {
226        "simd-tolerance"
227    }
228
229    fn try_deserialize(
230        serialized: &serde_json::Value,
231        checker: &mut Checker,
232    ) -> anyhow::Result<Any> {
233        checker.any(Self::deserialize(serialized)?)
234    }
235
236    fn example() -> anyhow::Result<serde_json::Value> {
237        const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
238            Ok(v) => v,
239            Err(_) => panic!("use a non-negative finite please"),
240        };
241
242        Ok(serde_json::to_value(SimdTolerance {
243            min_time_regression: EXAMPLE,
244        })?)
245    }
246}
247
248/// Per-run comparison result showing before/after percentile differences.
249#[derive(Debug, Serialize)]
250struct Comparison {
251    run: Run,
252    tolerance: SimdTolerance,
253    before_min: f64,
254    after_min: f64,
255}
256
257/// Aggregated result of the regression check across all runs.
258#[derive(Debug, Serialize)]
259struct CheckResult {
260    checks: Vec<Comparison>,
261}
262
263impl std::fmt::Display for CheckResult {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        let header = [
266            "Distance",
267            "Dim",
268            "Min Before (ns)",
269            "Min After (ns)",
270            "Change (%)",
271            "Remark",
272        ];
273
274        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
275
276        for (i, c) in self.checks.iter().enumerate() {
277            let mut row = table.row(i);
278            let change = relative_change(c.before_min, c.after_min);
279
280            row.insert(c.run.distance, 0);
281            row.insert(c.run.dim, 1);
282            row.insert(format!("{:.3}", c.before_min), 2);
283            row.insert(format!("{:.3}", c.after_min), 3);
284            match change {
285                Ok(change) => {
286                    row.insert(format!("{:.3} %", change * 100.0), 4);
287                    if change > c.tolerance.min_time_regression.get() {
288                        row.insert("FAIL", 5);
289                    }
290                }
291                Err(err) => {
292                    row.insert("invalid", 4);
293                    row.insert(err, 5);
294                }
295            }
296        }
297
298        table.fmt(f)
299    }
300}
301
302////////////////////////////
303// Benchmark Registration //
304////////////////////////////
305
306macro_rules! register {
307    ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
308        #[cfg(target_arch = $arch)]
309        $dispatcher.register_regression::<$($kernel)*>($name)
310    };
311    ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
312        $dispatcher.register_regression::<$($kernel)*>($name)
313    };
314}
315
316fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
317    // x86-64-v4
318    register!(
319        "x86_64",
320        dispatcher,
321        "simd-op-f32xf32-x86_64_V4",
322        Kernel<diskann_wide::arch::x86_64::V4, f32, f32>
323    );
324    register!(
325        "x86_64",
326        dispatcher,
327        "simd-op-f16xf16-x86_64_V4",
328        Kernel<diskann_wide::arch::x86_64::V4, f16, f16>
329    );
330    register!(
331        "x86_64",
332        dispatcher,
333        "simd-op-u8xu8-x86_64_V4",
334        Kernel<diskann_wide::arch::x86_64::V4, u8, u8>
335    );
336    register!(
337        "x86_64",
338        dispatcher,
339        "simd-op-i8xi8-x86_64_V4",
340        Kernel<diskann_wide::arch::x86_64::V4, i8, i8>
341    );
342
343    // x86-64-v3
344    register!(
345        "x86_64",
346        dispatcher,
347        "simd-op-f32xf32-x86_64_V3",
348        Kernel<diskann_wide::arch::x86_64::V3, f32, f32>
349    );
350    register!(
351        "x86_64",
352        dispatcher,
353        "simd-op-f16xf16-x86_64_V3",
354        Kernel<diskann_wide::arch::x86_64::V3, f16, f16>
355    );
356    register!(
357        "x86_64",
358        dispatcher,
359        "simd-op-u8xu8-x86_64_V3",
360        Kernel<diskann_wide::arch::x86_64::V3, u8, u8>
361    );
362    register!(
363        "x86_64",
364        dispatcher,
365        "simd-op-i8xi8-x86_64_V3",
366        Kernel<diskann_wide::arch::x86_64::V3, i8, i8>
367    );
368
369    // aarch64-neon
370    register!(
371        "aarch64",
372        dispatcher,
373        "simd-op-f32xf32-aarch64_neon",
374        Kernel<diskann_wide::arch::aarch64::Neon, f32, f32>
375    );
376    register!(
377        "aarch64",
378        dispatcher,
379        "simd-op-f16xf16-aarch64_neon",
380        Kernel<diskann_wide::arch::aarch64::Neon, f16, f16>
381    );
382    register!(
383        "aarch64",
384        dispatcher,
385        "simd-op-u8xu8-aarch64_neon",
386        Kernel<diskann_wide::arch::aarch64::Neon, u8, u8>
387    );
388    register!(
389        "aarch64",
390        dispatcher,
391        "simd-op-i8xi8-aarch64_neon",
392        Kernel<diskann_wide::arch::aarch64::Neon, i8, i8>
393    );
394
395    // scalar
396    register!(
397        dispatcher,
398        "simd-op-f32xf32-scalar",
399        Kernel<diskann_wide::arch::Scalar, f32, f32>
400    );
401    register!(
402        dispatcher,
403        "simd-op-f16xf16-scalar",
404        Kernel<diskann_wide::arch::Scalar, f16, f16>
405    );
406    register!(
407        dispatcher,
408        "simd-op-u8xu8-scalar",
409        Kernel<diskann_wide::arch::Scalar, u8, u8>
410    );
411    register!(
412        dispatcher,
413        "simd-op-i8xi8-scalar",
414        Kernel<diskann_wide::arch::Scalar, i8, i8>
415    );
416
417    // reference
418    register!(
419        dispatcher,
420        "simd-op-f32xf32-reference",
421        Kernel<Reference, f32, f32>
422    );
423    register!(
424        dispatcher,
425        "simd-op-f16xf16-reference",
426        Kernel<Reference, f16, f16>
427    );
428    register!(
429        dispatcher,
430        "simd-op-u8xu8-reference",
431        Kernel<Reference, u8, u8>
432    );
433    register!(
434        dispatcher,
435        "simd-op-i8xi8-reference",
436        Kernel<Reference, i8, i8>
437    );
438}
439
440//////////////
441// Dispatch //
442//////////////
443
444/// Dispatch receiver for the reference implementations.
445struct Reference;
446
447/// A dispatch mapper for `wide` types.
448#[derive(Debug)]
449struct Identity<T>(T);
450
451struct Kernel<A, Q, D> {
452    arch: A,
453    _type: std::marker::PhantomData<(A, Q, D)>,
454}
455
456impl<A, Q, D> Kernel<A, Q, D> {
457    fn new(arch: A) -> Self {
458        Self {
459            arch,
460            _type: std::marker::PhantomData,
461        }
462    }
463}
464
465// Map Architectures to the enum.
466#[derive(Debug, Error)]
467#[error("architecture {0} is not supported by this CPU")]
468pub(crate) struct ArchNotSupported(Arch);
469
470impl DispatchRule<Arch> for Identity<Reference> {
471    type Error = ArchNotSupported;
472
473    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
474        if *from == Arch::Reference {
475            Ok(MatchScore(0))
476        } else {
477            Err(FailureScore(1))
478        }
479    }
480
481    fn convert(from: Arch) -> Result<Self, Self::Error> {
482        assert_eq!(from, Arch::Reference);
483        Ok(Identity(Reference))
484    }
485
486    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
487        match from {
488            None => write!(f, "loop based"),
489            Some(arch) => {
490                if Self::try_match(arch).is_ok() {
491                    write!(f, "matched {}", arch)
492                } else {
493                    write!(f, "expected {}, got {}", Arch::Reference, arch)
494                }
495            }
496        }
497    }
498}
499
500impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
501    type Error = ArchNotSupported;
502
503    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
504        if *from == Arch::Scalar {
505            Ok(MatchScore(0))
506        } else {
507            Err(FailureScore(1))
508        }
509    }
510
511    fn convert(from: Arch) -> Result<Self, Self::Error> {
512        assert_eq!(from, Arch::Scalar);
513        Ok(Identity(diskann_wide::arch::Scalar))
514    }
515
516    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
517        match from {
518            None => write!(f, "scalar (compilation target CPU)"),
519            Some(arch) => {
520                if Self::try_match(arch).is_ok() {
521                    write!(f, "matched {}", arch)
522                } else {
523                    write!(f, "expected {}, got {}", Arch::Scalar, arch)
524                }
525            }
526        }
527    }
528}
529
530macro_rules! match_arch {
531    ($target_arch:literal, $arch:path, $enum:ident) => {
532        #[cfg(target_arch = $target_arch)]
533        impl DispatchRule<Arch> for Identity<$arch> {
534            type Error = ArchNotSupported;
535
536            fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
537                let available = <$arch>::new_checked().is_some();
538                if available && *from == Arch::$enum {
539                    Ok(MatchScore(0))
540                } else if !available && *from == Arch::$enum {
541                    Err(FailureScore(0))
542                } else {
543                    Err(FailureScore(1))
544                }
545            }
546
547            fn convert(from: Arch) -> Result<Self, Self::Error> {
548                assert_eq!(from, Arch::$enum);
549                <$arch>::new_checked()
550                    .ok_or(ArchNotSupported(from))
551                    .map(Identity)
552            }
553
554            fn description(
555                f: &mut std::fmt::Formatter<'_>,
556                from: Option<&Arch>,
557            ) -> std::fmt::Result {
558                let available = <$arch>::new_checked().is_some();
559                match from {
560                    None => write!(f, "{}", Arch::$enum),
561                    Some(arch) => {
562                        if Self::try_match(arch).is_ok() {
563                            write!(f, "matched {}", arch)
564                        } else if !available && *arch == Arch::$enum {
565                            write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
566                        } else {
567                            write!(f, "expected {}, got {}", Arch::$enum, arch)
568                        }
569                    }
570                }
571            }
572        }
573    };
574}
575
576match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
577match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
578match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
579
580impl<A, Q, D> Benchmark for Kernel<A, Q, D>
581where
582    datatype::Type<Q>: DispatchRule<datatype::DataType>,
583    datatype::Type<D>: DispatchRule<datatype::DataType>,
584    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
585    Kernel<A, Q, D>: RunBenchmark,
586{
587    type Input = SimdOp;
588    type Output = Vec<RunResult>;
589
590    // Matching simply requires that we match the inner type.
591    fn try_match(from: &SimdOp) -> Result<MatchScore, FailureScore> {
592        let mut failscore: Option<u32> = None;
593        if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
594            *failscore.get_or_insert(0) += 10;
595        }
596        if datatype::Type::<D>::try_match(&from.data_type).is_err() {
597            *failscore.get_or_insert(0) += 10;
598        }
599        if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
600            *failscore.get_or_insert(0) += 2 + score;
601        }
602
603        match failscore {
604            None => Ok(MatchScore(0)),
605            Some(score) => Err(FailureScore(score)),
606        }
607    }
608
609    fn run(
610        input: &SimdOp,
611        _: diskann_benchmark_runner::Checkpoint<'_>,
612        mut output: &mut dyn diskann_benchmark_runner::Output,
613    ) -> anyhow::Result<Self::Output> {
614        let arch = Identity::<A>::convert(input.arch)?.0;
615        let kernel = Self::new(arch);
616        writeln!(output, "{}", input)?;
617        let results = kernel.run(input)?;
618        writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
619        Ok(results)
620    }
621
622    fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&SimdOp>) -> std::fmt::Result {
623        match input {
624            None => {
625                describeln!(
626                    f,
627                    "- Query Type: {}",
628                    Description::<datatype::DataType, datatype::Type<Q>>::new()
629                )?;
630                describeln!(
631                    f,
632                    "- Data Type: {}",
633                    Description::<datatype::DataType, datatype::Type<D>>::new()
634                )?;
635                describeln!(
636                    f,
637                    "- Implementation: {}",
638                    Description::<Arch, Identity<A>>::new()
639                )?;
640            }
641            Some(input) => {
642                if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
643                    describeln!(f, "\n    - Mismatched query type: {}", err)?;
644                }
645                if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
646                    describeln!(f, "\n    - Mismatched data type: {}", err)?;
647                }
648                if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
649                    describeln!(f, "\n    - Mismatched architecture: {}", err)?;
650                }
651            }
652        }
653        Ok(())
654    }
655}
656
657impl<A, Q, D> Regression for Kernel<A, Q, D>
658where
659    datatype::Type<Q>: DispatchRule<datatype::DataType>,
660    datatype::Type<D>: DispatchRule<datatype::DataType>,
661    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
662    Kernel<A, Q, D>: RunBenchmark,
663{
664    type Tolerances = SimdTolerance;
665    type Pass = CheckResult;
666    type Fail = CheckResult;
667
668    fn check(
669        tolerance: &SimdTolerance,
670        _input: &SimdOp,
671        before: &Vec<RunResult>,
672        after: &Vec<RunResult>,
673    ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
674        anyhow::ensure!(
675            before.len() == after.len(),
676            "before has {} runs but after has {}",
677            before.len(),
678            after.len(),
679        );
680
681        let mut passed = true;
682        let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
683            .enumerate()
684            .map(|(i, (b, a))| {
685                anyhow::ensure!(b.run == a.run, "run {i} mismatched");
686
687                let computations_per_latency = b.computations_per_latency() as f64;
688
689                let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
690                let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
691
692                let comparison = Comparison {
693                    run: b.run.clone(),
694                    tolerance: *tolerance,
695                    before_min,
696                    after_min,
697                };
698
699                // Determine whether or not we pass.
700                match relative_change(before_min, after_min) {
701                    Ok(change) => {
702                        if change > tolerance.min_time_regression.get() {
703                            passed = false;
704                        }
705                    }
706                    Err(_) => passed = false,
707                };
708
709                Ok(comparison)
710            })
711            .collect::<anyhow::Result<Vec<Comparison>>>()?;
712
713        let check = CheckResult { checks };
714
715        if passed {
716            Ok(PassFail::Pass(check))
717        } else {
718            Ok(PassFail::Fail(check))
719        }
720    }
721}
722
723///////////////
724// Benchmark //
725///////////////
726
727trait RunBenchmark {
728    fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error>;
729}
730
731#[derive(Debug, Serialize, Deserialize)]
732struct RunResult {
733    /// The configuration for this run.
734    run: Run,
735    /// The latencies of individual runs.
736    latencies: Vec<MicroSeconds>,
737    /// Latency percentiles.
738    percentiles: percentiles::Percentiles<MicroSeconds>,
739}
740
741impl RunResult {
742    fn computations_per_latency(&self) -> usize {
743        self.run.num_points.get() * self.run.loops_per_measurement.get()
744    }
745}
746
747impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
748    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749        if self.is_empty() {
750            return Ok(());
751        }
752
753        let header = [
754            "Distance",
755            "Dim",
756            "Min Time (ns)",
757            "Mean Time (ns)",
758            "Points",
759            "Loops",
760            "Measurements",
761        ];
762
763        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
764
765        self.iter().enumerate().for_each(|(row, r)| {
766            let mut row = table.row(row);
767
768            let min_latency = r
769                .latencies
770                .iter()
771                .min()
772                .copied()
773                .unwrap_or(MicroSeconds::new(u64::MAX));
774            let mean_latency = r.percentiles.mean;
775
776            let computations_per_latency = r.computations_per_latency() as f64;
777
778            // Convert time from micro-seconds to nano-seconds.
779            let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
780            let mean_time = mean_latency / computations_per_latency * 1000.0;
781
782            row.insert(r.run.distance, 0);
783            row.insert(r.run.dim, 1);
784            row.insert(format!("{:.3}", min_time), 2);
785            row.insert(format!("{:.3}", mean_time), 3);
786            row.insert(r.run.num_points, 4);
787            row.insert(r.run.loops_per_measurement, 5);
788            row.insert(r.run.num_measurements, 6);
789        });
790
791        table.fmt(f)
792    }
793}
794
795fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
796where
797    F: Fn(&[Q], &[D]) -> f32,
798{
799    let mut latencies = Vec::with_capacity(run.num_measurements.get());
800    let mut dst = vec![0.0; data.nrows()];
801
802    for _ in 0..run.num_measurements.get() {
803        let start = std::time::Instant::now();
804        for _ in 0..run.loops_per_measurement.get() {
805            std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
806                *d = f(query, r);
807            });
808            std::hint::black_box(&mut dst);
809        }
810        latencies.push(start.elapsed().into());
811    }
812
813    let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
814    RunResult {
815        run: run.clone(),
816        latencies,
817        percentiles,
818    }
819}
820
821struct Data<Q, D> {
822    query: Box<[Q]>,
823    data: Matrix<D>,
824}
825
826impl<Q, D> Data<Q, D> {
827    fn new(run: &Run) -> Self
828    where
829        StandardUniform: Distribution<Q>,
830        StandardUniform: Distribution<D>,
831    {
832        let mut rng = StdRng::seed_from_u64(0x12345);
833        let query: Box<[Q]> = (0..run.dim.get())
834            .map(|_| StandardUniform.sample(&mut rng))
835            .collect();
836        let data = Matrix::<D>::new(
837            diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
838            run.num_points.get(),
839            run.dim.get(),
840        );
841
842        Self { query, data }
843    }
844
845    fn run<F>(&self, run: &Run, f: F) -> RunResult
846    where
847        F: Fn(&[Q], &[D]) -> f32,
848    {
849        run_loops(&self.query, self.data.as_view(), run, f)
850    }
851}
852
853/////////////////////
854// Implementations //
855/////////////////////
856
857macro_rules! stamp {
858    (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
859        impl RunBenchmark for Kernel<Reference, $Q, $D> {
860            fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error> {
861                let mut results = Vec::new();
862                for run in input.runs.iter() {
863                    let data = Data::<$Q, $D>::new(run);
864                    let result = match run.distance {
865                        SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
866                        SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
867                        SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
868                    };
869                    results.push(result);
870                }
871                Ok(results)
872            }
873        }
874    };
875    ($arch:path, $Q:ty, $D:ty) => {
876        impl RunBenchmark for Kernel<$arch, $Q, $D> {
877            fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error> {
878                let mut results = Vec::new();
879
880                let l2 = &simd::L2 {};
881                let ip = &simd::IP {};
882                let cosine = &simd::CosineStateless {};
883
884                for run in input.runs.iter() {
885                    let data = Data::<$Q, $D>::new(run);
886                    // For each kernel, we need to do a two-step wrapping of closures so
887                    // the inner-most closure is executed by the architecture.
888                    //
889                    // This is required for the implementation of `simd_op` to be inlined
890                    // into the architecture run function so it can properly inherit
891                    // target features.
892                    let result = match run.distance {
893                        SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
894                            self.arch
895                                .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
896                        }),
897                        SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
898                            self.arch
899                                .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
900                        }),
901                        SimilarityMeasure::Cosine => data.run(run, |q, d| {
902                            self.arch
903                                .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
904                        }),
905                    };
906                    results.push(result)
907                }
908                Ok(results)
909            }
910        }
911    };
912    ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
913        #[cfg(target_arch = $target_arch)]
914        stamp!($arch, $Q, $D);
915    };
916}
917
918stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
919stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
920stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
921stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
922
923stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
924stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
925stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
926stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
927
928stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
929stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
930stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
931stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
932
933stamp!(diskann_wide::arch::Scalar, f32, f32);
934stamp!(diskann_wide::arch::Scalar, f16, f16);
935stamp!(diskann_wide::arch::Scalar, u8, u8);
936stamp!(diskann_wide::arch::Scalar, i8, i8);
937
938stamp!(
939    reference,
940    f32,
941    f32,
942    squared_l2_f32,
943    inner_product_f32,
944    cosine_f32
945);
946stamp!(
947    reference,
948    f16,
949    f16,
950    squared_l2_f16,
951    inner_product_f16,
952    cosine_f16
953);
954stamp!(
955    reference,
956    u8,
957    u8,
958    squared_l2_u8,
959    inner_product_u8,
960    cosine_u8
961);
962stamp!(
963    reference,
964    i8,
965    i8,
966    squared_l2_i8,
967    inner_product_i8,
968    cosine_i8
969);
970
971///////////////
972// Reference //
973///////////////
974
975// These are largely copied from the implementations in vector, with a tweak that we don't
976// use FMA when the current architecture is scalar.
977mod reference {
978    use diskann_wide::ARCH;
979    use half::f16;
980
981    trait MaybeFMA {
982        // Perform `a*b + c` using FMA when a hardware instruction is guaranteed to be
983        // available, otherwise decompose into a multiply and add.
984        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
985    }
986
987    impl MaybeFMA for diskann_wide::arch::Scalar {
988        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
989            a * b + c
990        }
991    }
992
993    #[cfg(target_arch = "x86_64")]
994    impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
995        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
996            a.mul_add(b, c)
997        }
998    }
999
1000    #[cfg(target_arch = "x86_64")]
1001    impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
1002        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
1003            a.mul_add(b, c)
1004        }
1005    }
1006
1007    #[cfg(target_arch = "aarch64")]
1008    impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
1009        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
1010            a.mul_add(b, c)
1011        }
1012    }
1013
1014    //------------//
1015    // Squared L2 //
1016    //------------//
1017
1018    pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
1019        assert_eq!(x.len(), y.len());
1020        std::iter::zip(x.iter(), y.iter())
1021            .map(|(&a, &b)| {
1022                let a: i32 = a.into();
1023                let b: i32 = b.into();
1024                let diff = a - b;
1025                diff * diff
1026            })
1027            .sum::<i32>() as f32
1028    }
1029
1030    pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
1031        assert_eq!(x.len(), y.len());
1032        std::iter::zip(x.iter(), y.iter())
1033            .map(|(&a, &b)| {
1034                let a: i32 = a.into();
1035                let b: i32 = b.into();
1036                let diff = a - b;
1037                diff * diff
1038            })
1039            .sum::<i32>() as f32
1040    }
1041
1042    pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
1043        assert_eq!(x.len(), y.len());
1044        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1045            let a: f32 = a.into();
1046            let b: f32 = b.into();
1047            let diff = a - b;
1048            ARCH.maybe_fma(diff, diff, acc)
1049        })
1050    }
1051
1052    pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1053        assert_eq!(x.len(), y.len());
1054        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1055            let diff = a - b;
1056            ARCH.maybe_fma(diff, diff, acc)
1057        })
1058    }
1059
1060    //---------------//
1061    // Inner Product //
1062    //---------------//
1063
1064    pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1065        assert_eq!(x.len(), y.len());
1066        std::iter::zip(x.iter(), y.iter())
1067            .map(|(&a, &b)| {
1068                let a: i32 = a.into();
1069                let b: i32 = b.into();
1070                a * b
1071            })
1072            .sum::<i32>() as f32
1073    }
1074
1075    pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1076        assert_eq!(x.len(), y.len());
1077        std::iter::zip(x.iter(), y.iter())
1078            .map(|(&a, &b)| {
1079                let a: i32 = a.into();
1080                let b: i32 = b.into();
1081                a * b
1082            })
1083            .sum::<i32>() as f32
1084    }
1085
1086    pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1087        assert_eq!(x.len(), y.len());
1088        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1089            let a: f32 = a.into();
1090            let b: f32 = b.into();
1091            ARCH.maybe_fma(a, b, acc)
1092        })
1093    }
1094
1095    pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1096        assert_eq!(x.len(), y.len());
1097        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1098    }
1099
1100    //--------//
1101    // Cosine //
1102    //--------//
1103
1104    #[derive(Default)]
1105    struct XY<T> {
1106        xnorm: T,
1107        ynorm: T,
1108        xy: T,
1109    }
1110
1111    pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1112        assert_eq!(x.len(), y.len());
1113        let r: XY<i32> =
1114            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1115                let vx: i32 = vx.into();
1116                let vy: i32 = vy.into();
1117                XY {
1118                    xnorm: acc.xnorm + vx * vx,
1119                    ynorm: acc.ynorm + vy * vy,
1120                    xy: acc.xy + vx * vy,
1121                }
1122            });
1123
1124        if r.xnorm == 0 || r.ynorm == 0 {
1125            return 0.0;
1126        }
1127
1128        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1129    }
1130
1131    pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1132        assert_eq!(x.len(), y.len());
1133        let r: XY<i32> =
1134            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1135                let vx: i32 = vx.into();
1136                let vy: i32 = vy.into();
1137                XY {
1138                    xnorm: acc.xnorm + vx * vx,
1139                    ynorm: acc.ynorm + vy * vy,
1140                    xy: acc.xy + vx * vy,
1141                }
1142            });
1143
1144        if r.xnorm == 0 || r.ynorm == 0 {
1145            return 0.0;
1146        }
1147
1148        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1149    }
1150
1151    pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1152        assert_eq!(x.len(), y.len());
1153        let r: XY<f32> =
1154            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1155                let vx: f32 = vx.into();
1156                let vy: f32 = vy.into();
1157                XY {
1158                    xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1159                    ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1160                    xy: ARCH.maybe_fma(vx, vy, acc.xy),
1161                }
1162            });
1163
1164        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1165            return 0.0;
1166        }
1167
1168        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1169    }
1170
1171    pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1172        assert_eq!(x.len(), y.len());
1173        let r: XY<f32> =
1174            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1175                xnorm: vx.mul_add(vx, acc.xnorm),
1176                ynorm: vy.mul_add(vy, acc.ynorm),
1177                xy: vx.mul_add(vy, acc.xy),
1178            });
1179
1180        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1181            return 0.0;
1182        }
1183
1184        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1185    }
1186}
1187
1188///////////
1189// Tests //
1190///////////
1191
1192#[cfg(test)]
1193mod tests {
1194    use super::*;
1195
1196    use diskann_benchmark_runner::{
1197        benchmark::{PassFail, Regression},
1198        utils::percentiles::compute_percentiles,
1199    };
1200
1201    fn tiny_run(distance: SimilarityMeasure) -> Run {
1202        Run {
1203            distance,
1204            dim: NonZeroUsize::new(8).unwrap(),
1205            num_points: NonZeroUsize::new(1).unwrap(),
1206            loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1207            num_measurements: NonZeroUsize::new(1).unwrap(),
1208        }
1209    }
1210
1211    fn tiny_op() -> SimdOp {
1212        SimdOp {
1213            query_type: DataType::Float32,
1214            data_type: DataType::Float32,
1215            arch: Arch::Scalar,
1216            runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1217        }
1218    }
1219
1220    fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1221        let run = tiny_run(distance);
1222        let minimum = MicroSeconds::new(minimum);
1223        let mut latencies = vec![minimum];
1224        let percentiles = compute_percentiles(&mut latencies).unwrap();
1225        RunResult {
1226            run,
1227            latencies,
1228            percentiles,
1229        }
1230    }
1231
1232    fn tolerance(limit: f64) -> SimdTolerance {
1233        SimdTolerance {
1234            min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1235        }
1236    }
1237
1238    #[test]
1239    fn check_rejects_mismatched_runs() {
1240        type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1241
1242        let err = Bench::check(
1243            &tolerance(0.0),
1244            &tiny_op(),
1245            &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1246            &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1247        )
1248        .unwrap_err();
1249
1250        assert_eq!(err.to_string(), "run 0 mismatched");
1251    }
1252
1253    #[test]
1254    fn check_allows_negative_relative_change() {
1255        type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1256
1257        let result = Bench::check(
1258            &tolerance(0.0),
1259            &tiny_op(),
1260            &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1261            &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1262        )
1263        .unwrap();
1264
1265        assert!(matches!(result, PassFail::Pass(_)));
1266    }
1267
1268    #[test]
1269    fn check_passes_on_tolerance_boundary() {
1270        type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1271
1272        let result = Bench::check(
1273            &tolerance(0.05),
1274            &tiny_op(),
1275            &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1276            &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1277        )
1278        .unwrap();
1279
1280        assert!(matches!(result, PassFail::Pass(_)));
1281    }
1282
1283    #[test]
1284    fn check_fails_above_tolerance_boundary() {
1285        type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1286
1287        let result = Bench::check(
1288            &tolerance(0.05),
1289            &tiny_op(),
1290            &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1291            &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1292        )
1293        .unwrap();
1294
1295        assert!(matches!(result, PassFail::Fail(_)));
1296    }
1297
1298    #[test]
1299    fn check_result_display_includes_failure_details() {
1300        let check = CheckResult {
1301            checks: vec![Comparison {
1302                run: tiny_run(SimilarityMeasure::SquaredL2),
1303                tolerance: tolerance(0.05),
1304                before_min: 100.0,
1305                after_min: 106.0,
1306            }],
1307        };
1308
1309        let rendered = check.to_string();
1310        assert!(rendered.contains("Distance"), "rendered = {rendered}");
1311        assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1312        assert!(rendered.contains("100.000"), "rendered = {rendered}");
1313        assert!(rendered.contains("106.000"), "rendered = {rendered}");
1314        assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1315        assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1316    }
1317
1318    // If a "before" value is 0, we should fail with an error because this means the
1319    // measurement was too fast for us to obtain a reliable signal, so we *could* be letting
1320    // a regression through.
1321    //
1322    // We require at least a non-zero value.
1323    #[test]
1324    fn zero_values_rejected() {
1325        type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1326
1327        let result = Bench::check(
1328            &tolerance(0.05),
1329            &tiny_op(),
1330            &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1331            &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1332        )
1333        .unwrap();
1334
1335        assert!(matches!(result, PassFail::Fail(_)));
1336    }
1337}