Skip to main content

diskann_quantization/multi_vector/distance/query_computer/
f32.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4use diskann_wide::Architecture;
5use diskann_wide::arch::Scalar;
6#[cfg(target_arch = "aarch64")]
7use diskann_wide::arch::aarch64::Neon;
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11use super::{DynQueryComputer, Prepared, QueryComputer, build_prepared};
12use crate::multi_vector::distance::kernels::f32::F32Kernel;
13use crate::multi_vector::{BlockTransposed, BlockTransposedRef, MatRef, Standard};
14use diskann_utils::Reborrow;
15
16impl QueryComputer<f32> {
17    /// Build an f32 query computer, selecting the optimal architecture and
18    /// GROUP for the current CPU at runtime.
19    pub fn new(query: MatRef<'_, Standard<f32>>) -> Self {
20        diskann_wide::arch::dispatch1_no_features(BuildComputer, query)
21    }
22}
23
24impl<A, const GROUP: usize> DynQueryComputer<f32> for Prepared<A, BlockTransposed<f32, GROUP>>
25where
26    A: Architecture,
27    F32Kernel<GROUP>: for<'a> diskann_wide::arch::Target3<
28            A,
29            (),
30            BlockTransposedRef<'a, f32, GROUP>,
31            MatRef<'a, Standard<f32>>,
32            &'a mut [f32],
33        >,
34{
35    fn compute_max_sim(&self, doc: MatRef<'_, Standard<f32>>, scores: &mut [f32]) {
36        let mut scratch = vec![f32::MIN; self.prepared.padded_nrows()];
37        self.arch.run3(
38            F32Kernel::<GROUP>,
39            self.prepared.reborrow(),
40            doc,
41            &mut scratch,
42        );
43        for (dst, &src) in scores.iter_mut().zip(&scratch[..self.prepared.nrows()]) {
44            *dst = -src;
45        }
46    }
47
48    fn nrows(&self) -> usize {
49        self.prepared.nrows()
50    }
51}
52
53#[derive(Debug, Clone, Copy)]
54pub(super) struct BuildComputer;
55
56impl diskann_wide::arch::Target1<Scalar, QueryComputer<f32>, MatRef<'_, Standard<f32>>>
57    for BuildComputer
58{
59    fn run(self, arch: Scalar, query: MatRef<'_, Standard<f32>>) -> QueryComputer<f32> {
60        QueryComputer {
61            inner: Box::new(build_prepared::<f32, _, 8>(arch, query)),
62        }
63    }
64}
65
66#[cfg(target_arch = "x86_64")]
67impl diskann_wide::arch::Target1<V3, QueryComputer<f32>, MatRef<'_, Standard<f32>>>
68    for BuildComputer
69{
70    fn run(self, arch: V3, query: MatRef<'_, Standard<f32>>) -> QueryComputer<f32> {
71        QueryComputer {
72            inner: Box::new(build_prepared::<f32, _, 16>(arch, query)),
73        }
74    }
75}
76
77#[cfg(target_arch = "x86_64")]
78impl diskann_wide::arch::Target1<V4, QueryComputer<f32>, MatRef<'_, Standard<f32>>>
79    for BuildComputer
80{
81    fn run(self, arch: V4, query: MatRef<'_, Standard<f32>>) -> QueryComputer<f32> {
82        // V4 delegates to V3 — the V3 micro-kernel is valid on V4 hardware.
83        let arch = arch.retarget();
84        QueryComputer {
85            inner: Box::new(build_prepared::<f32, _, 16>(arch, query)),
86        }
87    }
88}
89
90#[cfg(target_arch = "aarch64")]
91impl diskann_wide::arch::Target1<Neon, QueryComputer<f32>, MatRef<'_, Standard<f32>>>
92    for BuildComputer
93{
94    fn run(self, arch: Neon, query: MatRef<'_, Standard<f32>>) -> QueryComputer<f32> {
95        // Neon delegates to Scalar.
96        let arch = arch.retarget();
97        QueryComputer {
98            inner: Box::new(build_prepared::<f32, _, 8>(arch, query)),
99        }
100    }
101}