Skip to main content

ailake_vec/
pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// Product Quantization — reduces per-vector storage from dim*4 bytes to num_subvectors bytes.
3// At dim=1536, M=48: 6144 bytes → 48 bytes per vector (128x reduction, ~93-95% recall@10).
4
5use ailake_core::AilakeError;
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct PQCodebook {
11    /// Number of sub-vectors (M)
12    pub num_subvectors: usize,
13    /// Number of centroids per sub-space (K, typically 256 so codes fit in u8)
14    pub num_centroids: usize,
15    /// Dimensionality of each sub-vector = dim / num_subvectors
16    pub sub_dim: usize,
17    /// Centroids: [num_subvectors][num_centroids][sub_dim]
18    pub centroids: Vec<Vec<Vec<f32>>>,
19}
20
21impl PQCodebook {
22    /// Train PQ codebook via k-means on each sub-space independently.
23    pub fn train(
24        vectors: &[Vec<f32>],
25        num_subvectors: usize,
26        num_centroids: usize,
27        max_iter: usize,
28    ) -> Result<Self, AilakeError> {
29        Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
30    }
31
32    /// Train PQ codebook with a custom k-means backend (e.g. GPU-accelerated).
33    ///
34    /// `kmeans_fn(vecs, k, max_iter)` must return exactly `k` centroids of the
35    /// same dimensionality as `vecs`.  The built-in CPU path passes `kmeans`.
36    pub fn train_with_kmeans<F>(
37        vectors: &[Vec<f32>],
38        num_subvectors: usize,
39        num_centroids: usize,
40        max_iter: usize,
41        kmeans_fn: F,
42    ) -> Result<Self, AilakeError>
43    where
44        F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
45    {
46        if vectors.is_empty() {
47            return Err(AilakeError::Catalog(
48                "PQ training requires at least 1 vector".into(),
49            ));
50        }
51        let dim = vectors[0].len();
52        if !dim.is_multiple_of(num_subvectors) {
53            return Err(AilakeError::Catalog(format!(
54                "dim {dim} not divisible by num_subvectors {num_subvectors}"
55            )));
56        }
57        let sub_dim = dim / num_subvectors;
58        let n_train = num_centroids.min(vectors.len());
59
60        let mut centroids = Vec::with_capacity(num_subvectors);
61        for m in 0..num_subvectors {
62            let start = m * sub_dim;
63            let end = start + sub_dim;
64            let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
65            let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
66            centroids.push(sub_centroids);
67        }
68
69        Ok(Self {
70            num_subvectors,
71            num_centroids,
72            sub_dim,
73            centroids,
74        })
75    }
76
77    /// Encode a single vector into `num_subvectors` u8 codes.
78    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
79        let mut codes = Vec::with_capacity(self.num_subvectors);
80        for m in 0..self.num_subvectors {
81            let start = m * self.sub_dim;
82            let sub = &vector[start..start + self.sub_dim];
83            let best = self.centroids[m]
84                .iter()
85                .enumerate()
86                .map(|(k, c)| (k, l2_sq(sub, c)))
87                .min_by(|a, b| a.1.total_cmp(&b.1))
88                .map(|(k, _)| k)
89                .unwrap_or(0);
90            codes.push(best as u8);
91        }
92        codes
93    }
94
95    /// Decode codes back into an approximate vector (centroid reconstruction).
96    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
97        let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
98        for (m, &code) in codes.iter().enumerate() {
99            out.extend_from_slice(&self.centroids[m][code as usize]);
100        }
101        out
102    }
103
104    /// Precompute query-to-centroid L2 distances for Asymmetric Distance Computation.
105    /// Returns [num_subvectors][num_centroids] distance table.
106    /// ADC is O(M*K) per query precomputation, then O(M) per encoded vector — much faster
107    /// than symmetric distance which would require decoding each vector first.
108    pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
109        (0..self.num_subvectors)
110            .map(|m| {
111                let start = m * self.sub_dim;
112                let q_sub = &query[start..start + self.sub_dim];
113                self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
114            })
115            .collect()
116    }
117
118    /// Compute approximate L2 distance using the precomputed ADC table.
119    pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
120        codes
121            .iter()
122            .enumerate()
123            .map(|(m, &c)| table[m][c as usize])
124            .sum()
125    }
126}
127
128/// K-means clustering (k-means++ init, up to `max_iter` iterations).
129fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
130    let dim = points[0].len();
131    let mut centroids = kmeans_pp_init(points, k);
132
133    for _ in 0..max_iter {
134        // Parallel assignment: each point finds its nearest centroid independently.
135        let assignments: Vec<usize> = points
136            .par_iter()
137            .map(|p| nearest_centroid(p, &centroids))
138            .collect();
139
140        // Update centroids (serial reduction — n×dim is cache-friendly enough here)
141        let mut new_centroids = vec![vec![0.0f32; dim]; k];
142        let mut counts = vec![0usize; k];
143        for (point, &assigned) in points.iter().zip(assignments.iter()) {
144            for (d, &v) in point.iter().enumerate() {
145                new_centroids[assigned][d] += v;
146            }
147            counts[assigned] += 1;
148        }
149        let mut converged = true;
150        for (i, count) in counts.iter().enumerate() {
151            if *count > 0 {
152                let scale = *count as f32;
153                for x in new_centroids[i].iter_mut() {
154                    *x /= scale;
155                }
156            } else {
157                // Empty cluster: keep old centroid
158                new_centroids[i] = centroids[i].clone();
159            }
160            if l2_sq(&new_centroids[i], &centroids[i]) > 1e-8 {
161                converged = false;
162            }
163        }
164        centroids = new_centroids;
165        if converged {
166            break;
167        }
168    }
169    centroids
170}
171
172/// K-means++ centroid initialization — O(n × k) via incremental min-dist update.
173fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
174    let mut centroids = Vec::with_capacity(k);
175    let mut rng_state = 0x123456789u64;
176
177    centroids.push(points[0].clone());
178    // Track min distance from each point to the nearest centroid chosen so far.
179    let mut min_dists: Vec<f32> = points.par_iter().map(|p| l2_sq(p, &centroids[0])).collect();
180
181    while centroids.len() < k {
182        let total: f32 = min_dists.iter().sum();
183        rng_state = rng_state
184            .wrapping_mul(6364136223846793005)
185            .wrapping_add(1442695040888963407);
186        let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
187        let target = r * total;
188        let mut cumsum = 0.0f32;
189        let mut chosen = points.len() - 1;
190        for (i, &d) in min_dists.iter().enumerate() {
191            cumsum += d;
192            if cumsum >= target {
193                chosen = i;
194                break;
195            }
196        }
197        let new_centroid = points[chosen].clone();
198        // Incremental update: only recompute distance to the newly added centroid.
199        points
200            .par_iter()
201            .zip(min_dists.par_iter_mut())
202            .for_each(|(p, min_d)| {
203                let d = l2_sq(p, &new_centroid);
204                if d < *min_d {
205                    *min_d = d;
206                }
207            });
208        centroids.push(new_centroid);
209    }
210    centroids
211}
212
213fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
214    centroids
215        .iter()
216        .enumerate()
217        .map(|(i, c)| (i, l2_sq(point, c)))
218        .min_by(|a, b| a.1.total_cmp(&b.1))
219        .map(|(i, _)| i)
220        .unwrap_or(0)
221}
222
223fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
224    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
225}
226
227/// Train k-means centroids on `vectors`. Returns `k` centroids of same dimensionality.
228/// Exposed for IVF coarse quantizer training.
229pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
230    let k_eff = k.min(vectors.len());
231    kmeans(vectors, k_eff, max_iter)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
239        (0..n)
240            .map(|i| {
241                let mut v = vec![0.0f32; dim];
242                v[i % dim] = 1.0;
243                v
244            })
245            .collect()
246    }
247
248    #[test]
249    fn encode_decode_roundtrip_approx() {
250        let dim = 8;
251        let vecs = unit_vecs(64, dim);
252        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
253        for v in &vecs {
254            let codes = cb.encode(v);
255            assert_eq!(codes.len(), 2);
256            let decoded = cb.decode(&codes);
257            assert_eq!(decoded.len(), dim);
258        }
259    }
260
261    #[test]
262    fn adc_distance_non_negative() {
263        let dim = 8;
264        let vecs = unit_vecs(32, dim);
265        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
266        let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
267        let table = cb.compute_adc_table(&query);
268        for v in &vecs {
269            let codes = cb.encode(v);
270            let dist = cb.adc_distance(&codes, &table);
271            assert!(dist >= 0.0, "ADC distance must be non-negative");
272        }
273    }
274
275    #[test]
276    fn dim_not_divisible_errors() {
277        let vecs = unit_vecs(16, 9);
278        assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
279    }
280
281    #[test]
282    fn nearest_neighbor_rank_preserved() {
283        // Two clusters: vecs around [1,0,...,0] and [0,...,0,1]
284        let dim = 8;
285        let mut vecs: Vec<Vec<f32>> = Vec::new();
286        for _ in 0..20 {
287            let mut v = vec![0.0f32; dim];
288            v[0] = 1.0;
289            vecs.push(v);
290        }
291        for _ in 0..20 {
292            let mut v = vec![0.0f32; dim];
293            v[7] = 1.0;
294            vecs.push(v);
295        }
296        let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
297        let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
298        let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
299        let t1 = cb.compute_adc_table(&q1);
300        let t2 = cb.compute_adc_table(&q2);
301        let code1 = cb.encode(&vecs[0]);
302        let code2 = cb.encode(&vecs[39]);
303        // q1 closer to vecs[0] than to vecs[39]
304        assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
305        // q2 closer to vecs[39] than to vecs[0]
306        assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
307    }
308}