Skip to main content

ailake_vec/
pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// Product Quantization — reduces per-vector storage from dim*4 bytes to num_subvectors bytes.
3// At dim=1536, M=48: 6144 bytes → 48 bytes per vector (128x reduction, ~93-95% recall@10).
4
5use ailake_core::AilakeError;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct PQCodebook {
10    /// Number of sub-vectors (M)
11    pub num_subvectors: usize,
12    /// Number of centroids per sub-space (K, typically 256 so codes fit in u8)
13    pub num_centroids: usize,
14    /// Dimensionality of each sub-vector = dim / num_subvectors
15    pub sub_dim: usize,
16    /// Centroids: [num_subvectors][num_centroids][sub_dim]
17    pub centroids: Vec<Vec<Vec<f32>>>,
18}
19
20impl PQCodebook {
21    /// Train PQ codebook via k-means on each sub-space independently.
22    pub fn train(
23        vectors: &[Vec<f32>],
24        num_subvectors: usize,
25        num_centroids: usize,
26        max_iter: usize,
27    ) -> Result<Self, AilakeError> {
28        Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
29    }
30
31    /// Train PQ codebook with a custom k-means backend (e.g. GPU-accelerated).
32    ///
33    /// `kmeans_fn(vecs, k, max_iter)` must return exactly `k` centroids of the
34    /// same dimensionality as `vecs`.  The built-in CPU path passes `kmeans`.
35    pub fn train_with_kmeans<F>(
36        vectors: &[Vec<f32>],
37        num_subvectors: usize,
38        num_centroids: usize,
39        max_iter: usize,
40        kmeans_fn: F,
41    ) -> Result<Self, AilakeError>
42    where
43        F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
44    {
45        if vectors.is_empty() {
46            return Err(AilakeError::Catalog(
47                "PQ training requires at least 1 vector".into(),
48            ));
49        }
50        let dim = vectors[0].len();
51        if !dim.is_multiple_of(num_subvectors) {
52            return Err(AilakeError::Catalog(format!(
53                "dim {dim} not divisible by num_subvectors {num_subvectors}"
54            )));
55        }
56        let sub_dim = dim / num_subvectors;
57        let n_train = num_centroids.min(vectors.len());
58
59        let mut centroids = Vec::with_capacity(num_subvectors);
60        for m in 0..num_subvectors {
61            let start = m * sub_dim;
62            let end = start + sub_dim;
63            let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
64            let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
65            centroids.push(sub_centroids);
66        }
67
68        Ok(Self {
69            num_subvectors,
70            num_centroids,
71            sub_dim,
72            centroids,
73        })
74    }
75
76    /// Encode a single vector into `num_subvectors` u8 codes.
77    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
78        let mut codes = Vec::with_capacity(self.num_subvectors);
79        for m in 0..self.num_subvectors {
80            let start = m * self.sub_dim;
81            let sub = &vector[start..start + self.sub_dim];
82            let best = self.centroids[m]
83                .iter()
84                .enumerate()
85                .map(|(k, c)| (k, l2_sq(sub, c)))
86                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
87                .map(|(k, _)| k)
88                .unwrap_or(0);
89            codes.push(best as u8);
90        }
91        codes
92    }
93
94    /// Decode codes back into an approximate vector (centroid reconstruction).
95    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
96        let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
97        for (m, &code) in codes.iter().enumerate() {
98            out.extend_from_slice(&self.centroids[m][code as usize]);
99        }
100        out
101    }
102
103    /// Precompute query-to-centroid L2 distances for Asymmetric Distance Computation.
104    /// Returns [num_subvectors][num_centroids] distance table.
105    /// ADC is O(M*K) per query precomputation, then O(M) per encoded vector — much faster
106    /// than symmetric distance which would require decoding each vector first.
107    pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
108        (0..self.num_subvectors)
109            .map(|m| {
110                let start = m * self.sub_dim;
111                let q_sub = &query[start..start + self.sub_dim];
112                self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
113            })
114            .collect()
115    }
116
117    /// Compute approximate L2 distance using the precomputed ADC table.
118    pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
119        codes
120            .iter()
121            .enumerate()
122            .map(|(m, &c)| table[m][c as usize])
123            .sum()
124    }
125}
126
127/// K-means clustering (k-means++ init, up to `max_iter` iterations).
128fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
129    let dim = points[0].len();
130    let mut centroids = kmeans_pp_init(points, k);
131
132    for _ in 0..max_iter {
133        // Assign each point to nearest centroid
134        let assignments: Vec<usize> = points
135            .iter()
136            .map(|p| nearest_centroid(p, &centroids))
137            .collect();
138
139        // Update centroids
140        let mut new_centroids = vec![vec![0.0f32; dim]; k];
141        let mut counts = vec![0usize; k];
142        for (point, &assigned) in points.iter().zip(assignments.iter()) {
143            for (d, &v) in point.iter().enumerate() {
144                new_centroids[assigned][d] += v;
145            }
146            counts[assigned] += 1;
147        }
148        let mut converged = true;
149        for (i, count) in counts.iter().enumerate() {
150            if *count > 0 {
151                let scale = *count as f32;
152                for x in new_centroids[i].iter_mut() {
153                    *x /= scale;
154                }
155            } else {
156                // Empty cluster: keep old centroid
157                new_centroids[i] = centroids[i].clone();
158            }
159            if l2_sq(&new_centroids[i], &centroids[i]) > 1e-8 {
160                converged = false;
161            }
162        }
163        centroids = new_centroids;
164        if converged {
165            break;
166        }
167    }
168    centroids
169}
170
171/// K-means++ centroid initialization.
172fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
173    let mut centroids = vec![points[0].clone()];
174    let mut rng_state = 0x123456789u64;
175
176    while centroids.len() < k {
177        let dists: Vec<f32> = points
178            .iter()
179            .map(|p| {
180                centroids
181                    .iter()
182                    .map(|c| l2_sq(p, c))
183                    .fold(f32::INFINITY, f32::min)
184            })
185            .collect();
186        let total: f32 = dists.iter().sum();
187        // Simple LCG random for deterministic behavior (no rand dep in vec crate)
188        rng_state = rng_state
189            .wrapping_mul(6364136223846793005)
190            .wrapping_add(1442695040888963407);
191        let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
192        let target = r * total;
193        let mut cumsum = 0.0f32;
194        let mut chosen = points.len() - 1;
195        for (i, &d) in dists.iter().enumerate() {
196            cumsum += d;
197            if cumsum >= target {
198                chosen = i;
199                break;
200            }
201        }
202        centroids.push(points[chosen].clone());
203    }
204    centroids
205}
206
207fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
208    centroids
209        .iter()
210        .enumerate()
211        .map(|(i, c)| (i, l2_sq(point, c)))
212        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
213        .map(|(i, _)| i)
214        .unwrap_or(0)
215}
216
217fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
218    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
219}
220
221/// Train k-means centroids on `vectors`. Returns `k` centroids of same dimensionality.
222/// Exposed for IVF coarse quantizer training.
223pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
224    let k_eff = k.min(vectors.len());
225    kmeans(vectors, k_eff, max_iter)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
233        (0..n)
234            .map(|i| {
235                let mut v = vec![0.0f32; dim];
236                v[i % dim] = 1.0;
237                v
238            })
239            .collect()
240    }
241
242    #[test]
243    fn encode_decode_roundtrip_approx() {
244        let dim = 8;
245        let vecs = unit_vecs(64, dim);
246        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
247        for v in &vecs {
248            let codes = cb.encode(v);
249            assert_eq!(codes.len(), 2);
250            let decoded = cb.decode(&codes);
251            assert_eq!(decoded.len(), dim);
252        }
253    }
254
255    #[test]
256    fn adc_distance_non_negative() {
257        let dim = 8;
258        let vecs = unit_vecs(32, dim);
259        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
260        let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
261        let table = cb.compute_adc_table(&query);
262        for v in &vecs {
263            let codes = cb.encode(v);
264            let dist = cb.adc_distance(&codes, &table);
265            assert!(dist >= 0.0, "ADC distance must be non-negative");
266        }
267    }
268
269    #[test]
270    fn dim_not_divisible_errors() {
271        let vecs = unit_vecs(16, 9);
272        assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
273    }
274
275    #[test]
276    fn nearest_neighbor_rank_preserved() {
277        // Two clusters: vecs around [1,0,...,0] and [0,...,0,1]
278        let dim = 8;
279        let mut vecs: Vec<Vec<f32>> = Vec::new();
280        for _ in 0..20 {
281            let mut v = vec![0.0f32; dim];
282            v[0] = 1.0;
283            vecs.push(v);
284        }
285        for _ in 0..20 {
286            let mut v = vec![0.0f32; dim];
287            v[7] = 1.0;
288            vecs.push(v);
289        }
290        let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
291        let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
292        let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
293        let t1 = cb.compute_adc_table(&q1);
294        let t2 = cb.compute_adc_table(&q2);
295        let code1 = cb.encode(&vecs[0]);
296        let code2 = cb.encode(&vecs[39]);
297        // q1 closer to vecs[0] than to vecs[39]
298        assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
299        // q2 closer to vecs[39] than to vecs[0]
300        assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
301    }
302}