Skip to main content

diskann_benchmark_simd/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{io::Write, num::NonZeroUsize};
7
8use diskann_utils::views::{Matrix, MatrixView};
9use diskann_vector::distance::simd;
10use diskann_wide::Architecture;
11use half::f16;
12use rand::{
13    distr::{Distribution, StandardUniform},
14    rngs::StdRng,
15    SeedableRng,
16};
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19
20use diskann_benchmark_runner::{
21    describeln,
22    dispatcher::{self, DispatchRule, FailureScore, MatchScore},
23    utils::{
24        datatype::{self, DataType},
25        percentiles, MicroSeconds,
26    },
27    Any, CheckDeserialization, Checker,
28};
29
30////////////////
31// Public API //
32////////////////
33
34#[derive(Debug)]
35pub struct SimdInput;
36
37pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38    register_benchmarks_impl(dispatcher)
39}
40
41///////////
42// Utils //
43///////////
44
45#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49    type Target = T;
50    fn deref(&self) -> &T {
51        self.0
52    }
53}
54
55////////////
56// Inputs //
57////////////
58
59#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62    SquaredL2,
63    InnerProduct,
64    Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        let st = match self {
70            Self::SquaredL2 => "squared_l2",
71            Self::InnerProduct => "inner_product",
72            Self::Cosine => "cosine",
73        };
74        write!(f, "{}", st)
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80pub(crate) enum Arch {
81    #[serde(rename = "x86-64-v4")]
82    #[allow(non_camel_case_types)]
83    X86_64_V4,
84    #[serde(rename = "x86-64-v3")]
85    #[allow(non_camel_case_types)]
86    X86_64_V3,
87    Scalar,
88    Reference,
89}
90
91impl std::fmt::Display for Arch {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let st = match self {
94            Self::X86_64_V4 => "x86-64-v4",
95            Self::X86_64_V3 => "x86-64-v3",
96            Self::Scalar => "scalar",
97            Self::Reference => "reference",
98        };
99        write!(f, "{}", st)
100    }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub(crate) struct Run {
105    pub(crate) distance: SimilarityMeasure,
106    pub(crate) dim: NonZeroUsize,
107    pub(crate) num_points: NonZeroUsize,
108    pub(crate) loops_per_measurement: NonZeroUsize,
109    pub(crate) num_measurements: NonZeroUsize,
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113pub(crate) struct SimdOp {
114    pub(crate) query_type: DataType,
115    pub(crate) data_type: DataType,
116    pub(crate) arch: Arch,
117    pub(crate) runs: Vec<Run>,
118}
119
120impl CheckDeserialization for SimdOp {
121    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
122        Ok(())
123    }
124}
125
126macro_rules! write_field {
127    ($f:ident, $field:tt, $($expr:tt)*) => {
128        writeln!($f, "{:>18}: {}", $field, $($expr)*)
129    }
130}
131
132impl SimdOp {
133    pub(crate) const fn tag() -> &'static str {
134        "simd-op"
135    }
136
137    fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write_field!(f, "query type", self.query_type)?;
139        write_field!(f, "data type", self.data_type)?;
140        write_field!(f, "arch", self.arch)?;
141        write_field!(f, "number of runs", self.runs.len())?;
142        Ok(())
143    }
144}
145
146impl std::fmt::Display for SimdOp {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        writeln!(f, "SIMD Operation\n")?;
149        write_field!(f, "tag", Self::tag())?;
150        self.summarize_fields(f)
151    }
152}
153
154impl diskann_benchmark_runner::Input for SimdInput {
155    fn tag(&self) -> &'static str {
156        "simd-op"
157    }
158
159    fn try_deserialize(
160        &self,
161        serialized: &serde_json::Value,
162        checker: &mut Checker,
163    ) -> anyhow::Result<Any> {
164        checker.any(SimdOp::deserialize(serialized)?)
165    }
166
167    fn example(&self) -> anyhow::Result<serde_json::Value> {
168        const DIM: [NonZeroUsize; 2] = [
169            NonZeroUsize::new(128).unwrap(),
170            NonZeroUsize::new(150).unwrap(),
171        ];
172
173        const NUM_POINTS: [NonZeroUsize; 2] = [
174            NonZeroUsize::new(1000).unwrap(),
175            NonZeroUsize::new(800).unwrap(),
176        ];
177
178        const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
179        const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
180
181        let runs = vec![
182            Run {
183                distance: SimilarityMeasure::SquaredL2,
184                dim: DIM[0],
185                num_points: NUM_POINTS[0],
186                loops_per_measurement: LOOPS_PER_MEASUREMENT,
187                num_measurements: NUM_MEASUREMENTS,
188            },
189            Run {
190                distance: SimilarityMeasure::InnerProduct,
191                dim: DIM[1],
192                num_points: NUM_POINTS[1],
193                loops_per_measurement: LOOPS_PER_MEASUREMENT,
194                num_measurements: NUM_MEASUREMENTS,
195            },
196        ];
197
198        Ok(serde_json::to_value(&SimdOp {
199            query_type: DataType::Float32,
200            data_type: DataType::Float32,
201            arch: Arch::X86_64_V3,
202            runs,
203        })?)
204    }
205}
206
207////////////////////////////
208// Benchmark Registration //
209////////////////////////////
210
211macro_rules! register {
212    ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
213        #[cfg(target_arch = $arch)]
214        $dispatcher.register::<$($kernel)*>(
215            $name,
216            run_benchmark,
217        )
218    };
219    ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
220        $dispatcher.register::<$($kernel)*>(
221            $name,
222            run_benchmark,
223        )
224    };
225}
226
227fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
228    // x86-64-v4
229    register!(
230        "x86_64",
231        dispatcher,
232        "simd-op-f32xf32-x86_64_V4",
233        Kernel<'static, diskann_wide::arch::x86_64::V4, f32, f32>
234    );
235    register!(
236        "x86_64",
237        dispatcher,
238        "simd-op-f16xf16-x86_64_V4",
239        Kernel<'static, diskann_wide::arch::x86_64::V4, f16, f16>
240    );
241    register!(
242        "x86_64",
243        dispatcher,
244        "simd-op-u8xu8-x86_64_V4",
245        Kernel<'static, diskann_wide::arch::x86_64::V4, u8, u8>
246    );
247    register!(
248        "x86_64",
249        dispatcher,
250        "simd-op-i8xi8-x86_64_V4",
251        Kernel<'static, diskann_wide::arch::x86_64::V4, i8, i8>
252    );
253
254    // x86-64-v3
255    register!(
256        "x86_64",
257        dispatcher,
258        "simd-op-f32xf32-x86_64_V3",
259        Kernel<'static, diskann_wide::arch::x86_64::V3, f32, f32>
260    );
261    register!(
262        "x86_64",
263        dispatcher,
264        "simd-op-f16xf16-x86_64_V3",
265        Kernel<'static, diskann_wide::arch::x86_64::V3, f16, f16>
266    );
267    register!(
268        "x86_64",
269        dispatcher,
270        "simd-op-u8xu8-x86_64_V3",
271        Kernel<'static, diskann_wide::arch::x86_64::V3, u8, u8>
272    );
273    register!(
274        "x86_64",
275        dispatcher,
276        "simd-op-i8xi8-x86_64_V3",
277        Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8>
278    );
279
280    // scalar
281    register!(
282        dispatcher,
283        "simd-op-f32xf32-scalar",
284        Kernel<'static, diskann_wide::arch::Scalar, f32, f32>
285    );
286    register!(
287        dispatcher,
288        "simd-op-f16xf16-scalar",
289        Kernel<'static, diskann_wide::arch::Scalar, f16, f16>
290    );
291    register!(
292        dispatcher,
293        "simd-op-u8xu8-scalar",
294        Kernel<'static, diskann_wide::arch::Scalar, u8, u8>
295    );
296    register!(
297        dispatcher,
298        "simd-op-i8xi8-scalar",
299        Kernel<'static, diskann_wide::arch::Scalar, i8, i8>
300    );
301
302    // reference
303    register!(
304        dispatcher,
305        "simd-op-f32xf32-reference",
306        Kernel<'static, Reference, f32, f32>
307    );
308    register!(
309        dispatcher,
310        "simd-op-f16xf16-reference",
311        Kernel<'static, Reference, f16, f16>
312    );
313    register!(
314        dispatcher,
315        "simd-op-u8xu8-reference",
316        Kernel<'static, Reference, u8, u8>
317    );
318    register!(
319        dispatcher,
320        "simd-op-i8xi8-reference",
321        Kernel<'static, Reference, i8, i8>
322    );
323}
324
325//////////////
326// Dispatch //
327//////////////
328
329/// Dispatch receiver for the reference implementations.
330struct Reference;
331
332/// A dispatch mapper for `wide` types.
333#[derive(Debug)]
334struct Identity<T>(T);
335
336impl<T> dispatcher::Map for Identity<T>
337where
338    T: 'static,
339{
340    type Type<'a> = T;
341}
342
343struct Kernel<'a, A, Q, D> {
344    input: &'a SimdOp,
345    arch: A,
346    _type: std::marker::PhantomData<(A, Q, D)>,
347}
348
349impl<'a, A, Q, D> Kernel<'a, A, Q, D> {
350    fn new(input: &'a SimdOp, arch: A) -> Self {
351        Self {
352            input,
353            arch,
354            _type: std::marker::PhantomData,
355        }
356    }
357}
358
359impl<A, Q, D> dispatcher::Map for Kernel<'static, A, Q, D>
360where
361    A: 'static,
362    Q: 'static,
363    D: 'static,
364{
365    type Type<'a> = Kernel<'a, A, Q, D>;
366}
367
368// Map Architectures to the enum.
369#[derive(Debug, Error)]
370#[error("architecture {0} is not supported by this CPU")]
371pub(crate) struct ArchNotSupported(Arch);
372
373impl DispatchRule<Arch> for Identity<Reference> {
374    type Error = ArchNotSupported;
375
376    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
377        if *from == Arch::Reference {
378            Ok(MatchScore(0))
379        } else {
380            Err(FailureScore(0))
381        }
382    }
383
384    fn convert(from: Arch) -> Result<Self, Self::Error> {
385        assert_eq!(from, Arch::Reference);
386        Ok(Identity(Reference))
387    }
388
389    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
390        match from {
391            None => write!(f, "loop based"),
392            Some(arch) => {
393                if Self::try_match(arch).is_ok() {
394                    write!(f, "matched {}", arch)
395                } else {
396                    write!(f, "expected {}, got {}", Arch::Reference, arch)
397                }
398            }
399        }
400    }
401}
402
403impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
404    type Error = ArchNotSupported;
405
406    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
407        if *from == Arch::Scalar {
408            Ok(MatchScore(0))
409        } else {
410            Err(FailureScore(0))
411        }
412    }
413
414    fn convert(from: Arch) -> Result<Self, Self::Error> {
415        assert_eq!(from, Arch::Scalar);
416        Ok(Identity(diskann_wide::arch::Scalar))
417    }
418
419    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
420        match from {
421            None => write!(f, "scalar (compilation target CPU)"),
422            Some(arch) => {
423                if Self::try_match(arch).is_ok() {
424                    write!(f, "matched {}", arch)
425                } else {
426                    write!(f, "expected {}, got {}", Arch::Scalar, arch)
427                }
428            }
429        }
430    }
431}
432
433#[cfg(target_arch = "x86_64")]
434impl DispatchRule<Arch> for Identity<diskann_wide::arch::x86_64::V4> {
435    type Error = ArchNotSupported;
436
437    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
438        if *from == Arch::X86_64_V4 {
439            Ok(MatchScore(0))
440        } else {
441            Err(FailureScore(0))
442        }
443    }
444
445    fn convert(from: Arch) -> Result<Self, Self::Error> {
446        assert_eq!(from, Arch::X86_64_V4);
447        diskann_wide::arch::x86_64::V4::new_checked()
448            .ok_or(ArchNotSupported(from))
449            .map(Identity)
450    }
451
452    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
453        match from {
454            None => write!(f, "x86-64-v4"),
455            Some(arch) => {
456                if Self::try_match(arch).is_ok() {
457                    write!(f, "matched {}", arch)
458                } else {
459                    write!(f, "expected {}, got {}", Arch::X86_64_V4, arch)
460                }
461            }
462        }
463    }
464}
465
466#[cfg(target_arch = "x86_64")]
467impl DispatchRule<Arch> for Identity<diskann_wide::arch::x86_64::V3> {
468    type Error = ArchNotSupported;
469
470    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
471        if *from == Arch::X86_64_V3 {
472            Ok(MatchScore(0))
473        } else {
474            Err(FailureScore(0))
475        }
476    }
477
478    fn convert(from: Arch) -> Result<Self, Self::Error> {
479        assert_eq!(from, Arch::X86_64_V3);
480        diskann_wide::arch::x86_64::V3::new_checked()
481            .ok_or(ArchNotSupported(from))
482            .map(Identity)
483    }
484
485    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
486        match from {
487            None => write!(f, "x86-64-v3"),
488            Some(arch) => {
489                if Self::try_match(arch).is_ok() {
490                    write!(f, "matched {}", arch)
491                } else {
492                    write!(f, "expected {}, got {}", Arch::X86_64_V3, arch)
493                }
494            }
495        }
496    }
497}
498
499impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D>
500where
501    datatype::Type<Q>: DispatchRule<datatype::DataType>,
502    datatype::Type<D>: DispatchRule<datatype::DataType>,
503    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
504{
505    type Error = ArchNotSupported;
506
507    // Matching simply requires that we match the inner type.
508    fn try_match(from: &&'a SimdOp) -> Result<MatchScore, FailureScore> {
509        let mut failscore: Option<u32> = None;
510        if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
511            *failscore.get_or_insert(0) += 10;
512        }
513        if datatype::Type::<D>::try_match(&from.data_type).is_err() {
514            *failscore.get_or_insert(0) += 10;
515        }
516        if Identity::<A>::try_match(&from.arch).is_err() {
517            *failscore.get_or_insert(0) += 2;
518        }
519        match failscore {
520            None => Ok(MatchScore(0)),
521            Some(score) => Err(FailureScore(score)),
522        }
523    }
524
525    fn convert(from: &'a SimdOp) -> Result<Self, Self::Error> {
526        assert!(Self::try_match(&from).is_ok());
527        let arch = Identity::<A>::convert(from.arch)?.0;
528        Ok(Self::new(from, arch))
529    }
530
531    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a SimdOp>) -> std::fmt::Result {
532        match from {
533            None => {
534                describeln!(
535                    f,
536                    "- Query Type: {}",
537                    dispatcher::Description::<datatype::DataType, datatype::Type<Q>>::new()
538                )?;
539                describeln!(
540                    f,
541                    "- Data Type: {}",
542                    dispatcher::Description::<datatype::DataType, datatype::Type<D>>::new()
543                )?;
544                describeln!(
545                    f,
546                    "- Implementation: {}",
547                    dispatcher::Description::<Arch, Identity<A>>::new()
548                )?;
549            }
550            Some(input) => {
551                if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
552                    describeln!(f, "- Mismatched query type: {}", err)?;
553                }
554                if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
555                    describeln!(f, "- Mismatched data type: {}", err)?;
556                }
557                if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
558                    describeln!(f, "- Mismatched architecture: {}", err)?;
559                }
560            }
561        }
562        Ok(())
563    }
564}
565
566impl<'a, A, Q, D> DispatchRule<&'a diskann_benchmark_runner::Any> for Kernel<'a, A, Q, D>
567where
568    Kernel<'a, A, Q, D>: DispatchRule<&'a SimdOp>,
569    <Kernel<'a, A, Q, D> as DispatchRule<&'a SimdOp>>::Error:
570        std::error::Error + Send + Sync + 'static,
571{
572    type Error = anyhow::Error;
573
574    fn try_match(from: &&'a diskann_benchmark_runner::Any) -> Result<MatchScore, FailureScore> {
575        from.try_match::<SimdOp, Self>()
576    }
577
578    fn convert(from: &'a diskann_benchmark_runner::Any) -> Result<Self, Self::Error> {
579        from.convert::<SimdOp, Self>()
580    }
581
582    fn description(
583        f: &mut std::fmt::Formatter<'_>,
584        from: Option<&&'a diskann_benchmark_runner::Any>,
585    ) -> std::fmt::Result {
586        Any::description::<SimdOp, Self>(f, from, SimdOp::tag())
587    }
588}
589
590///////////////
591// Benchmark //
592///////////////
593
594fn run_benchmark<A, Q, D>(
595    kernel: Kernel<'_, A, Q, D>,
596    _: diskann_benchmark_runner::Checkpoint<'_>,
597    mut output: &mut dyn diskann_benchmark_runner::Output,
598) -> Result<serde_json::Value, anyhow::Error>
599where
600    for<'a> Kernel<'a, A, Q, D>: RunBenchmark,
601{
602    writeln!(output, "{}", kernel.input)?;
603    let results = kernel.run()?;
604    writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
605    Ok(serde_json::to_value(results)?)
606}
607
608trait RunBenchmark {
609    fn run(self) -> Result<Vec<RunResult>, anyhow::Error>;
610}
611
612#[derive(Debug, Serialize)]
613struct RunResult {
614    /// The setup
615    run: Run,
616    /// The latencies of individual runs.
617    latencies: Vec<MicroSeconds>,
618    /// Latency percentiles.
619    percentiles: percentiles::Percentiles<MicroSeconds>,
620}
621
622impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
623    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624        if self.is_empty() {
625            return Ok(());
626        }
627
628        let header = [
629            "Distance",
630            "Dim",
631            "Min Time (ns)",
632            "Mean Time (ns)",
633            "Points",
634            "Loops",
635            "Measurements",
636        ];
637
638        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
639
640        self.iter().enumerate().for_each(|(row, r)| {
641            let mut row = table.row(row);
642
643            let min_latency = r
644                .latencies
645                .iter()
646                .min()
647                .copied()
648                .unwrap_or(MicroSeconds::new(u64::MAX));
649            let mean_latency = r.percentiles.mean;
650
651            let computations_per_latency: f64 =
652                (r.run.num_points.get() * r.run.loops_per_measurement.get()) as f64;
653
654            // Convert time from micro-seconds to nano-seconds.
655            let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
656            let mean_time = mean_latency / computations_per_latency * 1000.0;
657
658            row.insert(r.run.distance, 0);
659            row.insert(r.run.dim, 1);
660            row.insert(format!("{:.3}", min_time), 2);
661            row.insert(format!("{:.3}", mean_time), 3);
662            row.insert(r.run.num_points, 4);
663            row.insert(r.run.loops_per_measurement, 5);
664            row.insert(r.run.num_measurements, 6);
665        });
666
667        table.fmt(f)
668    }
669}
670
671fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
672where
673    F: Fn(&[Q], &[D]) -> f32,
674{
675    let mut latencies = Vec::with_capacity(run.num_measurements.get());
676    let mut dst = vec![0.0; data.nrows()];
677
678    for _ in 0..run.num_measurements.get() {
679        let start = std::time::Instant::now();
680        for _ in 0..run.loops_per_measurement.get() {
681            std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
682                *d = f(query, r);
683            });
684            std::hint::black_box(&mut dst);
685        }
686        latencies.push(start.elapsed().into());
687    }
688
689    let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
690    RunResult {
691        run: run.clone(),
692        latencies,
693        percentiles,
694    }
695}
696
697struct Data<Q, D> {
698    query: Box<[Q]>,
699    data: Matrix<D>,
700}
701
702impl<Q, D> Data<Q, D> {
703    fn new(run: &Run) -> Self
704    where
705        StandardUniform: Distribution<Q>,
706        StandardUniform: Distribution<D>,
707    {
708        let mut rng = StdRng::seed_from_u64(0x12345);
709        let query: Box<[Q]> = (0..run.dim.get())
710            .map(|_| StandardUniform.sample(&mut rng))
711            .collect();
712        let data = Matrix::<D>::new(
713            diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
714            run.num_points.get(),
715            run.dim.get(),
716        );
717
718        Self { query, data }
719    }
720
721    fn run<F>(&self, run: &Run, f: F) -> RunResult
722    where
723        F: Fn(&[Q], &[D]) -> f32,
724    {
725        run_loops(&self.query, self.data.as_view(), run, f)
726    }
727}
728
729/////////////////////
730// Implementations //
731/////////////////////
732
733macro_rules! stamp {
734    (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
735        impl RunBenchmark for Kernel<'_, Reference, $Q, $D> {
736            fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
737                let mut results = Vec::new();
738                for run in self.input.runs.iter() {
739                    let data = Data::<$Q, $D>::new(run);
740                    let result = match run.distance {
741                        SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
742                        SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
743                        SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
744                    };
745                    results.push(result);
746                }
747                Ok(results)
748            }
749        }
750    };
751    ($arch:path, $Q:ty, $D:ty) => {
752        impl RunBenchmark for Kernel<'_, $arch, $Q, $D> {
753            fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
754                let mut results = Vec::new();
755
756                let l2 = &simd::L2 {};
757                let ip = &simd::IP {};
758                let cosine = &simd::CosineStateless {};
759
760                for run in self.input.runs.iter() {
761                    let data = Data::<$Q, $D>::new(run);
762                    // For each kernel, we need to do a two-step wrapping of closures so
763                    // the inner-most closure is executed by the architecture.
764                    //
765                    // This is required for the implementation of `simd_op` to be inlined
766                    // into the architecture run function so it can properly inherit
767                    // target features.
768                    let result = match run.distance {
769                        SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
770                            self.arch
771                                .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
772                        }),
773                        SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
774                            self.arch
775                                .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
776                        }),
777                        SimilarityMeasure::Cosine => data.run(run, |q, d| {
778                            self.arch
779                                .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
780                        }),
781                    };
782                    results.push(result)
783                }
784                Ok(results)
785            }
786        }
787    };
788    ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
789        #[cfg(target_arch = $target_arch)]
790        stamp!($arch, $Q, $D);
791    };
792}
793
794stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
795stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
796stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
797stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
798
799stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
800stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
801stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
802stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
803
804stamp!(diskann_wide::arch::Scalar, f32, f32);
805stamp!(diskann_wide::arch::Scalar, f16, f16);
806stamp!(diskann_wide::arch::Scalar, u8, u8);
807stamp!(diskann_wide::arch::Scalar, i8, i8);
808
809stamp!(
810    reference,
811    f32,
812    f32,
813    squared_l2_f32,
814    inner_product_f32,
815    cosine_f32
816);
817stamp!(
818    reference,
819    f16,
820    f16,
821    squared_l2_f16,
822    inner_product_f16,
823    cosine_f16
824);
825stamp!(
826    reference,
827    u8,
828    u8,
829    squared_l2_u8,
830    inner_product_u8,
831    cosine_u8
832);
833stamp!(
834    reference,
835    i8,
836    i8,
837    squared_l2_i8,
838    inner_product_i8,
839    cosine_i8
840);
841
842///////////////
843// Reference //
844///////////////
845
846// These are largely copied from the implementations in vector, with a tweak that we don't
847// use FMA when the current architecture is scalar.
848mod reference {
849    use diskann_wide::ARCH;
850    use half::f16;
851
852    trait MaybeFMA {
853        // Perform `a*b + c` using FMA when a hardware instruction is guaranteed to be
854        // available, otherwise decompose into a multiply and add.
855        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
856    }
857
858    impl MaybeFMA for diskann_wide::arch::Scalar {
859        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
860            a * b + c
861        }
862    }
863
864    #[cfg(target_arch = "x86_64")]
865    impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
866        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
867            a.mul_add(b, c)
868        }
869    }
870
871    #[cfg(target_arch = "x86_64")]
872    impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
873        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
874            a.mul_add(b, c)
875        }
876    }
877
878    //------------//
879    // Squared L2 //
880    //------------//
881
882    pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
883        assert_eq!(x.len(), y.len());
884        std::iter::zip(x.iter(), y.iter())
885            .map(|(&a, &b)| {
886                let a: i32 = a.into();
887                let b: i32 = b.into();
888                let diff = a - b;
889                diff * diff
890            })
891            .sum::<i32>() as f32
892    }
893
894    pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
895        assert_eq!(x.len(), y.len());
896        std::iter::zip(x.iter(), y.iter())
897            .map(|(&a, &b)| {
898                let a: i32 = a.into();
899                let b: i32 = b.into();
900                let diff = a - b;
901                diff * diff
902            })
903            .sum::<i32>() as f32
904    }
905
906    pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
907        assert_eq!(x.len(), y.len());
908        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
909            let a: f32 = a.into();
910            let b: f32 = b.into();
911            let diff = a - b;
912            ARCH.maybe_fma(diff, diff, acc)
913        })
914    }
915
916    pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
917        assert_eq!(x.len(), y.len());
918        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
919            let diff = a - b;
920            ARCH.maybe_fma(diff, diff, acc)
921        })
922    }
923
924    //---------------//
925    // Inner Product //
926    //---------------//
927
928    pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
929        assert_eq!(x.len(), y.len());
930        std::iter::zip(x.iter(), y.iter())
931            .map(|(&a, &b)| {
932                let a: i32 = a.into();
933                let b: i32 = b.into();
934                a * b
935            })
936            .sum::<i32>() as f32
937    }
938
939    pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
940        assert_eq!(x.len(), y.len());
941        std::iter::zip(x.iter(), y.iter())
942            .map(|(&a, &b)| {
943                let a: i32 = a.into();
944                let b: i32 = b.into();
945                a * b
946            })
947            .sum::<i32>() as f32
948    }
949
950    pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
951        assert_eq!(x.len(), y.len());
952        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
953            let a: f32 = a.into();
954            let b: f32 = b.into();
955            ARCH.maybe_fma(a, b, acc)
956        })
957    }
958
959    pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
960        assert_eq!(x.len(), y.len());
961        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
962    }
963
964    //--------//
965    // Cosine //
966    //--------//
967
968    #[derive(Default)]
969    struct XY<T> {
970        xnorm: T,
971        ynorm: T,
972        xy: T,
973    }
974
975    pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
976        assert_eq!(x.len(), y.len());
977        let r: XY<i32> =
978            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
979                let vx: i32 = vx.into();
980                let vy: i32 = vy.into();
981                XY {
982                    xnorm: acc.xnorm + vx * vx,
983                    ynorm: acc.ynorm + vy * vy,
984                    xy: acc.xy + vx * vy,
985                }
986            });
987
988        if r.xnorm == 0 || r.ynorm == 0 {
989            return 0.0;
990        }
991
992        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
993    }
994
995    pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
996        assert_eq!(x.len(), y.len());
997        let r: XY<i32> =
998            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
999                let vx: i32 = vx.into();
1000                let vy: i32 = vy.into();
1001                XY {
1002                    xnorm: acc.xnorm + vx * vx,
1003                    ynorm: acc.ynorm + vy * vy,
1004                    xy: acc.xy + vx * vy,
1005                }
1006            });
1007
1008        if r.xnorm == 0 || r.ynorm == 0 {
1009            return 0.0;
1010        }
1011
1012        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1013    }
1014
1015    pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1016        assert_eq!(x.len(), y.len());
1017        let r: XY<f32> =
1018            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1019                let vx: f32 = vx.into();
1020                let vy: f32 = vy.into();
1021                XY {
1022                    xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1023                    ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1024                    xy: ARCH.maybe_fma(vx, vy, acc.xy),
1025                }
1026            });
1027
1028        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1029            return 0.0;
1030        }
1031
1032        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1033    }
1034
1035    pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1036        assert_eq!(x.len(), y.len());
1037        let r: XY<f32> =
1038            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1039                xnorm: vx.mul_add(vx, acc.xnorm),
1040                ynorm: vy.mul_add(vy, acc.ynorm),
1041                xy: vx.mul_add(vy, acc.xy),
1042            });
1043
1044        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1045            return 0.0;
1046        }
1047
1048        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1049    }
1050}