1use std::{io::Write, num::NonZeroUsize};
7
8use diskann_utils::views::{Matrix, MatrixView};
9use diskann_vector::distance::simd;
10use diskann_wide::Architecture;
11use half::f16;
12use rand::{
13 distr::{Distribution, StandardUniform},
14 rngs::StdRng,
15 SeedableRng,
16};
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19
20use diskann_benchmark_runner::{
21 describeln,
22 dispatcher::{self, DispatchRule, FailureScore, MatchScore},
23 utils::{
24 datatype::{self, DataType},
25 percentiles, MicroSeconds,
26 },
27 Any, CheckDeserialization, Checker,
28};
29
30#[derive(Debug)]
35pub struct SimdInput;
36
37pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38 register_benchmarks_impl(dispatcher)
39}
40
41#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49 type Target = T;
50 fn deref(&self) -> &T {
51 self.0
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62 SquaredL2,
63 InnerProduct,
64 Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let st = match self {
70 Self::SquaredL2 => "squared_l2",
71 Self::InnerProduct => "inner_product",
72 Self::Cosine => "cosine",
73 };
74 write!(f, "{}", st)
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80pub(crate) enum Arch {
81 #[serde(rename = "x86-64-v4")]
82 #[allow(non_camel_case_types)]
83 X86_64_V4,
84 #[serde(rename = "x86-64-v3")]
85 #[allow(non_camel_case_types)]
86 X86_64_V3,
87 Neon,
88 Scalar,
89 Reference,
90}
91
92impl std::fmt::Display for Arch {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 let st = match self {
95 Self::X86_64_V4 => "x86-64-v4",
96 Self::X86_64_V3 => "x86-64-v3",
97 Self::Neon => "neon",
98 Self::Scalar => "scalar",
99 Self::Reference => "reference",
100 };
101 write!(f, "{}", st)
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub(crate) struct Run {
107 pub(crate) distance: SimilarityMeasure,
108 pub(crate) dim: NonZeroUsize,
109 pub(crate) num_points: NonZeroUsize,
110 pub(crate) loops_per_measurement: NonZeroUsize,
111 pub(crate) num_measurements: NonZeroUsize,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub(crate) struct SimdOp {
116 pub(crate) query_type: DataType,
117 pub(crate) data_type: DataType,
118 pub(crate) arch: Arch,
119 pub(crate) runs: Vec<Run>,
120}
121
122impl CheckDeserialization for SimdOp {
123 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
124 Ok(())
125 }
126}
127
128macro_rules! write_field {
129 ($f:ident, $field:tt, $($expr:tt)*) => {
130 writeln!($f, "{:>18}: {}", $field, $($expr)*)
131 }
132}
133
134impl SimdOp {
135 pub(crate) const fn tag() -> &'static str {
136 "simd-op"
137 }
138
139 fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write_field!(f, "query type", self.query_type)?;
141 write_field!(f, "data type", self.data_type)?;
142 write_field!(f, "arch", self.arch)?;
143 write_field!(f, "number of runs", self.runs.len())?;
144 Ok(())
145 }
146}
147
148impl std::fmt::Display for SimdOp {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 writeln!(f, "SIMD Operation\n")?;
151 write_field!(f, "tag", Self::tag())?;
152 self.summarize_fields(f)
153 }
154}
155
156impl diskann_benchmark_runner::Input for SimdInput {
157 fn tag(&self) -> &'static str {
158 "simd-op"
159 }
160
161 fn try_deserialize(
162 &self,
163 serialized: &serde_json::Value,
164 checker: &mut Checker,
165 ) -> anyhow::Result<Any> {
166 checker.any(SimdOp::deserialize(serialized)?)
167 }
168
169 fn example(&self) -> anyhow::Result<serde_json::Value> {
170 const DIM: [NonZeroUsize; 2] = [
171 NonZeroUsize::new(128).unwrap(),
172 NonZeroUsize::new(150).unwrap(),
173 ];
174
175 const NUM_POINTS: [NonZeroUsize; 2] = [
176 NonZeroUsize::new(1000).unwrap(),
177 NonZeroUsize::new(800).unwrap(),
178 ];
179
180 const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
181 const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
182
183 let runs = vec![
184 Run {
185 distance: SimilarityMeasure::SquaredL2,
186 dim: DIM[0],
187 num_points: NUM_POINTS[0],
188 loops_per_measurement: LOOPS_PER_MEASUREMENT,
189 num_measurements: NUM_MEASUREMENTS,
190 },
191 Run {
192 distance: SimilarityMeasure::InnerProduct,
193 dim: DIM[1],
194 num_points: NUM_POINTS[1],
195 loops_per_measurement: LOOPS_PER_MEASUREMENT,
196 num_measurements: NUM_MEASUREMENTS,
197 },
198 ];
199
200 Ok(serde_json::to_value(&SimdOp {
201 query_type: DataType::Float32,
202 data_type: DataType::Float32,
203 arch: Arch::X86_64_V3,
204 runs,
205 })?)
206 }
207}
208
209macro_rules! register {
214 ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
215 #[cfg(target_arch = $arch)]
216 $dispatcher.register::<$($kernel)*>(
217 $name,
218 run_benchmark,
219 )
220 };
221 ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
222 $dispatcher.register::<$($kernel)*>(
223 $name,
224 run_benchmark,
225 )
226 };
227}
228
229fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
230 register!(
232 "x86_64",
233 dispatcher,
234 "simd-op-f32xf32-x86_64_V4",
235 Kernel<'static, diskann_wide::arch::x86_64::V4, f32, f32>
236 );
237 register!(
238 "x86_64",
239 dispatcher,
240 "simd-op-f16xf16-x86_64_V4",
241 Kernel<'static, diskann_wide::arch::x86_64::V4, f16, f16>
242 );
243 register!(
244 "x86_64",
245 dispatcher,
246 "simd-op-u8xu8-x86_64_V4",
247 Kernel<'static, diskann_wide::arch::x86_64::V4, u8, u8>
248 );
249 register!(
250 "x86_64",
251 dispatcher,
252 "simd-op-i8xi8-x86_64_V4",
253 Kernel<'static, diskann_wide::arch::x86_64::V4, i8, i8>
254 );
255
256 register!(
258 "x86_64",
259 dispatcher,
260 "simd-op-f32xf32-x86_64_V3",
261 Kernel<'static, diskann_wide::arch::x86_64::V3, f32, f32>
262 );
263 register!(
264 "x86_64",
265 dispatcher,
266 "simd-op-f16xf16-x86_64_V3",
267 Kernel<'static, diskann_wide::arch::x86_64::V3, f16, f16>
268 );
269 register!(
270 "x86_64",
271 dispatcher,
272 "simd-op-u8xu8-x86_64_V3",
273 Kernel<'static, diskann_wide::arch::x86_64::V3, u8, u8>
274 );
275 register!(
276 "x86_64",
277 dispatcher,
278 "simd-op-i8xi8-x86_64_V3",
279 Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8>
280 );
281
282 register!(
284 "aarch64",
285 dispatcher,
286 "simd-op-f32xf32-aarch64_neon",
287 Kernel<'static, diskann_wide::arch::aarch64::Neon, f32, f32>
288 );
289 register!(
290 "aarch64",
291 dispatcher,
292 "simd-op-f16xf16-aarch64_neon",
293 Kernel<'static, diskann_wide::arch::aarch64::Neon, f16, f16>
294 );
295 register!(
296 "aarch64",
297 dispatcher,
298 "simd-op-u8xu8-aarch64_neon",
299 Kernel<'static, diskann_wide::arch::aarch64::Neon, u8, u8>
300 );
301 register!(
302 "aarch64",
303 dispatcher,
304 "simd-op-i8xi8-aarch64_neon",
305 Kernel<'static, diskann_wide::arch::aarch64::Neon, i8, i8>
306 );
307
308 register!(
310 dispatcher,
311 "simd-op-f32xf32-scalar",
312 Kernel<'static, diskann_wide::arch::Scalar, f32, f32>
313 );
314 register!(
315 dispatcher,
316 "simd-op-f16xf16-scalar",
317 Kernel<'static, diskann_wide::arch::Scalar, f16, f16>
318 );
319 register!(
320 dispatcher,
321 "simd-op-u8xu8-scalar",
322 Kernel<'static, diskann_wide::arch::Scalar, u8, u8>
323 );
324 register!(
325 dispatcher,
326 "simd-op-i8xi8-scalar",
327 Kernel<'static, diskann_wide::arch::Scalar, i8, i8>
328 );
329
330 register!(
332 dispatcher,
333 "simd-op-f32xf32-reference",
334 Kernel<'static, Reference, f32, f32>
335 );
336 register!(
337 dispatcher,
338 "simd-op-f16xf16-reference",
339 Kernel<'static, Reference, f16, f16>
340 );
341 register!(
342 dispatcher,
343 "simd-op-u8xu8-reference",
344 Kernel<'static, Reference, u8, u8>
345 );
346 register!(
347 dispatcher,
348 "simd-op-i8xi8-reference",
349 Kernel<'static, Reference, i8, i8>
350 );
351}
352
353struct Reference;
359
360#[derive(Debug)]
362struct Identity<T>(T);
363
364impl<T> dispatcher::Map for Identity<T>
365where
366 T: 'static,
367{
368 type Type<'a> = T;
369}
370
371struct Kernel<'a, A, Q, D> {
372 input: &'a SimdOp,
373 arch: A,
374 _type: std::marker::PhantomData<(A, Q, D)>,
375}
376
377impl<'a, A, Q, D> Kernel<'a, A, Q, D> {
378 fn new(input: &'a SimdOp, arch: A) -> Self {
379 Self {
380 input,
381 arch,
382 _type: std::marker::PhantomData,
383 }
384 }
385}
386
387impl<A, Q, D> dispatcher::Map for Kernel<'static, A, Q, D>
388where
389 A: 'static,
390 Q: 'static,
391 D: 'static,
392{
393 type Type<'a> = Kernel<'a, A, Q, D>;
394}
395
396#[derive(Debug, Error)]
398#[error("architecture {0} is not supported by this CPU")]
399pub(crate) struct ArchNotSupported(Arch);
400
401impl DispatchRule<Arch> for Identity<Reference> {
402 type Error = ArchNotSupported;
403
404 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
405 if *from == Arch::Reference {
406 Ok(MatchScore(0))
407 } else {
408 Err(FailureScore(1))
409 }
410 }
411
412 fn convert(from: Arch) -> Result<Self, Self::Error> {
413 assert_eq!(from, Arch::Reference);
414 Ok(Identity(Reference))
415 }
416
417 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
418 match from {
419 None => write!(f, "loop based"),
420 Some(arch) => {
421 if Self::try_match(arch).is_ok() {
422 write!(f, "matched {}", arch)
423 } else {
424 write!(f, "expected {}, got {}", Arch::Reference, arch)
425 }
426 }
427 }
428 }
429}
430
431impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
432 type Error = ArchNotSupported;
433
434 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
435 if *from == Arch::Scalar {
436 Ok(MatchScore(0))
437 } else {
438 Err(FailureScore(1))
439 }
440 }
441
442 fn convert(from: Arch) -> Result<Self, Self::Error> {
443 assert_eq!(from, Arch::Scalar);
444 Ok(Identity(diskann_wide::arch::Scalar))
445 }
446
447 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
448 match from {
449 None => write!(f, "scalar (compilation target CPU)"),
450 Some(arch) => {
451 if Self::try_match(arch).is_ok() {
452 write!(f, "matched {}", arch)
453 } else {
454 write!(f, "expected {}, got {}", Arch::Scalar, arch)
455 }
456 }
457 }
458 }
459}
460
461macro_rules! match_arch {
462 ($target_arch:literal, $arch:path, $enum:ident) => {
463 #[cfg(target_arch = $target_arch)]
464 impl DispatchRule<Arch> for Identity<$arch> {
465 type Error = ArchNotSupported;
466
467 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
468 let available = <$arch>::new_checked().is_some();
469 if available && *from == Arch::$enum {
470 Ok(MatchScore(0))
471 } else if !available && *from == Arch::$enum {
472 Err(FailureScore(0))
473 } else {
474 Err(FailureScore(1))
475 }
476 }
477
478 fn convert(from: Arch) -> Result<Self, Self::Error> {
479 assert_eq!(from, Arch::$enum);
480 <$arch>::new_checked()
481 .ok_or(ArchNotSupported(from))
482 .map(Identity)
483 }
484
485 fn description(
486 f: &mut std::fmt::Formatter<'_>,
487 from: Option<&Arch>,
488 ) -> std::fmt::Result {
489 let available = <$arch>::new_checked().is_some();
490 match from {
491 None => write!(f, "{}", Arch::$enum),
492 Some(arch) => {
493 if Self::try_match(arch).is_ok() {
494 write!(f, "matched {}", arch)
495 } else if !available && *arch == Arch::$enum {
496 write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
497 } else {
498 write!(f, "expected {}, got {}", Arch::$enum, arch)
499 }
500 }
501 }
502 }
503 }
504 };
505}
506
507match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
508match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
509match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
510
511impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D>
512where
513 datatype::Type<Q>: DispatchRule<datatype::DataType>,
514 datatype::Type<D>: DispatchRule<datatype::DataType>,
515 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
516{
517 type Error = ArchNotSupported;
518
519 fn try_match(from: &&'a SimdOp) -> Result<MatchScore, FailureScore> {
521 let mut failscore: Option<u32> = None;
522 if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
523 *failscore.get_or_insert(0) += 10;
524 }
525 if datatype::Type::<D>::try_match(&from.data_type).is_err() {
526 *failscore.get_or_insert(0) += 10;
527 }
528 if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
529 *failscore.get_or_insert(0) += 2 + score;
530 }
531
532 match failscore {
533 None => Ok(MatchScore(0)),
534 Some(score) => Err(FailureScore(score)),
535 }
536 }
537
538 fn convert(from: &'a SimdOp) -> Result<Self, Self::Error> {
539 assert!(Self::try_match(&from).is_ok());
540 let arch = Identity::<A>::convert(from.arch)?.0;
541 Ok(Self::new(from, arch))
542 }
543
544 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a SimdOp>) -> std::fmt::Result {
545 match from {
546 None => {
547 describeln!(
548 f,
549 "- Query Type: {}",
550 dispatcher::Description::<datatype::DataType, datatype::Type<Q>>::new()
551 )?;
552 describeln!(
553 f,
554 "- Data Type: {}",
555 dispatcher::Description::<datatype::DataType, datatype::Type<D>>::new()
556 )?;
557 describeln!(
558 f,
559 "- Implementation: {}",
560 dispatcher::Description::<Arch, Identity<A>>::new()
561 )?;
562 }
563 Some(input) => {
564 if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
565 describeln!(f, "\n - Mismatched query type: {}", err)?;
566 }
567 if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
568 describeln!(f, "\n - Mismatched data type: {}", err)?;
569 }
570 if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
571 describeln!(f, "\n - Mismatched architecture: {}", err)?;
572 }
573 }
574 }
575 Ok(())
576 }
577}
578
579impl<'a, A, Q, D> DispatchRule<&'a diskann_benchmark_runner::Any> for Kernel<'a, A, Q, D>
580where
581 Kernel<'a, A, Q, D>: DispatchRule<&'a SimdOp>,
582 <Kernel<'a, A, Q, D> as DispatchRule<&'a SimdOp>>::Error:
583 std::error::Error + Send + Sync + 'static,
584{
585 type Error = anyhow::Error;
586
587 fn try_match(from: &&'a diskann_benchmark_runner::Any) -> Result<MatchScore, FailureScore> {
588 from.try_match::<SimdOp, Self>()
589 }
590
591 fn convert(from: &'a diskann_benchmark_runner::Any) -> Result<Self, Self::Error> {
592 from.convert::<SimdOp, Self>()
593 }
594
595 fn description(
596 f: &mut std::fmt::Formatter<'_>,
597 from: Option<&&'a diskann_benchmark_runner::Any>,
598 ) -> std::fmt::Result {
599 Any::description::<SimdOp, Self>(f, from, SimdOp::tag())
600 }
601}
602
603fn run_benchmark<A, Q, D>(
608 kernel: Kernel<'_, A, Q, D>,
609 _: diskann_benchmark_runner::Checkpoint<'_>,
610 mut output: &mut dyn diskann_benchmark_runner::Output,
611) -> Result<serde_json::Value, anyhow::Error>
612where
613 for<'a> Kernel<'a, A, Q, D>: RunBenchmark,
614{
615 writeln!(output, "{}", kernel.input)?;
616 let results = kernel.run()?;
617 writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
618 Ok(serde_json::to_value(results)?)
619}
620
621trait RunBenchmark {
622 fn run(self) -> Result<Vec<RunResult>, anyhow::Error>;
623}
624
625#[derive(Debug, Serialize)]
626struct RunResult {
627 run: Run,
629 latencies: Vec<MicroSeconds>,
631 percentiles: percentiles::Percentiles<MicroSeconds>,
633}
634
635impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
636 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
637 if self.is_empty() {
638 return Ok(());
639 }
640
641 let header = [
642 "Distance",
643 "Dim",
644 "Min Time (ns)",
645 "Mean Time (ns)",
646 "Points",
647 "Loops",
648 "Measurements",
649 ];
650
651 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
652
653 self.iter().enumerate().for_each(|(row, r)| {
654 let mut row = table.row(row);
655
656 let min_latency = r
657 .latencies
658 .iter()
659 .min()
660 .copied()
661 .unwrap_or(MicroSeconds::new(u64::MAX));
662 let mean_latency = r.percentiles.mean;
663
664 let computations_per_latency: f64 =
665 (r.run.num_points.get() * r.run.loops_per_measurement.get()) as f64;
666
667 let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
669 let mean_time = mean_latency / computations_per_latency * 1000.0;
670
671 row.insert(r.run.distance, 0);
672 row.insert(r.run.dim, 1);
673 row.insert(format!("{:.3}", min_time), 2);
674 row.insert(format!("{:.3}", mean_time), 3);
675 row.insert(r.run.num_points, 4);
676 row.insert(r.run.loops_per_measurement, 5);
677 row.insert(r.run.num_measurements, 6);
678 });
679
680 table.fmt(f)
681 }
682}
683
684fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
685where
686 F: Fn(&[Q], &[D]) -> f32,
687{
688 let mut latencies = Vec::with_capacity(run.num_measurements.get());
689 let mut dst = vec![0.0; data.nrows()];
690
691 for _ in 0..run.num_measurements.get() {
692 let start = std::time::Instant::now();
693 for _ in 0..run.loops_per_measurement.get() {
694 std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
695 *d = f(query, r);
696 });
697 std::hint::black_box(&mut dst);
698 }
699 latencies.push(start.elapsed().into());
700 }
701
702 let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
703 RunResult {
704 run: run.clone(),
705 latencies,
706 percentiles,
707 }
708}
709
710struct Data<Q, D> {
711 query: Box<[Q]>,
712 data: Matrix<D>,
713}
714
715impl<Q, D> Data<Q, D> {
716 fn new(run: &Run) -> Self
717 where
718 StandardUniform: Distribution<Q>,
719 StandardUniform: Distribution<D>,
720 {
721 let mut rng = StdRng::seed_from_u64(0x12345);
722 let query: Box<[Q]> = (0..run.dim.get())
723 .map(|_| StandardUniform.sample(&mut rng))
724 .collect();
725 let data = Matrix::<D>::new(
726 diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
727 run.num_points.get(),
728 run.dim.get(),
729 );
730
731 Self { query, data }
732 }
733
734 fn run<F>(&self, run: &Run, f: F) -> RunResult
735 where
736 F: Fn(&[Q], &[D]) -> f32,
737 {
738 run_loops(&self.query, self.data.as_view(), run, f)
739 }
740}
741
742macro_rules! stamp {
747 (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
748 impl RunBenchmark for Kernel<'_, Reference, $Q, $D> {
749 fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
750 let mut results = Vec::new();
751 for run in self.input.runs.iter() {
752 let data = Data::<$Q, $D>::new(run);
753 let result = match run.distance {
754 SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
755 SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
756 SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
757 };
758 results.push(result);
759 }
760 Ok(results)
761 }
762 }
763 };
764 ($arch:path, $Q:ty, $D:ty) => {
765 impl RunBenchmark for Kernel<'_, $arch, $Q, $D> {
766 fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
767 let mut results = Vec::new();
768
769 let l2 = &simd::L2 {};
770 let ip = &simd::IP {};
771 let cosine = &simd::CosineStateless {};
772
773 for run in self.input.runs.iter() {
774 let data = Data::<$Q, $D>::new(run);
775 let result = match run.distance {
782 SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
783 self.arch
784 .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
785 }),
786 SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
787 self.arch
788 .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
789 }),
790 SimilarityMeasure::Cosine => data.run(run, |q, d| {
791 self.arch
792 .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
793 }),
794 };
795 results.push(result)
796 }
797 Ok(results)
798 }
799 }
800 };
801 ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
802 #[cfg(target_arch = $target_arch)]
803 stamp!($arch, $Q, $D);
804 };
805}
806
807stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
808stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
809stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
810stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
811
812stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
813stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
814stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
815stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
816
817stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
818stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
819stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
820stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
821
822stamp!(diskann_wide::arch::Scalar, f32, f32);
823stamp!(diskann_wide::arch::Scalar, f16, f16);
824stamp!(diskann_wide::arch::Scalar, u8, u8);
825stamp!(diskann_wide::arch::Scalar, i8, i8);
826
827stamp!(
828 reference,
829 f32,
830 f32,
831 squared_l2_f32,
832 inner_product_f32,
833 cosine_f32
834);
835stamp!(
836 reference,
837 f16,
838 f16,
839 squared_l2_f16,
840 inner_product_f16,
841 cosine_f16
842);
843stamp!(
844 reference,
845 u8,
846 u8,
847 squared_l2_u8,
848 inner_product_u8,
849 cosine_u8
850);
851stamp!(
852 reference,
853 i8,
854 i8,
855 squared_l2_i8,
856 inner_product_i8,
857 cosine_i8
858);
859
860mod reference {
867 use diskann_wide::ARCH;
868 use half::f16;
869
870 trait MaybeFMA {
871 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
874 }
875
876 impl MaybeFMA for diskann_wide::arch::Scalar {
877 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
878 a * b + c
879 }
880 }
881
882 #[cfg(target_arch = "x86_64")]
883 impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
884 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
885 a.mul_add(b, c)
886 }
887 }
888
889 #[cfg(target_arch = "x86_64")]
890 impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
891 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
892 a.mul_add(b, c)
893 }
894 }
895
896 #[cfg(target_arch = "aarch64")]
897 impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
898 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
899 a.mul_add(b, c)
900 }
901 }
902
903 pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
908 assert_eq!(x.len(), y.len());
909 std::iter::zip(x.iter(), y.iter())
910 .map(|(&a, &b)| {
911 let a: i32 = a.into();
912 let b: i32 = b.into();
913 let diff = a - b;
914 diff * diff
915 })
916 .sum::<i32>() as f32
917 }
918
919 pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
920 assert_eq!(x.len(), y.len());
921 std::iter::zip(x.iter(), y.iter())
922 .map(|(&a, &b)| {
923 let a: i32 = a.into();
924 let b: i32 = b.into();
925 let diff = a - b;
926 diff * diff
927 })
928 .sum::<i32>() as f32
929 }
930
931 pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
932 assert_eq!(x.len(), y.len());
933 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
934 let a: f32 = a.into();
935 let b: f32 = b.into();
936 let diff = a - b;
937 ARCH.maybe_fma(diff, diff, acc)
938 })
939 }
940
941 pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
942 assert_eq!(x.len(), y.len());
943 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
944 let diff = a - b;
945 ARCH.maybe_fma(diff, diff, acc)
946 })
947 }
948
949 pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
954 assert_eq!(x.len(), y.len());
955 std::iter::zip(x.iter(), y.iter())
956 .map(|(&a, &b)| {
957 let a: i32 = a.into();
958 let b: i32 = b.into();
959 a * b
960 })
961 .sum::<i32>() as f32
962 }
963
964 pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
965 assert_eq!(x.len(), y.len());
966 std::iter::zip(x.iter(), y.iter())
967 .map(|(&a, &b)| {
968 let a: i32 = a.into();
969 let b: i32 = b.into();
970 a * b
971 })
972 .sum::<i32>() as f32
973 }
974
975 pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
976 assert_eq!(x.len(), y.len());
977 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
978 let a: f32 = a.into();
979 let b: f32 = b.into();
980 ARCH.maybe_fma(a, b, acc)
981 })
982 }
983
984 pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
985 assert_eq!(x.len(), y.len());
986 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
987 }
988
989 #[derive(Default)]
994 struct XY<T> {
995 xnorm: T,
996 ynorm: T,
997 xy: T,
998 }
999
1000 pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1001 assert_eq!(x.len(), y.len());
1002 let r: XY<i32> =
1003 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1004 let vx: i32 = vx.into();
1005 let vy: i32 = vy.into();
1006 XY {
1007 xnorm: acc.xnorm + vx * vx,
1008 ynorm: acc.ynorm + vy * vy,
1009 xy: acc.xy + vx * vy,
1010 }
1011 });
1012
1013 if r.xnorm == 0 || r.ynorm == 0 {
1014 return 0.0;
1015 }
1016
1017 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1018 }
1019
1020 pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1021 assert_eq!(x.len(), y.len());
1022 let r: XY<i32> =
1023 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1024 let vx: i32 = vx.into();
1025 let vy: i32 = vy.into();
1026 XY {
1027 xnorm: acc.xnorm + vx * vx,
1028 ynorm: acc.ynorm + vy * vy,
1029 xy: acc.xy + vx * vy,
1030 }
1031 });
1032
1033 if r.xnorm == 0 || r.ynorm == 0 {
1034 return 0.0;
1035 }
1036
1037 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1038 }
1039
1040 pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1041 assert_eq!(x.len(), y.len());
1042 let r: XY<f32> =
1043 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1044 let vx: f32 = vx.into();
1045 let vy: f32 = vy.into();
1046 XY {
1047 xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1048 ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1049 xy: ARCH.maybe_fma(vx, vy, acc.xy),
1050 }
1051 });
1052
1053 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1054 return 0.0;
1055 }
1056
1057 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1058 }
1059
1060 pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1061 assert_eq!(x.len(), y.len());
1062 let r: XY<f32> =
1063 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1064 xnorm: vx.mul_add(vx, acc.xnorm),
1065 ynorm: vy.mul_add(vy, acc.ynorm),
1066 xy: vx.mul_add(vy, acc.xy),
1067 });
1068
1069 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1070 return 0.0;
1071 }
1072
1073 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1074 }
1075}