Skip to main content

diskann_quantization/multi_vector/distance/query_computer/
f16.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4use diskann_wide::Architecture;
5use diskann_wide::arch::Scalar;
6#[cfg(target_arch = "aarch64")]
7use diskann_wide::arch::aarch64::Neon;
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11use super::{DynQueryComputer, Prepared, QueryComputer, build_prepared};
12use crate::multi_vector::distance::kernels::f16::F16Entry;
13use crate::multi_vector::{BlockTransposed, BlockTransposedRef, MatRef, Standard};
14use diskann_utils::Reborrow;
15
16impl QueryComputer<half::f16> {
17    /// Build an f16 query computer, selecting the optimal architecture and
18    /// GROUP for the current CPU at runtime.
19    pub fn new(query: MatRef<'_, Standard<half::f16>>) -> Self {
20        diskann_wide::arch::dispatch1_no_features(BuildComputer, query)
21    }
22}
23
24impl<A, const GROUP: usize> DynQueryComputer<half::f16>
25    for Prepared<A, BlockTransposed<half::f16, GROUP>>
26where
27    A: Architecture,
28    F16Entry<GROUP>: for<'a> diskann_wide::arch::Target3<
29            A,
30            (),
31            BlockTransposedRef<'a, half::f16, GROUP>,
32            MatRef<'a, Standard<half::f16>>,
33            &'a mut [f32],
34        >,
35{
36    fn compute_max_sim(&self, doc: MatRef<'_, Standard<half::f16>>, scores: &mut [f32]) {
37        let mut scratch = vec![f32::MIN; self.prepared.padded_nrows()];
38        self.arch.run3(
39            F16Entry::<GROUP>,
40            self.prepared.reborrow(),
41            doc,
42            &mut scratch,
43        );
44        for (dst, &src) in scores.iter_mut().zip(&scratch[..self.prepared.nrows()]) {
45            *dst = -src;
46        }
47    }
48
49    fn nrows(&self) -> usize {
50        self.prepared.nrows()
51    }
52}
53
54#[derive(Debug, Clone, Copy)]
55pub(super) struct BuildComputer;
56
57impl diskann_wide::arch::Target1<Scalar, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
58    for BuildComputer
59{
60    fn run(self, arch: Scalar, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
61        QueryComputer {
62            inner: Box::new(build_prepared::<half::f16, _, 8>(arch, query)),
63        }
64    }
65}
66
67#[cfg(target_arch = "x86_64")]
68impl diskann_wide::arch::Target1<V3, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
69    for BuildComputer
70{
71    fn run(self, arch: V3, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
72        QueryComputer {
73            inner: Box::new(build_prepared::<half::f16, _, 16>(arch, query)),
74        }
75    }
76}
77
78#[cfg(target_arch = "x86_64")]
79impl diskann_wide::arch::Target1<V4, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
80    for BuildComputer
81{
82    fn run(self, arch: V4, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
83        let arch = arch.retarget();
84        QueryComputer {
85            inner: Box::new(build_prepared::<half::f16, _, 16>(arch, query)),
86        }
87    }
88}
89
90#[cfg(target_arch = "aarch64")]
91impl diskann_wide::arch::Target1<Neon, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
92    for BuildComputer
93{
94    fn run(self, arch: Neon, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
95        let arch = arch.retarget();
96        QueryComputer {
97            inner: Box::new(build_prepared::<half::f16, _, 8>(arch, query)),
98        }
99    }
100}