Skip to main content

ripvec_core/
turbo_quant.rs

1//! TurboQuant: high-throughput compressed vector search.
2//!
3//! PolarQuant at 4 bits: rotated embedding pairs are encoded as (radius, angle_index).
4//! The scan uses a pre-computed centroid-query dot product table (24 KB, fits in L1)
5//! and streams sequentially through packed radii + indices — cache-line optimal.
6//!
7//! # Memory layout (SoA, not AoS)
8//!
9//! ```text
10//! CompressedCorpus:
11//!   radii:   [n × pairs] f32, contiguous — sequential streaming reads
12//!   indices: [n × pairs] u8,  contiguous — sequential streaming reads
13//!   (future: 4-bit packed indices → [n × pairs / 2] u8 for 2× index bandwidth)
14//! ```
15//!
16//! This layout enables:
17//! - GPU: one thread per vector, coalesced reads across threads
18//! - CPU NEON: process 4 pairs per SIMD iteration, amortize centroid loads
19//! - Cache: centroid table (24 KB) stays in L1 throughout the scan
20
21use std::f32::consts::PI;
22
23use ndarray::{Array1, Array2};
24use rand::SeedableRng;
25use rand_chacha::ChaCha8Rng;
26use rand_distr::{Distribution, StandardNormal};
27
28// ---------------------------------------------------------------------------
29// Compressed corpus — flat SoA layout for cache-friendly scanning
30// ---------------------------------------------------------------------------
31
32/// Flat, contiguous compressed embeddings for maximum scan throughput.
33///
34/// Structure-of-arrays layout: all radii packed, then all indices packed.
35/// No per-vector heap allocations, no pointer chasing.
36///
37/// At 4-bit, d=768: 384 pairs × (4 + 1) bytes = 1920 bytes/vector.
38/// For 100K vectors: 192 MB (vs 300 MB FP32 or 150 MB FP16).
39pub struct CompressedCorpus {
40    /// Number of vectors.
41    pub n: usize,
42    /// Number of pairs per vector (dim / 2).
43    pub pairs: usize,
44    /// Flat radii: `[n × pairs]` f32, row-major.
45    pub radii: Vec<f32>,
46    /// Flat angle indices: `[n × pairs]` u8, row-major.
47    pub indices: Vec<u8>,
48}
49
50/// Compressed representation of a single vector (for the old API).
51#[derive(Clone)]
52pub struct CompressedCode {
53    /// Per-pair radii (f32).
54    pub radii: Vec<f32>,
55    /// Quantized angle indices in \[0, 2^bits).
56    pub angle_indices: Vec<u8>,
57}
58
59impl CompressedCode {
60    /// Approximate memory in bytes.
61    #[must_use]
62    pub fn encoded_bytes(&self) -> usize {
63        self.radii.len() * 4 + self.angle_indices.len()
64    }
65}
66
67// ---------------------------------------------------------------------------
68// PolarCodec — encode, prepare query, scan
69// ---------------------------------------------------------------------------
70
71/// PolarQuant codec: batch encode, query preparation, and high-throughput scan.
72pub struct PolarCodec {
73    dim: usize,
74    #[expect(dead_code, reason = "stored for serialization / reconstruction")]
75    bits: u8,
76    levels: usize,
77    pairs: usize,
78    /// Row-major orthogonal rotation matrix \[dim × dim\].
79    rotation: Array2<f32>,
80    /// Pre-computed cos/sin for each quantized angle level.
81    cos_table: Vec<f32>,
82    sin_table: Vec<f32>,
83}
84
85impl PolarCodec {
86    /// Create a new codec.
87    ///
88    /// # Panics
89    ///
90    /// Panics if `dim` is 0, odd, or `bits` is 0 or > 8.
91    #[must_use]
92    pub fn new(dim: usize, bits: u8, seed: u64) -> Self {
93        assert!(
94            dim > 0 && dim.is_multiple_of(2),
95            "dim must be even and non-zero"
96        );
97        assert!(bits > 0 && bits <= 8, "bits must be 1..=8");
98
99        let levels = 1usize << bits;
100        let pairs = dim / 2;
101        let rotation = generate_rotation(dim, seed);
102
103        let mut cos_table = Vec::with_capacity(levels);
104        let mut sin_table = Vec::with_capacity(levels);
105        for j in 0..levels {
106            let theta = (j as f32 / levels as f32) * 2.0 * PI - PI;
107            cos_table.push(theta.cos());
108            sin_table.push(theta.sin());
109        }
110
111        Self {
112            dim,
113            bits,
114            levels,
115            pairs,
116            rotation,
117            cos_table,
118            sin_table,
119        }
120    }
121
122    /// Number of pairs per vector.
123    #[must_use]
124    pub fn pairs(&self) -> usize {
125        self.pairs
126    }
127
128    /// Encode a single vector (convenience, allocates).
129    #[must_use]
130    pub fn encode(&self, vector: &[f32]) -> CompressedCode {
131        assert_eq!(vector.len(), self.dim);
132        let x = Array1::from_vec(vector.to_vec());
133        let rotated = self.rotation.dot(&x);
134
135        let mut radii = Vec::with_capacity(self.pairs);
136        let mut angle_indices = Vec::with_capacity(self.pairs);
137        for i in 0..self.pairs {
138            let (r, idx) = self.encode_pair(rotated[2 * i], rotated[2 * i + 1]);
139            radii.push(r);
140            angle_indices.push(idx);
141        }
142        CompressedCode {
143            radii,
144            angle_indices,
145        }
146    }
147
148    /// Batch-encode into a flat [`CompressedCorpus`] (SoA layout).
149    ///
150    /// Uses BLAS for the rotation: `rotated = vectors × Πᵀ` (one GEMM).
151    /// Quantization is scalar but cache-friendly (sequential writes).
152    #[must_use]
153    pub fn encode_batch(&self, vectors: &Array2<f32>) -> CompressedCorpus {
154        assert_eq!(vectors.ncols(), self.dim);
155        let n = vectors.nrows();
156
157        // Batch rotation via BLAS: [n, dim] × [dim, dim]ᵀ → [n, dim]
158        let rotated = vectors.dot(&self.rotation.t());
159
160        let total = n * self.pairs;
161        let mut radii = Vec::with_capacity(total);
162        let mut indices = Vec::with_capacity(total);
163
164        for row in 0..n {
165            for i in 0..self.pairs {
166                let (r, idx) = self.encode_pair(rotated[[row, 2 * i]], rotated[[row, 2 * i + 1]]);
167                radii.push(r);
168                indices.push(idx);
169            }
170        }
171
172        CompressedCorpus {
173            n,
174            pairs: self.pairs,
175            radii,
176            indices,
177        }
178    }
179
180    /// Also produce the old per-vector codes (for backward compat with index.rs).
181    #[must_use]
182    pub fn encode_batch_codes(&self, vectors: &Array2<f32>) -> Vec<CompressedCode> {
183        let corpus = self.encode_batch(vectors);
184        (0..corpus.n)
185            .map(|v| {
186                let off = v * corpus.pairs;
187                CompressedCode {
188                    radii: corpus.radii[off..off + corpus.pairs].to_vec(),
189                    angle_indices: corpus.indices[off..off + corpus.pairs].to_vec(),
190                }
191            })
192            .collect()
193    }
194
195    /// Prepare query-dependent centroid lookup table.
196    ///
197    /// Cost: one 768×768 matvec + 384×16 multiply-adds = ~0.08ms.
198    /// The returned [`QueryState`] is reused for ALL vectors in the scan.
199    #[must_use]
200    pub fn prepare_query(&self, query: &[f32]) -> QueryState {
201        assert_eq!(query.len(), self.dim);
202        let q = Array1::from_vec(query.to_vec());
203        let rotated = self.rotation.dot(&q);
204
205        // centroid_q[pair * levels + level] = q_a·cos(θ_level) + q_b·sin(θ_level)
206        // Layout: pairs × levels, contiguous. Fits in L1 (384 × 16 × 4 = 24 KB).
207        let mut centroid_q = vec![0.0f32; self.pairs * self.levels];
208        for i in 0..self.pairs {
209            let q_a = rotated[2 * i];
210            let q_b = rotated[2 * i + 1];
211            let base = i * self.levels;
212            for j in 0..self.levels {
213                centroid_q[base + j] = q_a * self.cos_table[j] + q_b * self.sin_table[j];
214            }
215        }
216
217        QueryState {
218            centroid_q,
219            pairs: self.pairs,
220            levels: self.levels,
221        }
222    }
223
224    /// High-throughput scan of a [`CompressedCorpus`] against a prepared query.
225    ///
226    /// Returns approximate inner product scores for all vectors.
227    /// Memory access: sequential streaming through radii + indices (cache-optimal).
228    /// Centroid table: 24 KB, stays in L1 throughout.
229    ///
230    /// At 100K vectors, d=768: ~3.3ms on CPU, ~0.1ms on GPU (future Metal kernel).
231    #[must_use]
232    pub fn scan_corpus(&self, corpus: &CompressedCorpus, qs: &QueryState) -> Vec<f32> {
233        let n = corpus.n;
234        let pairs = corpus.pairs;
235        let mut scores = vec![0.0f32; n];
236
237        // Hot loop: sequential access to radii + indices,
238        // random-but-L1-hot access to centroid table.
239        for v in 0..n {
240            let base = v * pairs;
241            let mut score = 0.0f32;
242
243            // Process 4 pairs per iteration (manual unroll for ILP).
244            let chunks = pairs / 4;
245            let remainder = pairs % 4;
246
247            for c in 0..chunks {
248                let i = base + c * 4;
249                let i0 = corpus.indices[i] as usize;
250                let i1 = corpus.indices[i + 1] as usize;
251                let i2 = corpus.indices[i + 2] as usize;
252                let i3 = corpus.indices[i + 3] as usize;
253
254                let p = c * 4;
255                score += corpus.radii[i] * qs.centroid_q[p * qs.levels + i0];
256                score += corpus.radii[i + 1] * qs.centroid_q[(p + 1) * qs.levels + i1];
257                score += corpus.radii[i + 2] * qs.centroid_q[(p + 2) * qs.levels + i2];
258                score += corpus.radii[i + 3] * qs.centroid_q[(p + 3) * qs.levels + i3];
259            }
260            for r in 0..remainder {
261                let i = base + chunks * 4 + r;
262                let p = chunks * 4 + r;
263                let j = corpus.indices[i] as usize;
264                score += corpus.radii[i] * qs.centroid_q[p * qs.levels + j];
265            }
266
267            scores[v] = score;
268        }
269
270        scores
271    }
272
273    /// Scan per-vector codes (old API, for backward compat).
274    #[must_use]
275    pub fn batch_scan(&self, codes: &[CompressedCode], qs: &QueryState) -> Vec<f32> {
276        codes
277            .iter()
278            .map(|code| {
279                let mut score = 0.0f32;
280                for i in 0..qs.pairs {
281                    let j = code.angle_indices[i] as usize;
282                    score += code.radii[i] * qs.centroid_q[i * qs.levels + j];
283                }
284                score
285            })
286            .collect()
287    }
288
289    #[inline]
290    fn encode_pair(&self, a: f32, b: f32) -> (f32, u8) {
291        let r = (a * a + b * b).sqrt();
292        let theta = b.atan2(a);
293        let normalized = (theta + PI) / (2.0 * PI);
294        let idx = ((normalized * self.levels as f32) as usize).min(self.levels - 1);
295        (r, idx as u8)
296    }
297}
298
299/// Pre-computed query state for fast scanning.
300pub struct QueryState {
301    /// Flat `[pairs × levels]` centroid-query dot products (24 KB at d=768, 4-bit).
302    pub centroid_q: Vec<f32>,
303    /// Number of pairs.
304    pub pairs: usize,
305    /// Number of quantization levels.
306    pub levels: usize,
307}
308
309// ---------------------------------------------------------------------------
310// Rotation matrix generation (seeded, deterministic)
311// ---------------------------------------------------------------------------
312
313/// Generate a d×d orthogonal matrix via QR on a seeded Gaussian matrix.
314fn generate_rotation(dim: usize, seed: u64) -> Array2<f32> {
315    let mut rng = ChaCha8Rng::seed_from_u64(seed);
316    let mut data = Vec::with_capacity(dim * dim);
317    for _ in 0..(dim * dim) {
318        data.push(StandardNormal.sample(&mut rng));
319    }
320    let a = Array2::from_shape_vec((dim, dim), data).expect("shape matches data length");
321    gram_schmidt_qr(a)
322}
323
324/// Modified Gram-Schmidt → Q (orthogonal).
325fn gram_schmidt_qr(mut q: Array2<f32>) -> Array2<f32> {
326    let n = q.ncols();
327    for i in 0..n {
328        let norm: f32 = q.column(i).iter().map(|x| x * x).sum::<f32>().sqrt();
329        if norm < 1e-10 {
330            continue;
331        }
332        let inv = 1.0 / norm;
333        for row in 0..q.nrows() {
334            q[[row, i]] *= inv;
335        }
336        for j in (i + 1)..n {
337            let dot: f32 = (0..q.nrows()).map(|row| q[[row, i]] * q[[row, j]]).sum();
338            for row in 0..q.nrows() {
339                q[[row, j]] -= dot * q[[row, i]];
340            }
341        }
342    }
343    q
344}
345
346// ---------------------------------------------------------------------------
347// Tests
348// ---------------------------------------------------------------------------
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    fn l2_normalize(v: &mut [f32]) {
355        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
356        if norm > 1e-10 {
357            for x in v.iter_mut() {
358                *x /= norm;
359            }
360        }
361    }
362
363    #[test]
364    fn rotation_is_orthogonal() {
365        let r = generate_rotation(8, 42);
366        let eye = r.dot(&r.t());
367        for i in 0..8 {
368            for j in 0..8 {
369                let expected = if i == j { 1.0 } else { 0.0 };
370                assert!(
371                    (eye[[i, j]] - expected).abs() < 1e-5,
372                    "Q×Qᵀ[{i},{j}] = {}, expected {expected}",
373                    eye[[i, j]]
374                );
375            }
376        }
377    }
378
379    #[test]
380    fn encode_decode_roundtrip() {
381        let codec = PolarCodec::new(8, 4, 42);
382        let mut v = vec![0.3, -0.1, 0.5, 0.2, -0.4, 0.1, 0.3, -0.2];
383        l2_normalize(&mut v);
384        let code = codec.encode(&v);
385        assert_eq!(code.radii.len(), 4);
386        assert_eq!(code.angle_indices.len(), 4);
387    }
388
389    #[test]
390    fn corpus_scan_recall_and_throughput() {
391        let dim = 768;
392        let n = 1000;
393        let codec = PolarCodec::new(dim, 4, 42);
394
395        // Generate random L2-normalized vectors
396        let mut vecs = Array2::<f32>::zeros((n, dim));
397        for i in 0..n {
398            for d in 0..dim {
399                vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
400            }
401            let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
402            for d in 0..dim {
403                vecs[[i, d]] /= norm;
404            }
405        }
406
407        // Encode to SoA corpus
408        let t0 = std::time::Instant::now();
409        let corpus = codec.encode_batch(&vecs);
410        let encode_ms = t0.elapsed().as_secs_f64() * 1000.0;
411        eprintln!(
412            "encode {n} → SoA corpus: {encode_ms:.1}ms ({:.1}µs/vec)",
413            encode_ms * 1000.0 / n as f64
414        );
415
416        // Query
417        let mut query = vec![0.0f32; dim];
418        for d in 0..dim {
419            query[d] = ((42 * 7 + d * 13) as f32).sin();
420        }
421        l2_normalize(&mut query);
422
423        // Exact ranking
424        let query_arr = Array1::from_vec(query.clone());
425        let mut exact: Vec<(usize, f32)> =
426            (0..n).map(|i| (i, vecs.row(i).dot(&query_arr))).collect();
427        exact.sort_by(|a, b| b.1.total_cmp(&a.1));
428
429        // TurboQuant corpus scan
430        let t1 = std::time::Instant::now();
431        let qs = codec.prepare_query(&query);
432        let prep_us = t1.elapsed().as_secs_f64() * 1e6;
433
434        let t2 = std::time::Instant::now();
435        let scores = codec.scan_corpus(&corpus, &qs);
436        let scan_us = t2.elapsed().as_secs_f64() * 1e6;
437
438        eprintln!(
439            "prepare: {prep_us:.0}µs, scan {n}: {scan_us:.0}µs ({:.2}µs/vec)",
440            scan_us / n as f64
441        );
442        eprintln!("scan throughput: {:.1}M vec/s", n as f64 / scan_us);
443
444        // Recall@10
445        let mut approx: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
446        approx.sort_by(|a, b| b.1.total_cmp(&a.1));
447        let exact_top10: Vec<usize> = exact.iter().take(10).map(|(i, _)| *i).collect();
448        let approx_top10: Vec<usize> = approx.iter().take(10).map(|(i, _)| *i).collect();
449        let recall = exact_top10
450            .iter()
451            .filter(|i| approx_top10.contains(i))
452            .count();
453        eprintln!("Recall@10: {recall}/10");
454        // Raw scan recall (no re-rank) is 4-7/10 for PolarQuant-only 4-bit.
455        // With exact re-rank of top-100 (SearchIndex::rank_turboquant), recall is 10/10.
456        assert!(
457            recall >= 4,
458            "raw scan recall should be >= 4/10, got {recall}/10"
459        );
460    }
461
462    /// GPU vs CPU scan benchmark (Metal only).
463    #[test]
464    #[cfg(feature = "metal")]
465    fn metal_turboquant_scan() {
466        let dim = 768;
467        let n = 10_000;
468        let codec = PolarCodec::new(dim, 4, 42);
469
470        // Generate corpus
471        let mut vecs = Array2::<f32>::zeros((n, dim));
472        for i in 0..n {
473            for d in 0..dim {
474                vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
475            }
476            let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
477            for d in 0..dim {
478                vecs[[i, d]] /= norm;
479            }
480        }
481
482        let corpus = codec.encode_batch(&vecs);
483        let mut query = vec![0.0f32; dim];
484        for d in 0..dim {
485            query[d] = ((42 * 7 + d * 13) as f32).sin();
486        }
487        l2_normalize(&mut query);
488        let qs = codec.prepare_query(&query);
489
490        // CPU scan
491        let t0 = std::time::Instant::now();
492        let cpu_scores = codec.scan_corpus(&corpus, &qs);
493        let cpu_us = t0.elapsed().as_secs_f64() * 1e6;
494
495        // GPU scan — upload once, scan twice to measure warm vs cold
496        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
497
498        // Cold: upload + scan (includes buffer creation)
499        let t_cold = std::time::Instant::now();
500        let gpu_corpus = driver
501            .turboquant_upload_corpus(&corpus.radii, &corpus.indices)
502            .unwrap();
503        let upload_us = t_cold.elapsed().as_secs_f64() * 1e6;
504
505        let t_warm = std::time::Instant::now();
506        let gpu_scores = driver
507            .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
508            .unwrap();
509        let warm_us = t_warm.elapsed().as_secs_f64() * 1e6;
510
511        // Second scan — fully warm (centroid upload only)
512        let t_hot = std::time::Instant::now();
513        let _ = driver
514            .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
515            .unwrap();
516        let hot_us = t_hot.elapsed().as_secs_f64() * 1e6;
517
518        eprintln!("10K vectors:");
519        eprintln!("  CPU:        {cpu_us:.0}µs ({:.1}M/s)", n as f64 / cpu_us);
520        eprintln!("  GPU upload: {upload_us:.0}µs (one-time)");
521        eprintln!(
522            "  GPU warm:   {warm_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
523            n as f64 / warm_us,
524            cpu_us / warm_us
525        );
526        eprintln!(
527            "  GPU hot:    {hot_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
528            n as f64 / hot_us,
529            cpu_us / hot_us
530        );
531
532        // Verify GPU matches CPU (approximate — f32 accumulation order differs)
533        let mut max_diff = 0.0f32;
534        for i in 0..n {
535            let diff = (cpu_scores[i] - gpu_scores[i]).abs();
536            if diff > max_diff {
537                max_diff = diff;
538            }
539        }
540        eprintln!("max CPU/GPU score diff: {max_diff:.6}");
541        assert!(
542            max_diff < 0.01,
543            "GPU scores should match CPU within 0.01, got {max_diff}"
544        );
545    }
546}