diskann_quantization/multi_vector/distance/query_computer/
f16.rs1use diskann_wide::Architecture;
5use diskann_wide::arch::Scalar;
6#[cfg(target_arch = "aarch64")]
7use diskann_wide::arch::aarch64::Neon;
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11use super::{DynQueryComputer, Prepared, QueryComputer, build_prepared};
12use crate::multi_vector::distance::kernels::f16::F16Entry;
13use crate::multi_vector::{BlockTransposed, BlockTransposedRef, MatRef, Standard};
14use diskann_utils::Reborrow;
15
16impl QueryComputer<half::f16> {
17 pub fn new(query: MatRef<'_, Standard<half::f16>>) -> Self {
20 diskann_wide::arch::dispatch1_no_features(BuildComputer, query)
21 }
22}
23
24impl<A, const GROUP: usize> DynQueryComputer<half::f16>
25 for Prepared<A, BlockTransposed<half::f16, GROUP>>
26where
27 A: Architecture,
28 F16Entry<GROUP>: for<'a> diskann_wide::arch::Target3<
29 A,
30 (),
31 BlockTransposedRef<'a, half::f16, GROUP>,
32 MatRef<'a, Standard<half::f16>>,
33 &'a mut [f32],
34 >,
35{
36 fn compute_max_sim(&self, doc: MatRef<'_, Standard<half::f16>>, scores: &mut [f32]) {
37 let mut scratch = vec![f32::MIN; self.prepared.padded_nrows()];
38 self.arch.run3(
39 F16Entry::<GROUP>,
40 self.prepared.reborrow(),
41 doc,
42 &mut scratch,
43 );
44 for (dst, &src) in scores.iter_mut().zip(&scratch[..self.prepared.nrows()]) {
45 *dst = -src;
46 }
47 }
48
49 fn nrows(&self) -> usize {
50 self.prepared.nrows()
51 }
52}
53
54#[derive(Debug, Clone, Copy)]
55pub(super) struct BuildComputer;
56
57impl diskann_wide::arch::Target1<Scalar, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
58 for BuildComputer
59{
60 fn run(self, arch: Scalar, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
61 QueryComputer {
62 inner: Box::new(build_prepared::<half::f16, _, 8>(arch, query)),
63 }
64 }
65}
66
67#[cfg(target_arch = "x86_64")]
68impl diskann_wide::arch::Target1<V3, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
69 for BuildComputer
70{
71 fn run(self, arch: V3, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
72 QueryComputer {
73 inner: Box::new(build_prepared::<half::f16, _, 16>(arch, query)),
74 }
75 }
76}
77
78#[cfg(target_arch = "x86_64")]
79impl diskann_wide::arch::Target1<V4, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
80 for BuildComputer
81{
82 fn run(self, arch: V4, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
83 let arch = arch.retarget();
84 QueryComputer {
85 inner: Box::new(build_prepared::<half::f16, _, 16>(arch, query)),
86 }
87 }
88}
89
90#[cfg(target_arch = "aarch64")]
91impl diskann_wide::arch::Target1<Neon, QueryComputer<half::f16>, MatRef<'_, Standard<half::f16>>>
92 for BuildComputer
93{
94 fn run(self, arch: Neon, query: MatRef<'_, Standard<half::f16>>) -> QueryComputer<half::f16> {
95 let arch = arch.retarget();
96 QueryComputer {
97 inner: Box::new(build_prepared::<half::f16, _, 8>(arch, query)),
98 }
99 }
100}