Skip to main content

diskann_benchmark_simd/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! SIMD distance kernel benchmarks with regression detection.
7
8use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15    distr::{Distribution, StandardUniform},
16    rngs::StdRng,
17    SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23    benchmark::{FailureScore, MatchScore, PassFail, Regression},
24    utils::{
25        datatype::{AsDataType, DataType},
26        num::{relative_change, NonNegativeFinite},
27        percentiles, MicroSeconds,
28    },
29    Benchmark, Checker, Input, Registry,
30};
31
32////////////////
33// Public API //
34////////////////
35
36pub fn register(registry: &mut Registry) -> anyhow::Result<()> {
37    Ok(register_benchmarks_impl(registry)?)
38}
39
40///////////
41// Utils //
42///////////
43
44#[derive(Debug, Clone, Copy)]
45struct DisplayWrapper<'a, T: ?Sized>(&'a T);
46
47impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
48    type Target = T;
49    fn deref(&self) -> &T {
50        self.0
51    }
52}
53
54////////////
55// Inputs //
56////////////
57
58#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum SimilarityMeasure {
61    SquaredL2,
62    InnerProduct,
63    Cosine,
64}
65
66impl std::fmt::Display for SimilarityMeasure {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        let st = match self {
69            Self::SquaredL2 => "squared_l2",
70            Self::InnerProduct => "inner_product",
71            Self::Cosine => "cosine",
72        };
73        write!(f, "{}", st)
74    }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "kebab-case")]
79enum Arch {
80    #[serde(rename = "x86-64-v4")]
81    #[allow(non_camel_case_types)]
82    X86_64_V4,
83    #[serde(rename = "x86-64-v3")]
84    #[allow(non_camel_case_types)]
85    X86_64_V3,
86    Neon,
87    Scalar,
88    Reference,
89}
90
91impl std::fmt::Display for Arch {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let st = match self {
94            Self::X86_64_V4 => "x86-64-v4",
95            Self::X86_64_V3 => "x86-64-v3",
96            Self::Neon => "neon",
97            Self::Scalar => "scalar",
98            Self::Reference => "reference",
99        };
100        write!(f, "{}", st)
101    }
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105struct Run {
106    distance: SimilarityMeasure,
107    dim: NonZeroUsize,
108    num_points: NonZeroUsize,
109    loops_per_measurement: NonZeroUsize,
110    num_measurements: NonZeroUsize,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114pub struct SimdOp {
115    query_type: DataType,
116    data_type: DataType,
117    arch: Arch,
118    runs: Vec<Run>,
119}
120
121macro_rules! write_field {
122    ($f:ident, $field:tt, $($expr:tt)*) => {
123        writeln!($f, "{:>18}: {}", $field, $($expr)*)
124    }
125}
126
127impl SimdOp {
128    fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        write_field!(f, "query type", self.query_type)?;
130        write_field!(f, "data type", self.data_type)?;
131        write_field!(f, "arch", self.arch)?;
132        write_field!(f, "number of runs", self.runs.len())?;
133        Ok(())
134    }
135}
136
137impl std::fmt::Display for SimdOp {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        writeln!(f, "SIMD Operation\n")?;
140        write_field!(f, "tag", Self::tag())?;
141        self.summarize_fields(f)
142    }
143}
144
145impl Input for SimdOp {
146    type Raw = Self;
147
148    fn tag() -> &'static str {
149        "simd-op"
150    }
151
152    fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<Self> {
153        Ok(raw)
154    }
155
156    fn serialize(&self) -> anyhow::Result<serde_json::Value> {
157        Ok(serde_json::to_value(self)?)
158    }
159
160    fn example() -> Self::Raw {
161        const DIM: [NonZeroUsize; 2] = [
162            NonZeroUsize::new(128).unwrap(),
163            NonZeroUsize::new(150).unwrap(),
164        ];
165
166        const NUM_POINTS: [NonZeroUsize; 2] = [
167            NonZeroUsize::new(1000).unwrap(),
168            NonZeroUsize::new(800).unwrap(),
169        ];
170
171        const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
172        const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
173
174        let runs = vec![
175            Run {
176                distance: SimilarityMeasure::SquaredL2,
177                dim: DIM[0],
178                num_points: NUM_POINTS[0],
179                loops_per_measurement: LOOPS_PER_MEASUREMENT,
180                num_measurements: NUM_MEASUREMENTS,
181            },
182            Run {
183                distance: SimilarityMeasure::InnerProduct,
184                dim: DIM[1],
185                num_points: NUM_POINTS[1],
186                loops_per_measurement: LOOPS_PER_MEASUREMENT,
187                num_measurements: NUM_MEASUREMENTS,
188            },
189        ];
190
191        Self {
192            query_type: DataType::Float32,
193            data_type: DataType::Float32,
194            arch: Arch::X86_64_V3,
195            runs,
196        }
197    }
198}
199
200//////////////////////
201// Regression Check //
202//////////////////////
203
204/// Tolerance thresholds for SIMD benchmark regression detection.
205///
206/// Each field specifies the maximum allowed relative increase in the corresponding metric.
207/// For example, a value of `0.10` means a 10% increase is tolerated.
208#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
209struct SimdTolerance {
210    min_time_regression: NonNegativeFinite,
211}
212
213impl Input for SimdTolerance {
214    type Raw = Self;
215
216    fn tag() -> &'static str {
217        "simd-tolerance"
218    }
219
220    fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<Self> {
221        Ok(raw)
222    }
223
224    fn serialize(&self) -> anyhow::Result<serde_json::Value> {
225        Ok(serde_json::to_value(self)?)
226    }
227
228    fn example() -> Self {
229        const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
230            Ok(v) => v,
231            Err(_) => panic!("use a non-negative finite please"),
232        };
233
234        SimdTolerance {
235            min_time_regression: EXAMPLE,
236        }
237    }
238}
239
240/// Per-run comparison result showing before/after percentile differences.
241#[derive(Debug, Serialize)]
242struct Comparison {
243    run: Run,
244    tolerance: SimdTolerance,
245    before_min: f64,
246    after_min: f64,
247}
248
249/// Aggregated result of the regression check across all runs.
250#[derive(Debug, Serialize)]
251struct CheckResult {
252    checks: Vec<Comparison>,
253}
254
255impl std::fmt::Display for CheckResult {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        let header = [
258            "Distance",
259            "Dim",
260            "Min Before (ns)",
261            "Min After (ns)",
262            "Change (%)",
263            "Remark",
264        ];
265
266        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
267
268        for (i, c) in self.checks.iter().enumerate() {
269            let mut row = table.row(i);
270            let change = relative_change(c.before_min, c.after_min);
271
272            row.insert(c.run.distance, 0);
273            row.insert(c.run.dim, 1);
274            row.insert(format!("{:.3}", c.before_min), 2);
275            row.insert(format!("{:.3}", c.after_min), 3);
276            match change {
277                Ok(change) => {
278                    row.insert(format!("{:.3} %", change * 100.0), 4);
279                    if change > c.tolerance.min_time_regression.get() {
280                        row.insert("FAIL", 5);
281                    }
282                }
283                Err(err) => {
284                    row.insert("invalid", 4);
285                    row.insert(err, 5);
286                }
287            }
288        }
289
290        table.fmt(f)
291    }
292}
293
294////////////////////////////
295// Benchmark Registration //
296////////////////////////////
297
298fn register_benchmarks_impl(
299    registry: &mut diskann_benchmark_runner::Registry,
300) -> Result<(), diskann_benchmark_runner::RegistryError> {
301    // x86-64-v4
302    #[cfg(target_arch = "x86_64")]
303    {
304        registry.register_regression(
305            "simd-op-f32xf32-x86_64_V4",
306            Kernel::<diskann_wide::arch::x86_64::V4, f32, f32>::new(),
307        )?;
308        registry.register_regression(
309            "simd-op-f16xf16-x86_64_V4",
310            Kernel::<diskann_wide::arch::x86_64::V4, f16, f16>::new(),
311        )?;
312        registry.register_regression(
313            "simd-op-u8xu8-x86_64_V4",
314            Kernel::<diskann_wide::arch::x86_64::V4, u8, u8>::new(),
315        )?;
316        registry.register_regression(
317            "simd-op-i8xi8-x86_64_V4",
318            Kernel::<diskann_wide::arch::x86_64::V4, i8, i8>::new(),
319        )?;
320    }
321
322    // x86-64-v3
323    #[cfg(target_arch = "x86_64")]
324    {
325        registry.register_regression(
326            "simd-op-f32xf32-x86_64_V3",
327            Kernel::<diskann_wide::arch::x86_64::V3, f32, f32>::new(),
328        )?;
329        registry.register_regression(
330            "simd-op-f16xf16-x86_64_V3",
331            Kernel::<diskann_wide::arch::x86_64::V3, f16, f16>::new(),
332        )?;
333        registry.register_regression(
334            "simd-op-u8xu8-x86_64_V3",
335            Kernel::<diskann_wide::arch::x86_64::V3, u8, u8>::new(),
336        )?;
337        registry.register_regression(
338            "simd-op-i8xi8-x86_64_V3",
339            Kernel::<diskann_wide::arch::x86_64::V3, i8, i8>::new(),
340        )?;
341    }
342
343    // aarch64-neon
344    #[cfg(target_arch = "aarch64")]
345    {
346        registry.register_regression(
347            "simd-op-f32xf32-aarch64_neon",
348            Kernel::<diskann_wide::arch::aarch64::Neon, f32, f32>::new(),
349        )?;
350        registry.register_regression(
351            "simd-op-f16xf16-aarch64_neon",
352            Kernel::<diskann_wide::arch::aarch64::Neon, f16, f16>::new(),
353        )?;
354        registry.register_regression(
355            "simd-op-u8xu8-aarch64_neon",
356            Kernel::<diskann_wide::arch::aarch64::Neon, u8, u8>::new(),
357        )?;
358        registry.register_regression(
359            "simd-op-i8xi8-aarch64_neon",
360            Kernel::<diskann_wide::arch::aarch64::Neon, i8, i8>::new(),
361        )?;
362    }
363
364    // scalar
365    registry.register_regression(
366        "simd-op-f32xf32-scalar",
367        Kernel::<diskann_wide::arch::Scalar, f32, f32>::new(),
368    )?;
369    registry.register_regression(
370        "simd-op-f16xf16-scalar",
371        Kernel::<diskann_wide::arch::Scalar, f16, f16>::new(),
372    )?;
373    registry.register_regression(
374        "simd-op-u8xu8-scalar",
375        Kernel::<diskann_wide::arch::Scalar, u8, u8>::new(),
376    )?;
377    registry.register_regression(
378        "simd-op-i8xi8-scalar",
379        Kernel::<diskann_wide::arch::Scalar, i8, i8>::new(),
380    )?;
381
382    // reference
383    registry.register_regression(
384        "simd-op-f32xf32-reference",
385        Kernel::<Reference, f32, f32>::new(),
386    )?;
387    registry.register_regression(
388        "simd-op-f16xf16-reference",
389        Kernel::<Reference, f16, f16>::new(),
390    )?;
391    registry.register_regression(
392        "simd-op-u8xu8-reference",
393        Kernel::<Reference, u8, u8>::new(),
394    )?;
395    registry.register_regression(
396        "simd-op-i8xi8-reference",
397        Kernel::<Reference, i8, i8>::new(),
398    )?;
399    Ok(())
400}
401
402//////////////
403// Dispatch //
404//////////////
405
406/// Dispatch receiver for the reference implementations.
407struct Reference;
408
409struct Kernel<A, Q, D> {
410    _type: std::marker::PhantomData<(A, Q, D)>,
411}
412
413impl<A, Q, D> Kernel<A, Q, D> {
414    fn new() -> Self {
415        Self {
416            _type: std::marker::PhantomData,
417        }
418    }
419}
420
421#[derive(Debug, Error)]
422#[error("architecture {0} is not supported by this CPU")]
423pub(crate) struct ArchNotSupported(Arch);
424
425/// Lifting architecture enum variants into the Rust type domain.
426trait AsArch: Sized + 'static {
427    const ARCH: Arch;
428    const DISPLAY_NAME: &'static str;
429
430    fn is_available() -> bool {
431        true
432    }
433
434    fn try_new() -> Result<Self, ArchNotSupported>;
435
436    fn is_match(arch: Arch) -> bool {
437        arch == Self::ARCH && Self::is_available()
438    }
439
440    fn describe(arch: Arch) -> ArchDescribe {
441        if arch != Self::ARCH {
442            ArchDescribe::Mismatch {
443                expected: Self::ARCH,
444                got: arch,
445            }
446        } else if !Self::is_available() {
447            ArchDescribe::Unsupported(Self::ARCH)
448        } else {
449            ArchDescribe::Match(arch)
450        }
451    }
452}
453
454#[derive(Debug, Clone, Copy)]
455enum ArchDescribe {
456    Match(Arch),
457    Unsupported(Arch),
458    Mismatch { expected: Arch, got: Arch },
459}
460
461impl ArchDescribe {
462    fn is_match(&self) -> bool {
463        matches!(self, ArchDescribe::Match(_))
464    }
465}
466
467impl std::fmt::Display for ArchDescribe {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match self {
470            Self::Match(arch) => write!(f, "matched {}", arch),
471            Self::Unsupported(arch) => {
472                write!(f, "matched {} but unsupported by this CPU", arch)
473            }
474            Self::Mismatch { expected, got } => write!(f, "expected {}, got {}", expected, got),
475        }
476    }
477}
478
479impl AsArch for Reference {
480    const ARCH: Arch = Arch::Reference;
481    const DISPLAY_NAME: &'static str = "loop based";
482
483    fn try_new() -> Result<Self, ArchNotSupported> {
484        Ok(Reference)
485    }
486}
487
488impl AsArch for diskann_wide::arch::Scalar {
489    const ARCH: Arch = Arch::Scalar;
490    const DISPLAY_NAME: &'static str = "scalar (compilation target CPU)";
491
492    fn try_new() -> Result<Self, ArchNotSupported> {
493        Ok(diskann_wide::arch::Scalar)
494    }
495}
496
497macro_rules! match_arch {
498    ($target_arch:literal, $arch:path, $enum:ident) => {
499        #[cfg(target_arch = $target_arch)]
500        impl AsArch for $arch {
501            const ARCH: Arch = Arch::$enum;
502            const DISPLAY_NAME: &'static str = stringify!($enum);
503
504            fn is_available() -> bool {
505                <$arch>::new_checked().is_some()
506            }
507
508            fn try_new() -> Result<Self, ArchNotSupported> {
509                <$arch>::new_checked().ok_or(ArchNotSupported(Arch::$enum))
510            }
511        }
512    };
513}
514
515match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
516match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
517match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
518
519impl<A, Q, D> Benchmark for Kernel<A, Q, D>
520where
521    Q: AsDataType,
522    D: AsDataType,
523    A: AsArch,
524    Kernel<A, Q, D>: RunBenchmark<A>,
525{
526    type Input = SimdOp;
527    type Output = Vec<RunResult>;
528
529    // Matching simply requires that we match the inner type.
530    fn try_match(&self, from: &SimdOp) -> Result<MatchScore, FailureScore> {
531        let mut failscore: Option<u32> = None;
532        if !Q::is_match(from.query_type) {
533            *failscore.get_or_insert(0) += 10;
534        }
535        if !D::is_match(from.data_type) {
536            *failscore.get_or_insert(0) += 10;
537        }
538        if !A::is_match(from.arch) {
539            let penalty = if from.arch == A::ARCH { 2 } else { 3 };
540            *failscore.get_or_insert(0) += penalty;
541        }
542
543        match failscore {
544            None => Ok(MatchScore(0)),
545            Some(score) => Err(FailureScore(score)),
546        }
547    }
548
549    fn run(
550        &self,
551        input: &SimdOp,
552        _: diskann_benchmark_runner::Checkpoint<'_>,
553        mut output: &mut dyn diskann_benchmark_runner::Output,
554    ) -> anyhow::Result<Self::Output> {
555        if input.arch != A::ARCH {
556            anyhow::bail!(
557                "architecture mismatch: input requested {:?}, but kernel implementation requires {:?}",
558                input.arch,
559                A::ARCH
560            );
561        }
562
563        let arch = A::try_new()?;
564        writeln!(output, "{}", input)?;
565        let results = self.run_benchmark(input, arch)?;
566        writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
567        Ok(results)
568    }
569
570    fn description(
571        &self,
572        f: &mut std::fmt::Formatter<'_>,
573        input: Option<&SimdOp>,
574    ) -> std::fmt::Result {
575        match input {
576            None => {
577                writeln!(f, "- Query Type: {}", Q::DATA_TYPE)?;
578                writeln!(f, "- Data Type: {}", D::DATA_TYPE)?;
579                writeln!(f, "- Implementation: {}", A::DISPLAY_NAME)?;
580            }
581            Some(input) => {
582                let desc = Q::describe(input.query_type);
583                if !desc.is_match() {
584                    writeln!(f, "\n    - Mismatched query type: {}", desc)?;
585                }
586
587                let desc = D::describe(input.data_type);
588                if !desc.is_match() {
589                    writeln!(f, "\n    - Mismatched data type: {}", desc)?;
590                }
591
592                let desc = A::describe(input.arch);
593                if !desc.is_match() {
594                    writeln!(f, "\n    - Mismatched architecture: {}", desc)?;
595                }
596            }
597        }
598        Ok(())
599    }
600}
601
602impl<A, Q, D> Regression for Kernel<A, Q, D>
603where
604    Q: AsDataType,
605    D: AsDataType,
606    A: AsArch,
607    Kernel<A, Q, D>: RunBenchmark<A>,
608{
609    type Tolerances = SimdTolerance;
610    type Pass = CheckResult;
611    type Fail = CheckResult;
612
613    fn check(
614        &self,
615        tolerance: &SimdTolerance,
616        _input: &SimdOp,
617        before: &Vec<RunResult>,
618        after: &Vec<RunResult>,
619    ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
620        anyhow::ensure!(
621            before.len() == after.len(),
622            "before has {} runs but after has {}",
623            before.len(),
624            after.len(),
625        );
626
627        let mut passed = true;
628        let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
629            .enumerate()
630            .map(|(i, (b, a))| {
631                anyhow::ensure!(b.run == a.run, "run {i} mismatched");
632
633                let computations_per_latency = b.computations_per_latency() as f64;
634
635                let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
636                let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
637
638                let comparison = Comparison {
639                    run: b.run.clone(),
640                    tolerance: *tolerance,
641                    before_min,
642                    after_min,
643                };
644
645                // Determine whether or not we pass.
646                match relative_change(before_min, after_min) {
647                    Ok(change) => {
648                        if change > tolerance.min_time_regression.get() {
649                            passed = false;
650                        }
651                    }
652                    Err(_) => passed = false,
653                };
654
655                Ok(comparison)
656            })
657            .collect::<anyhow::Result<Vec<Comparison>>>()?;
658
659        let check = CheckResult { checks };
660
661        if passed {
662            Ok(PassFail::Pass(check))
663        } else {
664            Ok(PassFail::Fail(check))
665        }
666    }
667}
668
669///////////////
670// Benchmark //
671///////////////
672
673trait RunBenchmark<A> {
674    fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result<Vec<RunResult>, anyhow::Error>;
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678struct RunResult {
679    /// The configuration for this run.
680    run: Run,
681    /// The latencies of individual runs.
682    latencies: Vec<MicroSeconds>,
683    /// Latency percentiles.
684    percentiles: percentiles::Percentiles<MicroSeconds>,
685}
686
687impl RunResult {
688    fn computations_per_latency(&self) -> usize {
689        self.run.num_points.get() * self.run.loops_per_measurement.get()
690    }
691}
692
693impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
694    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
695        if self.is_empty() {
696            return Ok(());
697        }
698
699        let header = [
700            "Distance",
701            "Dim",
702            "Min Time (ns)",
703            "Mean Time (ns)",
704            "Points",
705            "Loops",
706            "Measurements",
707        ];
708
709        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
710
711        self.iter().enumerate().for_each(|(row, r)| {
712            let mut row = table.row(row);
713
714            let min_latency = r
715                .latencies
716                .iter()
717                .min()
718                .copied()
719                .unwrap_or(MicroSeconds::new(u64::MAX));
720            let mean_latency = r.percentiles.mean;
721
722            let computations_per_latency = r.computations_per_latency() as f64;
723
724            // Convert time from micro-seconds to nano-seconds.
725            let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
726            let mean_time = mean_latency / computations_per_latency * 1000.0;
727
728            row.insert(r.run.distance, 0);
729            row.insert(r.run.dim, 1);
730            row.insert(format!("{:.3}", min_time), 2);
731            row.insert(format!("{:.3}", mean_time), 3);
732            row.insert(r.run.num_points, 4);
733            row.insert(r.run.loops_per_measurement, 5);
734            row.insert(r.run.num_measurements, 6);
735        });
736
737        table.fmt(f)
738    }
739}
740
741fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
742where
743    F: Fn(&[Q], &[D]) -> f32,
744{
745    let mut latencies = Vec::with_capacity(run.num_measurements.get());
746    let mut dst = vec![0.0; data.nrows()];
747
748    for _ in 0..run.num_measurements.get() {
749        let start = std::time::Instant::now();
750        for _ in 0..run.loops_per_measurement.get() {
751            std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
752                *d = f(query, r);
753            });
754            std::hint::black_box(&mut dst);
755        }
756        latencies.push(start.elapsed().into());
757    }
758
759    let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
760    RunResult {
761        run: run.clone(),
762        latencies,
763        percentiles,
764    }
765}
766
767struct Data<Q, D> {
768    query: Box<[Q]>,
769    data: Matrix<D>,
770}
771
772impl<Q, D> Data<Q, D> {
773    fn new(run: &Run) -> Self
774    where
775        StandardUniform: Distribution<Q>,
776        StandardUniform: Distribution<D>,
777    {
778        let mut rng = StdRng::seed_from_u64(0x12345);
779        let query: Box<[Q]> = (0..run.dim.get())
780            .map(|_| StandardUniform.sample(&mut rng))
781            .collect();
782        let data = Matrix::<D>::new(
783            diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
784            run.num_points.get(),
785            run.dim.get(),
786        );
787
788        Self { query, data }
789    }
790
791    fn run<F>(&self, run: &Run, f: F) -> RunResult
792    where
793        F: Fn(&[Q], &[D]) -> f32,
794    {
795        run_loops(&self.query, self.data.as_view(), run, f)
796    }
797}
798
799/////////////////////
800// Implementations //
801/////////////////////
802
803macro_rules! stamp {
804    (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
805        impl RunBenchmark<Reference> for Kernel<Reference, $Q, $D> {
806            fn run_benchmark(
807                &self,
808                input: &SimdOp,
809                _arch: Reference,
810            ) -> Result<Vec<RunResult>, anyhow::Error> {
811                let mut results = Vec::new();
812                for run in input.runs.iter() {
813                    let data = Data::<$Q, $D>::new(run);
814                    let result = match run.distance {
815                        SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
816                        SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
817                        SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
818                    };
819                    results.push(result);
820                }
821                Ok(results)
822            }
823        }
824    };
825    ($arch:path, $Q:ty, $D:ty) => {
826        impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> {
827            fn run_benchmark(
828                &self,
829                input: &SimdOp,
830                arch: $arch,
831            ) -> Result<Vec<RunResult>, anyhow::Error> {
832                let mut results = Vec::new();
833
834                let l2 = &simd::L2 {};
835                let ip = &simd::IP {};
836                let cosine = &simd::CosineStateless {};
837
838                for run in input.runs.iter() {
839                    let data = Data::<$Q, $D>::new(run);
840                    // For each kernel, we need to do a two-step wrapping of closures so
841                    // the inner-most closure is executed by the architecture.
842                    //
843                    // This is required for the implementation of `simd_op` to be inlined
844                    // into the architecture run function so it can properly inherit
845                    // target features.
846                    let result = match run.distance {
847                        SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
848                            arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d)
849                        }),
850                        SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
851                            arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d)
852                        }),
853                        SimilarityMeasure::Cosine => data.run(run, |q, d| {
854                            arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d)
855                        }),
856                    };
857                    results.push(result)
858                }
859                Ok(results)
860            }
861        }
862    };
863    ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
864        #[cfg(target_arch = $target_arch)]
865        stamp!($arch, $Q, $D);
866    };
867}
868
869stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
870stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
871stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
872stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
873
874stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
875stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
876stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
877stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
878
879stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
880stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
881stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
882stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
883
884stamp!(diskann_wide::arch::Scalar, f32, f32);
885stamp!(diskann_wide::arch::Scalar, f16, f16);
886stamp!(diskann_wide::arch::Scalar, u8, u8);
887stamp!(diskann_wide::arch::Scalar, i8, i8);
888
889stamp!(
890    reference,
891    f32,
892    f32,
893    squared_l2_f32,
894    inner_product_f32,
895    cosine_f32
896);
897stamp!(
898    reference,
899    f16,
900    f16,
901    squared_l2_f16,
902    inner_product_f16,
903    cosine_f16
904);
905stamp!(
906    reference,
907    u8,
908    u8,
909    squared_l2_u8,
910    inner_product_u8,
911    cosine_u8
912);
913stamp!(
914    reference,
915    i8,
916    i8,
917    squared_l2_i8,
918    inner_product_i8,
919    cosine_i8
920);
921
922///////////////
923// Reference //
924///////////////
925
926// These are largely copied from the implementations in vector, with a tweak that we don't
927// use FMA when the current architecture is scalar.
928mod reference {
929    use diskann_wide::ARCH;
930    use half::f16;
931
932    trait MaybeFMA {
933        // Perform `a*b + c` using FMA when a hardware instruction is guaranteed to be
934        // available, otherwise decompose into a multiply and add.
935        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
936    }
937
938    impl MaybeFMA for diskann_wide::arch::Scalar {
939        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
940            a * b + c
941        }
942    }
943
944    #[cfg(target_arch = "x86_64")]
945    impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
946        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
947            a.mul_add(b, c)
948        }
949    }
950
951    #[cfg(target_arch = "x86_64")]
952    impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
953        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
954            a.mul_add(b, c)
955        }
956    }
957
958    #[cfg(target_arch = "aarch64")]
959    impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
960        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
961            a.mul_add(b, c)
962        }
963    }
964
965    //------------//
966    // Squared L2 //
967    //------------//
968
969    pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
970        assert_eq!(x.len(), y.len());
971        std::iter::zip(x.iter(), y.iter())
972            .map(|(&a, &b)| {
973                let a: i32 = a.into();
974                let b: i32 = b.into();
975                let diff = a - b;
976                diff * diff
977            })
978            .sum::<i32>() as f32
979    }
980
981    pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
982        assert_eq!(x.len(), y.len());
983        std::iter::zip(x.iter(), y.iter())
984            .map(|(&a, &b)| {
985                let a: i32 = a.into();
986                let b: i32 = b.into();
987                let diff = a - b;
988                diff * diff
989            })
990            .sum::<i32>() as f32
991    }
992
993    pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
994        assert_eq!(x.len(), y.len());
995        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
996            let a: f32 = a.into();
997            let b: f32 = b.into();
998            let diff = a - b;
999            ARCH.maybe_fma(diff, diff, acc)
1000        })
1001    }
1002
1003    pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1004        assert_eq!(x.len(), y.len());
1005        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1006            let diff = a - b;
1007            ARCH.maybe_fma(diff, diff, acc)
1008        })
1009    }
1010
1011    //---------------//
1012    // Inner Product //
1013    //---------------//
1014
1015    pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1016        assert_eq!(x.len(), y.len());
1017        std::iter::zip(x.iter(), y.iter())
1018            .map(|(&a, &b)| {
1019                let a: i32 = a.into();
1020                let b: i32 = b.into();
1021                a * b
1022            })
1023            .sum::<i32>() as f32
1024    }
1025
1026    pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1027        assert_eq!(x.len(), y.len());
1028        std::iter::zip(x.iter(), y.iter())
1029            .map(|(&a, &b)| {
1030                let a: i32 = a.into();
1031                let b: i32 = b.into();
1032                a * b
1033            })
1034            .sum::<i32>() as f32
1035    }
1036
1037    pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1038        assert_eq!(x.len(), y.len());
1039        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1040            let a: f32 = a.into();
1041            let b: f32 = b.into();
1042            ARCH.maybe_fma(a, b, acc)
1043        })
1044    }
1045
1046    pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1047        assert_eq!(x.len(), y.len());
1048        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1049    }
1050
1051    //--------//
1052    // Cosine //
1053    //--------//
1054
1055    #[derive(Default)]
1056    struct XY<T> {
1057        xnorm: T,
1058        ynorm: T,
1059        xy: T,
1060    }
1061
1062    pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1063        assert_eq!(x.len(), y.len());
1064        let r: XY<i32> =
1065            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1066                let vx: i32 = vx.into();
1067                let vy: i32 = vy.into();
1068                XY {
1069                    xnorm: acc.xnorm + vx * vx,
1070                    ynorm: acc.ynorm + vy * vy,
1071                    xy: acc.xy + vx * vy,
1072                }
1073            });
1074
1075        if r.xnorm == 0 || r.ynorm == 0 {
1076            return 0.0;
1077        }
1078
1079        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1080    }
1081
1082    pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1083        assert_eq!(x.len(), y.len());
1084        let r: XY<i32> =
1085            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1086                let vx: i32 = vx.into();
1087                let vy: i32 = vy.into();
1088                XY {
1089                    xnorm: acc.xnorm + vx * vx,
1090                    ynorm: acc.ynorm + vy * vy,
1091                    xy: acc.xy + vx * vy,
1092                }
1093            });
1094
1095        if r.xnorm == 0 || r.ynorm == 0 {
1096            return 0.0;
1097        }
1098
1099        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1100    }
1101
1102    pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1103        assert_eq!(x.len(), y.len());
1104        let r: XY<f32> =
1105            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1106                let vx: f32 = vx.into();
1107                let vy: f32 = vy.into();
1108                XY {
1109                    xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1110                    ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1111                    xy: ARCH.maybe_fma(vx, vy, acc.xy),
1112                }
1113            });
1114
1115        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1116            return 0.0;
1117        }
1118
1119        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1120    }
1121
1122    pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1123        assert_eq!(x.len(), y.len());
1124        let r: XY<f32> =
1125            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1126                xnorm: vx.mul_add(vx, acc.xnorm),
1127                ynorm: vy.mul_add(vy, acc.ynorm),
1128                xy: vx.mul_add(vy, acc.xy),
1129            });
1130
1131        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1132            return 0.0;
1133        }
1134
1135        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1136    }
1137}
1138
1139///////////
1140// Tests //
1141///////////
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146
1147    use diskann_benchmark_runner::{
1148        benchmark::{PassFail, Regression},
1149        utils::percentiles::compute_percentiles,
1150    };
1151
1152    fn tiny_run(distance: SimilarityMeasure) -> Run {
1153        Run {
1154            distance,
1155            dim: NonZeroUsize::new(8).unwrap(),
1156            num_points: NonZeroUsize::new(1).unwrap(),
1157            loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1158            num_measurements: NonZeroUsize::new(1).unwrap(),
1159        }
1160    }
1161
1162    fn tiny_op() -> SimdOp {
1163        SimdOp {
1164            query_type: DataType::Float32,
1165            data_type: DataType::Float32,
1166            arch: Arch::Scalar,
1167            runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1168        }
1169    }
1170
1171    fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1172        let run = tiny_run(distance);
1173        let minimum = MicroSeconds::new(minimum);
1174        let mut latencies = vec![minimum];
1175        let percentiles = compute_percentiles(&mut latencies).unwrap();
1176        RunResult {
1177            run,
1178            latencies,
1179            percentiles,
1180        }
1181    }
1182
1183    fn tolerance(limit: f64) -> SimdTolerance {
1184        SimdTolerance {
1185            min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1186        }
1187    }
1188
1189    #[test]
1190    fn check_rejects_mismatched_runs() {
1191        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1192
1193        let err = kernel
1194            .check(
1195                &tolerance(0.0),
1196                &tiny_op(),
1197                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1198                &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1199            )
1200            .unwrap_err();
1201
1202        assert_eq!(err.to_string(), "run 0 mismatched");
1203    }
1204
1205    #[test]
1206    fn check_allows_negative_relative_change() {
1207        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1208
1209        let result = kernel
1210            .check(
1211                &tolerance(0.0),
1212                &tiny_op(),
1213                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1214                &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1215            )
1216            .unwrap();
1217
1218        assert!(matches!(result, PassFail::Pass(_)));
1219    }
1220
1221    #[test]
1222    fn check_passes_on_tolerance_boundary() {
1223        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1224
1225        let result = kernel
1226            .check(
1227                &tolerance(0.05),
1228                &tiny_op(),
1229                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1230                &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1231            )
1232            .unwrap();
1233
1234        assert!(matches!(result, PassFail::Pass(_)));
1235    }
1236
1237    #[test]
1238    fn check_fails_above_tolerance_boundary() {
1239        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1240
1241        let result = kernel
1242            .check(
1243                &tolerance(0.05),
1244                &tiny_op(),
1245                &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1246                &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1247            )
1248            .unwrap();
1249
1250        assert!(matches!(result, PassFail::Fail(_)));
1251    }
1252
1253    #[test]
1254    fn check_result_display_includes_failure_details() {
1255        let check = CheckResult {
1256            checks: vec![Comparison {
1257                run: tiny_run(SimilarityMeasure::SquaredL2),
1258                tolerance: tolerance(0.05),
1259                before_min: 100.0,
1260                after_min: 106.0,
1261            }],
1262        };
1263
1264        let rendered = check.to_string();
1265        assert!(rendered.contains("Distance"), "rendered = {rendered}");
1266        assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1267        assert!(rendered.contains("100.000"), "rendered = {rendered}");
1268        assert!(rendered.contains("106.000"), "rendered = {rendered}");
1269        assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1270        assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1271    }
1272
1273    // If a "before" value is 0, we should fail with an error because this means the
1274    // measurement was too fast for us to obtain a reliable signal, so we *could* be letting
1275    // a regression through.
1276    //
1277    // We require at least a non-zero value.
1278    #[test]
1279    fn zero_values_rejected() {
1280        let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1281
1282        let result = kernel
1283            .check(
1284                &tolerance(0.05),
1285                &tiny_op(),
1286                &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1287                &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1288            )
1289            .unwrap();
1290
1291        assert!(matches!(result, PassFail::Fail(_)));
1292    }
1293}