1use std::borrow::Cow;
2use std::fmt;
3
4pub use binary_quantized_cosine::{BinaryQuantizedCosine, NodeHeaderBinaryQuantizedCosine};
5pub use binary_quantized_euclidean::{
6 BinaryQuantizedEuclidean, NodeHeaderBinaryQuantizedEuclidean,
7};
8pub use binary_quantized_manhattan::{
9 BinaryQuantizedManhattan, NodeHeaderBinaryQuantizedManhattan,
10};
11use bytemuck::{Pod, Zeroable};
12pub use cosine::{Cosine, NodeHeaderCosine};
13pub use dot_product::{DotProduct, NodeHeaderDotProduct};
14pub use euclidean::{Euclidean, NodeHeaderEuclidean};
15use heed::{RwPrefix, RwTxn};
16pub use manhattan::{Manhattan, NodeHeaderManhattan};
17use rand::Rng;
18
19use crate::internals::{KeyCodec, Side};
20use crate::node::Leaf;
21use crate::parallel::ImmutableSubsetLeafs;
22use crate::unaligned_vector::{UnalignedVector, UnalignedVectorCodec};
23use crate::NodeCodec;
24
25mod binary_quantized_cosine;
26mod binary_quantized_euclidean;
27mod binary_quantized_manhattan;
28mod cosine;
29mod dot_product;
30mod euclidean;
31mod manhattan;
32
33fn new_leaf<D: Distance>(vec: Vec<f32>) -> Leaf<'static, D> {
34 let vector = UnalignedVector::from_vec(vec);
35 Leaf { header: D::new_header(&vector), vector }
36}
37
38#[allow(missing_docs)]
41pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static {
42 const DEFAULT_OVERSAMPLING: usize = 1;
43
44 type Header: Pod + Zeroable + fmt::Debug;
46 type VectorCodec: UnalignedVectorCodec;
47
48 fn name() -> &'static str;
49
50 fn new_header(vector: &UnalignedVector<Self::VectorCodec>) -> Self::Header;
51
52 fn built_distance(p: &Leaf<Self>, q: &Leaf<Self>) -> f32;
54
55 fn non_built_distance(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
56 Self::built_distance(p, q)
57 }
58
59 fn normalized_distance(d: f32, _dimensions: usize) -> f32 {
61 d.sqrt()
62 }
63
64 fn pq_distance(distance: f32, margin: f32, side: Side) -> f32 {
65 match side {
66 Side::Left => (-margin).min(distance),
67 Side::Right => margin.min(distance),
68 }
69 }
70
71 fn norm(leaf: &Leaf<Self>) -> f32 {
72 Self::norm_no_header(&leaf.vector)
73 }
74
75 fn norm_no_header(v: &UnalignedVector<Self::VectorCodec>) -> f32;
76
77 fn normalize(node: &mut Leaf<Self>) {
78 let norm = Self::norm(node);
79 if norm > 0.0 {
80 let vec: Vec<_> = node.vector.iter().map(|x| x / norm).collect();
81 node.vector = UnalignedVector::from_vec(vec);
82 }
83 }
84
85 fn init(node: &mut Leaf<Self>);
86
87 fn update_mean(mean: &mut Leaf<Self>, new_node: &Leaf<Self>, norm: f32, c: f32) {
88 let vec: Vec<_> = mean
89 .vector
90 .iter()
91 .zip(new_node.vector.iter())
92 .map(|(x, n)| (x * c + n / norm) / (c + 1.0))
93 .collect();
94 mean.vector = UnalignedVector::from_vec(vec);
95 }
96
97 fn create_split<'a, R: Rng>(
98 children: &'a ImmutableSubsetLeafs<Self>,
99 rng: &mut R,
100 ) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>>;
101
102 fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
103 Self::margin_no_header(&p.vector, &q.vector)
104 }
105
106 fn margin_no_header(
107 p: &UnalignedVector<Self::VectorCodec>,
108 q: &UnalignedVector<Self::VectorCodec>,
109 ) -> f32;
110
111 fn side<R: Rng>(
112 normal_plane: &UnalignedVector<Self::VectorCodec>,
113 node: &Leaf<Self>,
114 rng: &mut R,
115 ) -> Side {
116 let dot = Self::margin_no_header(&node.vector, normal_plane);
117 if dot > 0.0 {
118 Side::Right
119 } else if dot < 0.0 {
120 Side::Left
121 } else {
122 Side::random(rng)
123 }
124 }
125
126 fn preprocess(
127 _wtxn: &mut RwTxn,
128 _new_iter: impl for<'a> Fn(
129 &'a mut RwTxn,
130 ) -> heed::Result<RwPrefix<'a, KeyCodec, NodeCodec<Self>>>,
131 ) -> heed::Result<()> {
132 Ok(())
133 }
134}
135
136fn two_means<D: Distance, R: Rng>(
137 rng: &mut R,
138 leafs: &ImmutableSubsetLeafs<D>,
139 cosine: bool,
140) -> heed::Result<[Leaf<'static, D>; 2]> {
141 const ITERATION_STEPS: usize = 200;
147
148 let [leaf_p, leaf_q] = leafs.choose_two(rng)?.unwrap();
149 let (mut leaf_p, mut leaf_q) = (leaf_p.into_owned(), leaf_q.into_owned());
150
151 if cosine {
152 D::normalize(&mut leaf_p);
153 D::normalize(&mut leaf_q);
154 }
155
156 D::init(&mut leaf_p);
157 D::init(&mut leaf_q);
158
159 let mut ic = 1.0;
160 let mut jc = 1.0;
161 for _ in 0..ITERATION_STEPS {
162 let node_k = leafs.choose(rng)?.unwrap();
163 let di = ic * D::non_built_distance(&leaf_p, &node_k);
164 let dj = jc * D::non_built_distance(&leaf_q, &node_k);
165 let norm = if cosine { D::norm(&node_k) } else { 1.0 };
166 if norm.is_nan() || norm <= 0.0 {
167 continue;
168 }
169 if di < dj {
170 Distance::update_mean(&mut leaf_p, &node_k, norm, ic);
171 Distance::init(&mut leaf_p);
172 ic += 1.0;
173 } else if dj < di {
174 Distance::update_mean(&mut leaf_q, &node_k, norm, jc);
175 Distance::init(&mut leaf_q);
176 jc += 1.0;
177 }
178 }
179
180 Ok([leaf_p, leaf_q])
181}
182
183pub fn two_means_binary_quantized<D: Distance, NonBqDist: Distance, R: Rng>(
184 rng: &mut R,
185 leafs: &ImmutableSubsetLeafs<D>,
186 cosine: bool,
187) -> heed::Result<[Leaf<'static, NonBqDist>; 2]> {
188 const ITERATION_STEPS: usize = 200;
197
198 let [leaf_p, leaf_q] = leafs.choose_two(rng)?.unwrap();
199 let mut leaf_p: Leaf<'static, NonBqDist> = new_leaf(leaf_p.vector.to_vec());
200 let mut leaf_q: Leaf<'static, NonBqDist> = new_leaf(leaf_q.vector.to_vec());
201
202 if cosine {
203 NonBqDist::normalize(&mut leaf_p);
204 NonBqDist::normalize(&mut leaf_q);
205 }
206
207 NonBqDist::init(&mut leaf_p);
208 NonBqDist::init(&mut leaf_q);
209
210 let mut ic = 1.0;
211 let mut jc = 1.0;
212 for _ in 0..ITERATION_STEPS {
213 let node_k = leafs.choose(rng)?.unwrap();
214 let node_k: Leaf<'static, NonBqDist> = new_leaf(node_k.vector.to_vec());
215 let di = ic * NonBqDist::non_built_distance(&leaf_p, &node_k);
216 let dj = jc * NonBqDist::non_built_distance(&leaf_q, &node_k);
217 let norm = if cosine { NonBqDist::norm(&node_k) } else { 1.0 };
218 if norm.is_nan() || norm <= 0.0 {
219 continue;
220 }
221 if di < dj {
222 Distance::update_mean(&mut leaf_p, &node_k, norm, ic);
223 Distance::init(&mut leaf_p);
224 ic += 1.0;
225 } else if dj < di {
226 Distance::update_mean(&mut leaf_q, &node_k, norm, jc);
227 Distance::init(&mut leaf_q);
228 jc += 1.0;
229 }
230 }
231
232 Ok([leaf_p, leaf_q])
233}