use half::f16;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::distance::Distance;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Codec {
Int8Quantized,
Fp16,
Turbovec2Bit,
Turbovec3Bit,
Turbovec4Bit,
}
impl Codec {
#[must_use]
pub fn encoder(self) -> Box<dyn Encoder> {
match self {
Self::Int8Quantized => Box::new(Int8Quantized),
Self::Fp16 => Box::new(Fp16),
Self::Turbovec2Bit => Box::new(Turbovec::new(2)),
Self::Turbovec3Bit => Box::new(Turbovec::new(3)),
Self::Turbovec4Bit => Box::new(Turbovec::new(4)),
}
}
#[must_use]
pub fn name(self) -> &'static str {
match self {
Self::Int8Quantized => "int8q",
Self::Fp16 => "fp16",
Self::Turbovec2Bit => "turbovec2",
Self::Turbovec3Bit => "turbovec3",
Self::Turbovec4Bit => "turbovec4",
}
}
#[must_use]
pub fn turbovec_bits(self) -> Option<u8> {
match self {
Self::Turbovec2Bit => Some(2),
Self::Turbovec3Bit => Some(3),
Self::Turbovec4Bit => Some(4),
_ => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EncodedVector {
pub codec: Codec,
pub dim: u16,
pub bytes: Vec<u8>,
pub params: Vec<f32>,
}
impl EncodedVector {
#[must_use]
pub fn l2_norm(&self) -> f32 {
let v = self.codec.encoder().decode(self).unwrap_or_default();
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum EncodingError {
#[error("vector has {0} dimensions; max supported is 65535")]
DimensionTooLarge(usize),
#[error("vector has zero dimensions")]
EmptyVector,
#[error("malformed encoded vector: dim={dim} payload_bytes={bytes}")]
Malformed {
dim: u16,
bytes: usize,
},
#[error("codec mismatch: expected {expected:?}, got {got:?}")]
CodecMismatch {
expected: Codec,
got: Codec,
},
#[error("vector contains non-finite component")]
NonFinite,
#[error("turbovec bit width must be 2, 3, or 4, got {0}")]
UnsupportedBitWidth(u8),
#[error("turbovec requires dim to be a positive multiple of 8, got {0}")]
UnsupportedDim(u16),
}
pub trait Encoder: Send + Sync {
fn codec(&self) -> Codec;
fn encode(&self, values: &[f32]) -> Result<EncodedVector, EncodingError>;
fn decode(&self, ev: &EncodedVector) -> Result<Vec<f32>, EncodingError>;
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Int8Quantized;
impl Encoder for Int8Quantized {
fn codec(&self) -> Codec {
Codec::Int8Quantized
}
fn encode(&self, values: &[f32]) -> Result<EncodedVector, EncodingError> {
if values.is_empty() {
return Err(EncodingError::EmptyVector);
}
let dim = u16::try_from(values.len())
.map_err(|_| EncodingError::DimensionTooLarge(values.len()))?;
for v in values {
if !v.is_finite() {
return Err(EncodingError::NonFinite);
}
}
let mut min = values[0];
let mut max = values[0];
for &v in &values[1..] {
if v < min {
min = v;
}
if v > max {
max = v;
}
}
let range = max - min;
let scale = if range > 0.0 { range / 255.0 } else { 0.0 };
let bytes: Vec<u8> = values
.iter()
.map(|&v| {
if scale == 0.0 {
0_u8
} else {
let q = ((v - min) / scale).round();
let clamped = q.clamp(0.0, 255.0);
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "clamped to [0, 255]"
)]
let byte = clamped as u8;
byte
}
})
.collect();
Ok(EncodedVector {
codec: Codec::Int8Quantized,
dim,
bytes,
params: vec![min, scale],
})
}
fn decode(&self, ev: &EncodedVector) -> Result<Vec<f32>, EncodingError> {
if ev.codec != Codec::Int8Quantized {
return Err(EncodingError::CodecMismatch {
expected: Codec::Int8Quantized,
got: ev.codec,
});
}
if ev.bytes.len() != usize::from(ev.dim) {
return Err(EncodingError::Malformed {
dim: ev.dim,
bytes: ev.bytes.len(),
});
}
if ev.params.len() != 2 {
return Err(EncodingError::Malformed {
dim: ev.dim,
bytes: ev.bytes.len(),
});
}
let min = ev.params[0];
let scale = ev.params[1];
Ok(ev
.bytes
.iter()
.map(|&b| min + scale * f32::from(b))
.collect())
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Fp16;
impl Encoder for Fp16 {
fn codec(&self) -> Codec {
Codec::Fp16
}
fn encode(&self, values: &[f32]) -> Result<EncodedVector, EncodingError> {
if values.is_empty() {
return Err(EncodingError::EmptyVector);
}
let dim = u16::try_from(values.len())
.map_err(|_| EncodingError::DimensionTooLarge(values.len()))?;
for v in values {
if !v.is_finite() {
return Err(EncodingError::NonFinite);
}
}
let mut bytes = Vec::with_capacity(values.len() * 2);
for &v in values {
let h = f16::from_f32(v);
bytes.extend_from_slice(&h.to_le_bytes());
}
Ok(EncodedVector {
codec: Codec::Fp16,
dim,
bytes,
params: Vec::new(),
})
}
fn decode(&self, ev: &EncodedVector) -> Result<Vec<f32>, EncodingError> {
if ev.codec != Codec::Fp16 {
return Err(EncodingError::CodecMismatch {
expected: Codec::Fp16,
got: ev.codec,
});
}
if ev.bytes.len() != usize::from(ev.dim) * 2 {
return Err(EncodingError::Malformed {
dim: ev.dim,
bytes: ev.bytes.len(),
});
}
let mut out = Vec::with_capacity(usize::from(ev.dim));
for chunk in ev.bytes.chunks_exact(2) {
let hb: [u8; 2] = [chunk[0], chunk[1]];
out.push(f16::from_le_bytes(hb).to_f32());
}
Ok(out)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Turbovec {
bits: u8,
}
impl Turbovec {
#[must_use]
pub fn new(bits: u8) -> Self {
assert!(
(2..=4).contains(&bits),
"turbovec bit width must be 2, 3, or 4"
);
Self { bits }
}
#[must_use]
pub fn bits(self) -> u8 {
self.bits
}
fn codec_for(bits: u8) -> Codec {
match bits {
2 => Codec::Turbovec2Bit,
3 => Codec::Turbovec3Bit,
4 => Codec::Turbovec4Bit,
_ => unreachable!("turbovec bits validated in Turbovec::new"),
}
}
}
impl Encoder for Turbovec {
fn codec(&self) -> Codec {
Self::codec_for(self.bits)
}
fn encode(&self, values: &[f32]) -> Result<EncodedVector, EncodingError> {
if values.is_empty() {
return Err(EncodingError::EmptyVector);
}
let dim = u16::try_from(values.len())
.map_err(|_| EncodingError::DimensionTooLarge(values.len()))?;
for v in values {
if !v.is_finite() {
return Err(EncodingError::NonFinite);
}
}
if dim == 0 || !dim.is_multiple_of(8) {
return Err(EncodingError::UnsupportedDim(dim));
}
let mut bytes = Vec::with_capacity(values.len() * 4);
for &v in values {
bytes.extend_from_slice(&v.to_le_bytes());
}
Ok(EncodedVector {
codec: Self::codec_for(self.bits),
dim,
bytes,
params: vec![f32::from(self.bits)],
})
}
fn decode(&self, ev: &EncodedVector) -> Result<Vec<f32>, EncodingError> {
let expected = Self::codec_for(self.bits);
if ev.codec != expected {
return Err(EncodingError::CodecMismatch {
expected,
got: ev.codec,
});
}
if ev.bytes.len() != usize::from(ev.dim) * 4 {
return Err(EncodingError::Malformed {
dim: ev.dim,
bytes: ev.bytes.len(),
});
}
let mut out = Vec::with_capacity(usize::from(ev.dim));
for chunk in ev.bytes.chunks_exact(4) {
let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
out.push(f32::from_le_bytes(arr));
}
Ok(out)
}
}
pub fn encode_turbovec(vector: &[f32], bits: u8) -> Result<EncodedVector, EncodingError> {
if !(2..=4).contains(&bits) {
return Err(EncodingError::UnsupportedBitWidth(bits));
}
Turbovec::new(bits).encode(vector)
}
pub fn decode_turbovec(ev: &EncodedVector, bits: u8) -> Result<Vec<f32>, EncodingError> {
if !(2..=4).contains(&bits) {
return Err(EncodingError::UnsupportedBitWidth(bits));
}
Turbovec::new(bits).decode(ev)
}
#[must_use]
pub fn distance_turbovec(query: &[f32], stored: &EncodedVector, metric: Distance) -> f32 {
let Some(bits) = stored.codec.turbovec_bits() else {
return f32::INFINITY;
};
if usize::from(stored.dim) != query.len() {
return f32::INFINITY;
}
let dim = usize::from(stored.dim);
if dim == 0 || !dim.is_multiple_of(8) {
return f32::INFINITY;
}
let Ok(stored_vec) = decode_turbovec(stored, bits) else {
return f32::INFINITY;
};
let Ok(mut index) = turbovec::TurboQuantIndex::new(dim, usize::from(bits)) else {
return f32::INFINITY;
};
let (q_input, s_input) = match metric {
Distance::DotProduct => (query.to_vec(), stored_vec.clone()),
Distance::Cosine | Distance::Euclidean => (l2_normalise(query), l2_normalise(&stored_vec)),
};
if index.add_2d(&s_input, dim).is_err() {
return f32::INFINITY;
}
let results = index.search(&q_input, 1);
if results.scores.is_empty() {
return f32::INFINITY;
}
let similarity = results.scores[0];
match metric {
Distance::DotProduct => -similarity,
Distance::Cosine => 1.0 - similarity,
Distance::Euclidean => (2.0 - 2.0 * similarity).max(0.0).sqrt(),
}
}
fn l2_normalise(v: &[f32]) -> Vec<f32> {
let n2: f32 = v.iter().map(|x| x * x).sum();
let n = n2.sqrt();
if n <= 0.0 {
return v.to_vec();
}
v.iter().map(|x| x / n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn vec_close(a: &[f32], b: &[f32], eps: f32) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() <= eps)
}
#[test]
fn int8_round_trip_within_budget() {
let v: Vec<f32> = (0_i16..64).map(|i| f32::from(i) * 0.1 - 3.2).collect();
let enc = Int8Quantized.encode(&v).unwrap();
let dec = Int8Quantized.decode(&enc).unwrap();
let range = 6.4_f32;
let bucket = range / 255.0;
assert!(vec_close(&v, &dec, bucket));
assert_eq!(enc.dim as usize, v.len());
}
#[test]
fn fp16_round_trip_within_budget() {
let v: Vec<f32> = (0_i16..32).map(|i| f32::from(i) * 0.05 - 0.8).collect();
let enc = Fp16.encode(&v).unwrap();
let dec = Fp16.decode(&enc).unwrap();
assert!(vec_close(&v, &dec, 1e-2));
}
#[test]
fn rejects_non_finite() {
assert!(matches!(
Int8Quantized.encode(&[1.0, f32::NAN]),
Err(EncodingError::NonFinite)
));
assert!(matches!(
Fp16.encode(&[1.0, f32::INFINITY]),
Err(EncodingError::NonFinite)
));
}
#[test]
fn rejects_empty() {
assert!(matches!(
Int8Quantized.encode(&[]),
Err(EncodingError::EmptyVector)
));
}
#[test]
fn codec_mismatch_detected() {
let v = vec![0.1, 0.2, 0.3];
let enc = Fp16.encode(&v).unwrap();
assert!(matches!(
Int8Quantized.decode(&enc),
Err(EncodingError::CodecMismatch { .. })
));
}
#[test]
fn constant_vector_quantises_cleanly() {
let v = vec![1.5_f32; 8];
let enc = Int8Quantized.encode(&v).unwrap();
let dec = Int8Quantized.decode(&enc).unwrap();
assert_eq!(dec, vec![1.5_f32; 8]);
}
#[test]
fn l2_norm_matches_decoded() {
let v = vec![3.0_f32, 4.0, 0.0];
let enc = Fp16.encode(&v).unwrap();
assert!((enc.l2_norm() - 5.0).abs() < 1e-2);
}
#[test]
fn turbovec_encode_round_trips_at_row_layer() {
let v: Vec<f32> = (0_i16..64).map(|i| f32::from(i) * 0.05 - 1.6).collect();
let enc = encode_turbovec(&v, 4).unwrap();
assert_eq!(enc.codec, Codec::Turbovec4Bit);
assert_eq!(enc.dim as usize, v.len());
let dec = decode_turbovec(&enc, 4).unwrap();
assert_eq!(dec, v);
}
#[test]
fn turbovec_rejects_unsupported_bit_width() {
let v = vec![0.1_f32; 8];
assert!(matches!(
encode_turbovec(&v, 5),
Err(EncodingError::UnsupportedBitWidth(5))
));
}
#[test]
fn turbovec_rejects_non_multiple_of_eight_dim() {
let v = vec![0.1_f32; 7];
assert!(matches!(
encode_turbovec(&v, 4),
Err(EncodingError::UnsupportedDim(7))
));
}
}