use crate::error::Result;
use crate::turbo::{TurboCode, TurboQuantizer};
use crate::traits::VectorQuantizer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub enum Tier {
Hot,
Warm,
Cold,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub enum TieredCode {
Hot(TurboCode),
Warm(TurboCode),
Cold(TurboCode),
}
impl TieredCode {
pub fn tier(&self) -> Tier {
match self {
TieredCode::Hot(_) => Tier::Hot,
TieredCode::Warm(_) => Tier::Warm,
TieredCode::Cold(_) => Tier::Cold,
}
}
}
#[derive(Debug, Clone)]
pub struct TieredQuantization {
hot: TurboQuantizer,
warm: TurboQuantizer,
cold: TurboQuantizer,
}
impl TieredQuantization {
pub fn new(dim: usize, projections: usize, seed: u64) -> Result<Self> {
let hot = TurboQuantizer::new(dim, 8, projections, seed)?;
let warm = TurboQuantizer::new(dim, 4, projections, seed)?;
let cold = TurboQuantizer::new(dim, 3, projections, seed)?;
Ok(Self { hot, warm, cold })
}
pub fn dim(&self) -> usize {
self.hot.dim()
}
pub fn encode(&self, vector: &[f32], tier: Tier) -> Result<TieredCode> {
match tier {
Tier::Hot => Ok(TieredCode::Hot(self.hot.encode(vector)?)),
Tier::Warm => Ok(TieredCode::Warm(self.warm.encode(vector)?)),
Tier::Cold => Ok(TieredCode::Cold(self.cold.encode(vector)?)),
}
}
pub fn decode(&self, code: &TieredCode) -> Vec<f32> {
match code {
TieredCode::Hot(c) => self.hot.decode(c),
TieredCode::Warm(c) => self.warm.decode(c),
TieredCode::Cold(c) => self.cold.decode(c),
}
}
pub fn inner_product_estimate(&self, code: &TieredCode, query: &[f32]) -> Result<f32> {
match code {
TieredCode::Hot(c) => self.hot.inner_product_estimate(c, query),
TieredCode::Warm(c) => self.warm.inner_product_estimate(c, query),
TieredCode::Cold(c) => self.cold.inner_product_estimate(c, query),
}
}
pub fn l2_distance_estimate(&self, code: &TieredCode, query: &[f32]) -> Result<f32> {
match code {
TieredCode::Hot(c) => self.hot.l2_distance_estimate(c, query),
TieredCode::Warm(c) => self.warm.l2_distance_estimate(c, query),
TieredCode::Cold(c) => self.cold.l2_distance_estimate(c, query),
}
}
pub fn tier(&self, code: &TieredCode) -> Tier {
code.tier()
}
pub fn code_size_bytes(&self, code: &TieredCode) -> usize {
match code {
TieredCode::Hot(c) => self.hot.code_size_bytes(c),
TieredCode::Warm(c) => self.warm.code_size_bytes(c),
TieredCode::Cold(c) => self.cold.code_size_bytes(c),
}
}
pub fn recompress(&self, code: &TieredCode, new_tier: Tier) -> Result<TieredCode> {
let approx = self.decode(code);
self.encode(&approx, new_tier)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vector(dim: usize) -> Vec<f32> {
(0..dim).map(|i| (i as f32 + 1.0) * 0.05).collect()
}
fn make_query(dim: usize) -> Vec<f32> {
(0..dim).map(|i| (i as f32 * 0.03).sin()).collect()
}
#[test]
fn test_new_valid() {
assert!(TieredQuantization::new(64, 32, 42).is_ok());
}
#[test]
fn test_new_zero_dim() {
assert!(TieredQuantization::new(0, 32, 42).is_err());
}
#[test]
fn test_new_odd_dim() {
assert!(TieredQuantization::new(7, 32, 42).is_err());
}
#[test]
fn test_dim() {
let tq = TieredQuantization::new(64, 32, 42).unwrap();
assert_eq!(tq.dim(), 64);
}
#[test]
fn test_encode_hot() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
let code = tq.encode(&v, Tier::Hot).unwrap();
assert!(matches!(code, TieredCode::Hot(_)));
}
#[test]
fn test_encode_warm() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
let code = tq.encode(&v, Tier::Warm).unwrap();
assert!(matches!(code, TieredCode::Warm(_)));
}
#[test]
fn test_encode_cold() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
let code = tq.encode(&v, Tier::Cold).unwrap();
assert!(matches!(code, TieredCode::Cold(_)));
}
#[test]
fn test_decode_shape() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = tq.encode(&v, tier).unwrap();
let decoded = tq.decode(&code);
assert_eq!(decoded.len(), 16, "tier {tier:?} decode len mismatch");
}
}
#[test]
fn test_inner_product_estimate() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let q = make_query(32);
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = tq.encode(&v, tier).unwrap();
let ip = tq.inner_product_estimate(&code, &q).unwrap();
assert!(ip.is_finite(), "tier {tier:?} IP should be finite");
}
}
#[test]
fn test_l2_distance_estimate() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let q = make_query(32);
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = tq.encode(&v, tier).unwrap();
let l2 = tq.l2_distance_estimate(&code, &q).unwrap();
assert!(l2 >= 0.0, "tier {tier:?} L2 should be non-negative");
assert!(l2.is_finite(), "tier {tier:?} L2 should be finite");
}
}
#[test]
fn test_code_size_bytes_same_across_tiers() {
let tq = TieredQuantization::new(64, 32, 42).unwrap();
let v = make_vector(64);
let hot = tq.encode(&v, Tier::Hot).unwrap();
let warm = tq.encode(&v, Tier::Warm).unwrap();
let cold = tq.encode(&v, Tier::Cold).unwrap();
assert_eq!(tq.code_size_bytes(&hot), tq.code_size_bytes(&warm));
assert_eq!(tq.code_size_bytes(&warm), tq.code_size_bytes(&cold));
}
#[test]
fn test_tier_returns_correct_variant() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
assert_eq!(tq.tier(&tq.encode(&v, Tier::Hot).unwrap()), Tier::Hot);
assert_eq!(tq.tier(&tq.encode(&v, Tier::Warm).unwrap()), Tier::Warm);
assert_eq!(tq.tier(&tq.encode(&v, Tier::Cold).unwrap()), Tier::Cold);
}
#[test]
fn test_recompress_hot_to_cold() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let hot = tq.encode(&v, Tier::Hot).unwrap();
let cold = tq.recompress(&hot, Tier::Cold).unwrap();
assert_eq!(tq.tier(&cold), Tier::Cold);
assert_eq!(tq.decode(&cold).len(), 32);
}
#[test]
fn test_recompress_cold_to_hot() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let cold = tq.encode(&v, Tier::Cold).unwrap();
let hot = tq.recompress(&cold, Tier::Hot).unwrap();
assert_eq!(tq.tier(&hot), Tier::Hot);
}
#[test]
fn test_encode_dimension_mismatch() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = vec![0.0_f32; 8]; assert!(tq.encode(&v, Tier::Hot).is_err());
}
#[test]
fn test_inner_product_self_positive() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v: Vec<f32> = (1..=32).map(|i| i as f32 * 0.1).collect();
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = tq.encode(&v, tier).unwrap();
let ip = tq.inner_product_estimate(&code, &v).unwrap();
assert!(ip > 0.0, "tier {tier:?} self-IP should be positive, got {ip}");
}
}
#[test]
fn test_code_size_bytes_positive() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = tq.encode(&v, tier).unwrap();
assert!(tq.code_size_bytes(&code) > 0);
}
}
#[test]
fn test_tiered_code_clone() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let v = make_vector(16);
let hot = tq.encode(&v, Tier::Hot).unwrap();
let cloned = hot.clone();
assert_eq!(tq.tier(&cloned), Tier::Hot);
}
#[test]
fn test_recompress_same_tier() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let warm = tq.encode(&v, Tier::Warm).unwrap();
let rewarm = tq.recompress(&warm, Tier::Warm).unwrap();
assert_eq!(tq.tier(&rewarm), Tier::Warm);
assert_eq!(tq.decode(&rewarm).len(), 32);
}
#[test]
fn test_recompress_warm_to_hot() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let warm = tq.encode(&v, Tier::Warm).unwrap();
let hot = tq.recompress(&warm, Tier::Hot).unwrap();
assert_eq!(tq.tier(&hot), Tier::Hot);
}
#[test]
fn test_l2_estimate_self_is_near_zero() {
let tq = TieredQuantization::new(32, 16, 42).unwrap();
let v = make_vector(32);
let hot = tq.encode(&v, Tier::Hot).unwrap();
let decoded = tq.decode(&hot);
let l2 = tq
.l2_distance_estimate(&tq.encode(&decoded, Tier::Hot).unwrap(), &decoded)
.unwrap();
assert!(l2 < 1e-3, "L2 self-distance should be ~0, got {l2}");
}
#[test]
fn test_encode_non_finite_rejected() {
let tq = TieredQuantization::new(16, 8, 1).unwrap();
let mut v = make_vector(16);
v[3] = f32::NAN;
assert!(tq.encode(&v, Tier::Hot).is_err());
assert!(tq.encode(&v, Tier::Warm).is_err());
assert!(tq.encode(&v, Tier::Cold).is_err());
}
}