Skip to main content

turbo_quant/
turbo.rs

1//! TurboQuant: profile-selected vector compression using PolarQuant with optional QJL.
2//!
3//! TurboQuant can split its bit budget across two stages:
4//!
5//! 1. **PolarQuant stage** (b−1 bits): Compress the vector via polar encoding.
6//!    Captures the main signal with high fidelity.
7//!
8//! 2. **Optional QJL stage** (1 bit per projection): Apply the Quantized Johnson-Lindenstrauss
9//!    transform to the *residual* (original minus PolarQuant reconstruction).
10//!    This provides a residual correction path whose quality must be benchmarked
11//!    for the workload.
12//!
13//! # Inner Product Estimation
14//!
15//! The combined estimator for ⟨x, query⟩ given TurboCode(x) and raw query y:
16//!
17//! ```text
18//! ⟨x, y⟩ ≈ IP_polar(code, y) + IP_qjl(residual_sketch, y)
19//! ```
20//!
21//! This estimator is approximate; retrieval quality still needs
22//! workload-specific recall/ranking measurement.
23//!
24//! # Reference
25//!
26//! TurboQuant-style two-stage compression with a polar code and residual
27//! quantized Johnson-Lindenstrauss sketch.
28
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31
32use crate::{
33    error::{Result, TurboQuantError},
34    polar::{PolarCode, PolarQuantizer},
35    profile::{CodecProfileV1, CompressionReceiptV1, ValidationState},
36    qjl::{QjlProjectedQuery, QjlQuantizer, QjlSketch},
37    rotation::RotationKind,
38};
39
40/// TurboQuant mode.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
42pub enum TurboMode {
43    /// PolarQuant only. No QJL residual sketch is present.
44    PolarOnly,
45    /// PolarQuant plus a QJL residual sketch.
46    PolarWithQjl,
47}
48
49/// A TurboQuant-compressed vector: polar code plus a QJL residual sketch.
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51pub struct TurboCode {
52    /// PolarQuant code capturing the main signal.
53    pub polar_code: PolarCode,
54    /// QJL sketch of the reconstruction residual (1 bit per projection).
55    pub residual_sketch: QjlSketch,
56}
57
58impl TurboCode {
59    /// Total serialized payload bytes used by this code.
60    pub fn encoded_bytes(&self) -> usize {
61        self.polar_code.encoded_bytes() + self.residual_sketch.encoded_bytes()
62    }
63
64    /// Compression ratio relative to f32 storage of the original vector.
65    pub fn compression_ratio(&self) -> f32 {
66        let original = self.polar_code.dim * std::mem::size_of::<f32>();
67        original as f32 / self.encoded_bytes() as f32
68    }
69
70    /// Validate this code against an expected TurboQuant profile.
71    pub fn validate_for(
72        &self,
73        dim: usize,
74        bits: u8,
75        projections: usize,
76        mode: TurboMode,
77    ) -> Result<()> {
78        let polar_bits = match mode {
79            TurboMode::PolarOnly => bits,
80            TurboMode::PolarWithQjl => bits.saturating_sub(1),
81        };
82        self.polar_code.validate_for(dim, polar_bits)?;
83        match mode {
84            TurboMode::PolarOnly => self.residual_sketch.validate_for(dim, 0),
85            TurboMode::PolarWithQjl => self.residual_sketch.validate_for(dim, projections),
86        }
87    }
88}
89
90/// TurboQuant compressor: encodes vectors and estimates inner products.
91///
92/// Configuration `(dim, bits, projections, seed)` fully determines the
93/// quantizer state. Only these four values need to be persisted; all internal
94/// matrices are regenerated on demand.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct TurboQuantizer {
97    dim: usize,
98    /// Total bits per value: (bits-1) go to PolarQuant, 1 to QJL.
99    bits: u8,
100    /// Number of QJL projections for the residual sketch.
101    projections: usize,
102    seed: u64,
103    mode: TurboMode,
104    polar: PolarQuantizer,
105    qjl: Option<QjlQuantizer>,
106}
107
108/// Query state prepared once for scoring multiple TurboQuant codes.
109#[derive(Debug, Clone, PartialEq)]
110pub struct TurboProjectedQuery {
111    polar: crate::polar::PolarProjectedQuery,
112    qjl: Option<QjlProjectedQuery>,
113}
114
115impl TurboQuantizer {
116    /// Create a new TurboQuant compressor.
117    ///
118    /// - `dim`: vector dimension (must be even, non-zero)
119    /// - `bits`: total bit budget per scalar (2–16). PolarQuant uses `bits-1`,
120    ///   QJL uses 1 bit on the residual.
121    /// - `projections`: QJL sketch dimension. Rule of thumb: `dim / 4` to `dim / 2`.
122    ///   More projections reduce variance but increase sketch size.
123    /// - `seed`: deterministic seed for all random matrices
124    pub fn new(dim: usize, bits: u8, projections: usize, seed: u64) -> Result<Self> {
125        Self::new_with_mode(dim, bits, projections, seed, TurboMode::PolarWithQjl)
126    }
127
128    /// Create a new TurboQuant compressor with explicit QJL mode.
129    pub fn new_with_mode(
130        dim: usize,
131        bits: u8,
132        projections: usize,
133        seed: u64,
134        mode: TurboMode,
135    ) -> Result<Self> {
136        Self::new_with_mode_and_rotation(dim, bits, projections, seed, mode, RotationKind::Auto)
137    }
138
139    /// Create a new TurboQuant compressor with explicit QJL and rotation policies.
140    pub fn new_with_mode_and_rotation(
141        dim: usize,
142        bits: u8,
143        projections: usize,
144        seed: u64,
145        mode: TurboMode,
146        rotation_kind: RotationKind,
147    ) -> Result<Self> {
148        if dim == 0 {
149            return Err(TurboQuantError::ZeroDimension);
150        }
151        if dim % 2 != 0 {
152            return Err(TurboQuantError::OddDimension { got: dim });
153        }
154        let valid_bits = match mode {
155            TurboMode::PolarOnly => (1..=16).contains(&bits),
156            TurboMode::PolarWithQjl => (2..=16).contains(&bits),
157        };
158        if !valid_bits {
159            return Err(TurboQuantError::InvalidBitWidth { got: bits });
160        }
161        if mode == TurboMode::PolarWithQjl && projections == 0 {
162            return Err(TurboQuantError::ZeroProjectionCount);
163        }
164
165        // Separate seeds for polar and QJL so they use independent random matrices.
166        let polar_seed = seed;
167        let qjl_seed = seed.wrapping_add(0xCAFE_BABE_0000_0001);
168
169        let polar_bits = match mode {
170            TurboMode::PolarOnly => bits,
171            TurboMode::PolarWithQjl => bits - 1,
172        };
173        let polar = PolarQuantizer::new_with_rotation(dim, polar_bits, polar_seed, rotation_kind)?;
174        let qjl = match mode {
175            TurboMode::PolarOnly => None,
176            TurboMode::PolarWithQjl => Some(QjlQuantizer::new(dim, projections, qjl_seed)?),
177        };
178
179        Ok(Self {
180            dim,
181            bits,
182            projections,
183            seed,
184            mode,
185            polar,
186            qjl,
187        })
188    }
189
190    /// Create a TurboQuant compressor using dense QR reference rotation.
191    pub fn new_with_stored_rotation(
192        dim: usize,
193        bits: u8,
194        projections: usize,
195        seed: u64,
196    ) -> Result<Self> {
197        Self::new_with_mode_and_rotation(
198            dim,
199            bits,
200            projections,
201            seed,
202            TurboMode::PolarWithQjl,
203            RotationKind::StoredQr,
204        )
205    }
206
207    /// The vector dimension this quantizer operates on.
208    pub fn dim(&self) -> usize {
209        self.dim
210    }
211
212    /// Total bit budget per scalar value.
213    pub fn bits(&self) -> u8 {
214        self.bits
215    }
216
217    /// Number of QJL projections for the residual sketch.
218    pub fn projections(&self) -> usize {
219        self.projections
220    }
221
222    /// Deterministic seed used to derive TurboQuant internal projection state.
223    pub fn seed(&self) -> u64 {
224        self.seed
225    }
226
227    /// TurboQuant mode. QJL is optional and must be benchmark-gated by callers.
228    pub fn mode(&self) -> TurboMode {
229        self.mode
230    }
231
232    /// Resolved PolarQuant rotation backend.
233    pub fn rotation_kind(&self) -> RotationKind {
234        self.polar.rotation_kind()
235    }
236
237    /// Stable profile for this quantizer.
238    pub fn profile(&self) -> CodecProfileV1 {
239        CodecProfileV1::turbo(
240            self.dim,
241            self.bits,
242            self.projections,
243            self.seed,
244            self.mode == TurboMode::PolarWithQjl,
245            self.polar.rotation_kind_label(),
246        )
247    }
248
249    /// Encode a vector into a [`TurboCode`].
250    ///
251    /// # Steps
252    /// 1. Compress via PolarQuant (b-1 bits).
253    /// 2. Reconstruct the PolarQuant approximation.
254    /// 3. Compute residual = original - reconstruction.
255    /// 4. Sketch the residual with QJL.
256    pub fn encode(&self, vector: &[f32]) -> Result<TurboCode> {
257        if vector.len() != self.dim {
258            return Err(TurboQuantError::DimensionMismatch {
259                expected: self.dim,
260                got: vector.len(),
261            });
262        }
263        check_finite_vector(vector)?;
264
265        let polar_code = self.polar.encode(vector)?;
266
267        // Reconstruct to get the residual.
268        let reconstruction = self.polar.decode(&polar_code)?;
269        let residual: Vec<f32> = vector
270            .iter()
271            .zip(reconstruction.iter())
272            .map(|(orig, rec)| orig - rec)
273            .collect();
274
275        let residual_sketch = match &self.qjl {
276            Some(qjl) => qjl.sketch(&residual)?,
277            None => QjlSketch {
278                dim: self.dim,
279                projections: 0,
280                signs: Vec::new(),
281            },
282        };
283
284        Ok(TurboCode {
285            polar_code,
286            residual_sketch,
287        })
288    }
289
290    /// Encode a vector and return a receipt bound to the quantizer profile.
291    pub fn encode_with_receipt(
292        &self,
293        vector: &[f32],
294        source_digest: Option<String>,
295    ) -> Result<(TurboCode, CompressionReceiptV1)> {
296        let code = self.encode(vector)?;
297        let receipt = CompressionReceiptV1::new(
298            self.profile(),
299            source_digest,
300            vector.len(),
301            code.encoded_bytes(),
302            ValidationState::Validated,
303        );
304        Ok((code, receipt))
305    }
306
307    /// Encode a batch of vectors using the same quantizer profile.
308    pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<TurboCode>> {
309        vectors.iter().map(|vector| self.encode(vector)).collect()
310    }
311
312    /// Estimate ⟨original_vector, query⟩ from a [`TurboCode`] and raw query.
313    ///
314    /// Combines the PolarQuant inner product estimate with the optional QJL
315    /// residual correction.
316    pub fn inner_product_estimate(&self, code: &TurboCode, query: &[f32]) -> Result<f32> {
317        let projected = self.prepare_query(query)?;
318        self.inner_product_estimate_prepared(code, &projected)
319    }
320
321    /// Prepare a query once for repeated TurboQuant scoring.
322    pub fn prepare_query(&self, query: &[f32]) -> Result<TurboProjectedQuery> {
323        if query.len() != self.dim {
324            return Err(TurboQuantError::DimensionMismatch {
325                expected: self.dim,
326                got: query.len(),
327            });
328        }
329        check_finite_vector(query)?;
330        Ok(TurboProjectedQuery {
331            polar: self.polar.project_query(query)?,
332            qjl: match &self.qjl {
333                Some(qjl) => Some(qjl.project_query(query)?),
334                None => None,
335            },
336        })
337    }
338
339    /// Estimate inner product using a prepared query.
340    pub fn inner_product_estimate_prepared(
341        &self,
342        code: &TurboCode,
343        query: &TurboProjectedQuery,
344    ) -> Result<f32> {
345        code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
346
347        let polar_estimate = self
348            .polar
349            .inner_product_estimate_with_projected_query(&code.polar_code, &query.polar)?;
350        let qjl_correction = match (&self.qjl, &query.qjl) {
351            (Some(qjl), Some(qjl_query)) => {
352                qjl.inner_product_estimate_with_projected_query(&code.residual_sketch, qjl_query)?
353            }
354            (None, None) => 0.0,
355            _ => {
356                return Err(TurboQuantError::MalformedCode {
357                    reason: "TurboQuant QJL mode/query/code mismatch".into(),
358                });
359            }
360        };
361
362        let score = polar_estimate + qjl_correction;
363        if !score.is_finite() {
364            return Err(TurboQuantError::MalformedCode {
365                reason: "turbo score is not finite".into(),
366            });
367        }
368        Ok(score)
369    }
370
371    /// Score a batch of codes against a prepared query.
372    pub fn score_batch_prepared(
373        &self,
374        query: &TurboProjectedQuery,
375        codes: &[TurboCode],
376    ) -> Result<Vec<f32>> {
377        codes
378            .iter()
379            .map(|code| self.inner_product_estimate_prepared(code, query))
380            .collect()
381    }
382
383    /// Estimate squared L2 distance between the encoded vector and query.
384    ///
385    /// Uses: ||x - y||² = ||x||² + ||y||² - 2⟨x, y⟩.
386    /// The code's norm is derived from stored polar radii (lossless).
387    pub fn l2_distance_estimate(&self, code: &TurboCode, query: &[f32]) -> Result<f32> {
388        let ip = self.inner_product_estimate(code, query)?;
389        let query_norm_sq: f32 = query.iter().map(|x| x * x).sum();
390        let code_norm_sq: f32 = code.polar_code.radii.iter().map(|r| r * r).sum();
391        let distance = (query_norm_sq + code_norm_sq - 2.0 * ip).max(0.0);
392        if !distance.is_finite() {
393            return Err(TurboQuantError::MalformedCode {
394                reason: "turbo l2 distance is not finite".into(),
395            });
396        }
397        Ok(distance)
398    }
399
400    /// Decode the TurboCode to an approximate reconstruction.
401    ///
402    /// Note: this only reconstructs the PolarQuant component. The QJL sketch
403    /// is designed for inner product correction, not reconstruction.
404    pub fn decode_approximate(&self, code: &TurboCode) -> Result<Vec<f32>> {
405        code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
406        self.polar.decode(&code.polar_code)
407    }
408
409    /// Decode a batch of TurboCodes to approximate reconstructions in one
410    /// call. Bit-exact identical to `decode_approximate` for each code
411    /// in turn; the win is amortizing the per-call branch / lookup
412    /// overhead and keeping the rotation's signs (or matrix) hot in
413    /// cache across the whole batch.
414    pub fn decode_approximate_batch(&self, codes: &[TurboCode]) -> Result<Vec<Vec<f32>>> {
415        for code in codes {
416            code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
417        }
418        let polar_refs: Vec<PolarCode> = codes.iter().map(|c| c.polar_code.clone()).collect();
419        self.polar.decode_batch(&polar_refs)
420    }
421
422    /// Encode a vector into deterministic TurboQuant wire bytes.
423    pub fn encode_to_bytes(&self, vector: &[f32]) -> Result<Vec<u8>> {
424        let code = self.encode(vector)?;
425        crate::wire::TurboCodeWireV1::encode(&code, self)
426    }
427
428    /// Decode deterministic TurboQuant wire bytes into a validated code.
429    pub fn decode_code_from_bytes(&self, bytes: &[u8]) -> Result<TurboCode> {
430        crate::wire::TurboCodeWireV1::decode(bytes, self)
431    }
432
433    /// Score deterministic TurboQuant wire bytes against a raw query.
434    pub fn score_inner_product_from_bytes(&self, bytes: &[u8], query: &[f32]) -> Result<f32> {
435        let code = self.decode_code_from_bytes(bytes)?;
436        let prepared = self.prepare_query(query)?;
437        self.inner_product_estimate_prepared(&code, &prepared)
438    }
439
440    /// Summary statistics for a batch of encoded vectors.
441    pub fn batch_stats(&self, codes: &[TurboCode]) -> BatchStats {
442        let total_bytes: usize = codes.iter().map(|c| c.encoded_bytes()).sum();
443        let original_bytes = codes.len() * self.dim * std::mem::size_of::<f32>();
444        BatchStats {
445            count: codes.len(),
446            total_encoded_bytes: total_bytes,
447            total_original_bytes: original_bytes,
448            compression_ratio: if total_bytes > 0 {
449                original_bytes as f32 / total_bytes as f32
450            } else {
451                0.0
452            },
453        }
454    }
455}
456
457fn check_finite_vector(vector: &[f32]) -> Result<()> {
458    if let Some((index, _)) = vector
459        .iter()
460        .enumerate()
461        .find(|(_, value)| !value.is_finite())
462    {
463        return Err(TurboQuantError::NonFiniteInput { index });
464    }
465    Ok(())
466}
467
468/// Compression statistics for a batch of encoded vectors.
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
470pub struct BatchStats {
471    pub count: usize,
472    pub total_encoded_bytes: usize,
473    pub total_original_bytes: usize,
474    pub compression_ratio: f32,
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
482        use rand::SeedableRng;
483        use rand_chacha::ChaCha8Rng;
484        use rand_distr::{Distribution, StandardNormal};
485        let mut rng = ChaCha8Rng::seed_from_u64(seed);
486        (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect()
487    }
488
489    #[test]
490    fn encode_is_deterministic() {
491        let q = TurboQuantizer::new(16, 8, 16, 42).unwrap();
492        let x = random_vector(16, 1);
493        let c1 = q.encode(&x).unwrap();
494        let c2 = q.encode(&x).unwrap();
495        assert_eq!(c1.polar_code, c2.polar_code);
496        assert_eq!(c1.residual_sketch.signs, c2.residual_sketch.signs);
497    }
498
499    #[test]
500    fn inner_product_estimate_outperforms_polar_alone_at_low_bits() {
501        // At low bit widths, PolarQuant alone is biased. TurboQuant should
502        // consistently give a better (closer to true) estimate.
503        let dim = 64;
504        let bits = 4u8; // deliberately low
505        let projections = 64;
506
507        let polar_only = PolarQuantizer::new(dim, bits, 0).unwrap();
508        let turbo = TurboQuantizer::new(dim, bits + 1, projections, 0).unwrap();
509
510        let mut polar_errors = Vec::new();
511        let mut turbo_errors = Vec::new();
512
513        for seed in 0..20u64 {
514            let x = random_vector(dim, seed * 2);
515            let y = random_vector(dim, seed * 2 + 1);
516
517            let exact: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
518
519            let polar_code = polar_only.encode(&x).unwrap();
520            let polar_est = polar_only.inner_product_estimate(&polar_code, &y).unwrap();
521
522            let turbo_code = turbo.encode(&x).unwrap();
523            let turbo_est = turbo.inner_product_estimate(&turbo_code, &y).unwrap();
524
525            polar_errors.push((polar_est - exact).abs());
526            turbo_errors.push((turbo_est - exact).abs());
527        }
528
529        let avg_polar: f32 = polar_errors.iter().sum::<f32>() / polar_errors.len() as f32;
530        let avg_turbo: f32 = turbo_errors.iter().sum::<f32>() / turbo_errors.len() as f32;
531
532        assert!(
533            avg_turbo <= avg_polar * 1.5,
534            "TurboQuant should be competitive with PolarQuant: turbo_avg={avg_turbo:.3}, polar_avg={avg_polar:.3}"
535        );
536    }
537
538    #[test]
539    fn nearest_neighbor_ordering_is_preserved() {
540        let q = TurboQuantizer::new(16, 8, 16, 7).unwrap();
541        let query = random_vector(16, 99);
542
543        let close = {
544            let mut v = query.clone();
545            v.iter_mut().for_each(|x| *x += 0.05);
546            v
547        };
548        let far1 = random_vector(16, 200);
549        let far2 = random_vector(16, 201);
550
551        let code_close = q.encode(&close).unwrap();
552        let code_far1 = q.encode(&far1).unwrap();
553        let code_far2 = q.encode(&far2).unwrap();
554
555        let ip_close = q.inner_product_estimate(&code_close, &query).unwrap();
556        let ip_far1 = q.inner_product_estimate(&code_far1, &query).unwrap();
557        let ip_far2 = q.inner_product_estimate(&code_far2, &query).unwrap();
558
559        assert!(
560            ip_close > ip_far1 && ip_close > ip_far2,
561            "close={ip_close:.3}, far1={ip_far1:.3}, far2={ip_far2:.3}"
562        );
563    }
564
565    #[test]
566    fn compression_ratio_is_positive() {
567        let q = TurboQuantizer::new(64, 8, 32, 0).unwrap();
568        let x = random_vector(64, 1);
569        let code = q.encode(&x).unwrap();
570        assert!(code.compression_ratio() > 1.0);
571    }
572
573    #[test]
574    fn batch_stats_sums_correctly() {
575        let dim = 64;
576        let q = TurboQuantizer::new(dim, 8, 16, 0).unwrap();
577        let codes: Vec<_> = (0..10)
578            .map(|i| q.encode(&random_vector(dim, i)).unwrap())
579            .collect();
580        let stats = q.batch_stats(&codes);
581        assert_eq!(stats.count, 10);
582        assert!(stats.compression_ratio > 1.0);
583        assert_eq!(
584            stats.total_original_bytes,
585            10 * dim * std::mem::size_of::<f32>()
586        );
587    }
588
589    #[test]
590    fn turbo_code_serialization_roundtrip() {
591        let q = TurboQuantizer::new(16, 8, 16, 42).unwrap();
592        let x = random_vector(16, 1);
593        let code = q.encode(&x).unwrap();
594        let json = serde_json::to_string(&code).unwrap();
595        let restored: TurboCode = serde_json::from_str(&json).unwrap();
596        assert_eq!(code, restored);
597    }
598
599    #[test]
600    fn invalid_config_rejected() {
601        assert!(TurboQuantizer::new(0, 8, 16, 0).is_err()); // zero dim
602        assert!(TurboQuantizer::new(7, 8, 16, 0).is_err()); // odd dim
603        assert!(TurboQuantizer::new(8, 1, 16, 0).is_err()); // bits < 2
604        assert!(TurboQuantizer::new(8, 8, 0, 0).is_err()); // zero projections
605    }
606}