Skip to main content

diskann_benchmark_simd/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{io::Write, num::NonZeroUsize};
7
8use diskann_utils::views::{Matrix, MatrixView};
9use diskann_vector::distance::simd;
10use diskann_wide::Architecture;
11use half::f16;
12use rand::{
13    distr::{Distribution, StandardUniform},
14    rngs::StdRng,
15    SeedableRng,
16};
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19
20use diskann_benchmark_runner::{
21    describeln,
22    dispatcher::{self, DispatchRule, FailureScore, MatchScore},
23    utils::{
24        datatype::{self, DataType},
25        percentiles, MicroSeconds,
26    },
27    Any, CheckDeserialization, Checker,
28};
29
30////////////////
31// Public API //
32////////////////
33
34#[derive(Debug)]
35pub struct SimdInput;
36
37pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38    register_benchmarks_impl(dispatcher)
39}
40
41///////////
42// Utils //
43///////////
44
45#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49    type Target = T;
50    fn deref(&self) -> &T {
51        self.0
52    }
53}
54
55////////////
56// Inputs //
57////////////
58
59#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62    SquaredL2,
63    InnerProduct,
64    Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        let st = match self {
70            Self::SquaredL2 => "squared_l2",
71            Self::InnerProduct => "inner_product",
72            Self::Cosine => "cosine",
73        };
74        write!(f, "{}", st)
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80pub(crate) enum Arch {
81    #[serde(rename = "x86-64-v4")]
82    #[allow(non_camel_case_types)]
83    X86_64_V4,
84    #[serde(rename = "x86-64-v3")]
85    #[allow(non_camel_case_types)]
86    X86_64_V3,
87    Neon,
88    Scalar,
89    Reference,
90}
91
92impl std::fmt::Display for Arch {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        let st = match self {
95            Self::X86_64_V4 => "x86-64-v4",
96            Self::X86_64_V3 => "x86-64-v3",
97            Self::Neon => "neon",
98            Self::Scalar => "scalar",
99            Self::Reference => "reference",
100        };
101        write!(f, "{}", st)
102    }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub(crate) struct Run {
107    pub(crate) distance: SimilarityMeasure,
108    pub(crate) dim: NonZeroUsize,
109    pub(crate) num_points: NonZeroUsize,
110    pub(crate) loops_per_measurement: NonZeroUsize,
111    pub(crate) num_measurements: NonZeroUsize,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub(crate) struct SimdOp {
116    pub(crate) query_type: DataType,
117    pub(crate) data_type: DataType,
118    pub(crate) arch: Arch,
119    pub(crate) runs: Vec<Run>,
120}
121
122impl CheckDeserialization for SimdOp {
123    fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
124        Ok(())
125    }
126}
127
128macro_rules! write_field {
129    ($f:ident, $field:tt, $($expr:tt)*) => {
130        writeln!($f, "{:>18}: {}", $field, $($expr)*)
131    }
132}
133
134impl SimdOp {
135    pub(crate) const fn tag() -> &'static str {
136        "simd-op"
137    }
138
139    fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        write_field!(f, "query type", self.query_type)?;
141        write_field!(f, "data type", self.data_type)?;
142        write_field!(f, "arch", self.arch)?;
143        write_field!(f, "number of runs", self.runs.len())?;
144        Ok(())
145    }
146}
147
148impl std::fmt::Display for SimdOp {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        writeln!(f, "SIMD Operation\n")?;
151        write_field!(f, "tag", Self::tag())?;
152        self.summarize_fields(f)
153    }
154}
155
156impl diskann_benchmark_runner::Input for SimdInput {
157    fn tag(&self) -> &'static str {
158        "simd-op"
159    }
160
161    fn try_deserialize(
162        &self,
163        serialized: &serde_json::Value,
164        checker: &mut Checker,
165    ) -> anyhow::Result<Any> {
166        checker.any(SimdOp::deserialize(serialized)?)
167    }
168
169    fn example(&self) -> anyhow::Result<serde_json::Value> {
170        const DIM: [NonZeroUsize; 2] = [
171            NonZeroUsize::new(128).unwrap(),
172            NonZeroUsize::new(150).unwrap(),
173        ];
174
175        const NUM_POINTS: [NonZeroUsize; 2] = [
176            NonZeroUsize::new(1000).unwrap(),
177            NonZeroUsize::new(800).unwrap(),
178        ];
179
180        const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
181        const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
182
183        let runs = vec![
184            Run {
185                distance: SimilarityMeasure::SquaredL2,
186                dim: DIM[0],
187                num_points: NUM_POINTS[0],
188                loops_per_measurement: LOOPS_PER_MEASUREMENT,
189                num_measurements: NUM_MEASUREMENTS,
190            },
191            Run {
192                distance: SimilarityMeasure::InnerProduct,
193                dim: DIM[1],
194                num_points: NUM_POINTS[1],
195                loops_per_measurement: LOOPS_PER_MEASUREMENT,
196                num_measurements: NUM_MEASUREMENTS,
197            },
198        ];
199
200        Ok(serde_json::to_value(&SimdOp {
201            query_type: DataType::Float32,
202            data_type: DataType::Float32,
203            arch: Arch::X86_64_V3,
204            runs,
205        })?)
206    }
207}
208
209////////////////////////////
210// Benchmark Registration //
211////////////////////////////
212
213macro_rules! register {
214    ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
215        #[cfg(target_arch = $arch)]
216        $dispatcher.register::<$($kernel)*>(
217            $name,
218            run_benchmark,
219        )
220    };
221    ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
222        $dispatcher.register::<$($kernel)*>(
223            $name,
224            run_benchmark,
225        )
226    };
227}
228
229fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
230    // x86-64-v4
231    register!(
232        "x86_64",
233        dispatcher,
234        "simd-op-f32xf32-x86_64_V4",
235        Kernel<'static, diskann_wide::arch::x86_64::V4, f32, f32>
236    );
237    register!(
238        "x86_64",
239        dispatcher,
240        "simd-op-f16xf16-x86_64_V4",
241        Kernel<'static, diskann_wide::arch::x86_64::V4, f16, f16>
242    );
243    register!(
244        "x86_64",
245        dispatcher,
246        "simd-op-u8xu8-x86_64_V4",
247        Kernel<'static, diskann_wide::arch::x86_64::V4, u8, u8>
248    );
249    register!(
250        "x86_64",
251        dispatcher,
252        "simd-op-i8xi8-x86_64_V4",
253        Kernel<'static, diskann_wide::arch::x86_64::V4, i8, i8>
254    );
255
256    // x86-64-v3
257    register!(
258        "x86_64",
259        dispatcher,
260        "simd-op-f32xf32-x86_64_V3",
261        Kernel<'static, diskann_wide::arch::x86_64::V3, f32, f32>
262    );
263    register!(
264        "x86_64",
265        dispatcher,
266        "simd-op-f16xf16-x86_64_V3",
267        Kernel<'static, diskann_wide::arch::x86_64::V3, f16, f16>
268    );
269    register!(
270        "x86_64",
271        dispatcher,
272        "simd-op-u8xu8-x86_64_V3",
273        Kernel<'static, diskann_wide::arch::x86_64::V3, u8, u8>
274    );
275    register!(
276        "x86_64",
277        dispatcher,
278        "simd-op-i8xi8-x86_64_V3",
279        Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8>
280    );
281
282    // aarch64-neon
283    register!(
284        "aarch64",
285        dispatcher,
286        "simd-op-f32xf32-aarch64_neon",
287        Kernel<'static, diskann_wide::arch::aarch64::Neon, f32, f32>
288    );
289    register!(
290        "aarch64",
291        dispatcher,
292        "simd-op-f16xf16-aarch64_neon",
293        Kernel<'static, diskann_wide::arch::aarch64::Neon, f16, f16>
294    );
295    register!(
296        "aarch64",
297        dispatcher,
298        "simd-op-u8xu8-aarch64_neon",
299        Kernel<'static, diskann_wide::arch::aarch64::Neon, u8, u8>
300    );
301    register!(
302        "aarch64",
303        dispatcher,
304        "simd-op-i8xi8-aarch64_neon",
305        Kernel<'static, diskann_wide::arch::aarch64::Neon, i8, i8>
306    );
307
308    // scalar
309    register!(
310        dispatcher,
311        "simd-op-f32xf32-scalar",
312        Kernel<'static, diskann_wide::arch::Scalar, f32, f32>
313    );
314    register!(
315        dispatcher,
316        "simd-op-f16xf16-scalar",
317        Kernel<'static, diskann_wide::arch::Scalar, f16, f16>
318    );
319    register!(
320        dispatcher,
321        "simd-op-u8xu8-scalar",
322        Kernel<'static, diskann_wide::arch::Scalar, u8, u8>
323    );
324    register!(
325        dispatcher,
326        "simd-op-i8xi8-scalar",
327        Kernel<'static, diskann_wide::arch::Scalar, i8, i8>
328    );
329
330    // reference
331    register!(
332        dispatcher,
333        "simd-op-f32xf32-reference",
334        Kernel<'static, Reference, f32, f32>
335    );
336    register!(
337        dispatcher,
338        "simd-op-f16xf16-reference",
339        Kernel<'static, Reference, f16, f16>
340    );
341    register!(
342        dispatcher,
343        "simd-op-u8xu8-reference",
344        Kernel<'static, Reference, u8, u8>
345    );
346    register!(
347        dispatcher,
348        "simd-op-i8xi8-reference",
349        Kernel<'static, Reference, i8, i8>
350    );
351}
352
353//////////////
354// Dispatch //
355//////////////
356
357/// Dispatch receiver for the reference implementations.
358struct Reference;
359
360/// A dispatch mapper for `wide` types.
361#[derive(Debug)]
362struct Identity<T>(T);
363
364impl<T> dispatcher::Map for Identity<T>
365where
366    T: 'static,
367{
368    type Type<'a> = T;
369}
370
371struct Kernel<'a, A, Q, D> {
372    input: &'a SimdOp,
373    arch: A,
374    _type: std::marker::PhantomData<(A, Q, D)>,
375}
376
377impl<'a, A, Q, D> Kernel<'a, A, Q, D> {
378    fn new(input: &'a SimdOp, arch: A) -> Self {
379        Self {
380            input,
381            arch,
382            _type: std::marker::PhantomData,
383        }
384    }
385}
386
387impl<A, Q, D> dispatcher::Map for Kernel<'static, A, Q, D>
388where
389    A: 'static,
390    Q: 'static,
391    D: 'static,
392{
393    type Type<'a> = Kernel<'a, A, Q, D>;
394}
395
396// Map Architectures to the enum.
397#[derive(Debug, Error)]
398#[error("architecture {0} is not supported by this CPU")]
399pub(crate) struct ArchNotSupported(Arch);
400
401impl DispatchRule<Arch> for Identity<Reference> {
402    type Error = ArchNotSupported;
403
404    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
405        if *from == Arch::Reference {
406            Ok(MatchScore(0))
407        } else {
408            Err(FailureScore(1))
409        }
410    }
411
412    fn convert(from: Arch) -> Result<Self, Self::Error> {
413        assert_eq!(from, Arch::Reference);
414        Ok(Identity(Reference))
415    }
416
417    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
418        match from {
419            None => write!(f, "loop based"),
420            Some(arch) => {
421                if Self::try_match(arch).is_ok() {
422                    write!(f, "matched {}", arch)
423                } else {
424                    write!(f, "expected {}, got {}", Arch::Reference, arch)
425                }
426            }
427        }
428    }
429}
430
431impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
432    type Error = ArchNotSupported;
433
434    fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
435        if *from == Arch::Scalar {
436            Ok(MatchScore(0))
437        } else {
438            Err(FailureScore(1))
439        }
440    }
441
442    fn convert(from: Arch) -> Result<Self, Self::Error> {
443        assert_eq!(from, Arch::Scalar);
444        Ok(Identity(diskann_wide::arch::Scalar))
445    }
446
447    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
448        match from {
449            None => write!(f, "scalar (compilation target CPU)"),
450            Some(arch) => {
451                if Self::try_match(arch).is_ok() {
452                    write!(f, "matched {}", arch)
453                } else {
454                    write!(f, "expected {}, got {}", Arch::Scalar, arch)
455                }
456            }
457        }
458    }
459}
460
461macro_rules! match_arch {
462    ($target_arch:literal, $arch:path, $enum:ident) => {
463        #[cfg(target_arch = $target_arch)]
464        impl DispatchRule<Arch> for Identity<$arch> {
465            type Error = ArchNotSupported;
466
467            fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
468                let available = <$arch>::new_checked().is_some();
469                if available && *from == Arch::$enum {
470                    Ok(MatchScore(0))
471                } else if !available && *from == Arch::$enum {
472                    Err(FailureScore(0))
473                } else {
474                    Err(FailureScore(1))
475                }
476            }
477
478            fn convert(from: Arch) -> Result<Self, Self::Error> {
479                assert_eq!(from, Arch::$enum);
480                <$arch>::new_checked()
481                    .ok_or(ArchNotSupported(from))
482                    .map(Identity)
483            }
484
485            fn description(
486                f: &mut std::fmt::Formatter<'_>,
487                from: Option<&Arch>,
488            ) -> std::fmt::Result {
489                let available = <$arch>::new_checked().is_some();
490                match from {
491                    None => write!(f, "{}", Arch::$enum),
492                    Some(arch) => {
493                        if Self::try_match(arch).is_ok() {
494                            write!(f, "matched {}", arch)
495                        } else if !available && *arch == Arch::$enum {
496                            write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
497                        } else {
498                            write!(f, "expected {}, got {}", Arch::$enum, arch)
499                        }
500                    }
501                }
502            }
503        }
504    };
505}
506
507match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
508match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
509match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
510
511impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D>
512where
513    datatype::Type<Q>: DispatchRule<datatype::DataType>,
514    datatype::Type<D>: DispatchRule<datatype::DataType>,
515    Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
516{
517    type Error = ArchNotSupported;
518
519    // Matching simply requires that we match the inner type.
520    fn try_match(from: &&'a SimdOp) -> Result<MatchScore, FailureScore> {
521        let mut failscore: Option<u32> = None;
522        if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
523            *failscore.get_or_insert(0) += 10;
524        }
525        if datatype::Type::<D>::try_match(&from.data_type).is_err() {
526            *failscore.get_or_insert(0) += 10;
527        }
528        if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
529            *failscore.get_or_insert(0) += 2 + score;
530        }
531
532        match failscore {
533            None => Ok(MatchScore(0)),
534            Some(score) => Err(FailureScore(score)),
535        }
536    }
537
538    fn convert(from: &'a SimdOp) -> Result<Self, Self::Error> {
539        assert!(Self::try_match(&from).is_ok());
540        let arch = Identity::<A>::convert(from.arch)?.0;
541        Ok(Self::new(from, arch))
542    }
543
544    fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a SimdOp>) -> std::fmt::Result {
545        match from {
546            None => {
547                describeln!(
548                    f,
549                    "- Query Type: {}",
550                    dispatcher::Description::<datatype::DataType, datatype::Type<Q>>::new()
551                )?;
552                describeln!(
553                    f,
554                    "- Data Type: {}",
555                    dispatcher::Description::<datatype::DataType, datatype::Type<D>>::new()
556                )?;
557                describeln!(
558                    f,
559                    "- Implementation: {}",
560                    dispatcher::Description::<Arch, Identity<A>>::new()
561                )?;
562            }
563            Some(input) => {
564                if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
565                    describeln!(f, "\n    - Mismatched query type: {}", err)?;
566                }
567                if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
568                    describeln!(f, "\n    - Mismatched data type: {}", err)?;
569                }
570                if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
571                    describeln!(f, "\n    - Mismatched architecture: {}", err)?;
572                }
573            }
574        }
575        Ok(())
576    }
577}
578
579impl<'a, A, Q, D> DispatchRule<&'a diskann_benchmark_runner::Any> for Kernel<'a, A, Q, D>
580where
581    Kernel<'a, A, Q, D>: DispatchRule<&'a SimdOp>,
582    <Kernel<'a, A, Q, D> as DispatchRule<&'a SimdOp>>::Error:
583        std::error::Error + Send + Sync + 'static,
584{
585    type Error = anyhow::Error;
586
587    fn try_match(from: &&'a diskann_benchmark_runner::Any) -> Result<MatchScore, FailureScore> {
588        from.try_match::<SimdOp, Self>()
589    }
590
591    fn convert(from: &'a diskann_benchmark_runner::Any) -> Result<Self, Self::Error> {
592        from.convert::<SimdOp, Self>()
593    }
594
595    fn description(
596        f: &mut std::fmt::Formatter<'_>,
597        from: Option<&&'a diskann_benchmark_runner::Any>,
598    ) -> std::fmt::Result {
599        Any::description::<SimdOp, Self>(f, from, SimdOp::tag())
600    }
601}
602
603///////////////
604// Benchmark //
605///////////////
606
607fn run_benchmark<A, Q, D>(
608    kernel: Kernel<'_, A, Q, D>,
609    _: diskann_benchmark_runner::Checkpoint<'_>,
610    mut output: &mut dyn diskann_benchmark_runner::Output,
611) -> Result<serde_json::Value, anyhow::Error>
612where
613    for<'a> Kernel<'a, A, Q, D>: RunBenchmark,
614{
615    writeln!(output, "{}", kernel.input)?;
616    let results = kernel.run()?;
617    writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
618    Ok(serde_json::to_value(results)?)
619}
620
621trait RunBenchmark {
622    fn run(self) -> Result<Vec<RunResult>, anyhow::Error>;
623}
624
625#[derive(Debug, Serialize)]
626struct RunResult {
627    /// The setup
628    run: Run,
629    /// The latencies of individual runs.
630    latencies: Vec<MicroSeconds>,
631    /// Latency percentiles.
632    percentiles: percentiles::Percentiles<MicroSeconds>,
633}
634
635impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
636    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
637        if self.is_empty() {
638            return Ok(());
639        }
640
641        let header = [
642            "Distance",
643            "Dim",
644            "Min Time (ns)",
645            "Mean Time (ns)",
646            "Points",
647            "Loops",
648            "Measurements",
649        ];
650
651        let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
652
653        self.iter().enumerate().for_each(|(row, r)| {
654            let mut row = table.row(row);
655
656            let min_latency = r
657                .latencies
658                .iter()
659                .min()
660                .copied()
661                .unwrap_or(MicroSeconds::new(u64::MAX));
662            let mean_latency = r.percentiles.mean;
663
664            let computations_per_latency: f64 =
665                (r.run.num_points.get() * r.run.loops_per_measurement.get()) as f64;
666
667            // Convert time from micro-seconds to nano-seconds.
668            let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
669            let mean_time = mean_latency / computations_per_latency * 1000.0;
670
671            row.insert(r.run.distance, 0);
672            row.insert(r.run.dim, 1);
673            row.insert(format!("{:.3}", min_time), 2);
674            row.insert(format!("{:.3}", mean_time), 3);
675            row.insert(r.run.num_points, 4);
676            row.insert(r.run.loops_per_measurement, 5);
677            row.insert(r.run.num_measurements, 6);
678        });
679
680        table.fmt(f)
681    }
682}
683
684fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
685where
686    F: Fn(&[Q], &[D]) -> f32,
687{
688    let mut latencies = Vec::with_capacity(run.num_measurements.get());
689    let mut dst = vec![0.0; data.nrows()];
690
691    for _ in 0..run.num_measurements.get() {
692        let start = std::time::Instant::now();
693        for _ in 0..run.loops_per_measurement.get() {
694            std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
695                *d = f(query, r);
696            });
697            std::hint::black_box(&mut dst);
698        }
699        latencies.push(start.elapsed().into());
700    }
701
702    let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
703    RunResult {
704        run: run.clone(),
705        latencies,
706        percentiles,
707    }
708}
709
710struct Data<Q, D> {
711    query: Box<[Q]>,
712    data: Matrix<D>,
713}
714
715impl<Q, D> Data<Q, D> {
716    fn new(run: &Run) -> Self
717    where
718        StandardUniform: Distribution<Q>,
719        StandardUniform: Distribution<D>,
720    {
721        let mut rng = StdRng::seed_from_u64(0x12345);
722        let query: Box<[Q]> = (0..run.dim.get())
723            .map(|_| StandardUniform.sample(&mut rng))
724            .collect();
725        let data = Matrix::<D>::new(
726            diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
727            run.num_points.get(),
728            run.dim.get(),
729        );
730
731        Self { query, data }
732    }
733
734    fn run<F>(&self, run: &Run, f: F) -> RunResult
735    where
736        F: Fn(&[Q], &[D]) -> f32,
737    {
738        run_loops(&self.query, self.data.as_view(), run, f)
739    }
740}
741
742/////////////////////
743// Implementations //
744/////////////////////
745
746macro_rules! stamp {
747    (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
748        impl RunBenchmark for Kernel<'_, Reference, $Q, $D> {
749            fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
750                let mut results = Vec::new();
751                for run in self.input.runs.iter() {
752                    let data = Data::<$Q, $D>::new(run);
753                    let result = match run.distance {
754                        SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
755                        SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
756                        SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
757                    };
758                    results.push(result);
759                }
760                Ok(results)
761            }
762        }
763    };
764    ($arch:path, $Q:ty, $D:ty) => {
765        impl RunBenchmark for Kernel<'_, $arch, $Q, $D> {
766            fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
767                let mut results = Vec::new();
768
769                let l2 = &simd::L2 {};
770                let ip = &simd::IP {};
771                let cosine = &simd::CosineStateless {};
772
773                for run in self.input.runs.iter() {
774                    let data = Data::<$Q, $D>::new(run);
775                    // For each kernel, we need to do a two-step wrapping of closures so
776                    // the inner-most closure is executed by the architecture.
777                    //
778                    // This is required for the implementation of `simd_op` to be inlined
779                    // into the architecture run function so it can properly inherit
780                    // target features.
781                    let result = match run.distance {
782                        SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
783                            self.arch
784                                .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
785                        }),
786                        SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
787                            self.arch
788                                .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
789                        }),
790                        SimilarityMeasure::Cosine => data.run(run, |q, d| {
791                            self.arch
792                                .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
793                        }),
794                    };
795                    results.push(result)
796                }
797                Ok(results)
798            }
799        }
800    };
801    ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
802        #[cfg(target_arch = $target_arch)]
803        stamp!($arch, $Q, $D);
804    };
805}
806
807stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
808stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
809stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
810stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
811
812stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
813stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
814stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
815stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
816
817stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
818stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
819stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
820stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
821
822stamp!(diskann_wide::arch::Scalar, f32, f32);
823stamp!(diskann_wide::arch::Scalar, f16, f16);
824stamp!(diskann_wide::arch::Scalar, u8, u8);
825stamp!(diskann_wide::arch::Scalar, i8, i8);
826
827stamp!(
828    reference,
829    f32,
830    f32,
831    squared_l2_f32,
832    inner_product_f32,
833    cosine_f32
834);
835stamp!(
836    reference,
837    f16,
838    f16,
839    squared_l2_f16,
840    inner_product_f16,
841    cosine_f16
842);
843stamp!(
844    reference,
845    u8,
846    u8,
847    squared_l2_u8,
848    inner_product_u8,
849    cosine_u8
850);
851stamp!(
852    reference,
853    i8,
854    i8,
855    squared_l2_i8,
856    inner_product_i8,
857    cosine_i8
858);
859
860///////////////
861// Reference //
862///////////////
863
864// These are largely copied from the implementations in vector, with a tweak that we don't
865// use FMA when the current architecture is scalar.
866mod reference {
867    use diskann_wide::ARCH;
868    use half::f16;
869
870    trait MaybeFMA {
871        // Perform `a*b + c` using FMA when a hardware instruction is guaranteed to be
872        // available, otherwise decompose into a multiply and add.
873        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
874    }
875
876    impl MaybeFMA for diskann_wide::arch::Scalar {
877        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
878            a * b + c
879        }
880    }
881
882    #[cfg(target_arch = "x86_64")]
883    impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
884        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
885            a.mul_add(b, c)
886        }
887    }
888
889    #[cfg(target_arch = "x86_64")]
890    impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
891        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
892            a.mul_add(b, c)
893        }
894    }
895
896    #[cfg(target_arch = "aarch64")]
897    impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
898        fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
899            a.mul_add(b, c)
900        }
901    }
902
903    //------------//
904    // Squared L2 //
905    //------------//
906
907    pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
908        assert_eq!(x.len(), y.len());
909        std::iter::zip(x.iter(), y.iter())
910            .map(|(&a, &b)| {
911                let a: i32 = a.into();
912                let b: i32 = b.into();
913                let diff = a - b;
914                diff * diff
915            })
916            .sum::<i32>() as f32
917    }
918
919    pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
920        assert_eq!(x.len(), y.len());
921        std::iter::zip(x.iter(), y.iter())
922            .map(|(&a, &b)| {
923                let a: i32 = a.into();
924                let b: i32 = b.into();
925                let diff = a - b;
926                diff * diff
927            })
928            .sum::<i32>() as f32
929    }
930
931    pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
932        assert_eq!(x.len(), y.len());
933        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
934            let a: f32 = a.into();
935            let b: f32 = b.into();
936            let diff = a - b;
937            ARCH.maybe_fma(diff, diff, acc)
938        })
939    }
940
941    pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
942        assert_eq!(x.len(), y.len());
943        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
944            let diff = a - b;
945            ARCH.maybe_fma(diff, diff, acc)
946        })
947    }
948
949    //---------------//
950    // Inner Product //
951    //---------------//
952
953    pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
954        assert_eq!(x.len(), y.len());
955        std::iter::zip(x.iter(), y.iter())
956            .map(|(&a, &b)| {
957                let a: i32 = a.into();
958                let b: i32 = b.into();
959                a * b
960            })
961            .sum::<i32>() as f32
962    }
963
964    pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
965        assert_eq!(x.len(), y.len());
966        std::iter::zip(x.iter(), y.iter())
967            .map(|(&a, &b)| {
968                let a: i32 = a.into();
969                let b: i32 = b.into();
970                a * b
971            })
972            .sum::<i32>() as f32
973    }
974
975    pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
976        assert_eq!(x.len(), y.len());
977        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
978            let a: f32 = a.into();
979            let b: f32 = b.into();
980            ARCH.maybe_fma(a, b, acc)
981        })
982    }
983
984    pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
985        assert_eq!(x.len(), y.len());
986        std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
987    }
988
989    //--------//
990    // Cosine //
991    //--------//
992
993    #[derive(Default)]
994    struct XY<T> {
995        xnorm: T,
996        ynorm: T,
997        xy: T,
998    }
999
1000    pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1001        assert_eq!(x.len(), y.len());
1002        let r: XY<i32> =
1003            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1004                let vx: i32 = vx.into();
1005                let vy: i32 = vy.into();
1006                XY {
1007                    xnorm: acc.xnorm + vx * vx,
1008                    ynorm: acc.ynorm + vy * vy,
1009                    xy: acc.xy + vx * vy,
1010                }
1011            });
1012
1013        if r.xnorm == 0 || r.ynorm == 0 {
1014            return 0.0;
1015        }
1016
1017        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1018    }
1019
1020    pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1021        assert_eq!(x.len(), y.len());
1022        let r: XY<i32> =
1023            std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1024                let vx: i32 = vx.into();
1025                let vy: i32 = vy.into();
1026                XY {
1027                    xnorm: acc.xnorm + vx * vx,
1028                    ynorm: acc.ynorm + vy * vy,
1029                    xy: acc.xy + vx * vy,
1030                }
1031            });
1032
1033        if r.xnorm == 0 || r.ynorm == 0 {
1034            return 0.0;
1035        }
1036
1037        (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1038    }
1039
1040    pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1041        assert_eq!(x.len(), y.len());
1042        let r: XY<f32> =
1043            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1044                let vx: f32 = vx.into();
1045                let vy: f32 = vy.into();
1046                XY {
1047                    xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1048                    ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1049                    xy: ARCH.maybe_fma(vx, vy, acc.xy),
1050                }
1051            });
1052
1053        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1054            return 0.0;
1055        }
1056
1057        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1058    }
1059
1060    pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1061        assert_eq!(x.len(), y.len());
1062        let r: XY<f32> =
1063            std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1064                xnorm: vx.mul_add(vx, acc.xnorm),
1065                ynorm: vy.mul_add(vy, acc.ynorm),
1066                xy: vx.mul_add(vy, acc.xy),
1067            });
1068
1069        if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1070            return 0.0;
1071        }
1072
1073        (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1074    }
1075}