arroy/distance/
mod.rs

1use std::borrow::Cow;
2use std::fmt;
3
4pub use binary_quantized_cosine::{BinaryQuantizedCosine, NodeHeaderBinaryQuantizedCosine};
5pub use binary_quantized_euclidean::{
6    BinaryQuantizedEuclidean, NodeHeaderBinaryQuantizedEuclidean,
7};
8pub use binary_quantized_manhattan::{
9    BinaryQuantizedManhattan, NodeHeaderBinaryQuantizedManhattan,
10};
11use bytemuck::{Pod, Zeroable};
12pub use cosine::{Cosine, NodeHeaderCosine};
13pub use dot_product::{DotProduct, NodeHeaderDotProduct};
14pub use euclidean::{Euclidean, NodeHeaderEuclidean};
15use heed::{RwPrefix, RwTxn};
16pub use manhattan::{Manhattan, NodeHeaderManhattan};
17use rand::Rng;
18
19use crate::internals::{KeyCodec, Side};
20use crate::node::Leaf;
21use crate::parallel::ImmutableSubsetLeafs;
22use crate::unaligned_vector::{UnalignedVector, UnalignedVectorCodec};
23use crate::NodeCodec;
24
25mod binary_quantized_cosine;
26mod binary_quantized_euclidean;
27mod binary_quantized_manhattan;
28mod cosine;
29mod dot_product;
30mod euclidean;
31mod manhattan;
32
33fn new_leaf<D: Distance>(vec: Vec<f32>) -> Leaf<'static, D> {
34    let vector = UnalignedVector::from_vec(vec);
35    Leaf { header: D::new_header(&vector), vector }
36}
37
38/// A trait used by arroy to compute the distances,
39/// compute the split planes, and normalize user vectors.
40#[allow(missing_docs)]
41pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static {
42    const DEFAULT_OVERSAMPLING: usize = 1;
43
44    /// A header structure with informations related to the
45    type Header: Pod + Zeroable + fmt::Debug;
46    type VectorCodec: UnalignedVectorCodec;
47
48    fn name() -> &'static str;
49
50    fn new_header(vector: &UnalignedVector<Self::VectorCodec>) -> Self::Header;
51
52    /// Returns a non-normalized distance.
53    fn built_distance(p: &Leaf<Self>, q: &Leaf<Self>) -> f32;
54
55    fn non_built_distance(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
56        Self::built_distance(p, q)
57    }
58
59    /// Normalizes the distance returned by the distance method.
60    fn normalized_distance(d: f32, _dimensions: usize) -> f32 {
61        d.sqrt()
62    }
63
64    fn pq_distance(distance: f32, margin: f32, side: Side) -> f32 {
65        match side {
66            Side::Left => (-margin).min(distance),
67            Side::Right => margin.min(distance),
68        }
69    }
70
71    fn norm(leaf: &Leaf<Self>) -> f32 {
72        Self::norm_no_header(&leaf.vector)
73    }
74
75    fn norm_no_header(v: &UnalignedVector<Self::VectorCodec>) -> f32;
76
77    fn normalize(node: &mut Leaf<Self>) {
78        let norm = Self::norm(node);
79        if norm > 0.0 {
80            let vec: Vec<_> = node.vector.iter().map(|x| x / norm).collect();
81            node.vector = UnalignedVector::from_vec(vec);
82        }
83    }
84
85    fn init(node: &mut Leaf<Self>);
86
87    fn update_mean(mean: &mut Leaf<Self>, new_node: &Leaf<Self>, norm: f32, c: f32) {
88        let vec: Vec<_> = mean
89            .vector
90            .iter()
91            .zip(new_node.vector.iter())
92            .map(|(x, n)| (x * c + n / norm) / (c + 1.0))
93            .collect();
94        mean.vector = UnalignedVector::from_vec(vec);
95    }
96
97    fn create_split<'a, R: Rng>(
98        children: &'a ImmutableSubsetLeafs<Self>,
99        rng: &mut R,
100    ) -> heed::Result<Cow<'a, UnalignedVector<Self::VectorCodec>>>;
101
102    fn margin(p: &Leaf<Self>, q: &Leaf<Self>) -> f32 {
103        Self::margin_no_header(&p.vector, &q.vector)
104    }
105
106    fn margin_no_header(
107        p: &UnalignedVector<Self::VectorCodec>,
108        q: &UnalignedVector<Self::VectorCodec>,
109    ) -> f32;
110
111    fn side<R: Rng>(
112        normal_plane: &UnalignedVector<Self::VectorCodec>,
113        node: &Leaf<Self>,
114        rng: &mut R,
115    ) -> Side {
116        let dot = Self::margin_no_header(&node.vector, normal_plane);
117        if dot > 0.0 {
118            Side::Right
119        } else if dot < 0.0 {
120            Side::Left
121        } else {
122            Side::random(rng)
123        }
124    }
125
126    fn preprocess(
127        _wtxn: &mut RwTxn,
128        _new_iter: impl for<'a> Fn(
129            &'a mut RwTxn,
130        ) -> heed::Result<RwPrefix<'a, KeyCodec, NodeCodec<Self>>>,
131    ) -> heed::Result<()> {
132        Ok(())
133    }
134}
135
136fn two_means<D: Distance, R: Rng>(
137    rng: &mut R,
138    leafs: &ImmutableSubsetLeafs<D>,
139    cosine: bool,
140) -> heed::Result<[Leaf<'static, D>; 2]> {
141    // This algorithm is a huge heuristic. Empirically it works really well, but I
142    // can't motivate it well. The basic idea is to keep two centroids and assign
143    // points to either one of them. We weight each centroid by the number of points
144    // assigned to it, so to balance it.
145
146    const ITERATION_STEPS: usize = 200;
147
148    let [leaf_p, leaf_q] = leafs.choose_two(rng)?.unwrap();
149    let (mut leaf_p, mut leaf_q) = (leaf_p.into_owned(), leaf_q.into_owned());
150
151    if cosine {
152        D::normalize(&mut leaf_p);
153        D::normalize(&mut leaf_q);
154    }
155
156    D::init(&mut leaf_p);
157    D::init(&mut leaf_q);
158
159    let mut ic = 1.0;
160    let mut jc = 1.0;
161    for _ in 0..ITERATION_STEPS {
162        let node_k = leafs.choose(rng)?.unwrap();
163        let di = ic * D::non_built_distance(&leaf_p, &node_k);
164        let dj = jc * D::non_built_distance(&leaf_q, &node_k);
165        let norm = if cosine { D::norm(&node_k) } else { 1.0 };
166        if norm.is_nan() || norm <= 0.0 {
167            continue;
168        }
169        if di < dj {
170            Distance::update_mean(&mut leaf_p, &node_k, norm, ic);
171            Distance::init(&mut leaf_p);
172            ic += 1.0;
173        } else if dj < di {
174            Distance::update_mean(&mut leaf_q, &node_k, norm, jc);
175            Distance::init(&mut leaf_q);
176            jc += 1.0;
177        }
178    }
179
180    Ok([leaf_p, leaf_q])
181}
182
183pub fn two_means_binary_quantized<D: Distance, NonBqDist: Distance, R: Rng>(
184    rng: &mut R,
185    leafs: &ImmutableSubsetLeafs<D>,
186    cosine: bool,
187) -> heed::Result<[Leaf<'static, NonBqDist>; 2]> {
188    // This algorithm is a huge heuristic. Empirically it works really well, but I
189    // can't motivate it well. The basic idea is to keep two centroids and assign
190    // points to either one of them. We weight each centroid by the number of points
191    // assigned to it, so to balance it.
192    // Even though the points we're working on are binary quantized, for the centroid
193    // to move, we need to store it as f32. This requires us to convert every binary quantized
194    // vectors to f32 vectors, but the recall suffers too much if we don't do it.
195
196    const ITERATION_STEPS: usize = 200;
197
198    let [leaf_p, leaf_q] = leafs.choose_two(rng)?.unwrap();
199    let mut leaf_p: Leaf<'static, NonBqDist> = new_leaf(leaf_p.vector.to_vec());
200    let mut leaf_q: Leaf<'static, NonBqDist> = new_leaf(leaf_q.vector.to_vec());
201
202    if cosine {
203        NonBqDist::normalize(&mut leaf_p);
204        NonBqDist::normalize(&mut leaf_q);
205    }
206
207    NonBqDist::init(&mut leaf_p);
208    NonBqDist::init(&mut leaf_q);
209
210    let mut ic = 1.0;
211    let mut jc = 1.0;
212    for _ in 0..ITERATION_STEPS {
213        let node_k = leafs.choose(rng)?.unwrap();
214        let node_k: Leaf<'static, NonBqDist> = new_leaf(node_k.vector.to_vec());
215        let di = ic * NonBqDist::non_built_distance(&leaf_p, &node_k);
216        let dj = jc * NonBqDist::non_built_distance(&leaf_q, &node_k);
217        let norm = if cosine { NonBqDist::norm(&node_k) } else { 1.0 };
218        if norm.is_nan() || norm <= 0.0 {
219            continue;
220        }
221        if di < dj {
222            Distance::update_mean(&mut leaf_p, &node_k, norm, ic);
223            Distance::init(&mut leaf_p);
224            ic += 1.0;
225        } else if dj < di {
226            Distance::update_mean(&mut leaf_q, &node_k, norm, jc);
227            Distance::init(&mut leaf_q);
228            jc += 1.0;
229        }
230    }
231
232    Ok([leaf_p, leaf_q])
233}