use super::binary::BinaryVector;
use super::int4::Int4Vector;
use super::quantized::{QuantizedVector, cosine_similarity_i8_trusted, dot_product_i8_trusted};
use super::{cosine_similarity, dot_product};
use crate::error::{EmbedError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormalizationHint {
Unknown,
Unit,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum QuantizationTier {
Full,
Int8,
Int4,
Binary,
}
impl QuantizationTier {
pub fn bytes_per_dim(&self) -> f32 {
match self {
Self::Full => 4.0,
Self::Int8 => 1.0,
Self::Int4 => 0.5,
Self::Binary => 0.125,
}
}
pub fn compression_ratio(&self) -> f32 {
4.0 / self.bytes_per_dim()
}
pub fn storage_bytes(&self, dims: usize) -> usize {
match self {
Self::Full => dims * 4,
Self::Int8 => dims,
Self::Int4 => dims.div_ceil(2),
Self::Binary => dims.div_ceil(8),
}
}
pub fn from_age_seconds(age_secs: u64) -> Self {
const HOUR: u64 = 3600;
const DAY: u64 = 86400;
const WEEK: u64 = 604800;
if age_secs < HOUR {
Self::Full
} else if age_secs < DAY {
Self::Int8
} else if age_secs < WEEK {
Self::Int4
} else {
Self::Binary
}
}
}
#[derive(Debug, Clone)]
pub enum QuantizedData {
Full(Vec<f32>),
Int8(QuantizedVector),
Int4(Int4Vector),
Binary(BinaryVector),
}
impl QuantizedData {
pub fn tier(&self) -> QuantizationTier {
match self {
Self::Full(_) => QuantizationTier::Full,
Self::Int8(_) => QuantizationTier::Int8,
Self::Int4(_) => QuantizationTier::Int4,
Self::Binary(_) => QuantizationTier::Binary,
}
}
pub fn dims(&self) -> usize {
match self {
Self::Full(v) => v.len(),
Self::Int8(q) => q.data.len(),
Self::Int4(q) => q.dims,
Self::Binary(q) => q.dims,
}
}
pub fn storage_bytes(&self) -> usize {
match self {
Self::Full(v) => v.len() * 4,
Self::Int8(q) => q.data.len(),
Self::Int4(q) => q.data.len(),
Self::Binary(q) => q.data.len(),
}
}
pub fn from_f32(vector: &[f32], tier: QuantizationTier) -> Self {
match tier {
QuantizationTier::Full => Self::Full(vector.to_vec()),
QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(vector)),
QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(vector)),
QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(vector)),
}
}
pub fn to_f32(&self) -> Vec<f32> {
match self {
Self::Full(v) => v.clone(),
Self::Int8(q) => q.to_f32(),
Self::Int4(q) => q.to_f32(),
Self::Binary(q) => q.to_f32(),
}
}
pub fn promote(&self, target: QuantizationTier) -> Self {
let f32_data = self.to_f32();
Self::from_f32(&f32_data, target)
}
pub fn demote(&self, target: QuantizationTier) -> Self {
self.promote(target) }
}
#[derive(Debug, Clone)]
pub enum PreparedQuery {
Full(Vec<f32>),
Int8(QuantizedVector),
Int4(Int4Vector),
Binary(BinaryVector),
}
impl PreparedQuery {
#[inline]
pub fn from_f32(query_f32: &[f32], tier: QuantizationTier) -> Self {
match tier {
QuantizationTier::Full => Self::Full(query_f32.to_vec()),
QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(query_f32)),
QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(query_f32)),
QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(query_f32)),
}
}
#[inline]
pub fn tier(&self) -> QuantizationTier {
match self {
Self::Full(_) => QuantizationTier::Full,
Self::Int8(_) => QuantizationTier::Int8,
Self::Int4(_) => QuantizationTier::Int4,
Self::Binary(_) => QuantizationTier::Binary,
}
}
#[inline]
pub fn dims(&self) -> usize {
match self {
Self::Full(v) => v.len(),
Self::Int8(q) => q.data.len(),
Self::Int4(q) => q.dims,
Self::Binary(q) => q.dims,
}
}
}
#[inline]
pub fn prepare_query(query_f32: &[f32], tier: QuantizationTier) -> PreparedQuery {
PreparedQuery::from_f32(query_f32, tier)
}
#[derive(Debug, Clone)]
pub struct PreparedQueryWithMeta {
pub query: PreparedQuery,
pub norm: NormalizationHint,
}
impl PreparedQueryWithMeta {
#[inline]
pub fn from_f32(query_f32: &[f32], tier: QuantizationTier, norm: NormalizationHint) -> Self {
Self {
query: PreparedQuery::from_f32(query_f32, tier),
norm,
}
}
#[inline]
pub fn tier(&self) -> QuantizationTier {
self.query.tier()
}
#[inline]
pub fn dims(&self) -> usize {
self.query.dims()
}
}
#[inline]
pub fn is_unit_norm(v: &[f32]) -> bool {
let sq: f32 = v.iter().map(|x| x * x).sum();
(sq - 1.0).abs() < 1e-4
}
#[inline]
pub fn prepare_query_with_norm(
query_f32: &[f32],
tier: QuantizationTier,
norm: NormalizationHint,
) -> PreparedQueryWithMeta {
PreparedQueryWithMeta::from_f32(query_f32, tier, norm)
}
#[inline]
pub fn approximate_cosine_distance_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
match (query, stored) {
(PreparedQuery::Full(q), QuantizedData::Full(s)) => 1.0 - cosine_similarity(q, s),
(PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
1.0 - cosine_similarity_i8_trusted(s, q)
}
(PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.cosine_distance(q),
(PreparedQuery::Binary(q), QuantizedData::Binary(s)) => s.cosine_distance_approx(q),
_ => panic!("PreparedQuery tier must match QuantizedData tier"),
}
}
#[inline]
pub fn try_approximate_cosine_distance_prepared(
query: &PreparedQuery,
stored: &QuantizedData,
) -> Result<f32> {
match (query, stored) {
(PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(1.0 - cosine_similarity(q, s)),
(PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
Ok(1.0 - cosine_similarity_i8_trusted(s, q))
}
(PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.cosine_distance(q)),
(PreparedQuery::Binary(q), QuantizedData::Binary(s)) => Ok(s.cosine_distance_approx(q)),
_ => Err(EmbedError::Internal(
"PreparedQuery tier must match QuantizedData tier for cosine distance".into(),
)),
}
}
#[inline]
pub fn try_approximate_dot_product_prepared(
query: &PreparedQuery,
stored: &QuantizedData,
) -> Result<f32> {
match (query, stored) {
(PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(dot_product(q, s)),
(PreparedQuery::Int8(q), QuantizedData::Int8(s)) => Ok(dot_product_i8_trusted(q, s)),
(PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.dot_product(q)),
(PreparedQuery::Binary(_), QuantizedData::Binary(_)) => Err(EmbedError::Internal(
"Binary has no prepared dot product; use try_approximate_cosine_distance_prepared"
.into(),
)),
_ => Err(EmbedError::Internal(
"PreparedQuery tier must match QuantizedData tier for dot product".into(),
)),
}
}
#[inline]
pub fn approximate_cosine_distance_prepared_with_meta(
meta: &PreparedQueryWithMeta,
stored: &QuantizedData,
stored_norm: NormalizationHint,
) -> f32 {
if meta.norm == NormalizationHint::Unit && stored_norm == NormalizationHint::Unit {
if let (PreparedQuery::Full(q), QuantizedData::Full(s)) = (&meta.query, stored) {
let dot = dot_product(q, s);
return 1.0 - dot.clamp(-1.0, 1.0);
}
}
approximate_cosine_distance_prepared(&meta.query, stored)
}
#[inline]
pub fn approximate_dot_product_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
match (query, stored) {
(PreparedQuery::Full(q), QuantizedData::Full(s)) => dot_product(q, s),
(PreparedQuery::Int8(q), QuantizedData::Int8(s)) => dot_product_i8_trusted(q, s),
(PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.dot_product(q),
(PreparedQuery::Binary(_), QuantizedData::Binary(_)) => {
panic!("Binary has no prepared dot product; use approximate_cosine_distance_prepared")
}
_ => panic!("PreparedQuery tier must match QuantizedData tier"),
}
}
#[inline]
pub fn batch_approximate_cosine_distance_prepared(
query: &PreparedQuery,
stored: &[QuantizedData],
) -> Vec<f32> {
stored
.iter()
.map(|item| approximate_cosine_distance_prepared(query, item))
.collect()
}
#[inline]
pub fn batch_approximate_cosine_distance_prepared_into(
query: &PreparedQuery,
stored: &[QuantizedData],
out: &mut Vec<f32>,
) {
out.clear();
out.reserve(stored.len());
out.extend(
stored
.iter()
.map(|item| approximate_cosine_distance_prepared(query, item)),
);
}
pub fn approximate_cosine_distance(query_f32: &[f32], stored: &QuantizedData) -> f32 {
match stored {
QuantizedData::Full(v) => {
1.0 - cosine_similarity(query_f32, v)
}
QuantizedData::Int8(q) => {
let query_q = QuantizedVector::from_f32(query_f32);
1.0 - q.cosine_similarity(&query_q)
}
QuantizedData::Int4(q) => {
let query_q = Int4Vector::from_f32(query_f32);
q.cosine_distance(&query_q)
}
QuantizedData::Binary(q) => {
let query_q = BinaryVector::from_f32(query_f32);
q.cosine_distance_approx(&query_q)
}
}
}
pub fn approximate_dot_product(query_f32: &[f32], stored: &QuantizedData) -> f32 {
match stored {
QuantizedData::Full(v) => dot_product(query_f32, v),
QuantizedData::Int8(q) => {
let query_q = QuantizedVector::from_f32(query_f32);
q.dot_product(&query_q)
}
QuantizedData::Int4(q) => {
let query_q = Int4Vector::from_f32(query_f32);
q.dot_product(&query_q)
}
QuantizedData::Binary(_q) => {
let stored_f32 = _q.to_f32();
dot_product(query_f32, &stored_f32)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
(0..dim)
.map(|i| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407)
.wrapping_add(i as u64);
let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
unit * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_tier_bytes_per_dim() {
assert_eq!(QuantizationTier::Full.bytes_per_dim(), 4.0);
assert_eq!(QuantizationTier::Int8.bytes_per_dim(), 1.0);
assert_eq!(QuantizationTier::Int4.bytes_per_dim(), 0.5);
assert_eq!(QuantizationTier::Binary.bytes_per_dim(), 0.125);
}
#[test]
fn test_tier_compression_ratios() {
assert_eq!(QuantizationTier::Full.compression_ratio(), 1.0);
assert_eq!(QuantizationTier::Int8.compression_ratio(), 4.0);
assert_eq!(QuantizationTier::Int4.compression_ratio(), 8.0);
assert_eq!(QuantizationTier::Binary.compression_ratio(), 32.0);
}
#[test]
fn test_tier_storage_bytes() {
assert_eq!(QuantizationTier::Full.storage_bytes(384), 1536);
assert_eq!(QuantizationTier::Int8.storage_bytes(384), 384);
assert_eq!(QuantizationTier::Int4.storage_bytes(384), 192);
assert_eq!(QuantizationTier::Binary.storage_bytes(384), 48);
}
#[test]
fn test_tier_from_age() {
assert_eq!(
QuantizationTier::from_age_seconds(0),
QuantizationTier::Full
);
assert_eq!(
QuantizationTier::from_age_seconds(1800),
QuantizationTier::Full
); assert_eq!(
QuantizationTier::from_age_seconds(7200),
QuantizationTier::Int8
); assert_eq!(
QuantizationTier::from_age_seconds(172800),
QuantizationTier::Int4
); assert_eq!(
QuantizationTier::from_age_seconds(1_000_000),
QuantizationTier::Binary
); }
#[test]
fn test_quantized_data_from_f32_all_tiers() {
let v = generate_vector(384, 42);
for tier in [
QuantizationTier::Full,
QuantizationTier::Int8,
QuantizationTier::Int4,
QuantizationTier::Binary,
] {
let data = QuantizedData::from_f32(&v, tier);
assert_eq!(data.tier(), tier, "tier mismatch for {tier:?}");
assert_eq!(data.dims(), 384, "dims mismatch for {tier:?}");
let expected_bytes = tier.storage_bytes(384);
assert_eq!(
data.storage_bytes(),
expected_bytes,
"storage bytes mismatch for {tier:?}"
);
}
}
#[test]
fn test_approximate_cosine_distance_ordering() {
let a = generate_vector(384, 1);
let b: Vec<f32> = a
.iter()
.enumerate()
.map(|(i, &x)| x + 0.05 * (i as f32 * 0.3).sin())
.collect();
let c = generate_vector(384, 999);
for tier in [
QuantizationTier::Full,
QuantizationTier::Int8,
QuantizationTier::Int4,
QuantizationTier::Binary,
] {
let stored_b = QuantizedData::from_f32(&b, tier);
let stored_c = QuantizedData::from_f32(&c, tier);
let dist_ab = approximate_cosine_distance(&a, &stored_b);
let dist_ac = approximate_cosine_distance(&a, &stored_c);
assert!(
dist_ab < dist_ac,
"{tier:?}: dist(a,b)={dist_ab} should be < dist(a,c)={dist_ac}"
);
}
}
#[test]
fn test_promote_demote_roundtrip() {
let v = generate_vector(384, 42);
let binary = QuantizedData::from_f32(&v, QuantizationTier::Binary);
let int4 = binary.promote(QuantizationTier::Int4);
assert_eq!(int4.tier(), QuantizationTier::Int4);
let int8 = int4.promote(QuantizationTier::Int8);
assert_eq!(int8.tier(), QuantizationTier::Int8);
let full = int8.promote(QuantizationTier::Full);
assert_eq!(full.tier(), QuantizationTier::Full);
assert_eq!(full.dims(), 384);
}
#[test]
fn test_quantized_data_to_f32_roundtrip() {
let v = generate_vector(384, 55);
let full_data = QuantizedData::from_f32(&v, QuantizationTier::Full);
let full_rt = full_data.to_f32();
for (a, b) in v.iter().zip(full_rt.iter()) {
assert!((a - b).abs() < 1e-10, "Full tier should be lossless");
}
}
}