Skip to main content

ripvec_core/
turbo_quant.rs

1//! TurboQuant: high-throughput compressed vector search.
2//!
3//! PolarQuant at 4 bits: rotated embedding pairs are encoded as (radius, angle_index).
4//! The scan uses a pre-computed centroid-query dot product table (24 KB, fits in L1)
5//! and streams sequentially through packed radii + indices — cache-line optimal.
6//!
7//! # Memory layout (SoA, not AoS)
8//!
9//! ```text
10//! CompressedCorpus:
11//!   radii:   [n × pairs] f32, contiguous — sequential streaming reads
12//!   indices: [n × pairs] u8,  contiguous — sequential streaming reads
13//!   (future: 4-bit packed indices → [n × pairs / 2] u8 for 2× index bandwidth)
14//! ```
15//!
16//! This layout enables:
17//! - GPU: one thread per vector, coalesced reads across threads
18//! - CPU NEON: process 4 pairs per SIMD iteration, amortize centroid loads
19//! - Cache: centroid table (24 KB) stays in L1 throughout the scan
20
21use std::f32::consts::PI;
22
23use ndarray::{Array1, Array2};
24use rand::SeedableRng;
25use rand_chacha::ChaCha8Rng;
26use rand_distr::{Distribution, StandardNormal};
27
28// ---------------------------------------------------------------------------
29// Compressed corpus — flat SoA layout for cache-friendly scanning
30// ---------------------------------------------------------------------------
31
32/// Flat, contiguous compressed embeddings for maximum scan throughput.
33///
34/// Structure-of-arrays layout: all radii packed, then all indices packed.
35/// No per-vector heap allocations, no pointer chasing.
36///
37/// At 4-bit, d=768: 384 pairs × (4 + 1) bytes = 1920 bytes/vector.
38/// For 100K vectors: 192 MB (vs 300 MB FP32 or 150 MB FP16).
39pub struct CompressedCorpus {
40    /// Number of vectors.
41    pub n: usize,
42    /// Number of pairs per vector (dim / 2).
43    pub pairs: usize,
44    /// Flat radii: `[n × pairs]` f32, row-major.
45    pub radii: Vec<f32>,
46    /// Flat angle indices: `[n × pairs]` u8, row-major.
47    pub indices: Vec<u8>,
48}
49
50/// Compressed representation of a single vector (for the old API).
51#[derive(Clone)]
52pub struct CompressedCode {
53    /// Per-pair radii (f32).
54    pub radii: Vec<f32>,
55    /// Quantized angle indices in \[0, 2^bits).
56    pub angle_indices: Vec<u8>,
57}
58
59impl CompressedCode {
60    /// Approximate memory in bytes.
61    #[must_use]
62    pub fn encoded_bytes(&self) -> usize {
63        self.radii.len() * 4 + self.angle_indices.len()
64    }
65}
66
67// ---------------------------------------------------------------------------
68// PolarCodec — encode, prepare query, scan
69// ---------------------------------------------------------------------------
70
71/// PolarQuant codec: batch encode, query preparation, and high-throughput scan.
72pub struct PolarCodec {
73    dim: usize,
74    #[expect(dead_code, reason = "stored for serialization / reconstruction")]
75    bits: u8,
76    levels: usize,
77    pairs: usize,
78    /// Row-major orthogonal rotation matrix \[dim × dim\].
79    rotation: Array2<f32>,
80    /// Pre-computed cos/sin for each quantized angle level.
81    cos_table: Vec<f32>,
82    sin_table: Vec<f32>,
83}
84
85impl PolarCodec {
86    /// Create a new codec.
87    ///
88    /// # Panics
89    ///
90    /// Panics if `dim` is 0, odd, or `bits` is 0 or > 8.
91    #[must_use]
92    pub fn new(dim: usize, bits: u8, seed: u64) -> Self {
93        assert!(
94            dim > 0 && dim.is_multiple_of(2),
95            "dim must be even and non-zero"
96        );
97        assert!(bits > 0 && bits <= 8, "bits must be 1..=8");
98
99        let levels = 1usize << bits;
100        let pairs = dim / 2;
101        let rotation = generate_rotation(dim, seed);
102
103        let mut cos_table = Vec::with_capacity(levels);
104        let mut sin_table = Vec::with_capacity(levels);
105        for j in 0..levels {
106            let theta = (j as f32 / levels as f32) * 2.0 * PI - PI;
107            cos_table.push(theta.cos());
108            sin_table.push(theta.sin());
109        }
110
111        Self {
112            dim,
113            bits,
114            levels,
115            pairs,
116            rotation,
117            cos_table,
118            sin_table,
119        }
120    }
121
122    /// Number of pairs per vector.
123    #[must_use]
124    pub fn pairs(&self) -> usize {
125        self.pairs
126    }
127
128    /// Encode a single vector (convenience, allocates).
129    ///
130    /// # Panics
131    ///
132    /// Panics if `vector.len() != self.dim`.
133    #[must_use]
134    pub fn encode(&self, vector: &[f32]) -> CompressedCode {
135        assert_eq!(vector.len(), self.dim);
136        let x = Array1::from_vec(vector.to_vec());
137        let rotated = self.rotation.dot(&x);
138
139        let mut radii = Vec::with_capacity(self.pairs);
140        let mut angle_indices = Vec::with_capacity(self.pairs);
141        for i in 0..self.pairs {
142            let (r, idx) = self.encode_pair(rotated[2 * i], rotated[2 * i + 1]);
143            radii.push(r);
144            angle_indices.push(idx);
145        }
146        CompressedCode {
147            radii,
148            angle_indices,
149        }
150    }
151
152    /// Batch-encode into a flat [`CompressedCorpus`] (SoA layout).
153    ///
154    /// Uses BLAS for the rotation: `rotated = vectors × Πᵀ` (one GEMM).
155    /// Quantization is scalar but cache-friendly (sequential writes).
156    ///
157    /// # Panics
158    ///
159    /// Panics if `vectors.ncols() != self.dim`.
160    #[must_use]
161    pub fn encode_batch(&self, vectors: &Array2<f32>) -> CompressedCorpus {
162        assert_eq!(vectors.ncols(), self.dim);
163        let n = vectors.nrows();
164
165        // Batch rotation via BLAS: [n, dim] × [dim, dim]ᵀ → [n, dim]
166        let rotated = vectors.dot(&self.rotation.t());
167
168        let total = n * self.pairs;
169        let mut radii = Vec::with_capacity(total);
170        let mut indices = Vec::with_capacity(total);
171
172        for row in 0..n {
173            for i in 0..self.pairs {
174                let (r, idx) = self.encode_pair(rotated[[row, 2 * i]], rotated[[row, 2 * i + 1]]);
175                radii.push(r);
176                indices.push(idx);
177            }
178        }
179
180        CompressedCorpus {
181            n,
182            pairs: self.pairs,
183            radii,
184            indices,
185        }
186    }
187
188    /// Also produce the old per-vector codes (for backward compat with index.rs).
189    #[must_use]
190    pub fn encode_batch_codes(&self, vectors: &Array2<f32>) -> Vec<CompressedCode> {
191        let corpus = self.encode_batch(vectors);
192        (0..corpus.n)
193            .map(|v| {
194                let off = v * corpus.pairs;
195                CompressedCode {
196                    radii: corpus.radii[off..off + corpus.pairs].to_vec(),
197                    angle_indices: corpus.indices[off..off + corpus.pairs].to_vec(),
198                }
199            })
200            .collect()
201    }
202
203    /// Prepare query-dependent centroid lookup table.
204    ///
205    /// Cost: one 768×768 matvec + 384×16 multiply-adds = ~0.08ms.
206    /// The returned [`QueryState`] is reused for ALL vectors in the scan.
207    ///
208    /// # Panics
209    ///
210    /// Panics if `query.len() != self.dim`.
211    #[must_use]
212    pub fn prepare_query(&self, query: &[f32]) -> QueryState {
213        assert_eq!(query.len(), self.dim);
214        let q = Array1::from_vec(query.to_vec());
215        let rotated = self.rotation.dot(&q);
216
217        // centroid_q[pair * levels + level] = q_a·cos(θ_level) + q_b·sin(θ_level)
218        // Layout: pairs × levels, contiguous. Fits in L1 (384 × 16 × 4 = 24 KB).
219        let mut centroid_q = vec![0.0f32; self.pairs * self.levels];
220        for i in 0..self.pairs {
221            let q_a = rotated[2 * i];
222            let q_b = rotated[2 * i + 1];
223            let base = i * self.levels;
224            for j in 0..self.levels {
225                centroid_q[base + j] = q_a * self.cos_table[j] + q_b * self.sin_table[j];
226            }
227        }
228
229        QueryState {
230            centroid_q,
231            pairs: self.pairs,
232            levels: self.levels,
233        }
234    }
235
236    /// High-throughput scan of a [`CompressedCorpus`] against a prepared query.
237    ///
238    /// Returns approximate inner product scores for all vectors.
239    /// Memory access: sequential streaming through radii + indices (cache-optimal).
240    /// Centroid table: 24 KB, stays in L1 throughout.
241    ///
242    /// At 100K vectors, d=768: ~3.3ms on CPU, ~0.1ms on GPU (future Metal kernel).
243    #[must_use]
244    #[expect(
245        clippy::needless_range_loop,
246        reason = "index-based loop is clearer for strided SoA access"
247    )]
248    pub fn scan_corpus(&self, corpus: &CompressedCorpus, qs: &QueryState) -> Vec<f32> {
249        let n = corpus.n;
250        let pairs = corpus.pairs;
251        let mut scores = vec![0.0f32; n];
252
253        // Hot loop: sequential access to radii + indices,
254        // random-but-L1-hot access to centroid table.
255        for v in 0..n {
256            let base = v * pairs;
257            let mut score = 0.0f32;
258
259            // Process 4 pairs per iteration (manual unroll for ILP).
260            let chunks = pairs / 4;
261            let remainder = pairs % 4;
262
263            for c in 0..chunks {
264                let i = base + c * 4;
265                let i0 = corpus.indices[i] as usize;
266                let i1 = corpus.indices[i + 1] as usize;
267                let i2 = corpus.indices[i + 2] as usize;
268                let i3 = corpus.indices[i + 3] as usize;
269
270                let p = c * 4;
271                score += corpus.radii[i] * qs.centroid_q[p * qs.levels + i0];
272                score += corpus.radii[i + 1] * qs.centroid_q[(p + 1) * qs.levels + i1];
273                score += corpus.radii[i + 2] * qs.centroid_q[(p + 2) * qs.levels + i2];
274                score += corpus.radii[i + 3] * qs.centroid_q[(p + 3) * qs.levels + i3];
275            }
276            for r in 0..remainder {
277                let i = base + chunks * 4 + r;
278                let p = chunks * 4 + r;
279                let j = corpus.indices[i] as usize;
280                score += corpus.radii[i] * qs.centroid_q[p * qs.levels + j];
281            }
282
283            scores[v] = score;
284        }
285
286        scores
287    }
288
289    /// Scan per-vector codes (old API, for backward compat).
290    #[must_use]
291    pub fn batch_scan(&self, codes: &[CompressedCode], qs: &QueryState) -> Vec<f32> {
292        codes
293            .iter()
294            .map(|code| {
295                let mut score = 0.0f32;
296                for i in 0..qs.pairs {
297                    let j = code.angle_indices[i] as usize;
298                    score += code.radii[i] * qs.centroid_q[i * qs.levels + j];
299                }
300                score
301            })
302            .collect()
303    }
304
305    #[inline]
306    #[expect(
307        clippy::cast_possible_truncation,
308        clippy::cast_sign_loss,
309        reason = "normalized angle [0,1) × levels fits in u8 (max 16 levels)"
310    )]
311    fn encode_pair(&self, a: f32, b: f32) -> (f32, u8) {
312        let r = (a * a + b * b).sqrt();
313        let theta = b.atan2(a);
314        let normalized = (theta + PI) / (2.0 * PI);
315        let idx = ((normalized * self.levels as f32) as usize).min(self.levels - 1);
316        (r, idx as u8)
317    }
318}
319
320/// Pre-computed query state for fast scanning.
321pub struct QueryState {
322    /// Flat `[pairs × levels]` centroid-query dot products (24 KB at d=768, 4-bit).
323    pub centroid_q: Vec<f32>,
324    /// Number of pairs.
325    pub pairs: usize,
326    /// Number of quantization levels.
327    pub levels: usize,
328}
329
330// ---------------------------------------------------------------------------
331// Rotation matrix generation (seeded, deterministic)
332// ---------------------------------------------------------------------------
333
334/// Generate a d×d orthogonal matrix via QR on a seeded Gaussian matrix.
335fn generate_rotation(dim: usize, seed: u64) -> Array2<f32> {
336    let mut rng = ChaCha8Rng::seed_from_u64(seed);
337    let mut data = Vec::with_capacity(dim * dim);
338    for _ in 0..(dim * dim) {
339        data.push(StandardNormal.sample(&mut rng));
340    }
341    let a = Array2::from_shape_vec((dim, dim), data).expect("shape matches data length");
342    gram_schmidt_qr(a)
343}
344
345/// Modified Gram-Schmidt → Q (orthogonal).
346fn gram_schmidt_qr(mut q: Array2<f32>) -> Array2<f32> {
347    let n = q.ncols();
348    for i in 0..n {
349        let norm: f32 = q.column(i).iter().map(|x| x * x).sum::<f32>().sqrt();
350        if norm < 1e-10 {
351            continue;
352        }
353        let inv = 1.0 / norm;
354        for row in 0..q.nrows() {
355            q[[row, i]] *= inv;
356        }
357        for j in (i + 1)..n {
358            let dot: f32 = (0..q.nrows()).map(|row| q[[row, i]] * q[[row, j]]).sum();
359            for row in 0..q.nrows() {
360                q[[row, j]] -= dot * q[[row, i]];
361            }
362        }
363    }
364    q
365}
366
367// ---------------------------------------------------------------------------
368// Tests
369// ---------------------------------------------------------------------------
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn l2_normalize(v: &mut [f32]) {
376        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
377        if norm > 1e-10 {
378            for x in v.iter_mut() {
379                *x /= norm;
380            }
381        }
382    }
383
384    #[test]
385    fn rotation_is_orthogonal() {
386        let r = generate_rotation(8, 42);
387        let eye = r.dot(&r.t());
388        for i in 0..8 {
389            for j in 0..8 {
390                let expected = if i == j { 1.0 } else { 0.0 };
391                assert!(
392                    (eye[[i, j]] - expected).abs() < 1e-5,
393                    "Q×Qᵀ[{i},{j}] = {}, expected {expected}",
394                    eye[[i, j]]
395                );
396            }
397        }
398    }
399
400    #[test]
401    fn encode_decode_roundtrip() {
402        let codec = PolarCodec::new(8, 4, 42);
403        let mut v = vec![0.3, -0.1, 0.5, 0.2, -0.4, 0.1, 0.3, -0.2];
404        l2_normalize(&mut v);
405        let code = codec.encode(&v);
406        assert_eq!(code.radii.len(), 4);
407        assert_eq!(code.angle_indices.len(), 4);
408    }
409
410    #[test]
411    fn corpus_scan_recall_and_throughput() {
412        let dim = 768;
413        let n = 1000;
414        let codec = PolarCodec::new(dim, 4, 42);
415
416        // Generate random L2-normalized vectors
417        let mut vecs = Array2::<f32>::zeros((n, dim));
418        for i in 0..n {
419            for d in 0..dim {
420                vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
421            }
422            let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
423            for d in 0..dim {
424                vecs[[i, d]] /= norm;
425            }
426        }
427
428        // Encode to SoA corpus
429        let t0 = std::time::Instant::now();
430        let corpus = codec.encode_batch(&vecs);
431        let encode_ms = t0.elapsed().as_secs_f64() * 1000.0;
432        eprintln!(
433            "encode {n} → SoA corpus: {encode_ms:.1}ms ({:.1}µs/vec)",
434            encode_ms * 1000.0 / n as f64
435        );
436
437        // Query
438        let mut query = vec![0.0f32; dim];
439        for d in 0..dim {
440            query[d] = ((42 * 7 + d * 13) as f32).sin();
441        }
442        l2_normalize(&mut query);
443
444        // Exact ranking
445        let query_arr = Array1::from_vec(query.clone());
446        let mut exact: Vec<(usize, f32)> =
447            (0..n).map(|i| (i, vecs.row(i).dot(&query_arr))).collect();
448        exact.sort_by(|a, b| b.1.total_cmp(&a.1));
449
450        // TurboQuant corpus scan
451        let t1 = std::time::Instant::now();
452        let qs = codec.prepare_query(&query);
453        let prep_us = t1.elapsed().as_secs_f64() * 1e6;
454
455        let t2 = std::time::Instant::now();
456        let scores = codec.scan_corpus(&corpus, &qs);
457        let scan_us = t2.elapsed().as_secs_f64() * 1e6;
458
459        eprintln!(
460            "prepare: {prep_us:.0}µs, scan {n}: {scan_us:.0}µs ({:.2}µs/vec)",
461            scan_us / n as f64
462        );
463        eprintln!("scan throughput: {:.1}M vec/s", n as f64 / scan_us);
464
465        // Recall@10
466        let mut approx: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
467        approx.sort_by(|a, b| b.1.total_cmp(&a.1));
468        let exact_top10: Vec<usize> = exact.iter().take(10).map(|(i, _)| *i).collect();
469        let approx_top10: Vec<usize> = approx.iter().take(10).map(|(i, _)| *i).collect();
470        let recall = exact_top10
471            .iter()
472            .filter(|i| approx_top10.contains(i))
473            .count();
474        eprintln!("Recall@10: {recall}/10");
475        // Raw scan recall (no re-rank) is 4-7/10 for PolarQuant-only 4-bit.
476        // With exact re-rank of top-100 (SearchIndex::rank_turboquant), recall is 10/10.
477        assert!(
478            recall >= 4,
479            "raw scan recall should be >= 4/10, got {recall}/10"
480        );
481    }
482
483    /// GPU vs CPU scan benchmark (Metal only).
484    #[test]
485    #[cfg(feature = "metal")]
486    fn metal_turboquant_scan() {
487        let dim = 768;
488        let n = 10_000;
489        let codec = PolarCodec::new(dim, 4, 42);
490
491        // Generate corpus
492        let mut vecs = Array2::<f32>::zeros((n, dim));
493        for i in 0..n {
494            for d in 0..dim {
495                vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
496            }
497            let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
498            for d in 0..dim {
499                vecs[[i, d]] /= norm;
500            }
501        }
502
503        let corpus = codec.encode_batch(&vecs);
504        let mut query = vec![0.0f32; dim];
505        for d in 0..dim {
506            query[d] = ((42 * 7 + d * 13) as f32).sin();
507        }
508        l2_normalize(&mut query);
509        let qs = codec.prepare_query(&query);
510
511        // CPU scan
512        let t0 = std::time::Instant::now();
513        let cpu_scores = codec.scan_corpus(&corpus, &qs);
514        let cpu_us = t0.elapsed().as_secs_f64() * 1e6;
515
516        // GPU scan — upload once, scan twice to measure warm vs cold
517        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
518
519        // Cold: upload + scan (includes buffer creation)
520        let t_cold = std::time::Instant::now();
521        let gpu_corpus = driver
522            .turboquant_upload_corpus(&corpus.radii, &corpus.indices)
523            .unwrap();
524        let upload_us = t_cold.elapsed().as_secs_f64() * 1e6;
525
526        let t_warm = std::time::Instant::now();
527        let gpu_scores = driver
528            .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
529            .unwrap();
530        let warm_us = t_warm.elapsed().as_secs_f64() * 1e6;
531
532        // Second scan — fully warm (centroid upload only)
533        let t_hot = std::time::Instant::now();
534        let _ = driver
535            .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
536            .unwrap();
537        let hot_us = t_hot.elapsed().as_secs_f64() * 1e6;
538
539        eprintln!("10K vectors:");
540        eprintln!("  CPU:        {cpu_us:.0}µs ({:.1}M/s)", n as f64 / cpu_us);
541        eprintln!("  GPU upload: {upload_us:.0}µs (one-time)");
542        eprintln!(
543            "  GPU warm:   {warm_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
544            n as f64 / warm_us,
545            cpu_us / warm_us
546        );
547        eprintln!(
548            "  GPU hot:    {hot_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
549            n as f64 / hot_us,
550            cpu_us / hot_us
551        );
552
553        // Verify GPU matches CPU (approximate — f32 accumulation order differs)
554        let mut max_diff = 0.0f32;
555        for i in 0..n {
556            let diff = (cpu_scores[i] - gpu_scores[i]).abs();
557            if diff > max_diff {
558                max_diff = diff;
559            }
560        }
561        eprintln!("max CPU/GPU score diff: {max_diff:.6}");
562        assert!(
563            max_diff < 0.01,
564            "GPU scores should match CPU within 0.01, got {max_diff}"
565        );
566    }
567}