1use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15 distr::{Distribution, StandardUniform},
16 rngs::StdRng,
17 SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23 benchmark::{FailureScore, MatchScore, PassFail, Regression},
24 utils::{
25 datatype::{AsDataType, DataType},
26 num::{relative_change, NonNegativeFinite},
27 percentiles, MicroSeconds,
28 },
29 Benchmark, Checker, Input, Registry,
30};
31
32pub fn register(registry: &mut Registry) -> anyhow::Result<()> {
37 Ok(register_benchmarks_impl(registry)?)
38}
39
40#[derive(Debug, Clone, Copy)]
45struct DisplayWrapper<'a, T: ?Sized>(&'a T);
46
47impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
48 type Target = T;
49 fn deref(&self) -> &T {
50 self.0
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum SimilarityMeasure {
61 SquaredL2,
62 InnerProduct,
63 Cosine,
64}
65
66impl std::fmt::Display for SimilarityMeasure {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let st = match self {
69 Self::SquaredL2 => "squared_l2",
70 Self::InnerProduct => "inner_product",
71 Self::Cosine => "cosine",
72 };
73 write!(f, "{}", st)
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
78#[serde(rename_all = "kebab-case")]
79enum Arch {
80 #[serde(rename = "x86-64-v4")]
81 #[allow(non_camel_case_types)]
82 X86_64_V4,
83 #[serde(rename = "x86-64-v3")]
84 #[allow(non_camel_case_types)]
85 X86_64_V3,
86 Neon,
87 Scalar,
88 Reference,
89}
90
91impl std::fmt::Display for Arch {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let st = match self {
94 Self::X86_64_V4 => "x86-64-v4",
95 Self::X86_64_V3 => "x86-64-v3",
96 Self::Neon => "neon",
97 Self::Scalar => "scalar",
98 Self::Reference => "reference",
99 };
100 write!(f, "{}", st)
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105struct Run {
106 distance: SimilarityMeasure,
107 dim: NonZeroUsize,
108 num_points: NonZeroUsize,
109 loops_per_measurement: NonZeroUsize,
110 num_measurements: NonZeroUsize,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114pub struct SimdOp {
115 query_type: DataType,
116 data_type: DataType,
117 arch: Arch,
118 runs: Vec<Run>,
119}
120
121macro_rules! write_field {
122 ($f:ident, $field:tt, $($expr:tt)*) => {
123 writeln!($f, "{:>18}: {}", $field, $($expr)*)
124 }
125}
126
127impl SimdOp {
128 fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write_field!(f, "query type", self.query_type)?;
130 write_field!(f, "data type", self.data_type)?;
131 write_field!(f, "arch", self.arch)?;
132 write_field!(f, "number of runs", self.runs.len())?;
133 Ok(())
134 }
135}
136
137impl std::fmt::Display for SimdOp {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 writeln!(f, "SIMD Operation\n")?;
140 write_field!(f, "tag", Self::tag())?;
141 self.summarize_fields(f)
142 }
143}
144
145impl Input for SimdOp {
146 type Raw = Self;
147
148 fn tag() -> &'static str {
149 "simd-op"
150 }
151
152 fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<Self> {
153 Ok(raw)
154 }
155
156 fn serialize(&self) -> anyhow::Result<serde_json::Value> {
157 Ok(serde_json::to_value(self)?)
158 }
159
160 fn example() -> Self::Raw {
161 const DIM: [NonZeroUsize; 2] = [
162 NonZeroUsize::new(128).unwrap(),
163 NonZeroUsize::new(150).unwrap(),
164 ];
165
166 const NUM_POINTS: [NonZeroUsize; 2] = [
167 NonZeroUsize::new(1000).unwrap(),
168 NonZeroUsize::new(800).unwrap(),
169 ];
170
171 const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
172 const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
173
174 let runs = vec![
175 Run {
176 distance: SimilarityMeasure::SquaredL2,
177 dim: DIM[0],
178 num_points: NUM_POINTS[0],
179 loops_per_measurement: LOOPS_PER_MEASUREMENT,
180 num_measurements: NUM_MEASUREMENTS,
181 },
182 Run {
183 distance: SimilarityMeasure::InnerProduct,
184 dim: DIM[1],
185 num_points: NUM_POINTS[1],
186 loops_per_measurement: LOOPS_PER_MEASUREMENT,
187 num_measurements: NUM_MEASUREMENTS,
188 },
189 ];
190
191 Self {
192 query_type: DataType::Float32,
193 data_type: DataType::Float32,
194 arch: Arch::X86_64_V3,
195 runs,
196 }
197 }
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
209struct SimdTolerance {
210 min_time_regression: NonNegativeFinite,
211}
212
213impl Input for SimdTolerance {
214 type Raw = Self;
215
216 fn tag() -> &'static str {
217 "simd-tolerance"
218 }
219
220 fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<Self> {
221 Ok(raw)
222 }
223
224 fn serialize(&self) -> anyhow::Result<serde_json::Value> {
225 Ok(serde_json::to_value(self)?)
226 }
227
228 fn example() -> Self {
229 const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
230 Ok(v) => v,
231 Err(_) => panic!("use a non-negative finite please"),
232 };
233
234 SimdTolerance {
235 min_time_regression: EXAMPLE,
236 }
237 }
238}
239
240#[derive(Debug, Serialize)]
242struct Comparison {
243 run: Run,
244 tolerance: SimdTolerance,
245 before_min: f64,
246 after_min: f64,
247}
248
249#[derive(Debug, Serialize)]
251struct CheckResult {
252 checks: Vec<Comparison>,
253}
254
255impl std::fmt::Display for CheckResult {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 let header = [
258 "Distance",
259 "Dim",
260 "Min Before (ns)",
261 "Min After (ns)",
262 "Change (%)",
263 "Remark",
264 ];
265
266 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
267
268 for (i, c) in self.checks.iter().enumerate() {
269 let mut row = table.row(i);
270 let change = relative_change(c.before_min, c.after_min);
271
272 row.insert(c.run.distance, 0);
273 row.insert(c.run.dim, 1);
274 row.insert(format!("{:.3}", c.before_min), 2);
275 row.insert(format!("{:.3}", c.after_min), 3);
276 match change {
277 Ok(change) => {
278 row.insert(format!("{:.3} %", change * 100.0), 4);
279 if change > c.tolerance.min_time_regression.get() {
280 row.insert("FAIL", 5);
281 }
282 }
283 Err(err) => {
284 row.insert("invalid", 4);
285 row.insert(err, 5);
286 }
287 }
288 }
289
290 table.fmt(f)
291 }
292}
293
294fn register_benchmarks_impl(
299 registry: &mut diskann_benchmark_runner::Registry,
300) -> Result<(), diskann_benchmark_runner::RegistryError> {
301 #[cfg(target_arch = "x86_64")]
303 {
304 registry.register_regression(
305 "simd-op-f32xf32-x86_64_V4",
306 Kernel::<diskann_wide::arch::x86_64::V4, f32, f32>::new(),
307 )?;
308 registry.register_regression(
309 "simd-op-f16xf16-x86_64_V4",
310 Kernel::<diskann_wide::arch::x86_64::V4, f16, f16>::new(),
311 )?;
312 registry.register_regression(
313 "simd-op-u8xu8-x86_64_V4",
314 Kernel::<diskann_wide::arch::x86_64::V4, u8, u8>::new(),
315 )?;
316 registry.register_regression(
317 "simd-op-i8xi8-x86_64_V4",
318 Kernel::<diskann_wide::arch::x86_64::V4, i8, i8>::new(),
319 )?;
320 }
321
322 #[cfg(target_arch = "x86_64")]
324 {
325 registry.register_regression(
326 "simd-op-f32xf32-x86_64_V3",
327 Kernel::<diskann_wide::arch::x86_64::V3, f32, f32>::new(),
328 )?;
329 registry.register_regression(
330 "simd-op-f16xf16-x86_64_V3",
331 Kernel::<diskann_wide::arch::x86_64::V3, f16, f16>::new(),
332 )?;
333 registry.register_regression(
334 "simd-op-u8xu8-x86_64_V3",
335 Kernel::<diskann_wide::arch::x86_64::V3, u8, u8>::new(),
336 )?;
337 registry.register_regression(
338 "simd-op-i8xi8-x86_64_V3",
339 Kernel::<diskann_wide::arch::x86_64::V3, i8, i8>::new(),
340 )?;
341 }
342
343 #[cfg(target_arch = "aarch64")]
345 {
346 registry.register_regression(
347 "simd-op-f32xf32-aarch64_neon",
348 Kernel::<diskann_wide::arch::aarch64::Neon, f32, f32>::new(),
349 )?;
350 registry.register_regression(
351 "simd-op-f16xf16-aarch64_neon",
352 Kernel::<diskann_wide::arch::aarch64::Neon, f16, f16>::new(),
353 )?;
354 registry.register_regression(
355 "simd-op-u8xu8-aarch64_neon",
356 Kernel::<diskann_wide::arch::aarch64::Neon, u8, u8>::new(),
357 )?;
358 registry.register_regression(
359 "simd-op-i8xi8-aarch64_neon",
360 Kernel::<diskann_wide::arch::aarch64::Neon, i8, i8>::new(),
361 )?;
362 }
363
364 registry.register_regression(
366 "simd-op-f32xf32-scalar",
367 Kernel::<diskann_wide::arch::Scalar, f32, f32>::new(),
368 )?;
369 registry.register_regression(
370 "simd-op-f16xf16-scalar",
371 Kernel::<diskann_wide::arch::Scalar, f16, f16>::new(),
372 )?;
373 registry.register_regression(
374 "simd-op-u8xu8-scalar",
375 Kernel::<diskann_wide::arch::Scalar, u8, u8>::new(),
376 )?;
377 registry.register_regression(
378 "simd-op-i8xi8-scalar",
379 Kernel::<diskann_wide::arch::Scalar, i8, i8>::new(),
380 )?;
381
382 registry.register_regression(
384 "simd-op-f32xf32-reference",
385 Kernel::<Reference, f32, f32>::new(),
386 )?;
387 registry.register_regression(
388 "simd-op-f16xf16-reference",
389 Kernel::<Reference, f16, f16>::new(),
390 )?;
391 registry.register_regression(
392 "simd-op-u8xu8-reference",
393 Kernel::<Reference, u8, u8>::new(),
394 )?;
395 registry.register_regression(
396 "simd-op-i8xi8-reference",
397 Kernel::<Reference, i8, i8>::new(),
398 )?;
399 Ok(())
400}
401
402struct Reference;
408
409struct Kernel<A, Q, D> {
410 _type: std::marker::PhantomData<(A, Q, D)>,
411}
412
413impl<A, Q, D> Kernel<A, Q, D> {
414 fn new() -> Self {
415 Self {
416 _type: std::marker::PhantomData,
417 }
418 }
419}
420
421#[derive(Debug, Error)]
422#[error("architecture {0} is not supported by this CPU")]
423pub(crate) struct ArchNotSupported(Arch);
424
425trait AsArch: Sized + 'static {
427 const ARCH: Arch;
428 const DISPLAY_NAME: &'static str;
429
430 fn is_available() -> bool {
431 true
432 }
433
434 fn try_new() -> Result<Self, ArchNotSupported>;
435
436 fn is_match(arch: Arch) -> bool {
437 arch == Self::ARCH && Self::is_available()
438 }
439
440 fn describe(arch: Arch) -> ArchDescribe {
441 if arch != Self::ARCH {
442 ArchDescribe::Mismatch {
443 expected: Self::ARCH,
444 got: arch,
445 }
446 } else if !Self::is_available() {
447 ArchDescribe::Unsupported(Self::ARCH)
448 } else {
449 ArchDescribe::Match(arch)
450 }
451 }
452}
453
454#[derive(Debug, Clone, Copy)]
455enum ArchDescribe {
456 Match(Arch),
457 Unsupported(Arch),
458 Mismatch { expected: Arch, got: Arch },
459}
460
461impl ArchDescribe {
462 fn is_match(&self) -> bool {
463 matches!(self, ArchDescribe::Match(_))
464 }
465}
466
467impl std::fmt::Display for ArchDescribe {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 match self {
470 Self::Match(arch) => write!(f, "matched {}", arch),
471 Self::Unsupported(arch) => {
472 write!(f, "matched {} but unsupported by this CPU", arch)
473 }
474 Self::Mismatch { expected, got } => write!(f, "expected {}, got {}", expected, got),
475 }
476 }
477}
478
479impl AsArch for Reference {
480 const ARCH: Arch = Arch::Reference;
481 const DISPLAY_NAME: &'static str = "loop based";
482
483 fn try_new() -> Result<Self, ArchNotSupported> {
484 Ok(Reference)
485 }
486}
487
488impl AsArch for diskann_wide::arch::Scalar {
489 const ARCH: Arch = Arch::Scalar;
490 const DISPLAY_NAME: &'static str = "scalar (compilation target CPU)";
491
492 fn try_new() -> Result<Self, ArchNotSupported> {
493 Ok(diskann_wide::arch::Scalar)
494 }
495}
496
497macro_rules! match_arch {
498 ($target_arch:literal, $arch:path, $enum:ident) => {
499 #[cfg(target_arch = $target_arch)]
500 impl AsArch for $arch {
501 const ARCH: Arch = Arch::$enum;
502 const DISPLAY_NAME: &'static str = stringify!($enum);
503
504 fn is_available() -> bool {
505 <$arch>::new_checked().is_some()
506 }
507
508 fn try_new() -> Result<Self, ArchNotSupported> {
509 <$arch>::new_checked().ok_or(ArchNotSupported(Arch::$enum))
510 }
511 }
512 };
513}
514
515match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
516match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
517match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
518
519impl<A, Q, D> Benchmark for Kernel<A, Q, D>
520where
521 Q: AsDataType,
522 D: AsDataType,
523 A: AsArch,
524 Kernel<A, Q, D>: RunBenchmark<A>,
525{
526 type Input = SimdOp;
527 type Output = Vec<RunResult>;
528
529 fn try_match(&self, from: &SimdOp) -> Result<MatchScore, FailureScore> {
531 let mut failscore: Option<u32> = None;
532 if !Q::is_match(from.query_type) {
533 *failscore.get_or_insert(0) += 10;
534 }
535 if !D::is_match(from.data_type) {
536 *failscore.get_or_insert(0) += 10;
537 }
538 if !A::is_match(from.arch) {
539 let penalty = if from.arch == A::ARCH { 2 } else { 3 };
540 *failscore.get_or_insert(0) += penalty;
541 }
542
543 match failscore {
544 None => Ok(MatchScore(0)),
545 Some(score) => Err(FailureScore(score)),
546 }
547 }
548
549 fn run(
550 &self,
551 input: &SimdOp,
552 _: diskann_benchmark_runner::Checkpoint<'_>,
553 mut output: &mut dyn diskann_benchmark_runner::Output,
554 ) -> anyhow::Result<Self::Output> {
555 if input.arch != A::ARCH {
556 anyhow::bail!(
557 "architecture mismatch: input requested {:?}, but kernel implementation requires {:?}",
558 input.arch,
559 A::ARCH
560 );
561 }
562
563 let arch = A::try_new()?;
564 writeln!(output, "{}", input)?;
565 let results = self.run_benchmark(input, arch)?;
566 writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
567 Ok(results)
568 }
569
570 fn description(
571 &self,
572 f: &mut std::fmt::Formatter<'_>,
573 input: Option<&SimdOp>,
574 ) -> std::fmt::Result {
575 match input {
576 None => {
577 writeln!(f, "- Query Type: {}", Q::DATA_TYPE)?;
578 writeln!(f, "- Data Type: {}", D::DATA_TYPE)?;
579 writeln!(f, "- Implementation: {}", A::DISPLAY_NAME)?;
580 }
581 Some(input) => {
582 let desc = Q::describe(input.query_type);
583 if !desc.is_match() {
584 writeln!(f, "\n - Mismatched query type: {}", desc)?;
585 }
586
587 let desc = D::describe(input.data_type);
588 if !desc.is_match() {
589 writeln!(f, "\n - Mismatched data type: {}", desc)?;
590 }
591
592 let desc = A::describe(input.arch);
593 if !desc.is_match() {
594 writeln!(f, "\n - Mismatched architecture: {}", desc)?;
595 }
596 }
597 }
598 Ok(())
599 }
600}
601
602impl<A, Q, D> Regression for Kernel<A, Q, D>
603where
604 Q: AsDataType,
605 D: AsDataType,
606 A: AsArch,
607 Kernel<A, Q, D>: RunBenchmark<A>,
608{
609 type Tolerances = SimdTolerance;
610 type Pass = CheckResult;
611 type Fail = CheckResult;
612
613 fn check(
614 &self,
615 tolerance: &SimdTolerance,
616 _input: &SimdOp,
617 before: &Vec<RunResult>,
618 after: &Vec<RunResult>,
619 ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
620 anyhow::ensure!(
621 before.len() == after.len(),
622 "before has {} runs but after has {}",
623 before.len(),
624 after.len(),
625 );
626
627 let mut passed = true;
628 let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
629 .enumerate()
630 .map(|(i, (b, a))| {
631 anyhow::ensure!(b.run == a.run, "run {i} mismatched");
632
633 let computations_per_latency = b.computations_per_latency() as f64;
634
635 let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
636 let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
637
638 let comparison = Comparison {
639 run: b.run.clone(),
640 tolerance: *tolerance,
641 before_min,
642 after_min,
643 };
644
645 match relative_change(before_min, after_min) {
647 Ok(change) => {
648 if change > tolerance.min_time_regression.get() {
649 passed = false;
650 }
651 }
652 Err(_) => passed = false,
653 };
654
655 Ok(comparison)
656 })
657 .collect::<anyhow::Result<Vec<Comparison>>>()?;
658
659 let check = CheckResult { checks };
660
661 if passed {
662 Ok(PassFail::Pass(check))
663 } else {
664 Ok(PassFail::Fail(check))
665 }
666 }
667}
668
669trait RunBenchmark<A> {
674 fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result<Vec<RunResult>, anyhow::Error>;
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678struct RunResult {
679 run: Run,
681 latencies: Vec<MicroSeconds>,
683 percentiles: percentiles::Percentiles<MicroSeconds>,
685}
686
687impl RunResult {
688 fn computations_per_latency(&self) -> usize {
689 self.run.num_points.get() * self.run.loops_per_measurement.get()
690 }
691}
692
693impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
694 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
695 if self.is_empty() {
696 return Ok(());
697 }
698
699 let header = [
700 "Distance",
701 "Dim",
702 "Min Time (ns)",
703 "Mean Time (ns)",
704 "Points",
705 "Loops",
706 "Measurements",
707 ];
708
709 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
710
711 self.iter().enumerate().for_each(|(row, r)| {
712 let mut row = table.row(row);
713
714 let min_latency = r
715 .latencies
716 .iter()
717 .min()
718 .copied()
719 .unwrap_or(MicroSeconds::new(u64::MAX));
720 let mean_latency = r.percentiles.mean;
721
722 let computations_per_latency = r.computations_per_latency() as f64;
723
724 let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
726 let mean_time = mean_latency / computations_per_latency * 1000.0;
727
728 row.insert(r.run.distance, 0);
729 row.insert(r.run.dim, 1);
730 row.insert(format!("{:.3}", min_time), 2);
731 row.insert(format!("{:.3}", mean_time), 3);
732 row.insert(r.run.num_points, 4);
733 row.insert(r.run.loops_per_measurement, 5);
734 row.insert(r.run.num_measurements, 6);
735 });
736
737 table.fmt(f)
738 }
739}
740
741fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
742where
743 F: Fn(&[Q], &[D]) -> f32,
744{
745 let mut latencies = Vec::with_capacity(run.num_measurements.get());
746 let mut dst = vec![0.0; data.nrows()];
747
748 for _ in 0..run.num_measurements.get() {
749 let start = std::time::Instant::now();
750 for _ in 0..run.loops_per_measurement.get() {
751 std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
752 *d = f(query, r);
753 });
754 std::hint::black_box(&mut dst);
755 }
756 latencies.push(start.elapsed().into());
757 }
758
759 let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
760 RunResult {
761 run: run.clone(),
762 latencies,
763 percentiles,
764 }
765}
766
767struct Data<Q, D> {
768 query: Box<[Q]>,
769 data: Matrix<D>,
770}
771
772impl<Q, D> Data<Q, D> {
773 fn new(run: &Run) -> Self
774 where
775 StandardUniform: Distribution<Q>,
776 StandardUniform: Distribution<D>,
777 {
778 let mut rng = StdRng::seed_from_u64(0x12345);
779 let query: Box<[Q]> = (0..run.dim.get())
780 .map(|_| StandardUniform.sample(&mut rng))
781 .collect();
782 let data = Matrix::<D>::new(
783 diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
784 run.num_points.get(),
785 run.dim.get(),
786 );
787
788 Self { query, data }
789 }
790
791 fn run<F>(&self, run: &Run, f: F) -> RunResult
792 where
793 F: Fn(&[Q], &[D]) -> f32,
794 {
795 run_loops(&self.query, self.data.as_view(), run, f)
796 }
797}
798
799macro_rules! stamp {
804 (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
805 impl RunBenchmark<Reference> for Kernel<Reference, $Q, $D> {
806 fn run_benchmark(
807 &self,
808 input: &SimdOp,
809 _arch: Reference,
810 ) -> Result<Vec<RunResult>, anyhow::Error> {
811 let mut results = Vec::new();
812 for run in input.runs.iter() {
813 let data = Data::<$Q, $D>::new(run);
814 let result = match run.distance {
815 SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
816 SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
817 SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
818 };
819 results.push(result);
820 }
821 Ok(results)
822 }
823 }
824 };
825 ($arch:path, $Q:ty, $D:ty) => {
826 impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> {
827 fn run_benchmark(
828 &self,
829 input: &SimdOp,
830 arch: $arch,
831 ) -> Result<Vec<RunResult>, anyhow::Error> {
832 let mut results = Vec::new();
833
834 let l2 = &simd::L2 {};
835 let ip = &simd::IP {};
836 let cosine = &simd::CosineStateless {};
837
838 for run in input.runs.iter() {
839 let data = Data::<$Q, $D>::new(run);
840 let result = match run.distance {
847 SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
848 arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d)
849 }),
850 SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
851 arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d)
852 }),
853 SimilarityMeasure::Cosine => data.run(run, |q, d| {
854 arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d)
855 }),
856 };
857 results.push(result)
858 }
859 Ok(results)
860 }
861 }
862 };
863 ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
864 #[cfg(target_arch = $target_arch)]
865 stamp!($arch, $Q, $D);
866 };
867}
868
869stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
870stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
871stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
872stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
873
874stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
875stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
876stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
877stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
878
879stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
880stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
881stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
882stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
883
884stamp!(diskann_wide::arch::Scalar, f32, f32);
885stamp!(diskann_wide::arch::Scalar, f16, f16);
886stamp!(diskann_wide::arch::Scalar, u8, u8);
887stamp!(diskann_wide::arch::Scalar, i8, i8);
888
889stamp!(
890 reference,
891 f32,
892 f32,
893 squared_l2_f32,
894 inner_product_f32,
895 cosine_f32
896);
897stamp!(
898 reference,
899 f16,
900 f16,
901 squared_l2_f16,
902 inner_product_f16,
903 cosine_f16
904);
905stamp!(
906 reference,
907 u8,
908 u8,
909 squared_l2_u8,
910 inner_product_u8,
911 cosine_u8
912);
913stamp!(
914 reference,
915 i8,
916 i8,
917 squared_l2_i8,
918 inner_product_i8,
919 cosine_i8
920);
921
922mod reference {
929 use diskann_wide::ARCH;
930 use half::f16;
931
932 trait MaybeFMA {
933 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
936 }
937
938 impl MaybeFMA for diskann_wide::arch::Scalar {
939 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
940 a * b + c
941 }
942 }
943
944 #[cfg(target_arch = "x86_64")]
945 impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
946 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
947 a.mul_add(b, c)
948 }
949 }
950
951 #[cfg(target_arch = "x86_64")]
952 impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
953 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
954 a.mul_add(b, c)
955 }
956 }
957
958 #[cfg(target_arch = "aarch64")]
959 impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
960 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
961 a.mul_add(b, c)
962 }
963 }
964
965 pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
970 assert_eq!(x.len(), y.len());
971 std::iter::zip(x.iter(), y.iter())
972 .map(|(&a, &b)| {
973 let a: i32 = a.into();
974 let b: i32 = b.into();
975 let diff = a - b;
976 diff * diff
977 })
978 .sum::<i32>() as f32
979 }
980
981 pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
982 assert_eq!(x.len(), y.len());
983 std::iter::zip(x.iter(), y.iter())
984 .map(|(&a, &b)| {
985 let a: i32 = a.into();
986 let b: i32 = b.into();
987 let diff = a - b;
988 diff * diff
989 })
990 .sum::<i32>() as f32
991 }
992
993 pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
994 assert_eq!(x.len(), y.len());
995 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
996 let a: f32 = a.into();
997 let b: f32 = b.into();
998 let diff = a - b;
999 ARCH.maybe_fma(diff, diff, acc)
1000 })
1001 }
1002
1003 pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1004 assert_eq!(x.len(), y.len());
1005 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1006 let diff = a - b;
1007 ARCH.maybe_fma(diff, diff, acc)
1008 })
1009 }
1010
1011 pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1016 assert_eq!(x.len(), y.len());
1017 std::iter::zip(x.iter(), y.iter())
1018 .map(|(&a, &b)| {
1019 let a: i32 = a.into();
1020 let b: i32 = b.into();
1021 a * b
1022 })
1023 .sum::<i32>() as f32
1024 }
1025
1026 pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1027 assert_eq!(x.len(), y.len());
1028 std::iter::zip(x.iter(), y.iter())
1029 .map(|(&a, &b)| {
1030 let a: i32 = a.into();
1031 let b: i32 = b.into();
1032 a * b
1033 })
1034 .sum::<i32>() as f32
1035 }
1036
1037 pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1038 assert_eq!(x.len(), y.len());
1039 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1040 let a: f32 = a.into();
1041 let b: f32 = b.into();
1042 ARCH.maybe_fma(a, b, acc)
1043 })
1044 }
1045
1046 pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1047 assert_eq!(x.len(), y.len());
1048 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1049 }
1050
1051 #[derive(Default)]
1056 struct XY<T> {
1057 xnorm: T,
1058 ynorm: T,
1059 xy: T,
1060 }
1061
1062 pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1063 assert_eq!(x.len(), y.len());
1064 let r: XY<i32> =
1065 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1066 let vx: i32 = vx.into();
1067 let vy: i32 = vy.into();
1068 XY {
1069 xnorm: acc.xnorm + vx * vx,
1070 ynorm: acc.ynorm + vy * vy,
1071 xy: acc.xy + vx * vy,
1072 }
1073 });
1074
1075 if r.xnorm == 0 || r.ynorm == 0 {
1076 return 0.0;
1077 }
1078
1079 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1080 }
1081
1082 pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1083 assert_eq!(x.len(), y.len());
1084 let r: XY<i32> =
1085 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1086 let vx: i32 = vx.into();
1087 let vy: i32 = vy.into();
1088 XY {
1089 xnorm: acc.xnorm + vx * vx,
1090 ynorm: acc.ynorm + vy * vy,
1091 xy: acc.xy + vx * vy,
1092 }
1093 });
1094
1095 if r.xnorm == 0 || r.ynorm == 0 {
1096 return 0.0;
1097 }
1098
1099 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1100 }
1101
1102 pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1103 assert_eq!(x.len(), y.len());
1104 let r: XY<f32> =
1105 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1106 let vx: f32 = vx.into();
1107 let vy: f32 = vy.into();
1108 XY {
1109 xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1110 ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1111 xy: ARCH.maybe_fma(vx, vy, acc.xy),
1112 }
1113 });
1114
1115 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1116 return 0.0;
1117 }
1118
1119 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1120 }
1121
1122 pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1123 assert_eq!(x.len(), y.len());
1124 let r: XY<f32> =
1125 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1126 xnorm: vx.mul_add(vx, acc.xnorm),
1127 ynorm: vy.mul_add(vy, acc.ynorm),
1128 xy: vx.mul_add(vy, acc.xy),
1129 });
1130
1131 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1132 return 0.0;
1133 }
1134
1135 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1136 }
1137}
1138
1139#[cfg(test)]
1144mod tests {
1145 use super::*;
1146
1147 use diskann_benchmark_runner::{
1148 benchmark::{PassFail, Regression},
1149 utils::percentiles::compute_percentiles,
1150 };
1151
1152 fn tiny_run(distance: SimilarityMeasure) -> Run {
1153 Run {
1154 distance,
1155 dim: NonZeroUsize::new(8).unwrap(),
1156 num_points: NonZeroUsize::new(1).unwrap(),
1157 loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1158 num_measurements: NonZeroUsize::new(1).unwrap(),
1159 }
1160 }
1161
1162 fn tiny_op() -> SimdOp {
1163 SimdOp {
1164 query_type: DataType::Float32,
1165 data_type: DataType::Float32,
1166 arch: Arch::Scalar,
1167 runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1168 }
1169 }
1170
1171 fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1172 let run = tiny_run(distance);
1173 let minimum = MicroSeconds::new(minimum);
1174 let mut latencies = vec![minimum];
1175 let percentiles = compute_percentiles(&mut latencies).unwrap();
1176 RunResult {
1177 run,
1178 latencies,
1179 percentiles,
1180 }
1181 }
1182
1183 fn tolerance(limit: f64) -> SimdTolerance {
1184 SimdTolerance {
1185 min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1186 }
1187 }
1188
1189 #[test]
1190 fn check_rejects_mismatched_runs() {
1191 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1192
1193 let err = kernel
1194 .check(
1195 &tolerance(0.0),
1196 &tiny_op(),
1197 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1198 &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1199 )
1200 .unwrap_err();
1201
1202 assert_eq!(err.to_string(), "run 0 mismatched");
1203 }
1204
1205 #[test]
1206 fn check_allows_negative_relative_change() {
1207 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1208
1209 let result = kernel
1210 .check(
1211 &tolerance(0.0),
1212 &tiny_op(),
1213 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1214 &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1215 )
1216 .unwrap();
1217
1218 assert!(matches!(result, PassFail::Pass(_)));
1219 }
1220
1221 #[test]
1222 fn check_passes_on_tolerance_boundary() {
1223 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1224
1225 let result = kernel
1226 .check(
1227 &tolerance(0.05),
1228 &tiny_op(),
1229 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1230 &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1231 )
1232 .unwrap();
1233
1234 assert!(matches!(result, PassFail::Pass(_)));
1235 }
1236
1237 #[test]
1238 fn check_fails_above_tolerance_boundary() {
1239 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1240
1241 let result = kernel
1242 .check(
1243 &tolerance(0.05),
1244 &tiny_op(),
1245 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1246 &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1247 )
1248 .unwrap();
1249
1250 assert!(matches!(result, PassFail::Fail(_)));
1251 }
1252
1253 #[test]
1254 fn check_result_display_includes_failure_details() {
1255 let check = CheckResult {
1256 checks: vec![Comparison {
1257 run: tiny_run(SimilarityMeasure::SquaredL2),
1258 tolerance: tolerance(0.05),
1259 before_min: 100.0,
1260 after_min: 106.0,
1261 }],
1262 };
1263
1264 let rendered = check.to_string();
1265 assert!(rendered.contains("Distance"), "rendered = {rendered}");
1266 assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1267 assert!(rendered.contains("100.000"), "rendered = {rendered}");
1268 assert!(rendered.contains("106.000"), "rendered = {rendered}");
1269 assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1270 assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1271 }
1272
1273 #[test]
1279 fn zero_values_rejected() {
1280 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1281
1282 let result = kernel
1283 .check(
1284 &tolerance(0.05),
1285 &tiny_op(),
1286 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1287 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1288 )
1289 .unwrap();
1290
1291 assert!(matches!(result, PassFail::Fail(_)));
1292 }
1293}