1use std::{io::Write, num::NonZeroUsize};
7
8use diskann_utils::views::{Matrix, MatrixView};
9use diskann_vector::distance::simd;
10use diskann_wide::Architecture;
11use half::f16;
12use rand::{
13 distr::{Distribution, StandardUniform},
14 rngs::StdRng,
15 SeedableRng,
16};
17use serde::{Deserialize, Serialize};
18use thiserror::Error;
19
20use diskann_benchmark_runner::{
21 describeln,
22 dispatcher::{self, DispatchRule, FailureScore, MatchScore},
23 utils::{
24 datatype::{self, DataType},
25 percentiles, MicroSeconds,
26 },
27 Any, CheckDeserialization, Checker,
28};
29
30#[derive(Debug)]
35pub struct SimdInput;
36
37pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38 register_benchmarks_impl(dispatcher)
39}
40
41#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49 type Target = T;
50 fn deref(&self) -> &T {
51 self.0
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62 SquaredL2,
63 InnerProduct,
64 Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let st = match self {
70 Self::SquaredL2 => "squared_l2",
71 Self::InnerProduct => "inner_product",
72 Self::Cosine => "cosine",
73 };
74 write!(f, "{}", st)
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80pub(crate) enum Arch {
81 #[serde(rename = "x86-64-v4")]
82 #[allow(non_camel_case_types)]
83 X86_64_V4,
84 #[serde(rename = "x86-64-v3")]
85 #[allow(non_camel_case_types)]
86 X86_64_V3,
87 Scalar,
88 Reference,
89}
90
91impl std::fmt::Display for Arch {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let st = match self {
94 Self::X86_64_V4 => "x86-64-v4",
95 Self::X86_64_V3 => "x86-64-v3",
96 Self::Scalar => "scalar",
97 Self::Reference => "reference",
98 };
99 write!(f, "{}", st)
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub(crate) struct Run {
105 pub(crate) distance: SimilarityMeasure,
106 pub(crate) dim: NonZeroUsize,
107 pub(crate) num_points: NonZeroUsize,
108 pub(crate) loops_per_measurement: NonZeroUsize,
109 pub(crate) num_measurements: NonZeroUsize,
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113pub(crate) struct SimdOp {
114 pub(crate) query_type: DataType,
115 pub(crate) data_type: DataType,
116 pub(crate) arch: Arch,
117 pub(crate) runs: Vec<Run>,
118}
119
120impl CheckDeserialization for SimdOp {
121 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
122 Ok(())
123 }
124}
125
126macro_rules! write_field {
127 ($f:ident, $field:tt, $($expr:tt)*) => {
128 writeln!($f, "{:>18}: {}", $field, $($expr)*)
129 }
130}
131
132impl SimdOp {
133 pub(crate) const fn tag() -> &'static str {
134 "simd-op"
135 }
136
137 fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write_field!(f, "query type", self.query_type)?;
139 write_field!(f, "data type", self.data_type)?;
140 write_field!(f, "arch", self.arch)?;
141 write_field!(f, "number of runs", self.runs.len())?;
142 Ok(())
143 }
144}
145
146impl std::fmt::Display for SimdOp {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 writeln!(f, "SIMD Operation\n")?;
149 write_field!(f, "tag", Self::tag())?;
150 self.summarize_fields(f)
151 }
152}
153
154impl diskann_benchmark_runner::Input for SimdInput {
155 fn tag(&self) -> &'static str {
156 "simd-op"
157 }
158
159 fn try_deserialize(
160 &self,
161 serialized: &serde_json::Value,
162 checker: &mut Checker,
163 ) -> anyhow::Result<Any> {
164 checker.any(SimdOp::deserialize(serialized)?)
165 }
166
167 fn example(&self) -> anyhow::Result<serde_json::Value> {
168 const DIM: [NonZeroUsize; 2] = [
169 NonZeroUsize::new(128).unwrap(),
170 NonZeroUsize::new(150).unwrap(),
171 ];
172
173 const NUM_POINTS: [NonZeroUsize; 2] = [
174 NonZeroUsize::new(1000).unwrap(),
175 NonZeroUsize::new(800).unwrap(),
176 ];
177
178 const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
179 const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
180
181 let runs = vec![
182 Run {
183 distance: SimilarityMeasure::SquaredL2,
184 dim: DIM[0],
185 num_points: NUM_POINTS[0],
186 loops_per_measurement: LOOPS_PER_MEASUREMENT,
187 num_measurements: NUM_MEASUREMENTS,
188 },
189 Run {
190 distance: SimilarityMeasure::InnerProduct,
191 dim: DIM[1],
192 num_points: NUM_POINTS[1],
193 loops_per_measurement: LOOPS_PER_MEASUREMENT,
194 num_measurements: NUM_MEASUREMENTS,
195 },
196 ];
197
198 Ok(serde_json::to_value(&SimdOp {
199 query_type: DataType::Float32,
200 data_type: DataType::Float32,
201 arch: Arch::X86_64_V3,
202 runs,
203 })?)
204 }
205}
206
207macro_rules! register {
212 ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
213 #[cfg(target_arch = $arch)]
214 $dispatcher.register::<$($kernel)*>(
215 $name,
216 run_benchmark,
217 )
218 };
219 ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
220 $dispatcher.register::<$($kernel)*>(
221 $name,
222 run_benchmark,
223 )
224 };
225}
226
227fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
228 register!(
230 "x86_64",
231 dispatcher,
232 "simd-op-f32xf32-x86_64_V4",
233 Kernel<'static, diskann_wide::arch::x86_64::V4, f32, f32>
234 );
235 register!(
236 "x86_64",
237 dispatcher,
238 "simd-op-f16xf16-x86_64_V4",
239 Kernel<'static, diskann_wide::arch::x86_64::V4, f16, f16>
240 );
241 register!(
242 "x86_64",
243 dispatcher,
244 "simd-op-u8xu8-x86_64_V4",
245 Kernel<'static, diskann_wide::arch::x86_64::V4, u8, u8>
246 );
247 register!(
248 "x86_64",
249 dispatcher,
250 "simd-op-i8xi8-x86_64_V4",
251 Kernel<'static, diskann_wide::arch::x86_64::V4, i8, i8>
252 );
253
254 register!(
256 "x86_64",
257 dispatcher,
258 "simd-op-f32xf32-x86_64_V3",
259 Kernel<'static, diskann_wide::arch::x86_64::V3, f32, f32>
260 );
261 register!(
262 "x86_64",
263 dispatcher,
264 "simd-op-f16xf16-x86_64_V3",
265 Kernel<'static, diskann_wide::arch::x86_64::V3, f16, f16>
266 );
267 register!(
268 "x86_64",
269 dispatcher,
270 "simd-op-u8xu8-x86_64_V3",
271 Kernel<'static, diskann_wide::arch::x86_64::V3, u8, u8>
272 );
273 register!(
274 "x86_64",
275 dispatcher,
276 "simd-op-i8xi8-x86_64_V3",
277 Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8>
278 );
279
280 register!(
282 dispatcher,
283 "simd-op-f32xf32-scalar",
284 Kernel<'static, diskann_wide::arch::Scalar, f32, f32>
285 );
286 register!(
287 dispatcher,
288 "simd-op-f16xf16-scalar",
289 Kernel<'static, diskann_wide::arch::Scalar, f16, f16>
290 );
291 register!(
292 dispatcher,
293 "simd-op-u8xu8-scalar",
294 Kernel<'static, diskann_wide::arch::Scalar, u8, u8>
295 );
296 register!(
297 dispatcher,
298 "simd-op-i8xi8-scalar",
299 Kernel<'static, diskann_wide::arch::Scalar, i8, i8>
300 );
301
302 register!(
304 dispatcher,
305 "simd-op-f32xf32-reference",
306 Kernel<'static, Reference, f32, f32>
307 );
308 register!(
309 dispatcher,
310 "simd-op-f16xf16-reference",
311 Kernel<'static, Reference, f16, f16>
312 );
313 register!(
314 dispatcher,
315 "simd-op-u8xu8-reference",
316 Kernel<'static, Reference, u8, u8>
317 );
318 register!(
319 dispatcher,
320 "simd-op-i8xi8-reference",
321 Kernel<'static, Reference, i8, i8>
322 );
323}
324
325struct Reference;
331
332#[derive(Debug)]
334struct Identity<T>(T);
335
336impl<T> dispatcher::Map for Identity<T>
337where
338 T: 'static,
339{
340 type Type<'a> = T;
341}
342
343struct Kernel<'a, A, Q, D> {
344 input: &'a SimdOp,
345 arch: A,
346 _type: std::marker::PhantomData<(A, Q, D)>,
347}
348
349impl<'a, A, Q, D> Kernel<'a, A, Q, D> {
350 fn new(input: &'a SimdOp, arch: A) -> Self {
351 Self {
352 input,
353 arch,
354 _type: std::marker::PhantomData,
355 }
356 }
357}
358
359impl<A, Q, D> dispatcher::Map for Kernel<'static, A, Q, D>
360where
361 A: 'static,
362 Q: 'static,
363 D: 'static,
364{
365 type Type<'a> = Kernel<'a, A, Q, D>;
366}
367
368#[derive(Debug, Error)]
370#[error("architecture {0} is not supported by this CPU")]
371pub(crate) struct ArchNotSupported(Arch);
372
373impl DispatchRule<Arch> for Identity<Reference> {
374 type Error = ArchNotSupported;
375
376 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
377 if *from == Arch::Reference {
378 Ok(MatchScore(0))
379 } else {
380 Err(FailureScore(0))
381 }
382 }
383
384 fn convert(from: Arch) -> Result<Self, Self::Error> {
385 assert_eq!(from, Arch::Reference);
386 Ok(Identity(Reference))
387 }
388
389 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
390 match from {
391 None => write!(f, "loop based"),
392 Some(arch) => {
393 if Self::try_match(arch).is_ok() {
394 write!(f, "matched {}", arch)
395 } else {
396 write!(f, "expected {}, got {}", Arch::Reference, arch)
397 }
398 }
399 }
400 }
401}
402
403impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
404 type Error = ArchNotSupported;
405
406 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
407 if *from == Arch::Scalar {
408 Ok(MatchScore(0))
409 } else {
410 Err(FailureScore(0))
411 }
412 }
413
414 fn convert(from: Arch) -> Result<Self, Self::Error> {
415 assert_eq!(from, Arch::Scalar);
416 Ok(Identity(diskann_wide::arch::Scalar))
417 }
418
419 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
420 match from {
421 None => write!(f, "scalar (compilation target CPU)"),
422 Some(arch) => {
423 if Self::try_match(arch).is_ok() {
424 write!(f, "matched {}", arch)
425 } else {
426 write!(f, "expected {}, got {}", Arch::Scalar, arch)
427 }
428 }
429 }
430 }
431}
432
433#[cfg(target_arch = "x86_64")]
434impl DispatchRule<Arch> for Identity<diskann_wide::arch::x86_64::V4> {
435 type Error = ArchNotSupported;
436
437 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
438 if *from == Arch::X86_64_V4 {
439 Ok(MatchScore(0))
440 } else {
441 Err(FailureScore(0))
442 }
443 }
444
445 fn convert(from: Arch) -> Result<Self, Self::Error> {
446 assert_eq!(from, Arch::X86_64_V4);
447 diskann_wide::arch::x86_64::V4::new_checked()
448 .ok_or(ArchNotSupported(from))
449 .map(Identity)
450 }
451
452 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
453 match from {
454 None => write!(f, "x86-64-v4"),
455 Some(arch) => {
456 if Self::try_match(arch).is_ok() {
457 write!(f, "matched {}", arch)
458 } else {
459 write!(f, "expected {}, got {}", Arch::X86_64_V4, arch)
460 }
461 }
462 }
463 }
464}
465
466#[cfg(target_arch = "x86_64")]
467impl DispatchRule<Arch> for Identity<diskann_wide::arch::x86_64::V3> {
468 type Error = ArchNotSupported;
469
470 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
471 if *from == Arch::X86_64_V3 {
472 Ok(MatchScore(0))
473 } else {
474 Err(FailureScore(0))
475 }
476 }
477
478 fn convert(from: Arch) -> Result<Self, Self::Error> {
479 assert_eq!(from, Arch::X86_64_V3);
480 diskann_wide::arch::x86_64::V3::new_checked()
481 .ok_or(ArchNotSupported(from))
482 .map(Identity)
483 }
484
485 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
486 match from {
487 None => write!(f, "x86-64-v3"),
488 Some(arch) => {
489 if Self::try_match(arch).is_ok() {
490 write!(f, "matched {}", arch)
491 } else {
492 write!(f, "expected {}, got {}", Arch::X86_64_V3, arch)
493 }
494 }
495 }
496 }
497}
498
499impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D>
500where
501 datatype::Type<Q>: DispatchRule<datatype::DataType>,
502 datatype::Type<D>: DispatchRule<datatype::DataType>,
503 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
504{
505 type Error = ArchNotSupported;
506
507 fn try_match(from: &&'a SimdOp) -> Result<MatchScore, FailureScore> {
509 let mut failscore: Option<u32> = None;
510 if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
511 *failscore.get_or_insert(0) += 10;
512 }
513 if datatype::Type::<D>::try_match(&from.data_type).is_err() {
514 *failscore.get_or_insert(0) += 10;
515 }
516 if Identity::<A>::try_match(&from.arch).is_err() {
517 *failscore.get_or_insert(0) += 2;
518 }
519 match failscore {
520 None => Ok(MatchScore(0)),
521 Some(score) => Err(FailureScore(score)),
522 }
523 }
524
525 fn convert(from: &'a SimdOp) -> Result<Self, Self::Error> {
526 assert!(Self::try_match(&from).is_ok());
527 let arch = Identity::<A>::convert(from.arch)?.0;
528 Ok(Self::new(from, arch))
529 }
530
531 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a SimdOp>) -> std::fmt::Result {
532 match from {
533 None => {
534 describeln!(
535 f,
536 "- Query Type: {}",
537 dispatcher::Description::<datatype::DataType, datatype::Type<Q>>::new()
538 )?;
539 describeln!(
540 f,
541 "- Data Type: {}",
542 dispatcher::Description::<datatype::DataType, datatype::Type<D>>::new()
543 )?;
544 describeln!(
545 f,
546 "- Implementation: {}",
547 dispatcher::Description::<Arch, Identity<A>>::new()
548 )?;
549 }
550 Some(input) => {
551 if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
552 describeln!(f, "- Mismatched query type: {}", err)?;
553 }
554 if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
555 describeln!(f, "- Mismatched data type: {}", err)?;
556 }
557 if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
558 describeln!(f, "- Mismatched architecture: {}", err)?;
559 }
560 }
561 }
562 Ok(())
563 }
564}
565
566impl<'a, A, Q, D> DispatchRule<&'a diskann_benchmark_runner::Any> for Kernel<'a, A, Q, D>
567where
568 Kernel<'a, A, Q, D>: DispatchRule<&'a SimdOp>,
569 <Kernel<'a, A, Q, D> as DispatchRule<&'a SimdOp>>::Error:
570 std::error::Error + Send + Sync + 'static,
571{
572 type Error = anyhow::Error;
573
574 fn try_match(from: &&'a diskann_benchmark_runner::Any) -> Result<MatchScore, FailureScore> {
575 from.try_match::<SimdOp, Self>()
576 }
577
578 fn convert(from: &'a diskann_benchmark_runner::Any) -> Result<Self, Self::Error> {
579 from.convert::<SimdOp, Self>()
580 }
581
582 fn description(
583 f: &mut std::fmt::Formatter<'_>,
584 from: Option<&&'a diskann_benchmark_runner::Any>,
585 ) -> std::fmt::Result {
586 Any::description::<SimdOp, Self>(f, from, SimdOp::tag())
587 }
588}
589
590fn run_benchmark<A, Q, D>(
595 kernel: Kernel<'_, A, Q, D>,
596 _: diskann_benchmark_runner::Checkpoint<'_>,
597 mut output: &mut dyn diskann_benchmark_runner::Output,
598) -> Result<serde_json::Value, anyhow::Error>
599where
600 for<'a> Kernel<'a, A, Q, D>: RunBenchmark,
601{
602 writeln!(output, "{}", kernel.input)?;
603 let results = kernel.run()?;
604 writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
605 Ok(serde_json::to_value(results)?)
606}
607
608trait RunBenchmark {
609 fn run(self) -> Result<Vec<RunResult>, anyhow::Error>;
610}
611
612#[derive(Debug, Serialize)]
613struct RunResult {
614 run: Run,
616 latencies: Vec<MicroSeconds>,
618 percentiles: percentiles::Percentiles<MicroSeconds>,
620}
621
622impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
623 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624 if self.is_empty() {
625 return Ok(());
626 }
627
628 let header = [
629 "Distance",
630 "Dim",
631 "Min Time (ns)",
632 "Mean Time (ns)",
633 "Points",
634 "Loops",
635 "Measurements",
636 ];
637
638 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
639
640 self.iter().enumerate().for_each(|(row, r)| {
641 let mut row = table.row(row);
642
643 let min_latency = r
644 .latencies
645 .iter()
646 .min()
647 .copied()
648 .unwrap_or(MicroSeconds::new(u64::MAX));
649 let mean_latency = r.percentiles.mean;
650
651 let computations_per_latency: f64 =
652 (r.run.num_points.get() * r.run.loops_per_measurement.get()) as f64;
653
654 let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
656 let mean_time = mean_latency / computations_per_latency * 1000.0;
657
658 row.insert(r.run.distance, 0);
659 row.insert(r.run.dim, 1);
660 row.insert(format!("{:.3}", min_time), 2);
661 row.insert(format!("{:.3}", mean_time), 3);
662 row.insert(r.run.num_points, 4);
663 row.insert(r.run.loops_per_measurement, 5);
664 row.insert(r.run.num_measurements, 6);
665 });
666
667 table.fmt(f)
668 }
669}
670
671fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
672where
673 F: Fn(&[Q], &[D]) -> f32,
674{
675 let mut latencies = Vec::with_capacity(run.num_measurements.get());
676 let mut dst = vec![0.0; data.nrows()];
677
678 for _ in 0..run.num_measurements.get() {
679 let start = std::time::Instant::now();
680 for _ in 0..run.loops_per_measurement.get() {
681 std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
682 *d = f(query, r);
683 });
684 std::hint::black_box(&mut dst);
685 }
686 latencies.push(start.elapsed().into());
687 }
688
689 let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
690 RunResult {
691 run: run.clone(),
692 latencies,
693 percentiles,
694 }
695}
696
697struct Data<Q, D> {
698 query: Box<[Q]>,
699 data: Matrix<D>,
700}
701
702impl<Q, D> Data<Q, D> {
703 fn new(run: &Run) -> Self
704 where
705 StandardUniform: Distribution<Q>,
706 StandardUniform: Distribution<D>,
707 {
708 let mut rng = StdRng::seed_from_u64(0x12345);
709 let query: Box<[Q]> = (0..run.dim.get())
710 .map(|_| StandardUniform.sample(&mut rng))
711 .collect();
712 let data = Matrix::<D>::new(
713 diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
714 run.num_points.get(),
715 run.dim.get(),
716 );
717
718 Self { query, data }
719 }
720
721 fn run<F>(&self, run: &Run, f: F) -> RunResult
722 where
723 F: Fn(&[Q], &[D]) -> f32,
724 {
725 run_loops(&self.query, self.data.as_view(), run, f)
726 }
727}
728
729macro_rules! stamp {
734 (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
735 impl RunBenchmark for Kernel<'_, Reference, $Q, $D> {
736 fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
737 let mut results = Vec::new();
738 for run in self.input.runs.iter() {
739 let data = Data::<$Q, $D>::new(run);
740 let result = match run.distance {
741 SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
742 SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
743 SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
744 };
745 results.push(result);
746 }
747 Ok(results)
748 }
749 }
750 };
751 ($arch:path, $Q:ty, $D:ty) => {
752 impl RunBenchmark for Kernel<'_, $arch, $Q, $D> {
753 fn run(self) -> Result<Vec<RunResult>, anyhow::Error> {
754 let mut results = Vec::new();
755
756 let l2 = &simd::L2 {};
757 let ip = &simd::IP {};
758 let cosine = &simd::CosineStateless {};
759
760 for run in self.input.runs.iter() {
761 let data = Data::<$Q, $D>::new(run);
762 let result = match run.distance {
769 SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
770 self.arch
771 .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
772 }),
773 SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
774 self.arch
775 .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
776 }),
777 SimilarityMeasure::Cosine => data.run(run, |q, d| {
778 self.arch
779 .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
780 }),
781 };
782 results.push(result)
783 }
784 Ok(results)
785 }
786 }
787 };
788 ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
789 #[cfg(target_arch = $target_arch)]
790 stamp!($arch, $Q, $D);
791 };
792}
793
794stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
795stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
796stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
797stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
798
799stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
800stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
801stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
802stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
803
804stamp!(diskann_wide::arch::Scalar, f32, f32);
805stamp!(diskann_wide::arch::Scalar, f16, f16);
806stamp!(diskann_wide::arch::Scalar, u8, u8);
807stamp!(diskann_wide::arch::Scalar, i8, i8);
808
809stamp!(
810 reference,
811 f32,
812 f32,
813 squared_l2_f32,
814 inner_product_f32,
815 cosine_f32
816);
817stamp!(
818 reference,
819 f16,
820 f16,
821 squared_l2_f16,
822 inner_product_f16,
823 cosine_f16
824);
825stamp!(
826 reference,
827 u8,
828 u8,
829 squared_l2_u8,
830 inner_product_u8,
831 cosine_u8
832);
833stamp!(
834 reference,
835 i8,
836 i8,
837 squared_l2_i8,
838 inner_product_i8,
839 cosine_i8
840);
841
842mod reference {
849 use diskann_wide::ARCH;
850 use half::f16;
851
852 trait MaybeFMA {
853 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
856 }
857
858 impl MaybeFMA for diskann_wide::arch::Scalar {
859 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
860 a * b + c
861 }
862 }
863
864 #[cfg(target_arch = "x86_64")]
865 impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
866 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
867 a.mul_add(b, c)
868 }
869 }
870
871 #[cfg(target_arch = "x86_64")]
872 impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
873 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
874 a.mul_add(b, c)
875 }
876 }
877
878 pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
883 assert_eq!(x.len(), y.len());
884 std::iter::zip(x.iter(), y.iter())
885 .map(|(&a, &b)| {
886 let a: i32 = a.into();
887 let b: i32 = b.into();
888 let diff = a - b;
889 diff * diff
890 })
891 .sum::<i32>() as f32
892 }
893
894 pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
895 assert_eq!(x.len(), y.len());
896 std::iter::zip(x.iter(), y.iter())
897 .map(|(&a, &b)| {
898 let a: i32 = a.into();
899 let b: i32 = b.into();
900 let diff = a - b;
901 diff * diff
902 })
903 .sum::<i32>() as f32
904 }
905
906 pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
907 assert_eq!(x.len(), y.len());
908 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
909 let a: f32 = a.into();
910 let b: f32 = b.into();
911 let diff = a - b;
912 ARCH.maybe_fma(diff, diff, acc)
913 })
914 }
915
916 pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
917 assert_eq!(x.len(), y.len());
918 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
919 let diff = a - b;
920 ARCH.maybe_fma(diff, diff, acc)
921 })
922 }
923
924 pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
929 assert_eq!(x.len(), y.len());
930 std::iter::zip(x.iter(), y.iter())
931 .map(|(&a, &b)| {
932 let a: i32 = a.into();
933 let b: i32 = b.into();
934 a * b
935 })
936 .sum::<i32>() as f32
937 }
938
939 pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
940 assert_eq!(x.len(), y.len());
941 std::iter::zip(x.iter(), y.iter())
942 .map(|(&a, &b)| {
943 let a: i32 = a.into();
944 let b: i32 = b.into();
945 a * b
946 })
947 .sum::<i32>() as f32
948 }
949
950 pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
951 assert_eq!(x.len(), y.len());
952 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
953 let a: f32 = a.into();
954 let b: f32 = b.into();
955 ARCH.maybe_fma(a, b, acc)
956 })
957 }
958
959 pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
960 assert_eq!(x.len(), y.len());
961 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
962 }
963
964 #[derive(Default)]
969 struct XY<T> {
970 xnorm: T,
971 ynorm: T,
972 xy: T,
973 }
974
975 pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
976 assert_eq!(x.len(), y.len());
977 let r: XY<i32> =
978 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
979 let vx: i32 = vx.into();
980 let vy: i32 = vy.into();
981 XY {
982 xnorm: acc.xnorm + vx * vx,
983 ynorm: acc.ynorm + vy * vy,
984 xy: acc.xy + vx * vy,
985 }
986 });
987
988 if r.xnorm == 0 || r.ynorm == 0 {
989 return 0.0;
990 }
991
992 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
993 }
994
995 pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
996 assert_eq!(x.len(), y.len());
997 let r: XY<i32> =
998 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
999 let vx: i32 = vx.into();
1000 let vy: i32 = vy.into();
1001 XY {
1002 xnorm: acc.xnorm + vx * vx,
1003 ynorm: acc.ynorm + vy * vy,
1004 xy: acc.xy + vx * vy,
1005 }
1006 });
1007
1008 if r.xnorm == 0 || r.ynorm == 0 {
1009 return 0.0;
1010 }
1011
1012 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1013 }
1014
1015 pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1016 assert_eq!(x.len(), y.len());
1017 let r: XY<f32> =
1018 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1019 let vx: f32 = vx.into();
1020 let vy: f32 = vy.into();
1021 XY {
1022 xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1023 ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1024 xy: ARCH.maybe_fma(vx, vy, acc.xy),
1025 }
1026 });
1027
1028 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1029 return 0.0;
1030 }
1031
1032 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1033 }
1034
1035 pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1036 assert_eq!(x.len(), y.len());
1037 let r: XY<f32> =
1038 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1039 xnorm: vx.mul_add(vx, acc.xnorm),
1040 ynorm: vy.mul_add(vy, acc.ynorm),
1041 xy: vx.mul_add(vy, acc.xy),
1042 });
1043
1044 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1045 return 0.0;
1046 }
1047
1048 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1049 }
1050}