use crate::error::{Result, TurboQuantError};
use crate::tiered::Tier;
use crate::traits::{SerializableCode, VectorQuantizer};
use crate::TurboQuantizer;
pub struct AdaptiveQuantizer {
hot: TurboQuantizer,
warm: TurboQuantizer,
cold: TurboQuantizer,
dim: usize,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub enum AdaptiveCode {
Hot(crate::turbo::TurboCode),
Warm(crate::turbo::TurboCode),
Cold(crate::turbo::TurboCode),
}
impl AdaptiveCode {
pub fn tier(&self) -> Tier {
match self {
AdaptiveCode::Hot(_) => Tier::Hot,
AdaptiveCode::Warm(_) => Tier::Warm,
AdaptiveCode::Cold(_) => Tier::Cold,
}
}
pub fn inner(&self) -> &crate::turbo::TurboCode {
match self {
AdaptiveCode::Hot(c) | AdaptiveCode::Warm(c) | AdaptiveCode::Cold(c) => c,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![self.tier() as u8];
bytes.extend(self.inner().to_compact_bytes());
bytes
}
}
pub struct AdaptiveBuilder {
dim: usize,
seed: u64,
hot_bits: u8,
warm_bits: u8,
cold_bits: u8,
projections: Option<usize>,
}
impl AdaptiveQuantizer {
pub fn builder(dim: usize, seed: u64) -> AdaptiveBuilder {
AdaptiveBuilder {
dim,
seed,
hot_bits: 8,
warm_bits: 4,
cold_bits: 3,
projections: None,
}
}
pub fn encode_adaptive(
&self,
vector: &[f32],
tier: Tier,
) -> Result<AdaptiveCode> {
match tier {
Tier::Hot => Ok(AdaptiveCode::Hot(self.hot.encode(vector)?)),
Tier::Warm => Ok(AdaptiveCode::Warm(self.warm.encode(vector)?)),
Tier::Cold => Ok(AdaptiveCode::Cold(self.cold.encode(vector)?)),
}
}
pub fn decode_adaptive(&self, code: &AdaptiveCode) -> Vec<f32> {
match code {
AdaptiveCode::Hot(c) => self.hot.decode(c),
AdaptiveCode::Warm(c) => self.warm.decode(c),
AdaptiveCode::Cold(c) => self.cold.decode(c),
}
}
pub fn inner_product_estimate(
&self,
code: &AdaptiveCode,
query: &[f32],
) -> Result<f32> {
match code {
AdaptiveCode::Hot(c) => self.hot.inner_product_estimate(c, query),
AdaptiveCode::Warm(c) => self.warm.inner_product_estimate(c, query),
AdaptiveCode::Cold(c) => self.cold.inner_product_estimate(c, query),
}
}
pub fn promote(&self, code: &AdaptiveCode) -> Result<AdaptiveCode> {
let decoded = self.decode_adaptive(code);
match code.tier() {
Tier::Cold => self.encode_adaptive(&decoded, Tier::Warm),
Tier::Warm => self.encode_adaptive(&decoded, Tier::Hot),
Tier::Hot => Ok(code.clone()), }
}
pub fn demote(&self, code: &AdaptiveCode) -> Result<AdaptiveCode> {
let decoded = self.decode_adaptive(code);
match code.tier() {
Tier::Hot => self.encode_adaptive(&decoded, Tier::Warm),
Tier::Warm => self.encode_adaptive(&decoded, Tier::Cold),
Tier::Cold => Ok(code.clone()), }
}
pub fn dim(&self) -> usize {
self.dim
}
}
impl AdaptiveBuilder {
pub fn hot_bits(mut self, bits: u8) -> Self {
self.hot_bits = bits;
self
}
pub fn warm_bits(mut self, bits: u8) -> Self {
self.warm_bits = bits;
self
}
pub fn cold_bits(mut self, bits: u8) -> Self {
self.cold_bits = bits;
self
}
pub fn projections(mut self, proj: usize) -> Self {
self.projections = Some(proj);
self
}
pub fn build(self) -> Result<AdaptiveQuantizer> {
if !(self.hot_bits >= self.warm_bits && self.warm_bits >= self.cold_bits) {
return Err(TurboQuantError::InvalidBitWidth(self.warm_bits));
}
let proj = self.projections.unwrap_or(self.dim / 4).max(1);
Ok(AdaptiveQuantizer {
hot: TurboQuantizer::new(self.dim, self.hot_bits, proj, self.seed)?,
warm: TurboQuantizer::new(self.dim, self.warm_bits, proj, self.seed)?,
cold: TurboQuantizer::new(self.dim, self.cold_bits, proj, self.seed)?,
dim: self.dim,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adaptive_encode_decode() {
let aq = AdaptiveQuantizer::builder(64, 42).build().unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
for tier in [Tier::Hot, Tier::Warm, Tier::Cold] {
let code = aq.encode_adaptive(&v, tier).unwrap();
assert_eq!(code.tier(), tier);
let decoded = aq.decode_adaptive(&code);
assert_eq!(decoded.len(), 64);
}
}
#[test]
fn test_promote_demote() {
let aq = AdaptiveQuantizer::builder(64, 42).build().unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
let cold = aq.encode_adaptive(&v, Tier::Cold).unwrap();
assert_eq!(cold.tier(), Tier::Cold);
let warm = aq.promote(&cold).unwrap();
assert_eq!(warm.tier(), Tier::Warm);
let hot = aq.promote(&warm).unwrap();
assert_eq!(hot.tier(), Tier::Hot);
let still_hot = aq.promote(&hot).unwrap();
assert_eq!(still_hot.tier(), Tier::Hot);
let demoted = aq.demote(&hot).unwrap();
assert_eq!(demoted.tier(), Tier::Warm);
let demoted2 = aq.demote(&demoted).unwrap();
assert_eq!(demoted2.tier(), Tier::Cold);
}
#[test]
fn test_inner_product() {
let aq = AdaptiveQuantizer::builder(64, 42).build().unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
let q: Vec<f32> = (0..64).map(|i| i as f32 * 0.02).collect();
let code = aq.encode_adaptive(&v, Tier::Warm).unwrap();
let score = aq.inner_product_estimate(&code, &q).unwrap();
assert!(score.is_finite());
}
#[test]
fn test_custom_bits() {
let aq = AdaptiveQuantizer::builder(64, 42)
.hot_bits(6)
.warm_bits(4)
.cold_bits(3)
.build()
.unwrap();
let v: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
let code = aq.encode_adaptive(&v, Tier::Hot).unwrap();
assert_eq!(code.tier(), Tier::Hot);
}
}