Skip to main content

diskann_benchmark_simd/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! SIMD distance kernel benchmarks with regression detection.
7
8use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15    distr::{Distribution, StandardUniform},
16    rngs::StdRng,
17    SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23    benchmark::{PassFail, Regression},
24    dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
25    utils::{
26        datatype::{self, DataType},
27        num::{relative_change, NonNegativeFinite},
28        percentiles, MicroSeconds,
29    },
30    Any, Benchmark, CheckDeserialization, Checker, Input,
31};
32
33////////////////
34// Public API //
35////////////////
36
37pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38    register_benchmarks_impl(dispatcher)
39}
40
41///////////
42// Utils //
43///////////
44
45#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49    type Target = T;
50    fn deref(&self) -> &T {
51        self.0
52    }
53}
54
55////////////
56// Inputs //
57////////////
58
59#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62    SquaredL2,
63    InnerProduct,
64    Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        let st = match self {
70            Self::SquaredL2 => "squared_l2",
71            Self::InnerProduct => "inner_product",
72            Self::Cosine => "cosine",
73        };
74        write!(f, "{}", st)
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80enum Arch {
81    #[serde(rename = "x86-64-v4")]
82    #[allow(non_camel_case_types)]
83    X86_64_V4,
84    #[serde(rename = "x86-64-v3")]
85    #[allow(non_camel_case_types)]
86    X86_64_V3,
87    Neon,
88    Scalar,
89    Reference,
90}
91
92impl std::fmt::Display for Arch {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        let st = match self {
95            Self::X86_64_V4 => "x86-64-v4",
96            Self::X86_64_V3 => "x86-64-v3",
97            Self::Neon => "neon",
98            Self::Scalar => "scalar",
99            Self::Reference => "reference",
100        };
101        write!(f, "{}", st)
102    }
103}
104
105#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
106struct Run {
107    distance: SimilarityMeasure,
108    dim: NonZeroUsize,
109    num_points: NonZeroUsize,
110    loops_per_measurement: NonZeroUsize,
111    num_measurements: NonZeroUsize,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct SimdOp {
116    query_type: DataType,
117    data_type: DataType,
118    arch: Arch,
119    runs: Vec<Run>,
120}
121
122impl CheckDeserialization for SimdOp {
123    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
124        Ok(())
125    }
126}
127
128macro_rules! write_field {
129    ($f:ident, $field:tt, $($expr:tt)*) => {
130        writeln!($f, "{:>18}: {}", $field, $($expr)*)
131    }
132}
133
134impl SimdOp {
135    fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        write_field!(f, "query type", self.query_type)?;
137        write_field!(f, "data type", self.data_type)?;
138        write_field!(f, "arch", self.arch)?;
139        write_field!(f, "number of runs", self.runs.len())?;
140        Ok(())
141    }
142}
143
144impl std::fmt::Display for SimdOp {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        writeln!(f, "SIMD Operation\n")?;
147        write_field!(f, "tag", Self::tag())?;
148        self.summarize_fields(f)
149    }
150}
151
152impl Input for SimdOp {
153    fn tag() -> &'static str {
154        "simd-op"
155    }
156
157    fn try_deserialize(
158        serialized: &serde_json::Value,
159        checker: &mut Checker,
160    ) -> anyhow::Result<Any> {
161        checker.any(Self::deserialize(serialized)?)
162    }
163
164    fn example() -> anyhow::Result<serde_json::Value> {
165        const DIM: [NonZeroUsize; 2] = [
166            NonZeroUsize::new(128).unwrap(),
167            NonZeroUsize::new(150).unwrap(),
168        ];
169
170        const NUM_POINTS: [NonZeroUsize; 2] = [
171            NonZeroUsize::new(1000).unwrap(),
172            NonZeroUsize::new(800).unwrap(),
173        ];
174
175        const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
176        const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
177
178        let runs = vec![
179            Run {
180                distance: SimilarityMeasure::SquaredL2,
181                dim: DIM[0],
182                num_points: NUM_POINTS[0],
183                loops_per_measurement: LOOPS_PER_MEASUREMENT,
184                num_measurements: NUM_MEASUREMENTS,
185            },
186            Run {
187                distance: SimilarityMeasure::InnerProduct,
188                dim: DIM[1],
189                num_points: NUM_POINTS[1],
190                loops_per_measurement: LOOPS_PER_MEASUREMENT,
191                num_measurements: NUM_MEASUREMENTS,
192            },
193        ];
194
195        Ok(serde_json::to_value(&Self {
196            query_type: DataType::Float32,
197            data_type: DataType::Float32,
198            arch: Arch::X86_64_V3,
199            runs,
200        })?)
201    }
202}
203
204//////////////////////
205// Regression Check //
206//////////////////////
207
208/// Tolerance thresholds for SIMD benchmark regression detection.
209///
210/// Each field specifies the maximum allowed relative increase in the corresponding metric.
211/// For example, a value of `0.10` means a 10% increase is tolerated.
212#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
213struct SimdTolerance {
214    min_time_regression: NonNegativeFinite,
215}
216
217impl CheckDeserialization for SimdTolerance {
218    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
219        Ok(())
220    }
221}
222
223impl Input for SimdTolerance {
224    fn tag() -> &'static str {
225        "simd-tolerance"
226    }
227
228    fn try_deserialize(
229        serialized: &serde_json::Value,
230        checker: &mut Checker,
231    ) -> anyhow::Result<Any> {
232        checker.any(Self::deserialize(serialized)?)
233    }
234
235    fn example() -> anyhow::Result<serde_json::Value> {
236        const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
237            Ok(v) => v,
238            Err(_) => panic!("use a non-negative finite please"),
239        };
240
241        Ok(serde_json::to_value(SimdTolerance {
242            min_time_regression: EXAMPLE,
243        })?)
244    }
245}
246
247/// Per-run comparison result showing before/after percentile differences.
248#[derive(Debug, Serialize)]
249struct Comparison {
250    run: Run,
251    tolerance: SimdTolerance,
252    before_min: f64,
253    after_min: f64,
254}
255
256/// Aggregated result of the regression check across all runs.
257#[derive(Debug, Serialize)]
258struct CheckResult {
259    checks: Vec<Comparison>,
260}
261
262impl std::fmt::Display for CheckResult {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        let header = [
265            "Distance",
266            "Dim",
267            "Min Before (ns)",
268            "Min After (ns)",
269            "Change (%)",
270            "Remark",
271        ];
272
273        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
274
275        for (i, c) in self.checks.iter().enumerate() {
276            let mut row = table.row(i);
277            let change = relative_change(c.before_min, c.after_min);
278
279            row.insert(c.run.distance, 0);
280            row.insert(c.run.dim, 1);
281            row.insert(format!("{:.3}", c.before_min), 2);
282            row.insert(format!("{:.3}", c.after_min), 3);
283            match change {
284                Ok(change) => {
285                    row.insert(format!("{:.3} %", change * 100.0), 4);
286                    if change > c.tolerance.min_time_regression.get() {
287                        row.insert("FAIL", 5);
288                    }
289                }
290                Err(err) => {
291                    row.insert("invalid", 4);
292                    row.insert(err, 5);
293                }
294            }
295        }
296
297        table.fmt(f)
298    }
299}
300
301////////////////////////////
302// Benchmark Registration //
303////////////////////////////
304
305fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
306    // x86-64-v4
307    #[cfg(target_arch = "x86_64")]
308    {
309        dispatcher.register_regression(
310            "simd-op-f32xf32-x86_64_V4",
311            Kernel::<diskann_wide::arch::x86_64::V4, f32, f32>::new(),
312        );
313        dispatcher.register_regression(
314            "simd-op-f16xf16-x86_64_V4",
315            Kernel::<diskann_wide::arch::x86_64::V4, f16, f16>::new(),
316        );
317        dispatcher.register_regression(
318            "simd-op-u8xu8-x86_64_V4",
319            Kernel::<diskann_wide::arch::x86_64::V4, u8, u8>::new(),
320        );
321        dispatcher.register_regression(
322            "simd-op-i8xi8-x86_64_V4",
323            Kernel::<diskann_wide::arch::x86_64::V4, i8, i8>::new(),
324        );
325    }
326
327    // x86-64-v3
328    #[cfg(target_arch = "x86_64")]
329    {
330        dispatcher.register_regression(
331            "simd-op-f32xf32-x86_64_V3",
332            Kernel::<diskann_wide::arch::x86_64::V3, f32, f32>::new(),
333        );
334        dispatcher.register_regression(
335            "simd-op-f16xf16-x86_64_V3",
336            Kernel::<diskann_wide::arch::x86_64::V3, f16, f16>::new(),
337        );
338        dispatcher.register_regression(
339            "simd-op-u8xu8-x86_64_V3",
340            Kernel::<diskann_wide::arch::x86_64::V3, u8, u8>::new(),
341        );
342        dispatcher.register_regression(
343            "simd-op-i8xi8-x86_64_V3",
344            Kernel::<diskann_wide::arch::x86_64::V3, i8, i8>::new(),
345        );
346    }
347
348    // aarch64-neon
349    #[cfg(target_arch = "aarch64")]
350    {
351        dispatcher.register_regression(
352            "simd-op-f32xf32-aarch64_neon",
353            Kernel::<diskann_wide::arch::aarch64::Neon, f32, f32>::new(),
354        );
355        dispatcher.register_regression(
356            "simd-op-f16xf16-aarch64_neon",
357            Kernel::<diskann_wide::arch::aarch64::Neon, f16, f16>::new(),
358        );
359        dispatcher.register_regression(
360            "simd-op-u8xu8-aarch64_neon",
361            Kernel::<diskann_wide::arch::aarch64::Neon, u8, u8>::new(),
362        );
363        dispatcher.register_regression(
364            "simd-op-i8xi8-aarch64_neon",
365            Kernel::<diskann_wide::arch::aarch64::Neon, i8, i8>::new(),
366        );
367    }
368
369    // scalar
370    dispatcher.register_regression(
371        "simd-op-f32xf32-scalar",
372        Kernel::<diskann_wide::arch::Scalar, f32, f32>::new(),
373    );
374    dispatcher.register_regression(
375        "simd-op-f16xf16-scalar",
376        Kernel::<diskann_wide::arch::Scalar, f16, f16>::new(),
377    );
378    dispatcher.register_regression(
379        "simd-op-u8xu8-scalar",
380        Kernel::<diskann_wide::arch::Scalar, u8, u8>::new(),
381    );
382    dispatcher.register_regression(
383        "simd-op-i8xi8-scalar",
384        Kernel::<diskann_wide::arch::Scalar, i8, i8>::new(),
385    );
386
387    // reference
388    dispatcher.register_regression(
389        "simd-op-f32xf32-reference",
390        Kernel::<Reference, f32, f32>::new(),
391    );
392    dispatcher.register_regression(
393        "simd-op-f16xf16-reference",
394        Kernel::<Reference, f16, f16>::new(),
395    );
396    dispatcher.register_regression(
397        "simd-op-u8xu8-reference",
398        Kernel::<Reference, u8, u8>::new(),
399    );
400    dispatcher.register_regression(
401        "simd-op-i8xi8-reference",
402        Kernel::<Reference, i8, i8>::new(),
403    );
404}
405
406//////////////
407// Dispatch //
408//////////////
409
410/// Dispatch receiver for the reference implementations.
411struct Reference;
412
413/// A dispatch mapper for `wide` types.
414#[derive(Debug)]
415struct Identity<T>(T);
416
417struct Kernel<A, Q, D> {
418    _type: std::marker::PhantomData<(A, Q, D)>,
419}
420
421impl<A, Q, D> Kernel<A, Q, D> {
422    fn new() -> Self {
423        Self {
424            _type: std::marker::PhantomData,
425        }
426    }
427}
428
429// Map Architectures to the enum.
430#[derive(Debug, Error)]
431#[error("architecture {0} is not supported by this CPU")]
432pub(crate) struct ArchNotSupported(Arch);
433
434impl DispatchRule<Arch> for Identity<Reference> {
435    type Error = ArchNotSupported;
436
437    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
438        if *from == Arch::Reference {
439            Ok(MatchScore(0))
440        } else {
441            Err(FailureScore(1))
442        }
443    }
444
445    fn convert(from: Arch) -> Result<Self, Self::Error> {
446        assert_eq!(from, Arch::Reference);
447        Ok(Identity(Reference))
448    }
449
450    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
451        match from {
452            None => write!(f, "loop based"),
453            Some(arch) => {
454                if Self::try_match(arch).is_ok() {
455                    write!(f, "matched {}", arch)
456                } else {
457                    write!(f, "expected {}, got {}", Arch::Reference, arch)
458                }
459            }
460        }
461    }
462}
463
464impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
465    type Error = ArchNotSupported;
466
467    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
468        if *from == Arch::Scalar {
469            Ok(MatchScore(0))
470        } else {
471            Err(FailureScore(1))
472        }
473    }
474
475    fn convert(from: Arch) -> Result<Self, Self::Error> {
476        assert_eq!(from, Arch::Scalar);
477        Ok(Identity(diskann_wide::arch::Scalar))
478    }
479
480    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
481        match from {
482            None => write!(f, "scalar (compilation target CPU)"),
483            Some(arch) => {
484                if Self::try_match(arch).is_ok() {
485                    write!(f, "matched {}", arch)
486                } else {
487                    write!(f, "expected {}, got {}", Arch::Scalar, arch)
488                }
489            }
490        }
491    }
492}
493
494macro_rules! match_arch {
495    ($target_arch:literal, $arch:path, $enum:ident) => {
496        #[cfg(target_arch = $target_arch)]
497        impl DispatchRule<Arch> for Identity<$arch> {
498            type Error = ArchNotSupported;
499
500            fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
501                let available = <$arch>::new_checked().is_some();
502                if available && *from == Arch::$enum {
503                    Ok(MatchScore(0))
504                } else if !available && *from == Arch::$enum {
505                    Err(FailureScore(0))
506                } else {
507                    Err(FailureScore(1))
508                }
509            }
510
511            fn convert(from: Arch) -> Result<Self, Self::Error> {
512                assert_eq!(from, Arch::$enum);
513                <$arch>::new_checked()
514                    .ok_or(ArchNotSupported(from))
515                    .map(Identity)
516            }
517
518            fn description(
519                f: &mut std::fmt::Formatter<'_>,
520                from: Option<&Arch>,
521            ) -> std::fmt::Result {
522                let available = <$arch>::new_checked().is_some();
523                match from {
524                    None => write!(f, "{}", Arch::$enum),
525                    Some(arch) => {
526                        if Self::try_match(arch).is_ok() {
527                            write!(f, "matched {}", arch)
528                        } else if !available && *arch == Arch::$enum {
529                            write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
530                        } else {
531                            write!(f, "expected {}, got {}", Arch::$enum, arch)
532                        }
533                    }
534                }
535            }
536        }
537    };
538}
539
540match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
541match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
542match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
543
544impl<A, Q, D> Benchmark for Kernel<A, Q, D>
545where
546    datatype::Type<Q>: DispatchRule<datatype::DataType>,
547    datatype::Type<D>: DispatchRule<datatype::DataType>,
548    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
549    Kernel<A, Q, D>: RunBenchmark<A>,
550    A: 'static,
551    Q: 'static,
552    D: 'static,
553{
554    type Input = SimdOp;
555    type Output = Vec<RunResult>;
556
557    // Matching simply requires that we match the inner type.
558    fn try_match(&self, from: &SimdOp) -> Result<MatchScore, FailureScore> {
559        let mut failscore: Option<u32> = None;
560        if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
561            *failscore.get_or_insert(0) += 10;
562        }
563        if datatype::Type::<D>::try_match(&from.data_type).is_err() {
564            *failscore.get_or_insert(0) += 10;
565        }
566        if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
567            *failscore.get_or_insert(0) += 2 + score;
568        }
569
570        match failscore {
571            None => Ok(MatchScore(0)),
572            Some(score) => Err(FailureScore(score)),
573        }
574    }
575
576    fn run(
577        &self,
578        input: &SimdOp,
579        _: diskann_benchmark_runner::Checkpoint<'_>,
580        mut output: &mut dyn diskann_benchmark_runner::Output,
581    ) -> anyhow::Result<Self::Output> {
582        let arch = Identity::<A>::convert(input.arch)?.0;
583        writeln!(output, "{}", input)?;
584        let results = self.run_benchmark(input, arch)?;
585        writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
586        Ok(results)
587    }
588
589    fn description(
590        &self,
591        f: &mut std::fmt::Formatter<'_>,
592        input: Option<&SimdOp>,
593    ) -> std::fmt::Result {
594        match input {
595            None => {
596                writeln!(
597                    f,
598                    "- Query Type: {}",
599                    Description::<datatype::DataType, datatype::Type<Q>>::new()
600                )?;
601                writeln!(
602                    f,
603                    "- Data Type: {}",
604                    Description::<datatype::DataType, datatype::Type<D>>::new()
605                )?;
606                writeln!(
607                    f,
608                    "- Implementation: {}",
609                    Description::<Arch, Identity<A>>::new()
610                )?;
611            }
612            Some(input) => {
613                if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
614                    writeln!(f, "\n    - Mismatched query type: {}", err)?;
615                }
616                if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
617                    writeln!(f, "\n    - Mismatched data type: {}", err)?;
618                }
619                if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
620                    writeln!(f, "\n    - Mismatched architecture: {}", err)?;
621                }
622            }
623        }
624        Ok(())
625    }
626}
627
628impl<A, Q, D> Regression for Kernel<A, Q, D>
629where
630    datatype::Type<Q>: DispatchRule<datatype::DataType>,
631    datatype::Type<D>: DispatchRule<datatype::DataType>,
632    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
633    Kernel<A, Q, D>: RunBenchmark<A>,
634    A: 'static,
635    Q: 'static,
636    D: 'static,
637{
638    type Tolerances = SimdTolerance;
639    type Pass = CheckResult;
640    type Fail = CheckResult;
641
642    fn check(
643        &self,
644        tolerance: &SimdTolerance,
645        _input: &SimdOp,
646        before: &Vec<RunResult>,
647        after: &Vec<RunResult>,
648    ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
649        anyhow::ensure!(
650            before.len() == after.len(),
651            "before has {} runs but after has {}",
652            before.len(),
653            after.len(),
654        );
655
656        let mut passed = true;
657        let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
658            .enumerate()
659            .map(|(i, (b, a))| {
660                anyhow::ensure!(b.run == a.run, "run {i} mismatched");
661
662                let computations_per_latency = b.computations_per_latency() as f64;
663
664                let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
665                let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
666
667                let comparison = Comparison {
668                    run: b.run.clone(),
669                    tolerance: *tolerance,
670                    before_min,
671                    after_min,
672                };
673
674                // Determine whether or not we pass.
675                match relative_change(before_min, after_min) {
676                    Ok(change) => {
677                        if change > tolerance.min_time_regression.get() {
678                            passed = false;
679                        }
680                    }
681                    Err(_) => passed = false,
682                };
683
684                Ok(comparison)
685            })
686            .collect::<anyhow::Result<Vec<Comparison>>>()?;
687
688        let check = CheckResult { checks };
689
690        if passed {
691            Ok(PassFail::Pass(check))
692        } else {
693            Ok(PassFail::Fail(check))
694        }
695    }
696}
697
698///////////////
699// Benchmark //
700///////////////
701
702trait RunBenchmark<A> {
703    fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result<Vec<RunResult>, anyhow::Error>;
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707struct RunResult {
708    /// The configuration for this run.
709    run: Run,
710    /// The latencies of individual runs.
711    latencies: Vec<MicroSeconds>,
712    /// Latency percentiles.
713    percentiles: percentiles::Percentiles<MicroSeconds>,
714}
715
716impl RunResult {
717    fn computations_per_latency(&self) -> usize {
718        self.run.num_points.get() * self.run.loops_per_measurement.get()
719    }
720}
721
722impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
723    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
724        if self.is_empty() {
725            return Ok(());
726        }
727
728        let header = [
729            "Distance",
730            "Dim",
731            "Min Time (ns)",
732            "Mean Time (ns)",
733            "Points",
734            "Loops",
735            "Measurements",
736        ];
737
738        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
739
740        self.iter().enumerate().for_each(|(row, r)| {
741            let mut row = table.row(row);
742
743            let min_latency = r
744                .latencies
745                .iter()
746                .min()
747                .copied()
748                .unwrap_or(MicroSeconds::new(u64::MAX));
749            let mean_latency = r.percentiles.mean;
750
751            let computations_per_latency = r.computations_per_latency() as f64;
752
753            // Convert time from micro-seconds to nano-seconds.
754            let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
755            let mean_time = mean_latency / computations_per_latency * 1000.0;
756
757            row.insert(r.run.distance, 0);
758            row.insert(r.run.dim, 1);
759            row.insert(format!("{:.3}", min_time), 2);
760            row.insert(format!("{:.3}", mean_time), 3);
761            row.insert(r.run.num_points, 4);
762            row.insert(r.run.loops_per_measurement, 5);
763            row.insert(r.run.num_measurements, 6);
764        });
765
766        table.fmt(f)
767    }
768}
769
770fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
771where
772    F: Fn(&[Q], &[D]) -> f32,
773{
774    let mut latencies = Vec::with_capacity(run.num_measurements.get());
775    let mut dst = vec![0.0; data.nrows()];
776
777    for _ in 0..run.num_measurements.get() {
778        let start = std::time::Instant::now();
779        for _ in 0..run.loops_per_measurement.get() {
780            std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
781                *d = f(query, r);
782            });
783            std::hint::black_box(&mut dst);
784        }
785        latencies.push(start.elapsed().into());
786    }
787
788    let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
789    RunResult {
790        run: run.clone(),
791        latencies,
792        percentiles,
793    }
794}
795
796struct Data<Q, D> {
797    query: Box<[Q]>,
798    data: Matrix<D>,
799}
800
801impl<Q, D> Data<Q, D> {
802    fn new(run: &Run) -> Self
803    where
804        StandardUniform: Distribution<Q>,
805        StandardUniform: Distribution<D>,
806    {
807        let mut rng = StdRng::seed_from_u64(0x12345);
808        let query: Box<[Q]> = (0..run.dim.get())
809            .map(|_| StandardUniform.sample(&mut rng))
810            .collect();
811        let data = Matrix::<D>::new(
812            diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
813            run.num_points.get(),
814            run.dim.get(),
815        );
816
817        Self { query, data }
818    }
819
820    fn run<F>(&self, run: &Run, f: F) -> RunResult
821    where
822        F: Fn(&[Q], &[D]) -> f32,
823    {
824        run_loops(&self.query, self.data.as_view(), run, f)
825    }
826}
827
828/////////////////////
829// Implementations //
830/////////////////////
831
832macro_rules! stamp {
833    (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
834        impl RunBenchmark<Reference> for Kernel<Reference, $Q, $D> {
835            fn run_benchmark(
836                &self,
837                input: &SimdOp,
838                _arch: Reference,
839            ) -> Result<Vec<RunResult>, anyhow::Error> {
840                let mut results = Vec::new();
841                for run in input.runs.iter() {
842                    let data = Data::<$Q, $D>::new(run);
843                    let result = match run.distance {
844                        SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
845                        SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
846                        SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
847                    };
848                    results.push(result);
849                }
850                Ok(results)
851            }
852        }
853    };
854    ($arch:path, $Q:ty, $D:ty) => {
855        impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> {
856            fn run_benchmark(
857                &self,
858                input: &SimdOp,
859                arch: $arch,
860            ) -> Result<Vec<RunResult>, anyhow::Error> {
861                let mut results = Vec::new();
862
863                let l2 = &simd::L2 {};
864                let ip = &simd::IP {};
865                let cosine = &simd::CosineStateless {};
866
867                for run in input.runs.iter() {
868                    let data = Data::<$Q, $D>::new(run);
869                    // For each kernel, we need to do a two-step wrapping of closures so
870                    // the inner-most closure is executed by the architecture.
871                    //
872                    // This is required for the implementation of `simd_op` to be inlined
873                    // into the architecture run function so it can properly inherit
874                    // target features.
875                    let result = match run.distance {
876                        SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
877                            arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d)
878                        }),
879                        SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
880                            arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d)
881                        }),
882                        SimilarityMeasure::Cosine => data.run(run, |q, d| {
883                            arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d)
884                        }),
885                    };
886                    results.push(result)
887                }
888                Ok(results)
889            }
890        }
891    };
892    ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
893        #[cfg(target_arch = $target_arch)]
894        stamp!($arch, $Q, $D);
895    };
896}
897
898stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
899stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
900stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
901stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
902
903stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
904stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
905stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
906stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
907
908stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
909stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
910stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
911stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
912
913stamp!(diskann_wide::arch::Scalar, f32, f32);
914stamp!(diskann_wide::arch::Scalar, f16, f16);
915stamp!(diskann_wide::arch::Scalar, u8, u8);
916stamp!(diskann_wide::arch::Scalar, i8, i8);
917
918stamp!(
919    reference,
920    f32,
921    f32,
922    squared_l2_f32,
923    inner_product_f32,
924    cosine_f32
925);
926stamp!(
927    reference,
928    f16,
929    f16,
930    squared_l2_f16,
931    inner_product_f16,
932    cosine_f16
933);
934stamp!(
935    reference,
936    u8,
937    u8,
938    squared_l2_u8,
939    inner_product_u8,
940    cosine_u8
941);
942stamp!(
943    reference,
944    i8,
945    i8,
946    squared_l2_i8,
947    inner_product_i8,
948    cosine_i8
949);
950
951///////////////
952// Reference //
953///////////////
954
955// These are largely copied from the implementations in vector, with a tweak that we don't
956// use FMA when the current architecture is scalar.
957mod reference {
958    use diskann_wide::ARCH;
959    use half::f16;
960
961    trait MaybeFMA {
962        // Perform `a*b + c` using FMA when a hardware instruction is guaranteed to be
963        // available, otherwise decompose into a multiply and add.
964        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
965    }
966
967    impl MaybeFMA for diskann_wide::arch::Scalar {
968        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
969            a * b + c
970        }
971    }
972
973    #[cfg(target_arch = "x86_64")]
974    impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
975        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
976            a.mul_add(b, c)
977        }
978    }
979
980    #[cfg(target_arch = "x86_64")]
981    impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
982        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
983            a.mul_add(b, c)
984        }
985    }
986
987    #[cfg(target_arch = "aarch64")]
988    impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
989        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
990            a.mul_add(b, c)
991        }
992    }
993
994    //------------//
995    // Squared L2 //
996    //------------//
997
998    pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
999        assert_eq!(x.len(), y.len());
1000        std::iter::zip(x.iter(), y.iter())
1001            .map(|(&a, &b)| {
1002                let a: i32 = a.into();
1003                let b: i32 = b.into();
1004                let diff = a - b;
1005                diff * diff
1006            })
1007            .sum::<i32>() as f32
1008    }
1009
1010    pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
1011        assert_eq!(x.len(), y.len());
1012        std::iter::zip(x.iter(), y.iter())
1013            .map(|(&a, &b)| {
1014                let a: i32 = a.into();
1015                let b: i32 = b.into();
1016                let diff = a - b;
1017                diff * diff
1018            })
1019            .sum::<i32>() as f32
1020    }
1021
1022    pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
1023        assert_eq!(x.len(), y.len());
1024        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1025            let a: f32 = a.into();
1026            let b: f32 = b.into();
1027            let diff = a - b;
1028            ARCH.maybe_fma(diff, diff, acc)
1029        })
1030    }
1031
1032    pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1033        assert_eq!(x.len(), y.len());
1034        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1035            let diff = a - b;
1036            ARCH.maybe_fma(diff, diff, acc)
1037        })
1038    }
1039
1040    //---------------//
1041    // Inner Product //
1042    //---------------//
1043
1044    pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1045        assert_eq!(x.len(), y.len());
1046        std::iter::zip(x.iter(), y.iter())
1047            .map(|(&a, &b)| {
1048                let a: i32 = a.into();
1049                let b: i32 = b.into();
1050                a * b
1051            })
1052            .sum::<i32>() as f32
1053    }
1054
1055    pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1056        assert_eq!(x.len(), y.len());
1057        std::iter::zip(x.iter(), y.iter())
1058            .map(|(&a, &b)| {
1059                let a: i32 = a.into();
1060                let b: i32 = b.into();
1061                a * b
1062            })
1063            .sum::<i32>() as f32
1064    }
1065
1066    pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1067        assert_eq!(x.len(), y.len());
1068        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1069            let a: f32 = a.into();
1070            let b: f32 = b.into();
1071            ARCH.maybe_fma(a, b, acc)
1072        })
1073    }
1074
1075    pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1076        assert_eq!(x.len(), y.len());
1077        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1078    }
1079
1080    //--------//
1081    // Cosine //
1082    //--------//
1083
1084    #[derive(Default)]
1085    struct XY<T> {
1086        xnorm: T,
1087        ynorm: T,
1088        xy: T,
1089    }
1090
1091    pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1092        assert_eq!(x.len(), y.len());
1093        let r: XY<i32> =
1094            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1095                let vx: i32 = vx.into();
1096                let vy: i32 = vy.into();
1097                XY {
1098                    xnorm: acc.xnorm + vx * vx,
1099                    ynorm: acc.ynorm + vy * vy,
1100                    xy: acc.xy + vx * vy,
1101                }
1102            });
1103
1104        if r.xnorm == 0 || r.ynorm == 0 {
1105            return 0.0;
1106        }
1107
1108        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1109    }
1110
1111    pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1112        assert_eq!(x.len(), y.len());
1113        let r: XY<i32> =
1114            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1115                let vx: i32 = vx.into();
1116                let vy: i32 = vy.into();
1117                XY {
1118                    xnorm: acc.xnorm + vx * vx,
1119                    ynorm: acc.ynorm + vy * vy,
1120                    xy: acc.xy + vx * vy,
1121                }
1122            });
1123
1124        if r.xnorm == 0 || r.ynorm == 0 {
1125            return 0.0;
1126        }
1127
1128        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1129    }
1130
1131    pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1132        assert_eq!(x.len(), y.len());
1133        let r: XY<f32> =
1134            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1135                let vx: f32 = vx.into();
1136                let vy: f32 = vy.into();
1137                XY {
1138                    xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1139                    ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1140                    xy: ARCH.maybe_fma(vx, vy, acc.xy),
1141                }
1142            });
1143
1144        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1145            return 0.0;
1146        }
1147
1148        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1149    }
1150
1151    pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1152        assert_eq!(x.len(), y.len());
1153        let r: XY<f32> =
1154            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1155                xnorm: vx.mul_add(vx, acc.xnorm),
1156                ynorm: vy.mul_add(vy, acc.ynorm),
1157                xy: vx.mul_add(vy, acc.xy),
1158            });
1159
1160        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1161            return 0.0;
1162        }
1163
1164        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1165    }
1166}
1167
1168///////////
1169// Tests //
1170///////////
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::*;
1175
1176    use diskann_benchmark_runner::{
1177        benchmark::{PassFail, Regression},
1178        utils::percentiles::compute_percentiles,
1179    };
1180
1181    fn tiny_run(distance: SimilarityMeasure) -> Run {
1182        Run {
1183            distance,
1184            dim: NonZeroUsize::new(8).unwrap(),
1185            num_points: NonZeroUsize::new(1).unwrap(),
1186            loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1187            num_measurements: NonZeroUsize::new(1).unwrap(),
1188        }
1189    }
1190
1191    fn tiny_op() -> SimdOp {
1192        SimdOp {
1193            query_type: DataType::Float32,
1194            data_type: DataType::Float32,
1195            arch: Arch::Scalar,
1196            runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1197        }
1198    }
1199
1200    fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1201        let run = tiny_run(distance);
1202        let minimum = MicroSeconds::new(minimum);
1203        let mut latencies = vec![minimum];
1204        let percentiles = compute_percentiles(&mut latencies).unwrap();
1205        RunResult {
1206            run,
1207            latencies,
1208            percentiles,
1209        }
1210    }
1211
1212    fn tolerance(limit: f64) -> SimdTolerance {
1213        SimdTolerance {
1214            min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1215        }
1216    }
1217
1218    #[test]
1219    fn check_rejects_mismatched_runs() {
1220        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1221
1222        let err = kernel
1223            .check(
1224                &tolerance(0.0),
1225                &tiny_op(),
1226                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1227                &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1228            )
1229            .unwrap_err();
1230
1231        assert_eq!(err.to_string(), "run 0 mismatched");
1232    }
1233
1234    #[test]
1235    fn check_allows_negative_relative_change() {
1236        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1237
1238        let result = kernel
1239            .check(
1240                &tolerance(0.0),
1241                &tiny_op(),
1242                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1243                &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1244            )
1245            .unwrap();
1246
1247        assert!(matches!(result, PassFail::Pass(_)));
1248    }
1249
1250    #[test]
1251    fn check_passes_on_tolerance_boundary() {
1252        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1253
1254        let result = kernel
1255            .check(
1256                &tolerance(0.05),
1257                &tiny_op(),
1258                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1259                &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1260            )
1261            .unwrap();
1262
1263        assert!(matches!(result, PassFail::Pass(_)));
1264    }
1265
1266    #[test]
1267    fn check_fails_above_tolerance_boundary() {
1268        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1269
1270        let result = kernel
1271            .check(
1272                &tolerance(0.05),
1273                &tiny_op(),
1274                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1275                &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1276            )
1277            .unwrap();
1278
1279        assert!(matches!(result, PassFail::Fail(_)));
1280    }
1281
1282    #[test]
1283    fn check_result_display_includes_failure_details() {
1284        let check = CheckResult {
1285            checks: vec![Comparison {
1286                run: tiny_run(SimilarityMeasure::SquaredL2),
1287                tolerance: tolerance(0.05),
1288                before_min: 100.0,
1289                after_min: 106.0,
1290            }],
1291        };
1292
1293        let rendered = check.to_string();
1294        assert!(rendered.contains("Distance"), "rendered = {rendered}");
1295        assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1296        assert!(rendered.contains("100.000"), "rendered = {rendered}");
1297        assert!(rendered.contains("106.000"), "rendered = {rendered}");
1298        assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1299        assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1300    }
1301
1302    // If a "before" value is 0, we should fail with an error because this means the
1303    // measurement was too fast for us to obtain a reliable signal, so we *could* be letting
1304    // a regression through.
1305    //
1306    // We require at least a non-zero value.
1307    #[test]
1308    fn zero_values_rejected() {
1309        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1310
1311        let result = kernel
1312            .check(
1313                &tolerance(0.05),
1314                &tiny_op(),
1315                &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1316                &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1317            )
1318            .unwrap();
1319
1320        assert!(matches!(result, PassFail::Fail(_)));
1321    }
1322}