1use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15 distr::{Distribution, StandardUniform},
16 rngs::StdRng,
17 SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23 benchmark::{PassFail, Regression},
24 dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
25 utils::{
26 datatype::{self, DataType},
27 num::{relative_change, NonNegativeFinite},
28 percentiles, MicroSeconds,
29 },
30 Any, Benchmark, CheckDeserialization, Checker, Input,
31};
32
33pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
38 register_benchmarks_impl(dispatcher)
39}
40
41#[derive(Debug, Clone, Copy)]
46struct DisplayWrapper<'a, T: ?Sized>(&'a T);
47
48impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
49 type Target = T;
50 fn deref(&self) -> &T {
51 self.0
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SimilarityMeasure {
62 SquaredL2,
63 InnerProduct,
64 Cosine,
65}
66
67impl std::fmt::Display for SimilarityMeasure {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let st = match self {
70 Self::SquaredL2 => "squared_l2",
71 Self::InnerProduct => "inner_product",
72 Self::Cosine => "cosine",
73 };
74 write!(f, "{}", st)
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
79#[serde(rename_all = "kebab-case")]
80enum Arch {
81 #[serde(rename = "x86-64-v4")]
82 #[allow(non_camel_case_types)]
83 X86_64_V4,
84 #[serde(rename = "x86-64-v3")]
85 #[allow(non_camel_case_types)]
86 X86_64_V3,
87 Neon,
88 Scalar,
89 Reference,
90}
91
92impl std::fmt::Display for Arch {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 let st = match self {
95 Self::X86_64_V4 => "x86-64-v4",
96 Self::X86_64_V3 => "x86-64-v3",
97 Self::Neon => "neon",
98 Self::Scalar => "scalar",
99 Self::Reference => "reference",
100 };
101 write!(f, "{}", st)
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
106struct Run {
107 distance: SimilarityMeasure,
108 dim: NonZeroUsize,
109 num_points: NonZeroUsize,
110 loops_per_measurement: NonZeroUsize,
111 num_measurements: NonZeroUsize,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct SimdOp {
116 query_type: DataType,
117 data_type: DataType,
118 arch: Arch,
119 runs: Vec<Run>,
120}
121
122impl CheckDeserialization for SimdOp {
123 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
124 Ok(())
125 }
126}
127
128macro_rules! write_field {
129 ($f:ident, $field:tt, $($expr:tt)*) => {
130 writeln!($f, "{:>18}: {}", $field, $($expr)*)
131 }
132}
133
134impl SimdOp {
135 fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 write_field!(f, "query type", self.query_type)?;
137 write_field!(f, "data type", self.data_type)?;
138 write_field!(f, "arch", self.arch)?;
139 write_field!(f, "number of runs", self.runs.len())?;
140 Ok(())
141 }
142}
143
144impl std::fmt::Display for SimdOp {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 writeln!(f, "SIMD Operation\n")?;
147 write_field!(f, "tag", Self::tag())?;
148 self.summarize_fields(f)
149 }
150}
151
152impl Input for SimdOp {
153 fn tag() -> &'static str {
154 "simd-op"
155 }
156
157 fn try_deserialize(
158 serialized: &serde_json::Value,
159 checker: &mut Checker,
160 ) -> anyhow::Result<Any> {
161 checker.any(Self::deserialize(serialized)?)
162 }
163
164 fn example() -> anyhow::Result<serde_json::Value> {
165 const DIM: [NonZeroUsize; 2] = [
166 NonZeroUsize::new(128).unwrap(),
167 NonZeroUsize::new(150).unwrap(),
168 ];
169
170 const NUM_POINTS: [NonZeroUsize; 2] = [
171 NonZeroUsize::new(1000).unwrap(),
172 NonZeroUsize::new(800).unwrap(),
173 ];
174
175 const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
176 const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
177
178 let runs = vec![
179 Run {
180 distance: SimilarityMeasure::SquaredL2,
181 dim: DIM[0],
182 num_points: NUM_POINTS[0],
183 loops_per_measurement: LOOPS_PER_MEASUREMENT,
184 num_measurements: NUM_MEASUREMENTS,
185 },
186 Run {
187 distance: SimilarityMeasure::InnerProduct,
188 dim: DIM[1],
189 num_points: NUM_POINTS[1],
190 loops_per_measurement: LOOPS_PER_MEASUREMENT,
191 num_measurements: NUM_MEASUREMENTS,
192 },
193 ];
194
195 Ok(serde_json::to_value(&Self {
196 query_type: DataType::Float32,
197 data_type: DataType::Float32,
198 arch: Arch::X86_64_V3,
199 runs,
200 })?)
201 }
202}
203
204#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
213struct SimdTolerance {
214 min_time_regression: NonNegativeFinite,
215}
216
217impl CheckDeserialization for SimdTolerance {
218 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
219 Ok(())
220 }
221}
222
223impl Input for SimdTolerance {
224 fn tag() -> &'static str {
225 "simd-tolerance"
226 }
227
228 fn try_deserialize(
229 serialized: &serde_json::Value,
230 checker: &mut Checker,
231 ) -> anyhow::Result<Any> {
232 checker.any(Self::deserialize(serialized)?)
233 }
234
235 fn example() -> anyhow::Result<serde_json::Value> {
236 const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
237 Ok(v) => v,
238 Err(_) => panic!("use a non-negative finite please"),
239 };
240
241 Ok(serde_json::to_value(SimdTolerance {
242 min_time_regression: EXAMPLE,
243 })?)
244 }
245}
246
247#[derive(Debug, Serialize)]
249struct Comparison {
250 run: Run,
251 tolerance: SimdTolerance,
252 before_min: f64,
253 after_min: f64,
254}
255
256#[derive(Debug, Serialize)]
258struct CheckResult {
259 checks: Vec<Comparison>,
260}
261
262impl std::fmt::Display for CheckResult {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 let header = [
265 "Distance",
266 "Dim",
267 "Min Before (ns)",
268 "Min After (ns)",
269 "Change (%)",
270 "Remark",
271 ];
272
273 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
274
275 for (i, c) in self.checks.iter().enumerate() {
276 let mut row = table.row(i);
277 let change = relative_change(c.before_min, c.after_min);
278
279 row.insert(c.run.distance, 0);
280 row.insert(c.run.dim, 1);
281 row.insert(format!("{:.3}", c.before_min), 2);
282 row.insert(format!("{:.3}", c.after_min), 3);
283 match change {
284 Ok(change) => {
285 row.insert(format!("{:.3} %", change * 100.0), 4);
286 if change > c.tolerance.min_time_regression.get() {
287 row.insert("FAIL", 5);
288 }
289 }
290 Err(err) => {
291 row.insert("invalid", 4);
292 row.insert(err, 5);
293 }
294 }
295 }
296
297 table.fmt(f)
298 }
299}
300
301fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
306 #[cfg(target_arch = "x86_64")]
308 {
309 dispatcher.register_regression(
310 "simd-op-f32xf32-x86_64_V4",
311 Kernel::<diskann_wide::arch::x86_64::V4, f32, f32>::new(),
312 );
313 dispatcher.register_regression(
314 "simd-op-f16xf16-x86_64_V4",
315 Kernel::<diskann_wide::arch::x86_64::V4, f16, f16>::new(),
316 );
317 dispatcher.register_regression(
318 "simd-op-u8xu8-x86_64_V4",
319 Kernel::<diskann_wide::arch::x86_64::V4, u8, u8>::new(),
320 );
321 dispatcher.register_regression(
322 "simd-op-i8xi8-x86_64_V4",
323 Kernel::<diskann_wide::arch::x86_64::V4, i8, i8>::new(),
324 );
325 }
326
327 #[cfg(target_arch = "x86_64")]
329 {
330 dispatcher.register_regression(
331 "simd-op-f32xf32-x86_64_V3",
332 Kernel::<diskann_wide::arch::x86_64::V3, f32, f32>::new(),
333 );
334 dispatcher.register_regression(
335 "simd-op-f16xf16-x86_64_V3",
336 Kernel::<diskann_wide::arch::x86_64::V3, f16, f16>::new(),
337 );
338 dispatcher.register_regression(
339 "simd-op-u8xu8-x86_64_V3",
340 Kernel::<diskann_wide::arch::x86_64::V3, u8, u8>::new(),
341 );
342 dispatcher.register_regression(
343 "simd-op-i8xi8-x86_64_V3",
344 Kernel::<diskann_wide::arch::x86_64::V3, i8, i8>::new(),
345 );
346 }
347
348 #[cfg(target_arch = "aarch64")]
350 {
351 dispatcher.register_regression(
352 "simd-op-f32xf32-aarch64_neon",
353 Kernel::<diskann_wide::arch::aarch64::Neon, f32, f32>::new(),
354 );
355 dispatcher.register_regression(
356 "simd-op-f16xf16-aarch64_neon",
357 Kernel::<diskann_wide::arch::aarch64::Neon, f16, f16>::new(),
358 );
359 dispatcher.register_regression(
360 "simd-op-u8xu8-aarch64_neon",
361 Kernel::<diskann_wide::arch::aarch64::Neon, u8, u8>::new(),
362 );
363 dispatcher.register_regression(
364 "simd-op-i8xi8-aarch64_neon",
365 Kernel::<diskann_wide::arch::aarch64::Neon, i8, i8>::new(),
366 );
367 }
368
369 dispatcher.register_regression(
371 "simd-op-f32xf32-scalar",
372 Kernel::<diskann_wide::arch::Scalar, f32, f32>::new(),
373 );
374 dispatcher.register_regression(
375 "simd-op-f16xf16-scalar",
376 Kernel::<diskann_wide::arch::Scalar, f16, f16>::new(),
377 );
378 dispatcher.register_regression(
379 "simd-op-u8xu8-scalar",
380 Kernel::<diskann_wide::arch::Scalar, u8, u8>::new(),
381 );
382 dispatcher.register_regression(
383 "simd-op-i8xi8-scalar",
384 Kernel::<diskann_wide::arch::Scalar, i8, i8>::new(),
385 );
386
387 dispatcher.register_regression(
389 "simd-op-f32xf32-reference",
390 Kernel::<Reference, f32, f32>::new(),
391 );
392 dispatcher.register_regression(
393 "simd-op-f16xf16-reference",
394 Kernel::<Reference, f16, f16>::new(),
395 );
396 dispatcher.register_regression(
397 "simd-op-u8xu8-reference",
398 Kernel::<Reference, u8, u8>::new(),
399 );
400 dispatcher.register_regression(
401 "simd-op-i8xi8-reference",
402 Kernel::<Reference, i8, i8>::new(),
403 );
404}
405
406struct Reference;
412
413#[derive(Debug)]
415struct Identity<T>(T);
416
417struct Kernel<A, Q, D> {
418 _type: std::marker::PhantomData<(A, Q, D)>,
419}
420
421impl<A, Q, D> Kernel<A, Q, D> {
422 fn new() -> Self {
423 Self {
424 _type: std::marker::PhantomData,
425 }
426 }
427}
428
429#[derive(Debug, Error)]
431#[error("architecture {0} is not supported by this CPU")]
432pub(crate) struct ArchNotSupported(Arch);
433
434impl DispatchRule<Arch> for Identity<Reference> {
435 type Error = ArchNotSupported;
436
437 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
438 if *from == Arch::Reference {
439 Ok(MatchScore(0))
440 } else {
441 Err(FailureScore(1))
442 }
443 }
444
445 fn convert(from: Arch) -> Result<Self, Self::Error> {
446 assert_eq!(from, Arch::Reference);
447 Ok(Identity(Reference))
448 }
449
450 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
451 match from {
452 None => write!(f, "loop based"),
453 Some(arch) => {
454 if Self::try_match(arch).is_ok() {
455 write!(f, "matched {}", arch)
456 } else {
457 write!(f, "expected {}, got {}", Arch::Reference, arch)
458 }
459 }
460 }
461 }
462}
463
464impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
465 type Error = ArchNotSupported;
466
467 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
468 if *from == Arch::Scalar {
469 Ok(MatchScore(0))
470 } else {
471 Err(FailureScore(1))
472 }
473 }
474
475 fn convert(from: Arch) -> Result<Self, Self::Error> {
476 assert_eq!(from, Arch::Scalar);
477 Ok(Identity(diskann_wide::arch::Scalar))
478 }
479
480 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
481 match from {
482 None => write!(f, "scalar (compilation target CPU)"),
483 Some(arch) => {
484 if Self::try_match(arch).is_ok() {
485 write!(f, "matched {}", arch)
486 } else {
487 write!(f, "expected {}, got {}", Arch::Scalar, arch)
488 }
489 }
490 }
491 }
492}
493
494macro_rules! match_arch {
495 ($target_arch:literal, $arch:path, $enum:ident) => {
496 #[cfg(target_arch = $target_arch)]
497 impl DispatchRule<Arch> for Identity<$arch> {
498 type Error = ArchNotSupported;
499
500 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
501 let available = <$arch>::new_checked().is_some();
502 if available && *from == Arch::$enum {
503 Ok(MatchScore(0))
504 } else if !available && *from == Arch::$enum {
505 Err(FailureScore(0))
506 } else {
507 Err(FailureScore(1))
508 }
509 }
510
511 fn convert(from: Arch) -> Result<Self, Self::Error> {
512 assert_eq!(from, Arch::$enum);
513 <$arch>::new_checked()
514 .ok_or(ArchNotSupported(from))
515 .map(Identity)
516 }
517
518 fn description(
519 f: &mut std::fmt::Formatter<'_>,
520 from: Option<&Arch>,
521 ) -> std::fmt::Result {
522 let available = <$arch>::new_checked().is_some();
523 match from {
524 None => write!(f, "{}", Arch::$enum),
525 Some(arch) => {
526 if Self::try_match(arch).is_ok() {
527 write!(f, "matched {}", arch)
528 } else if !available && *arch == Arch::$enum {
529 write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
530 } else {
531 write!(f, "expected {}, got {}", Arch::$enum, arch)
532 }
533 }
534 }
535 }
536 }
537 };
538}
539
540match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
541match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
542match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
543
544impl<A, Q, D> Benchmark for Kernel<A, Q, D>
545where
546 datatype::Type<Q>: DispatchRule<datatype::DataType>,
547 datatype::Type<D>: DispatchRule<datatype::DataType>,
548 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
549 Kernel<A, Q, D>: RunBenchmark<A>,
550 A: 'static,
551 Q: 'static,
552 D: 'static,
553{
554 type Input = SimdOp;
555 type Output = Vec<RunResult>;
556
557 fn try_match(&self, from: &SimdOp) -> Result<MatchScore, FailureScore> {
559 let mut failscore: Option<u32> = None;
560 if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
561 *failscore.get_or_insert(0) += 10;
562 }
563 if datatype::Type::<D>::try_match(&from.data_type).is_err() {
564 *failscore.get_or_insert(0) += 10;
565 }
566 if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
567 *failscore.get_or_insert(0) += 2 + score;
568 }
569
570 match failscore {
571 None => Ok(MatchScore(0)),
572 Some(score) => Err(FailureScore(score)),
573 }
574 }
575
576 fn run(
577 &self,
578 input: &SimdOp,
579 _: diskann_benchmark_runner::Checkpoint<'_>,
580 mut output: &mut dyn diskann_benchmark_runner::Output,
581 ) -> anyhow::Result<Self::Output> {
582 let arch = Identity::<A>::convert(input.arch)?.0;
583 writeln!(output, "{}", input)?;
584 let results = self.run_benchmark(input, arch)?;
585 writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
586 Ok(results)
587 }
588
589 fn description(
590 &self,
591 f: &mut std::fmt::Formatter<'_>,
592 input: Option<&SimdOp>,
593 ) -> std::fmt::Result {
594 match input {
595 None => {
596 writeln!(
597 f,
598 "- Query Type: {}",
599 Description::<datatype::DataType, datatype::Type<Q>>::new()
600 )?;
601 writeln!(
602 f,
603 "- Data Type: {}",
604 Description::<datatype::DataType, datatype::Type<D>>::new()
605 )?;
606 writeln!(
607 f,
608 "- Implementation: {}",
609 Description::<Arch, Identity<A>>::new()
610 )?;
611 }
612 Some(input) => {
613 if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
614 writeln!(f, "\n - Mismatched query type: {}", err)?;
615 }
616 if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
617 writeln!(f, "\n - Mismatched data type: {}", err)?;
618 }
619 if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
620 writeln!(f, "\n - Mismatched architecture: {}", err)?;
621 }
622 }
623 }
624 Ok(())
625 }
626}
627
628impl<A, Q, D> Regression for Kernel<A, Q, D>
629where
630 datatype::Type<Q>: DispatchRule<datatype::DataType>,
631 datatype::Type<D>: DispatchRule<datatype::DataType>,
632 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
633 Kernel<A, Q, D>: RunBenchmark<A>,
634 A: 'static,
635 Q: 'static,
636 D: 'static,
637{
638 type Tolerances = SimdTolerance;
639 type Pass = CheckResult;
640 type Fail = CheckResult;
641
642 fn check(
643 &self,
644 tolerance: &SimdTolerance,
645 _input: &SimdOp,
646 before: &Vec<RunResult>,
647 after: &Vec<RunResult>,
648 ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
649 anyhow::ensure!(
650 before.len() == after.len(),
651 "before has {} runs but after has {}",
652 before.len(),
653 after.len(),
654 );
655
656 let mut passed = true;
657 let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
658 .enumerate()
659 .map(|(i, (b, a))| {
660 anyhow::ensure!(b.run == a.run, "run {i} mismatched");
661
662 let computations_per_latency = b.computations_per_latency() as f64;
663
664 let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
665 let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
666
667 let comparison = Comparison {
668 run: b.run.clone(),
669 tolerance: *tolerance,
670 before_min,
671 after_min,
672 };
673
674 match relative_change(before_min, after_min) {
676 Ok(change) => {
677 if change > tolerance.min_time_regression.get() {
678 passed = false;
679 }
680 }
681 Err(_) => passed = false,
682 };
683
684 Ok(comparison)
685 })
686 .collect::<anyhow::Result<Vec<Comparison>>>()?;
687
688 let check = CheckResult { checks };
689
690 if passed {
691 Ok(PassFail::Pass(check))
692 } else {
693 Ok(PassFail::Fail(check))
694 }
695 }
696}
697
698trait RunBenchmark<A> {
703 fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result<Vec<RunResult>, anyhow::Error>;
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707struct RunResult {
708 run: Run,
710 latencies: Vec<MicroSeconds>,
712 percentiles: percentiles::Percentiles<MicroSeconds>,
714}
715
716impl RunResult {
717 fn computations_per_latency(&self) -> usize {
718 self.run.num_points.get() * self.run.loops_per_measurement.get()
719 }
720}
721
722impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
723 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
724 if self.is_empty() {
725 return Ok(());
726 }
727
728 let header = [
729 "Distance",
730 "Dim",
731 "Min Time (ns)",
732 "Mean Time (ns)",
733 "Points",
734 "Loops",
735 "Measurements",
736 ];
737
738 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
739
740 self.iter().enumerate().for_each(|(row, r)| {
741 let mut row = table.row(row);
742
743 let min_latency = r
744 .latencies
745 .iter()
746 .min()
747 .copied()
748 .unwrap_or(MicroSeconds::new(u64::MAX));
749 let mean_latency = r.percentiles.mean;
750
751 let computations_per_latency = r.computations_per_latency() as f64;
752
753 let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
755 let mean_time = mean_latency / computations_per_latency * 1000.0;
756
757 row.insert(r.run.distance, 0);
758 row.insert(r.run.dim, 1);
759 row.insert(format!("{:.3}", min_time), 2);
760 row.insert(format!("{:.3}", mean_time), 3);
761 row.insert(r.run.num_points, 4);
762 row.insert(r.run.loops_per_measurement, 5);
763 row.insert(r.run.num_measurements, 6);
764 });
765
766 table.fmt(f)
767 }
768}
769
770fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
771where
772 F: Fn(&[Q], &[D]) -> f32,
773{
774 let mut latencies = Vec::with_capacity(run.num_measurements.get());
775 let mut dst = vec![0.0; data.nrows()];
776
777 for _ in 0..run.num_measurements.get() {
778 let start = std::time::Instant::now();
779 for _ in 0..run.loops_per_measurement.get() {
780 std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
781 *d = f(query, r);
782 });
783 std::hint::black_box(&mut dst);
784 }
785 latencies.push(start.elapsed().into());
786 }
787
788 let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
789 RunResult {
790 run: run.clone(),
791 latencies,
792 percentiles,
793 }
794}
795
796struct Data<Q, D> {
797 query: Box<[Q]>,
798 data: Matrix<D>,
799}
800
801impl<Q, D> Data<Q, D> {
802 fn new(run: &Run) -> Self
803 where
804 StandardUniform: Distribution<Q>,
805 StandardUniform: Distribution<D>,
806 {
807 let mut rng = StdRng::seed_from_u64(0x12345);
808 let query: Box<[Q]> = (0..run.dim.get())
809 .map(|_| StandardUniform.sample(&mut rng))
810 .collect();
811 let data = Matrix::<D>::new(
812 diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
813 run.num_points.get(),
814 run.dim.get(),
815 );
816
817 Self { query, data }
818 }
819
820 fn run<F>(&self, run: &Run, f: F) -> RunResult
821 where
822 F: Fn(&[Q], &[D]) -> f32,
823 {
824 run_loops(&self.query, self.data.as_view(), run, f)
825 }
826}
827
828macro_rules! stamp {
833 (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
834 impl RunBenchmark<Reference> for Kernel<Reference, $Q, $D> {
835 fn run_benchmark(
836 &self,
837 input: &SimdOp,
838 _arch: Reference,
839 ) -> Result<Vec<RunResult>, anyhow::Error> {
840 let mut results = Vec::new();
841 for run in input.runs.iter() {
842 let data = Data::<$Q, $D>::new(run);
843 let result = match run.distance {
844 SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
845 SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
846 SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
847 };
848 results.push(result);
849 }
850 Ok(results)
851 }
852 }
853 };
854 ($arch:path, $Q:ty, $D:ty) => {
855 impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> {
856 fn run_benchmark(
857 &self,
858 input: &SimdOp,
859 arch: $arch,
860 ) -> Result<Vec<RunResult>, anyhow::Error> {
861 let mut results = Vec::new();
862
863 let l2 = &simd::L2 {};
864 let ip = &simd::IP {};
865 let cosine = &simd::CosineStateless {};
866
867 for run in input.runs.iter() {
868 let data = Data::<$Q, $D>::new(run);
869 let result = match run.distance {
876 SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
877 arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d)
878 }),
879 SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
880 arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d)
881 }),
882 SimilarityMeasure::Cosine => data.run(run, |q, d| {
883 arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d)
884 }),
885 };
886 results.push(result)
887 }
888 Ok(results)
889 }
890 }
891 };
892 ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
893 #[cfg(target_arch = $target_arch)]
894 stamp!($arch, $Q, $D);
895 };
896}
897
898stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
899stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
900stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
901stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
902
903stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
904stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
905stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
906stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
907
908stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
909stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
910stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
911stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
912
913stamp!(diskann_wide::arch::Scalar, f32, f32);
914stamp!(diskann_wide::arch::Scalar, f16, f16);
915stamp!(diskann_wide::arch::Scalar, u8, u8);
916stamp!(diskann_wide::arch::Scalar, i8, i8);
917
918stamp!(
919 reference,
920 f32,
921 f32,
922 squared_l2_f32,
923 inner_product_f32,
924 cosine_f32
925);
926stamp!(
927 reference,
928 f16,
929 f16,
930 squared_l2_f16,
931 inner_product_f16,
932 cosine_f16
933);
934stamp!(
935 reference,
936 u8,
937 u8,
938 squared_l2_u8,
939 inner_product_u8,
940 cosine_u8
941);
942stamp!(
943 reference,
944 i8,
945 i8,
946 squared_l2_i8,
947 inner_product_i8,
948 cosine_i8
949);
950
951mod reference {
958 use diskann_wide::ARCH;
959 use half::f16;
960
961 trait MaybeFMA {
962 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
965 }
966
967 impl MaybeFMA for diskann_wide::arch::Scalar {
968 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
969 a * b + c
970 }
971 }
972
973 #[cfg(target_arch = "x86_64")]
974 impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
975 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
976 a.mul_add(b, c)
977 }
978 }
979
980 #[cfg(target_arch = "x86_64")]
981 impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
982 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
983 a.mul_add(b, c)
984 }
985 }
986
987 #[cfg(target_arch = "aarch64")]
988 impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
989 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
990 a.mul_add(b, c)
991 }
992 }
993
994 pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
999 assert_eq!(x.len(), y.len());
1000 std::iter::zip(x.iter(), y.iter())
1001 .map(|(&a, &b)| {
1002 let a: i32 = a.into();
1003 let b: i32 = b.into();
1004 let diff = a - b;
1005 diff * diff
1006 })
1007 .sum::<i32>() as f32
1008 }
1009
1010 pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
1011 assert_eq!(x.len(), y.len());
1012 std::iter::zip(x.iter(), y.iter())
1013 .map(|(&a, &b)| {
1014 let a: i32 = a.into();
1015 let b: i32 = b.into();
1016 let diff = a - b;
1017 diff * diff
1018 })
1019 .sum::<i32>() as f32
1020 }
1021
1022 pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
1023 assert_eq!(x.len(), y.len());
1024 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1025 let a: f32 = a.into();
1026 let b: f32 = b.into();
1027 let diff = a - b;
1028 ARCH.maybe_fma(diff, diff, acc)
1029 })
1030 }
1031
1032 pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1033 assert_eq!(x.len(), y.len());
1034 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1035 let diff = a - b;
1036 ARCH.maybe_fma(diff, diff, acc)
1037 })
1038 }
1039
1040 pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1045 assert_eq!(x.len(), y.len());
1046 std::iter::zip(x.iter(), y.iter())
1047 .map(|(&a, &b)| {
1048 let a: i32 = a.into();
1049 let b: i32 = b.into();
1050 a * b
1051 })
1052 .sum::<i32>() as f32
1053 }
1054
1055 pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1056 assert_eq!(x.len(), y.len());
1057 std::iter::zip(x.iter(), y.iter())
1058 .map(|(&a, &b)| {
1059 let a: i32 = a.into();
1060 let b: i32 = b.into();
1061 a * b
1062 })
1063 .sum::<i32>() as f32
1064 }
1065
1066 pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1067 assert_eq!(x.len(), y.len());
1068 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1069 let a: f32 = a.into();
1070 let b: f32 = b.into();
1071 ARCH.maybe_fma(a, b, acc)
1072 })
1073 }
1074
1075 pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1076 assert_eq!(x.len(), y.len());
1077 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1078 }
1079
1080 #[derive(Default)]
1085 struct XY<T> {
1086 xnorm: T,
1087 ynorm: T,
1088 xy: T,
1089 }
1090
1091 pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1092 assert_eq!(x.len(), y.len());
1093 let r: XY<i32> =
1094 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1095 let vx: i32 = vx.into();
1096 let vy: i32 = vy.into();
1097 XY {
1098 xnorm: acc.xnorm + vx * vx,
1099 ynorm: acc.ynorm + vy * vy,
1100 xy: acc.xy + vx * vy,
1101 }
1102 });
1103
1104 if r.xnorm == 0 || r.ynorm == 0 {
1105 return 0.0;
1106 }
1107
1108 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1109 }
1110
1111 pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1112 assert_eq!(x.len(), y.len());
1113 let r: XY<i32> =
1114 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1115 let vx: i32 = vx.into();
1116 let vy: i32 = vy.into();
1117 XY {
1118 xnorm: acc.xnorm + vx * vx,
1119 ynorm: acc.ynorm + vy * vy,
1120 xy: acc.xy + vx * vy,
1121 }
1122 });
1123
1124 if r.xnorm == 0 || r.ynorm == 0 {
1125 return 0.0;
1126 }
1127
1128 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1129 }
1130
1131 pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1132 assert_eq!(x.len(), y.len());
1133 let r: XY<f32> =
1134 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1135 let vx: f32 = vx.into();
1136 let vy: f32 = vy.into();
1137 XY {
1138 xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1139 ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1140 xy: ARCH.maybe_fma(vx, vy, acc.xy),
1141 }
1142 });
1143
1144 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1145 return 0.0;
1146 }
1147
1148 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1149 }
1150
1151 pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1152 assert_eq!(x.len(), y.len());
1153 let r: XY<f32> =
1154 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1155 xnorm: vx.mul_add(vx, acc.xnorm),
1156 ynorm: vy.mul_add(vy, acc.ynorm),
1157 xy: vx.mul_add(vy, acc.xy),
1158 });
1159
1160 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1161 return 0.0;
1162 }
1163
1164 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1165 }
1166}
1167
1168#[cfg(test)]
1173mod tests {
1174 use super::*;
1175
1176 use diskann_benchmark_runner::{
1177 benchmark::{PassFail, Regression},
1178 utils::percentiles::compute_percentiles,
1179 };
1180
1181 fn tiny_run(distance: SimilarityMeasure) -> Run {
1182 Run {
1183 distance,
1184 dim: NonZeroUsize::new(8).unwrap(),
1185 num_points: NonZeroUsize::new(1).unwrap(),
1186 loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1187 num_measurements: NonZeroUsize::new(1).unwrap(),
1188 }
1189 }
1190
1191 fn tiny_op() -> SimdOp {
1192 SimdOp {
1193 query_type: DataType::Float32,
1194 data_type: DataType::Float32,
1195 arch: Arch::Scalar,
1196 runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1197 }
1198 }
1199
1200 fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1201 let run = tiny_run(distance);
1202 let minimum = MicroSeconds::new(minimum);
1203 let mut latencies = vec![minimum];
1204 let percentiles = compute_percentiles(&mut latencies).unwrap();
1205 RunResult {
1206 run,
1207 latencies,
1208 percentiles,
1209 }
1210 }
1211
1212 fn tolerance(limit: f64) -> SimdTolerance {
1213 SimdTolerance {
1214 min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1215 }
1216 }
1217
1218 #[test]
1219 fn check_rejects_mismatched_runs() {
1220 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1221
1222 let err = kernel
1223 .check(
1224 &tolerance(0.0),
1225 &tiny_op(),
1226 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1227 &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1228 )
1229 .unwrap_err();
1230
1231 assert_eq!(err.to_string(), "run 0 mismatched");
1232 }
1233
1234 #[test]
1235 fn check_allows_negative_relative_change() {
1236 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1237
1238 let result = kernel
1239 .check(
1240 &tolerance(0.0),
1241 &tiny_op(),
1242 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1243 &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1244 )
1245 .unwrap();
1246
1247 assert!(matches!(result, PassFail::Pass(_)));
1248 }
1249
1250 #[test]
1251 fn check_passes_on_tolerance_boundary() {
1252 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1253
1254 let result = kernel
1255 .check(
1256 &tolerance(0.05),
1257 &tiny_op(),
1258 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1259 &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1260 )
1261 .unwrap();
1262
1263 assert!(matches!(result, PassFail::Pass(_)));
1264 }
1265
1266 #[test]
1267 fn check_fails_above_tolerance_boundary() {
1268 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1269
1270 let result = kernel
1271 .check(
1272 &tolerance(0.05),
1273 &tiny_op(),
1274 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1275 &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1276 )
1277 .unwrap();
1278
1279 assert!(matches!(result, PassFail::Fail(_)));
1280 }
1281
1282 #[test]
1283 fn check_result_display_includes_failure_details() {
1284 let check = CheckResult {
1285 checks: vec![Comparison {
1286 run: tiny_run(SimilarityMeasure::SquaredL2),
1287 tolerance: tolerance(0.05),
1288 before_min: 100.0,
1289 after_min: 106.0,
1290 }],
1291 };
1292
1293 let rendered = check.to_string();
1294 assert!(rendered.contains("Distance"), "rendered = {rendered}");
1295 assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1296 assert!(rendered.contains("100.000"), "rendered = {rendered}");
1297 assert!(rendered.contains("106.000"), "rendered = {rendered}");
1298 assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1299 assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1300 }
1301
1302 #[test]
1308 fn zero_values_rejected() {
1309 let kernel = Kernel::<diskann_wide::arch::Scalar, f32, f32>::new();
1310
1311 let result = kernel
1312 .check(
1313 &tolerance(0.05),
1314 &tiny_op(),
1315 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1316 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1317 )
1318 .unwrap();
1319
1320 assert!(matches!(result, PassFail::Fail(_)));
1321 }
1322}