1use std::{io::Write, num::NonZeroUsize};
9
10use diskann_utils::views::{Matrix, MatrixView};
11use diskann_vector::distance::simd;
12use diskann_wide::Architecture;
13use half::f16;
14use rand::{
15 distr::{Distribution, StandardUniform},
16 rngs::StdRng,
17 SeedableRng,
18};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22use diskann_benchmark_runner::{
23 benchmark::{PassFail, Regression},
24 describeln,
25 dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
26 utils::{
27 datatype::{self, DataType},
28 num::{relative_change, NonNegativeFinite},
29 percentiles, MicroSeconds,
30 },
31 Any, Benchmark, CheckDeserialization, Checker, Input,
32};
33
34pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
39 register_benchmarks_impl(dispatcher)
40}
41
42#[derive(Debug, Clone, Copy)]
47struct DisplayWrapper<'a, T: ?Sized>(&'a T);
48
49impl<T: ?Sized> std::ops::Deref for DisplayWrapper<'_, T> {
50 type Target = T;
51 fn deref(&self) -> &T {
52 self.0
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
61#[serde(rename_all = "snake_case")]
62pub enum SimilarityMeasure {
63 SquaredL2,
64 InnerProduct,
65 Cosine,
66}
67
68impl std::fmt::Display for SimilarityMeasure {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 let st = match self {
71 Self::SquaredL2 => "squared_l2",
72 Self::InnerProduct => "inner_product",
73 Self::Cosine => "cosine",
74 };
75 write!(f, "{}", st)
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
80#[serde(rename_all = "kebab-case")]
81enum Arch {
82 #[serde(rename = "x86-64-v4")]
83 #[allow(non_camel_case_types)]
84 X86_64_V4,
85 #[serde(rename = "x86-64-v3")]
86 #[allow(non_camel_case_types)]
87 X86_64_V3,
88 Neon,
89 Scalar,
90 Reference,
91}
92
93impl std::fmt::Display for Arch {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 let st = match self {
96 Self::X86_64_V4 => "x86-64-v4",
97 Self::X86_64_V3 => "x86-64-v3",
98 Self::Neon => "neon",
99 Self::Scalar => "scalar",
100 Self::Reference => "reference",
101 };
102 write!(f, "{}", st)
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
107struct Run {
108 distance: SimilarityMeasure,
109 dim: NonZeroUsize,
110 num_points: NonZeroUsize,
111 loops_per_measurement: NonZeroUsize,
112 num_measurements: NonZeroUsize,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct SimdOp {
117 query_type: DataType,
118 data_type: DataType,
119 arch: Arch,
120 runs: Vec<Run>,
121}
122
123impl CheckDeserialization for SimdOp {
124 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
125 Ok(())
126 }
127}
128
129macro_rules! write_field {
130 ($f:ident, $field:tt, $($expr:tt)*) => {
131 writeln!($f, "{:>18}: {}", $field, $($expr)*)
132 }
133}
134
135impl SimdOp {
136 fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write_field!(f, "query type", self.query_type)?;
138 write_field!(f, "data type", self.data_type)?;
139 write_field!(f, "arch", self.arch)?;
140 write_field!(f, "number of runs", self.runs.len())?;
141 Ok(())
142 }
143}
144
145impl std::fmt::Display for SimdOp {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 writeln!(f, "SIMD Operation\n")?;
148 write_field!(f, "tag", Self::tag())?;
149 self.summarize_fields(f)
150 }
151}
152
153impl Input for SimdOp {
154 fn tag() -> &'static str {
155 "simd-op"
156 }
157
158 fn try_deserialize(
159 serialized: &serde_json::Value,
160 checker: &mut Checker,
161 ) -> anyhow::Result<Any> {
162 checker.any(Self::deserialize(serialized)?)
163 }
164
165 fn example() -> anyhow::Result<serde_json::Value> {
166 const DIM: [NonZeroUsize; 2] = [
167 NonZeroUsize::new(128).unwrap(),
168 NonZeroUsize::new(150).unwrap(),
169 ];
170
171 const NUM_POINTS: [NonZeroUsize; 2] = [
172 NonZeroUsize::new(1000).unwrap(),
173 NonZeroUsize::new(800).unwrap(),
174 ];
175
176 const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(100).unwrap();
177 const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap();
178
179 let runs = vec![
180 Run {
181 distance: SimilarityMeasure::SquaredL2,
182 dim: DIM[0],
183 num_points: NUM_POINTS[0],
184 loops_per_measurement: LOOPS_PER_MEASUREMENT,
185 num_measurements: NUM_MEASUREMENTS,
186 },
187 Run {
188 distance: SimilarityMeasure::InnerProduct,
189 dim: DIM[1],
190 num_points: NUM_POINTS[1],
191 loops_per_measurement: LOOPS_PER_MEASUREMENT,
192 num_measurements: NUM_MEASUREMENTS,
193 },
194 ];
195
196 Ok(serde_json::to_value(&Self {
197 query_type: DataType::Float32,
198 data_type: DataType::Float32,
199 arch: Arch::X86_64_V3,
200 runs,
201 })?)
202 }
203}
204
205#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
214struct SimdTolerance {
215 min_time_regression: NonNegativeFinite,
216}
217
218impl CheckDeserialization for SimdTolerance {
219 fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
220 Ok(())
221 }
222}
223
224impl Input for SimdTolerance {
225 fn tag() -> &'static str {
226 "simd-tolerance"
227 }
228
229 fn try_deserialize(
230 serialized: &serde_json::Value,
231 checker: &mut Checker,
232 ) -> anyhow::Result<Any> {
233 checker.any(Self::deserialize(serialized)?)
234 }
235
236 fn example() -> anyhow::Result<serde_json::Value> {
237 const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
238 Ok(v) => v,
239 Err(_) => panic!("use a non-negative finite please"),
240 };
241
242 Ok(serde_json::to_value(SimdTolerance {
243 min_time_regression: EXAMPLE,
244 })?)
245 }
246}
247
248#[derive(Debug, Serialize)]
250struct Comparison {
251 run: Run,
252 tolerance: SimdTolerance,
253 before_min: f64,
254 after_min: f64,
255}
256
257#[derive(Debug, Serialize)]
259struct CheckResult {
260 checks: Vec<Comparison>,
261}
262
263impl std::fmt::Display for CheckResult {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 let header = [
266 "Distance",
267 "Dim",
268 "Min Before (ns)",
269 "Min After (ns)",
270 "Change (%)",
271 "Remark",
272 ];
273
274 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len());
275
276 for (i, c) in self.checks.iter().enumerate() {
277 let mut row = table.row(i);
278 let change = relative_change(c.before_min, c.after_min);
279
280 row.insert(c.run.distance, 0);
281 row.insert(c.run.dim, 1);
282 row.insert(format!("{:.3}", c.before_min), 2);
283 row.insert(format!("{:.3}", c.after_min), 3);
284 match change {
285 Ok(change) => {
286 row.insert(format!("{:.3} %", change * 100.0), 4);
287 if change > c.tolerance.min_time_regression.get() {
288 row.insert("FAIL", 5);
289 }
290 }
291 Err(err) => {
292 row.insert("invalid", 4);
293 row.insert(err, 5);
294 }
295 }
296 }
297
298 table.fmt(f)
299 }
300}
301
302macro_rules! register {
307 ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => {
308 #[cfg(target_arch = $arch)]
309 $dispatcher.register_regression::<$($kernel)*>($name)
310 };
311 ($dispatcher:ident, $name:literal, $($kernel:tt)*) => {
312 $dispatcher.register_regression::<$($kernel)*>($name)
313 };
314}
315
316fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) {
317 register!(
319 "x86_64",
320 dispatcher,
321 "simd-op-f32xf32-x86_64_V4",
322 Kernel<diskann_wide::arch::x86_64::V4, f32, f32>
323 );
324 register!(
325 "x86_64",
326 dispatcher,
327 "simd-op-f16xf16-x86_64_V4",
328 Kernel<diskann_wide::arch::x86_64::V4, f16, f16>
329 );
330 register!(
331 "x86_64",
332 dispatcher,
333 "simd-op-u8xu8-x86_64_V4",
334 Kernel<diskann_wide::arch::x86_64::V4, u8, u8>
335 );
336 register!(
337 "x86_64",
338 dispatcher,
339 "simd-op-i8xi8-x86_64_V4",
340 Kernel<diskann_wide::arch::x86_64::V4, i8, i8>
341 );
342
343 register!(
345 "x86_64",
346 dispatcher,
347 "simd-op-f32xf32-x86_64_V3",
348 Kernel<diskann_wide::arch::x86_64::V3, f32, f32>
349 );
350 register!(
351 "x86_64",
352 dispatcher,
353 "simd-op-f16xf16-x86_64_V3",
354 Kernel<diskann_wide::arch::x86_64::V3, f16, f16>
355 );
356 register!(
357 "x86_64",
358 dispatcher,
359 "simd-op-u8xu8-x86_64_V3",
360 Kernel<diskann_wide::arch::x86_64::V3, u8, u8>
361 );
362 register!(
363 "x86_64",
364 dispatcher,
365 "simd-op-i8xi8-x86_64_V3",
366 Kernel<diskann_wide::arch::x86_64::V3, i8, i8>
367 );
368
369 register!(
371 "aarch64",
372 dispatcher,
373 "simd-op-f32xf32-aarch64_neon",
374 Kernel<diskann_wide::arch::aarch64::Neon, f32, f32>
375 );
376 register!(
377 "aarch64",
378 dispatcher,
379 "simd-op-f16xf16-aarch64_neon",
380 Kernel<diskann_wide::arch::aarch64::Neon, f16, f16>
381 );
382 register!(
383 "aarch64",
384 dispatcher,
385 "simd-op-u8xu8-aarch64_neon",
386 Kernel<diskann_wide::arch::aarch64::Neon, u8, u8>
387 );
388 register!(
389 "aarch64",
390 dispatcher,
391 "simd-op-i8xi8-aarch64_neon",
392 Kernel<diskann_wide::arch::aarch64::Neon, i8, i8>
393 );
394
395 register!(
397 dispatcher,
398 "simd-op-f32xf32-scalar",
399 Kernel<diskann_wide::arch::Scalar, f32, f32>
400 );
401 register!(
402 dispatcher,
403 "simd-op-f16xf16-scalar",
404 Kernel<diskann_wide::arch::Scalar, f16, f16>
405 );
406 register!(
407 dispatcher,
408 "simd-op-u8xu8-scalar",
409 Kernel<diskann_wide::arch::Scalar, u8, u8>
410 );
411 register!(
412 dispatcher,
413 "simd-op-i8xi8-scalar",
414 Kernel<diskann_wide::arch::Scalar, i8, i8>
415 );
416
417 register!(
419 dispatcher,
420 "simd-op-f32xf32-reference",
421 Kernel<Reference, f32, f32>
422 );
423 register!(
424 dispatcher,
425 "simd-op-f16xf16-reference",
426 Kernel<Reference, f16, f16>
427 );
428 register!(
429 dispatcher,
430 "simd-op-u8xu8-reference",
431 Kernel<Reference, u8, u8>
432 );
433 register!(
434 dispatcher,
435 "simd-op-i8xi8-reference",
436 Kernel<Reference, i8, i8>
437 );
438}
439
440struct Reference;
446
447#[derive(Debug)]
449struct Identity<T>(T);
450
451struct Kernel<A, Q, D> {
452 arch: A,
453 _type: std::marker::PhantomData<(A, Q, D)>,
454}
455
456impl<A, Q, D> Kernel<A, Q, D> {
457 fn new(arch: A) -> Self {
458 Self {
459 arch,
460 _type: std::marker::PhantomData,
461 }
462 }
463}
464
465#[derive(Debug, Error)]
467#[error("architecture {0} is not supported by this CPU")]
468pub(crate) struct ArchNotSupported(Arch);
469
470impl DispatchRule<Arch> for Identity<Reference> {
471 type Error = ArchNotSupported;
472
473 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
474 if *from == Arch::Reference {
475 Ok(MatchScore(0))
476 } else {
477 Err(FailureScore(1))
478 }
479 }
480
481 fn convert(from: Arch) -> Result<Self, Self::Error> {
482 assert_eq!(from, Arch::Reference);
483 Ok(Identity(Reference))
484 }
485
486 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
487 match from {
488 None => write!(f, "loop based"),
489 Some(arch) => {
490 if Self::try_match(arch).is_ok() {
491 write!(f, "matched {}", arch)
492 } else {
493 write!(f, "expected {}, got {}", Arch::Reference, arch)
494 }
495 }
496 }
497 }
498}
499
500impl DispatchRule<Arch> for Identity<diskann_wide::arch::Scalar> {
501 type Error = ArchNotSupported;
502
503 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
504 if *from == Arch::Scalar {
505 Ok(MatchScore(0))
506 } else {
507 Err(FailureScore(1))
508 }
509 }
510
511 fn convert(from: Arch) -> Result<Self, Self::Error> {
512 assert_eq!(from, Arch::Scalar);
513 Ok(Identity(diskann_wide::arch::Scalar))
514 }
515
516 fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Arch>) -> std::fmt::Result {
517 match from {
518 None => write!(f, "scalar (compilation target CPU)"),
519 Some(arch) => {
520 if Self::try_match(arch).is_ok() {
521 write!(f, "matched {}", arch)
522 } else {
523 write!(f, "expected {}, got {}", Arch::Scalar, arch)
524 }
525 }
526 }
527 }
528}
529
530macro_rules! match_arch {
531 ($target_arch:literal, $arch:path, $enum:ident) => {
532 #[cfg(target_arch = $target_arch)]
533 impl DispatchRule<Arch> for Identity<$arch> {
534 type Error = ArchNotSupported;
535
536 fn try_match(from: &Arch) -> Result<MatchScore, FailureScore> {
537 let available = <$arch>::new_checked().is_some();
538 if available && *from == Arch::$enum {
539 Ok(MatchScore(0))
540 } else if !available && *from == Arch::$enum {
541 Err(FailureScore(0))
542 } else {
543 Err(FailureScore(1))
544 }
545 }
546
547 fn convert(from: Arch) -> Result<Self, Self::Error> {
548 assert_eq!(from, Arch::$enum);
549 <$arch>::new_checked()
550 .ok_or(ArchNotSupported(from))
551 .map(Identity)
552 }
553
554 fn description(
555 f: &mut std::fmt::Formatter<'_>,
556 from: Option<&Arch>,
557 ) -> std::fmt::Result {
558 let available = <$arch>::new_checked().is_some();
559 match from {
560 None => write!(f, "{}", Arch::$enum),
561 Some(arch) => {
562 if Self::try_match(arch).is_ok() {
563 write!(f, "matched {}", arch)
564 } else if !available && *arch == Arch::$enum {
565 write!(f, "matched {} but unsupported by this CPU", Arch::$enum)
566 } else {
567 write!(f, "expected {}, got {}", Arch::$enum, arch)
568 }
569 }
570 }
571 }
572 }
573 };
574}
575
576match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4);
577match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3);
578match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon);
579
580impl<A, Q, D> Benchmark for Kernel<A, Q, D>
581where
582 datatype::Type<Q>: DispatchRule<datatype::DataType>,
583 datatype::Type<D>: DispatchRule<datatype::DataType>,
584 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
585 Kernel<A, Q, D>: RunBenchmark,
586{
587 type Input = SimdOp;
588 type Output = Vec<RunResult>;
589
590 fn try_match(from: &SimdOp) -> Result<MatchScore, FailureScore> {
592 let mut failscore: Option<u32> = None;
593 if datatype::Type::<Q>::try_match(&from.query_type).is_err() {
594 *failscore.get_or_insert(0) += 10;
595 }
596 if datatype::Type::<D>::try_match(&from.data_type).is_err() {
597 *failscore.get_or_insert(0) += 10;
598 }
599 if let Err(FailureScore(score)) = Identity::<A>::try_match(&from.arch) {
600 *failscore.get_or_insert(0) += 2 + score;
601 }
602
603 match failscore {
604 None => Ok(MatchScore(0)),
605 Some(score) => Err(FailureScore(score)),
606 }
607 }
608
609 fn run(
610 input: &SimdOp,
611 _: diskann_benchmark_runner::Checkpoint<'_>,
612 mut output: &mut dyn diskann_benchmark_runner::Output,
613 ) -> anyhow::Result<Self::Output> {
614 let arch = Identity::<A>::convert(input.arch)?.0;
615 let kernel = Self::new(arch);
616 writeln!(output, "{}", input)?;
617 let results = kernel.run(input)?;
618 writeln!(output, "\n\n{}", DisplayWrapper(&*results))?;
619 Ok(results)
620 }
621
622 fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&SimdOp>) -> std::fmt::Result {
623 match input {
624 None => {
625 describeln!(
626 f,
627 "- Query Type: {}",
628 Description::<datatype::DataType, datatype::Type<Q>>::new()
629 )?;
630 describeln!(
631 f,
632 "- Data Type: {}",
633 Description::<datatype::DataType, datatype::Type<D>>::new()
634 )?;
635 describeln!(
636 f,
637 "- Implementation: {}",
638 Description::<Arch, Identity<A>>::new()
639 )?;
640 }
641 Some(input) => {
642 if let Err(err) = datatype::Type::<Q>::try_match_verbose(&input.query_type) {
643 describeln!(f, "\n - Mismatched query type: {}", err)?;
644 }
645 if let Err(err) = datatype::Type::<D>::try_match_verbose(&input.data_type) {
646 describeln!(f, "\n - Mismatched data type: {}", err)?;
647 }
648 if let Err(err) = Identity::<A>::try_match_verbose(&input.arch) {
649 describeln!(f, "\n - Mismatched architecture: {}", err)?;
650 }
651 }
652 }
653 Ok(())
654 }
655}
656
657impl<A, Q, D> Regression for Kernel<A, Q, D>
658where
659 datatype::Type<Q>: DispatchRule<datatype::DataType>,
660 datatype::Type<D>: DispatchRule<datatype::DataType>,
661 Identity<A>: DispatchRule<Arch, Error = ArchNotSupported>,
662 Kernel<A, Q, D>: RunBenchmark,
663{
664 type Tolerances = SimdTolerance;
665 type Pass = CheckResult;
666 type Fail = CheckResult;
667
668 fn check(
669 tolerance: &SimdTolerance,
670 _input: &SimdOp,
671 before: &Vec<RunResult>,
672 after: &Vec<RunResult>,
673 ) -> anyhow::Result<PassFail<CheckResult, CheckResult>> {
674 anyhow::ensure!(
675 before.len() == after.len(),
676 "before has {} runs but after has {}",
677 before.len(),
678 after.len(),
679 );
680
681 let mut passed = true;
682 let checks: Vec<Comparison> = std::iter::zip(before.iter(), after.iter())
683 .enumerate()
684 .map(|(i, (b, a))| {
685 anyhow::ensure!(b.run == a.run, "run {i} mismatched");
686
687 let computations_per_latency = b.computations_per_latency() as f64;
688
689 let before_min = b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
690 let after_min = a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency;
691
692 let comparison = Comparison {
693 run: b.run.clone(),
694 tolerance: *tolerance,
695 before_min,
696 after_min,
697 };
698
699 match relative_change(before_min, after_min) {
701 Ok(change) => {
702 if change > tolerance.min_time_regression.get() {
703 passed = false;
704 }
705 }
706 Err(_) => passed = false,
707 };
708
709 Ok(comparison)
710 })
711 .collect::<anyhow::Result<Vec<Comparison>>>()?;
712
713 let check = CheckResult { checks };
714
715 if passed {
716 Ok(PassFail::Pass(check))
717 } else {
718 Ok(PassFail::Fail(check))
719 }
720 }
721}
722
723trait RunBenchmark {
728 fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error>;
729}
730
731#[derive(Debug, Serialize, Deserialize)]
732struct RunResult {
733 run: Run,
735 latencies: Vec<MicroSeconds>,
737 percentiles: percentiles::Percentiles<MicroSeconds>,
739}
740
741impl RunResult {
742 fn computations_per_latency(&self) -> usize {
743 self.run.num_points.get() * self.run.loops_per_measurement.get()
744 }
745}
746
747impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> {
748 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749 if self.is_empty() {
750 return Ok(());
751 }
752
753 let header = [
754 "Distance",
755 "Dim",
756 "Min Time (ns)",
757 "Mean Time (ns)",
758 "Points",
759 "Loops",
760 "Measurements",
761 ];
762
763 let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len());
764
765 self.iter().enumerate().for_each(|(row, r)| {
766 let mut row = table.row(row);
767
768 let min_latency = r
769 .latencies
770 .iter()
771 .min()
772 .copied()
773 .unwrap_or(MicroSeconds::new(u64::MAX));
774 let mean_latency = r.percentiles.mean;
775
776 let computations_per_latency = r.computations_per_latency() as f64;
777
778 let min_time = min_latency.as_f64() / computations_per_latency * 1000.0;
780 let mean_time = mean_latency / computations_per_latency * 1000.0;
781
782 row.insert(r.run.distance, 0);
783 row.insert(r.run.dim, 1);
784 row.insert(format!("{:.3}", min_time), 2);
785 row.insert(format!("{:.3}", mean_time), 3);
786 row.insert(r.run.num_points, 4);
787 row.insert(r.run.loops_per_measurement, 5);
788 row.insert(r.run.num_measurements, 6);
789 });
790
791 table.fmt(f)
792 }
793}
794
795fn run_loops<Q, D, F>(query: &[Q], data: MatrixView<D>, run: &Run, f: F) -> RunResult
796where
797 F: Fn(&[Q], &[D]) -> f32,
798{
799 let mut latencies = Vec::with_capacity(run.num_measurements.get());
800 let mut dst = vec![0.0; data.nrows()];
801
802 for _ in 0..run.num_measurements.get() {
803 let start = std::time::Instant::now();
804 for _ in 0..run.loops_per_measurement.get() {
805 std::iter::zip(dst.iter_mut(), data.row_iter()).for_each(|(d, r)| {
806 *d = f(query, r);
807 });
808 std::hint::black_box(&mut dst);
809 }
810 latencies.push(start.elapsed().into());
811 }
812
813 let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap();
814 RunResult {
815 run: run.clone(),
816 latencies,
817 percentiles,
818 }
819}
820
821struct Data<Q, D> {
822 query: Box<[Q]>,
823 data: Matrix<D>,
824}
825
826impl<Q, D> Data<Q, D> {
827 fn new(run: &Run) -> Self
828 where
829 StandardUniform: Distribution<Q>,
830 StandardUniform: Distribution<D>,
831 {
832 let mut rng = StdRng::seed_from_u64(0x12345);
833 let query: Box<[Q]> = (0..run.dim.get())
834 .map(|_| StandardUniform.sample(&mut rng))
835 .collect();
836 let data = Matrix::<D>::new(
837 diskann_utils::views::Init(|| StandardUniform.sample(&mut rng)),
838 run.num_points.get(),
839 run.dim.get(),
840 );
841
842 Self { query, data }
843 }
844
845 fn run<F>(&self, run: &Run, f: F) -> RunResult
846 where
847 F: Fn(&[Q], &[D]) -> f32,
848 {
849 run_loops(&self.query, self.data.as_view(), run, f)
850 }
851}
852
853macro_rules! stamp {
858 (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => {
859 impl RunBenchmark for Kernel<Reference, $Q, $D> {
860 fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error> {
861 let mut results = Vec::new();
862 for run in input.runs.iter() {
863 let data = Data::<$Q, $D>::new(run);
864 let result = match run.distance {
865 SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2),
866 SimilarityMeasure::InnerProduct => data.run(run, reference::$f_ip),
867 SimilarityMeasure::Cosine => data.run(run, reference::$f_cosine),
868 };
869 results.push(result);
870 }
871 Ok(results)
872 }
873 }
874 };
875 ($arch:path, $Q:ty, $D:ty) => {
876 impl RunBenchmark for Kernel<$arch, $Q, $D> {
877 fn run(self, input: &SimdOp) -> Result<Vec<RunResult>, anyhow::Error> {
878 let mut results = Vec::new();
879
880 let l2 = &simd::L2 {};
881 let ip = &simd::IP {};
882 let cosine = &simd::CosineStateless {};
883
884 for run in input.runs.iter() {
885 let data = Data::<$Q, $D>::new(run);
886 let result = match run.distance {
893 SimilarityMeasure::SquaredL2 => data.run(run, |q, d| {
894 self.arch
895 .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d)
896 }),
897 SimilarityMeasure::InnerProduct => data.run(run, |q, d| {
898 self.arch
899 .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d)
900 }),
901 SimilarityMeasure::Cosine => data.run(run, |q, d| {
902 self.arch
903 .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d)
904 }),
905 };
906 results.push(result)
907 }
908 Ok(results)
909 }
910 }
911 };
912 ($target_arch:literal, $arch:path, $Q:ty, $D:ty) => {
913 #[cfg(target_arch = $target_arch)]
914 stamp!($arch, $Q, $D);
915 };
916}
917
918stamp!("x86_64", diskann_wide::arch::x86_64::V4, f32, f32);
919stamp!("x86_64", diskann_wide::arch::x86_64::V4, f16, f16);
920stamp!("x86_64", diskann_wide::arch::x86_64::V4, u8, u8);
921stamp!("x86_64", diskann_wide::arch::x86_64::V4, i8, i8);
922
923stamp!("x86_64", diskann_wide::arch::x86_64::V3, f32, f32);
924stamp!("x86_64", diskann_wide::arch::x86_64::V3, f16, f16);
925stamp!("x86_64", diskann_wide::arch::x86_64::V3, u8, u8);
926stamp!("x86_64", diskann_wide::arch::x86_64::V3, i8, i8);
927
928stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f32, f32);
929stamp!("aarch64", diskann_wide::arch::aarch64::Neon, f16, f16);
930stamp!("aarch64", diskann_wide::arch::aarch64::Neon, u8, u8);
931stamp!("aarch64", diskann_wide::arch::aarch64::Neon, i8, i8);
932
933stamp!(diskann_wide::arch::Scalar, f32, f32);
934stamp!(diskann_wide::arch::Scalar, f16, f16);
935stamp!(diskann_wide::arch::Scalar, u8, u8);
936stamp!(diskann_wide::arch::Scalar, i8, i8);
937
938stamp!(
939 reference,
940 f32,
941 f32,
942 squared_l2_f32,
943 inner_product_f32,
944 cosine_f32
945);
946stamp!(
947 reference,
948 f16,
949 f16,
950 squared_l2_f16,
951 inner_product_f16,
952 cosine_f16
953);
954stamp!(
955 reference,
956 u8,
957 u8,
958 squared_l2_u8,
959 inner_product_u8,
960 cosine_u8
961);
962stamp!(
963 reference,
964 i8,
965 i8,
966 squared_l2_i8,
967 inner_product_i8,
968 cosine_i8
969);
970
971mod reference {
978 use diskann_wide::ARCH;
979 use half::f16;
980
981 trait MaybeFMA {
982 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32;
985 }
986
987 impl MaybeFMA for diskann_wide::arch::Scalar {
988 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
989 a * b + c
990 }
991 }
992
993 #[cfg(target_arch = "x86_64")]
994 impl MaybeFMA for diskann_wide::arch::x86_64::V3 {
995 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
996 a.mul_add(b, c)
997 }
998 }
999
1000 #[cfg(target_arch = "x86_64")]
1001 impl MaybeFMA for diskann_wide::arch::x86_64::V4 {
1002 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
1003 a.mul_add(b, c)
1004 }
1005 }
1006
1007 #[cfg(target_arch = "aarch64")]
1008 impl MaybeFMA for diskann_wide::arch::aarch64::Neon {
1009 fn maybe_fma(self, a: f32, b: f32, c: f32) -> f32 {
1010 a.mul_add(b, c)
1011 }
1012 }
1013
1014 pub(super) fn squared_l2_i8(x: &[i8], y: &[i8]) -> f32 {
1019 assert_eq!(x.len(), y.len());
1020 std::iter::zip(x.iter(), y.iter())
1021 .map(|(&a, &b)| {
1022 let a: i32 = a.into();
1023 let b: i32 = b.into();
1024 let diff = a - b;
1025 diff * diff
1026 })
1027 .sum::<i32>() as f32
1028 }
1029
1030 pub(super) fn squared_l2_u8(x: &[u8], y: &[u8]) -> f32 {
1031 assert_eq!(x.len(), y.len());
1032 std::iter::zip(x.iter(), y.iter())
1033 .map(|(&a, &b)| {
1034 let a: i32 = a.into();
1035 let b: i32 = b.into();
1036 let diff = a - b;
1037 diff * diff
1038 })
1039 .sum::<i32>() as f32
1040 }
1041
1042 pub(super) fn squared_l2_f16(x: &[f16], y: &[f16]) -> f32 {
1043 assert_eq!(x.len(), y.len());
1044 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1045 let a: f32 = a.into();
1046 let b: f32 = b.into();
1047 let diff = a - b;
1048 ARCH.maybe_fma(diff, diff, acc)
1049 })
1050 }
1051
1052 pub(super) fn squared_l2_f32(x: &[f32], y: &[f32]) -> f32 {
1053 assert_eq!(x.len(), y.len());
1054 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1055 let diff = a - b;
1056 ARCH.maybe_fma(diff, diff, acc)
1057 })
1058 }
1059
1060 pub(super) fn inner_product_i8(x: &[i8], y: &[i8]) -> f32 {
1065 assert_eq!(x.len(), y.len());
1066 std::iter::zip(x.iter(), y.iter())
1067 .map(|(&a, &b)| {
1068 let a: i32 = a.into();
1069 let b: i32 = b.into();
1070 a * b
1071 })
1072 .sum::<i32>() as f32
1073 }
1074
1075 pub(super) fn inner_product_u8(x: &[u8], y: &[u8]) -> f32 {
1076 assert_eq!(x.len(), y.len());
1077 std::iter::zip(x.iter(), y.iter())
1078 .map(|(&a, &b)| {
1079 let a: i32 = a.into();
1080 let b: i32 = b.into();
1081 a * b
1082 })
1083 .sum::<i32>() as f32
1084 }
1085
1086 pub(super) fn inner_product_f16(x: &[f16], y: &[f16]) -> f32 {
1087 assert_eq!(x.len(), y.len());
1088 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| {
1089 let a: f32 = a.into();
1090 let b: f32 = b.into();
1091 ARCH.maybe_fma(a, b, acc)
1092 })
1093 }
1094
1095 pub(super) fn inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
1096 assert_eq!(x.len(), y.len());
1097 std::iter::zip(x.iter(), y.iter()).fold(0.0f32, |acc, (&a, &b)| ARCH.maybe_fma(a, b, acc))
1098 }
1099
1100 #[derive(Default)]
1105 struct XY<T> {
1106 xnorm: T,
1107 ynorm: T,
1108 xy: T,
1109 }
1110
1111 pub(super) fn cosine_i8(x: &[i8], y: &[i8]) -> f32 {
1112 assert_eq!(x.len(), y.len());
1113 let r: XY<i32> =
1114 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1115 let vx: i32 = vx.into();
1116 let vy: i32 = vy.into();
1117 XY {
1118 xnorm: acc.xnorm + vx * vx,
1119 ynorm: acc.ynorm + vy * vy,
1120 xy: acc.xy + vx * vy,
1121 }
1122 });
1123
1124 if r.xnorm == 0 || r.ynorm == 0 {
1125 return 0.0;
1126 }
1127
1128 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1129 }
1130
1131 pub(super) fn cosine_u8(x: &[u8], y: &[u8]) -> f32 {
1132 assert_eq!(x.len(), y.len());
1133 let r: XY<i32> =
1134 std::iter::zip(x.iter(), y.iter()).fold(XY::<i32>::default(), |acc, (&vx, &vy)| {
1135 let vx: i32 = vx.into();
1136 let vy: i32 = vy.into();
1137 XY {
1138 xnorm: acc.xnorm + vx * vx,
1139 ynorm: acc.ynorm + vy * vy,
1140 xy: acc.xy + vx * vy,
1141 }
1142 });
1143
1144 if r.xnorm == 0 || r.ynorm == 0 {
1145 return 0.0;
1146 }
1147
1148 (r.xy as f32 / ((r.xnorm as f32).sqrt() * (r.ynorm as f32).sqrt())).clamp(-1.0, 1.0)
1149 }
1150
1151 pub(super) fn cosine_f16(x: &[f16], y: &[f16]) -> f32 {
1152 assert_eq!(x.len(), y.len());
1153 let r: XY<f32> =
1154 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| {
1155 let vx: f32 = vx.into();
1156 let vy: f32 = vy.into();
1157 XY {
1158 xnorm: ARCH.maybe_fma(vx, vx, acc.xnorm),
1159 ynorm: ARCH.maybe_fma(vy, vy, acc.ynorm),
1160 xy: ARCH.maybe_fma(vx, vy, acc.xy),
1161 }
1162 });
1163
1164 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1165 return 0.0;
1166 }
1167
1168 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1169 }
1170
1171 pub(super) fn cosine_f32(x: &[f32], y: &[f32]) -> f32 {
1172 assert_eq!(x.len(), y.len());
1173 let r: XY<f32> =
1174 std::iter::zip(x.iter(), y.iter()).fold(XY::<f32>::default(), |acc, (&vx, &vy)| XY {
1175 xnorm: vx.mul_add(vx, acc.xnorm),
1176 ynorm: vy.mul_add(vy, acc.ynorm),
1177 xy: vx.mul_add(vy, acc.xy),
1178 });
1179
1180 if r.xnorm < f32::EPSILON || r.ynorm < f32::EPSILON {
1181 return 0.0;
1182 }
1183
1184 (r.xy / (r.xnorm.sqrt() * r.ynorm.sqrt())).clamp(-1.0, 1.0)
1185 }
1186}
1187
1188#[cfg(test)]
1193mod tests {
1194 use super::*;
1195
1196 use diskann_benchmark_runner::{
1197 benchmark::{PassFail, Regression},
1198 utils::percentiles::compute_percentiles,
1199 };
1200
1201 fn tiny_run(distance: SimilarityMeasure) -> Run {
1202 Run {
1203 distance,
1204 dim: NonZeroUsize::new(8).unwrap(),
1205 num_points: NonZeroUsize::new(1).unwrap(),
1206 loops_per_measurement: NonZeroUsize::new(1).unwrap(),
1207 num_measurements: NonZeroUsize::new(1).unwrap(),
1208 }
1209 }
1210
1211 fn tiny_op() -> SimdOp {
1212 SimdOp {
1213 query_type: DataType::Float32,
1214 data_type: DataType::Float32,
1215 arch: Arch::Scalar,
1216 runs: vec![tiny_run(SimilarityMeasure::SquaredL2)],
1217 }
1218 }
1219
1220 fn tiny_result(distance: SimilarityMeasure, minimum: u64) -> RunResult {
1221 let run = tiny_run(distance);
1222 let minimum = MicroSeconds::new(minimum);
1223 let mut latencies = vec![minimum];
1224 let percentiles = compute_percentiles(&mut latencies).unwrap();
1225 RunResult {
1226 run,
1227 latencies,
1228 percentiles,
1229 }
1230 }
1231
1232 fn tolerance(limit: f64) -> SimdTolerance {
1233 SimdTolerance {
1234 min_time_regression: NonNegativeFinite::new(limit).unwrap(),
1235 }
1236 }
1237
1238 #[test]
1239 fn check_rejects_mismatched_runs() {
1240 type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1241
1242 let err = Bench::check(
1243 &tolerance(0.0),
1244 &tiny_op(),
1245 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1246 &vec![tiny_result(SimilarityMeasure::Cosine, 100)],
1247 )
1248 .unwrap_err();
1249
1250 assert_eq!(err.to_string(), "run 0 mismatched");
1251 }
1252
1253 #[test]
1254 fn check_allows_negative_relative_change() {
1255 type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1256
1257 let result = Bench::check(
1258 &tolerance(0.0),
1259 &tiny_op(),
1260 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1261 &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)],
1262 )
1263 .unwrap();
1264
1265 assert!(matches!(result, PassFail::Pass(_)));
1266 }
1267
1268 #[test]
1269 fn check_passes_on_tolerance_boundary() {
1270 type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1271
1272 let result = Bench::check(
1273 &tolerance(0.05),
1274 &tiny_op(),
1275 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1276 &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)],
1277 )
1278 .unwrap();
1279
1280 assert!(matches!(result, PassFail::Pass(_)));
1281 }
1282
1283 #[test]
1284 fn check_fails_above_tolerance_boundary() {
1285 type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1286
1287 let result = Bench::check(
1288 &tolerance(0.05),
1289 &tiny_op(),
1290 &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)],
1291 &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)],
1292 )
1293 .unwrap();
1294
1295 assert!(matches!(result, PassFail::Fail(_)));
1296 }
1297
1298 #[test]
1299 fn check_result_display_includes_failure_details() {
1300 let check = CheckResult {
1301 checks: vec![Comparison {
1302 run: tiny_run(SimilarityMeasure::SquaredL2),
1303 tolerance: tolerance(0.05),
1304 before_min: 100.0,
1305 after_min: 106.0,
1306 }],
1307 };
1308
1309 let rendered = check.to_string();
1310 assert!(rendered.contains("Distance"), "rendered = {rendered}");
1311 assert!(rendered.contains("squared_l2"), "rendered = {rendered}");
1312 assert!(rendered.contains("100.000"), "rendered = {rendered}");
1313 assert!(rendered.contains("106.000"), "rendered = {rendered}");
1314 assert!(rendered.contains("6.000 %"), "rendered = {rendered}");
1315 assert!(rendered.contains("FAIL"), "rendered = {rendered}");
1316 }
1317
1318 #[test]
1324 fn zero_values_rejected() {
1325 type Bench = Kernel<diskann_wide::arch::Scalar, f32, f32>;
1326
1327 let result = Bench::check(
1328 &tolerance(0.05),
1329 &tiny_op(),
1330 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1331 &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)],
1332 )
1333 .unwrap();
1334
1335 assert!(matches!(result, PassFail::Fail(_)));
1336 }
1337}