Skip to main content

ailake_vec/
pq.rs

1// Product Quantization — reduces per-vector storage from dim*4 bytes to num_subvectors bytes.
2// At dim=1536, M=48: 6144 bytes → 48 bytes per vector (128x reduction, ~93-95% recall@10).
3
4use ailake_core::AilakeError;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PQCodebook {
9    /// Number of sub-vectors (M)
10    pub num_subvectors: usize,
11    /// Number of centroids per sub-space (K, typically 256 so codes fit in u8)
12    pub num_centroids: usize,
13    /// Dimensionality of each sub-vector = dim / num_subvectors
14    pub sub_dim: usize,
15    /// Centroids: [num_subvectors][num_centroids][sub_dim]
16    pub centroids: Vec<Vec<Vec<f32>>>,
17}
18
19impl PQCodebook {
20    /// Train PQ codebook via k-means on each sub-space independently.
21    pub fn train(
22        vectors: &[Vec<f32>],
23        num_subvectors: usize,
24        num_centroids: usize,
25        max_iter: usize,
26    ) -> Result<Self, AilakeError> {
27        Self::train_with_kmeans(vectors, num_subvectors, num_centroids, max_iter, kmeans)
28    }
29
30    /// Train PQ codebook with a custom k-means backend (e.g. GPU-accelerated).
31    ///
32    /// `kmeans_fn(vecs, k, max_iter)` must return exactly `k` centroids of the
33    /// same dimensionality as `vecs`.  The built-in CPU path passes `kmeans`.
34    pub fn train_with_kmeans<F>(
35        vectors: &[Vec<f32>],
36        num_subvectors: usize,
37        num_centroids: usize,
38        max_iter: usize,
39        kmeans_fn: F,
40    ) -> Result<Self, AilakeError>
41    where
42        F: Fn(&[Vec<f32>], usize, usize) -> Vec<Vec<f32>>,
43    {
44        if vectors.is_empty() {
45            return Err(AilakeError::Catalog(
46                "PQ training requires at least 1 vector".into(),
47            ));
48        }
49        let dim = vectors[0].len();
50        if !dim.is_multiple_of(num_subvectors) {
51            return Err(AilakeError::Catalog(format!(
52                "dim {dim} not divisible by num_subvectors {num_subvectors}"
53            )));
54        }
55        let sub_dim = dim / num_subvectors;
56        let n_train = num_centroids.min(vectors.len());
57
58        let mut centroids = Vec::with_capacity(num_subvectors);
59        for m in 0..num_subvectors {
60            let start = m * sub_dim;
61            let end = start + sub_dim;
62            let sub_vecs: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
63            let sub_centroids = kmeans_fn(&sub_vecs, n_train, max_iter);
64            centroids.push(sub_centroids);
65        }
66
67        Ok(Self {
68            num_subvectors,
69            num_centroids,
70            sub_dim,
71            centroids,
72        })
73    }
74
75    /// Encode a single vector into `num_subvectors` u8 codes.
76    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
77        let mut codes = Vec::with_capacity(self.num_subvectors);
78        for m in 0..self.num_subvectors {
79            let start = m * self.sub_dim;
80            let sub = &vector[start..start + self.sub_dim];
81            let best = self.centroids[m]
82                .iter()
83                .enumerate()
84                .map(|(k, c)| (k, l2_sq(sub, c)))
85                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
86                .map(|(k, _)| k)
87                .unwrap_or(0);
88            codes.push(best as u8);
89        }
90        codes
91    }
92
93    /// Decode codes back into an approximate vector (centroid reconstruction).
94    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
95        let mut out = Vec::with_capacity(self.num_subvectors * self.sub_dim);
96        for (m, &code) in codes.iter().enumerate() {
97            out.extend_from_slice(&self.centroids[m][code as usize]);
98        }
99        out
100    }
101
102    /// Precompute query-to-centroid L2 distances for Asymmetric Distance Computation.
103    /// Returns [num_subvectors][num_centroids] distance table.
104    /// ADC is O(M*K) per query precomputation, then O(M) per encoded vector — much faster
105    /// than symmetric distance which would require decoding each vector first.
106    pub fn compute_adc_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
107        (0..self.num_subvectors)
108            .map(|m| {
109                let start = m * self.sub_dim;
110                let q_sub = &query[start..start + self.sub_dim];
111                self.centroids[m].iter().map(|c| l2_sq(q_sub, c)).collect()
112            })
113            .collect()
114    }
115
116    /// Compute approximate L2 distance using the precomputed ADC table.
117    pub fn adc_distance(&self, codes: &[u8], table: &[Vec<f32>]) -> f32 {
118        codes
119            .iter()
120            .enumerate()
121            .map(|(m, &c)| table[m][c as usize])
122            .sum()
123    }
124}
125
126/// K-means clustering (k-means++ init, up to `max_iter` iterations).
127fn kmeans(points: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
128    let dim = points[0].len();
129    let mut centroids = kmeans_pp_init(points, k);
130
131    for _ in 0..max_iter {
132        // Assign each point to nearest centroid
133        let assignments: Vec<usize> = points
134            .iter()
135            .map(|p| nearest_centroid(p, &centroids))
136            .collect();
137
138        // Update centroids
139        let mut new_centroids = vec![vec![0.0f32; dim]; k];
140        let mut counts = vec![0usize; k];
141        for (point, &assigned) in points.iter().zip(assignments.iter()) {
142            for (d, &v) in point.iter().enumerate() {
143                new_centroids[assigned][d] += v;
144            }
145            counts[assigned] += 1;
146        }
147        let mut converged = true;
148        for (i, count) in counts.iter().enumerate() {
149            if *count > 0 {
150                let scale = *count as f32;
151                for x in new_centroids[i].iter_mut() {
152                    *x /= scale;
153                }
154            } else {
155                // Empty cluster: keep old centroid
156                new_centroids[i] = centroids[i].clone();
157            }
158            if l2_sq(&new_centroids[i], &centroids[i]) > 1e-8 {
159                converged = false;
160            }
161        }
162        centroids = new_centroids;
163        if converged {
164            break;
165        }
166    }
167    centroids
168}
169
170/// K-means++ centroid initialization.
171fn kmeans_pp_init(points: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
172    let mut centroids = vec![points[0].clone()];
173    let mut rng_state = 0x123456789u64;
174
175    while centroids.len() < k {
176        let dists: Vec<f32> = points
177            .iter()
178            .map(|p| {
179                centroids
180                    .iter()
181                    .map(|c| l2_sq(p, c))
182                    .fold(f32::INFINITY, f32::min)
183            })
184            .collect();
185        let total: f32 = dists.iter().sum();
186        // Simple LCG random for deterministic behavior (no rand dep in vec crate)
187        rng_state = rng_state
188            .wrapping_mul(6364136223846793005)
189            .wrapping_add(1442695040888963407);
190        let r = (rng_state >> 33) as f32 / (u32::MAX as f32);
191        let target = r * total;
192        let mut cumsum = 0.0f32;
193        let mut chosen = points.len() - 1;
194        for (i, &d) in dists.iter().enumerate() {
195            cumsum += d;
196            if cumsum >= target {
197                chosen = i;
198                break;
199            }
200        }
201        centroids.push(points[chosen].clone());
202    }
203    centroids
204}
205
206fn nearest_centroid(point: &[f32], centroids: &[Vec<f32>]) -> usize {
207    centroids
208        .iter()
209        .enumerate()
210        .map(|(i, c)| (i, l2_sq(point, c)))
211        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
212        .map(|(i, _)| i)
213        .unwrap_or(0)
214}
215
216fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
217    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
218}
219
220/// Train k-means centroids on `vectors`. Returns `k` centroids of same dimensionality.
221/// Exposed for IVF coarse quantizer training.
222pub fn kmeans_centroids(vectors: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
223    let k_eff = k.min(vectors.len());
224    kmeans(vectors, k_eff, max_iter)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn unit_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
232        (0..n)
233            .map(|i| {
234                let mut v = vec![0.0f32; dim];
235                v[i % dim] = 1.0;
236                v
237            })
238            .collect()
239    }
240
241    #[test]
242    fn encode_decode_roundtrip_approx() {
243        let dim = 8;
244        let vecs = unit_vecs(64, dim);
245        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
246        for v in &vecs {
247            let codes = cb.encode(v);
248            assert_eq!(codes.len(), 2);
249            let decoded = cb.decode(&codes);
250            assert_eq!(decoded.len(), dim);
251        }
252    }
253
254    #[test]
255    fn adc_distance_non_negative() {
256        let dim = 8;
257        let vecs = unit_vecs(32, dim);
258        let cb = PQCodebook::train(&vecs, 2, 4, 50).unwrap();
259        let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
260        let table = cb.compute_adc_table(&query);
261        for v in &vecs {
262            let codes = cb.encode(v);
263            let dist = cb.adc_distance(&codes, &table);
264            assert!(dist >= 0.0, "ADC distance must be non-negative");
265        }
266    }
267
268    #[test]
269    fn dim_not_divisible_errors() {
270        let vecs = unit_vecs(16, 9);
271        assert!(PQCodebook::train(&vecs, 4, 4, 10).is_err());
272    }
273
274    #[test]
275    fn nearest_neighbor_rank_preserved() {
276        // Two clusters: vecs around [1,0,...,0] and [0,...,0,1]
277        let dim = 8;
278        let mut vecs: Vec<Vec<f32>> = Vec::new();
279        for _ in 0..20 {
280            let mut v = vec![0.0f32; dim];
281            v[0] = 1.0;
282            vecs.push(v);
283        }
284        for _ in 0..20 {
285            let mut v = vec![0.0f32; dim];
286            v[7] = 1.0;
287            vecs.push(v);
288        }
289        let cb = PQCodebook::train(&vecs, 2, 4, 100).unwrap();
290        let q1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
291        let q2 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
292        let t1 = cb.compute_adc_table(&q1);
293        let t2 = cb.compute_adc_table(&q2);
294        let code1 = cb.encode(&vecs[0]);
295        let code2 = cb.encode(&vecs[39]);
296        // q1 closer to vecs[0] than to vecs[39]
297        assert!(cb.adc_distance(&code1, &t1) < cb.adc_distance(&code2, &t1));
298        // q2 closer to vecs[39] than to vecs[0]
299        assert!(cb.adc_distance(&code2, &t2) < cb.adc_distance(&code1, &t2));
300    }
301}