use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
use crate::{
quantize::sq8::Sq8Codec,
rerank::codec::{CodecName, PreparedQuery, RerankCodec},
rerank::types::RerankError,
};
#[inline]
fn sq8_packed_bits_len(dim: usize) -> usize {
dim
}
pub struct Sq8Rerank {
codec: Sq8Codec,
dim: usize,
}
impl Sq8Rerank {
pub fn new(dim: usize) -> Self {
let lo = vec![0.0f32; dim];
let hi = vec![1.0f32; dim];
let samples: Vec<&[f32]> = vec![lo.as_slice(), hi.as_slice()];
let codec = Sq8Codec::calibrate(&samples, dim);
Self { codec, dim }
}
pub fn from_codec(codec: Sq8Codec) -> Self {
let dim = codec.dim;
Self { codec, dim }
}
}
impl RerankCodec for Sq8Rerank {
fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
if v.len() != self.dim {
return Err(RerankError::BadInput(format!(
"sq8 encode: vector len {} != codec dim {}",
v.len(),
self.dim
)));
}
use nodedb_codec::vector_quant::codec::VectorCodec as _;
let quantized = self.codec.encode(v);
Ok(quantized.as_ref().as_bytes().to_vec())
}
fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
if q.len() != self.dim {
return Err(RerankError::BadInput(format!(
"sq8 prepare_query: query len {} != codec dim {}",
q.len(),
self.dim
)));
}
Ok(PreparedQuery::Raw(q.to_vec()))
}
fn distance_prepared(
&self,
prepared: &PreparedQuery,
encoded: &[u8],
) -> Result<f32, RerankError> {
let q = match prepared {
PreparedQuery::Raw(q) => q,
_ => {
return Err(RerankError::BadInput(
"sq8 distance: expected PreparedQuery::Raw".to_string(),
));
}
};
let packed_len = sq8_packed_bits_len(self.dim);
let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
RerankError::BadInput(format!("sq8 distance: failed to parse encoded bytes: {e}"))
})?;
let packed = uqv_ref.packed_bits();
let dist = self.codec.asymmetric_l2(q, packed);
Ok(dist)
}
fn name(&self) -> CodecName {
CodecName::Sq8
}
fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
Ok(self.codec.to_bytes())
}
fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
if samples.is_empty() {
return Err(RerankError::BadInput(
"sq8 train: empty sample set".to_string(),
));
}
self.codec = Sq8Codec::calibrate(samples, self.dim);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const DIM: usize = 16;
const EPS: f32 = 1e-2;
fn make_vec(base: f32) -> Vec<f32> {
(0..DIM).map(|i| base + i as f32 * 0.01).collect()
}
fn trained_codec() -> Sq8Rerank {
let samples: Vec<Vec<f32>> = (0..50).map(|i| make_vec(i as f32 * 0.1)).collect();
let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
let mut codec = Sq8Rerank::new(DIM);
codec.train(&refs).expect("train must succeed");
codec
}
#[test]
fn round_trip_returns_finite_distance() {
let codec = trained_codec();
let v1 = make_vec(0.5);
let v2 = make_vec(1.0);
let enc = codec.encode(&v1).expect("encode v1");
let prepared = codec.prepare_query(&v2).expect("prepare_query v2");
let dist = codec
.distance_prepared(&prepared, &enc)
.expect("distance_prepared");
assert!(dist.is_finite(), "expected finite distance, got {dist}");
assert!(dist >= 0.0, "expected non-negative distance, got {dist}");
}
#[test]
fn identical_vectors_small_distance() {
let codec = trained_codec();
let v = make_vec(0.5);
let enc = codec.encode(&v).expect("encode");
let prepared = codec.prepare_query(&v).expect("prepare_query");
let dist = codec
.distance_prepared(&prepared, &enc)
.expect("distance_prepared");
assert!(dist.is_finite());
assert!(
dist < EPS,
"identical vectors should have near-zero distance, got {dist}"
);
}
#[test]
fn wrong_prepared_query_variant_returns_bad_input() {
let codec = trained_codec();
let v = make_vec(0.5);
let enc = codec.encode(&v).expect("encode");
let bad_prepared = PreparedQuery::Bytes(vec![0u8; 8]);
let result = codec.distance_prepared(&bad_prepared, &enc);
assert!(result.is_err(), "expected BadInput error");
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("Raw"),
"error message should mention Raw, got: {msg}"
);
}
#[test]
fn name_returns_sq8() {
let codec = Sq8Rerank::new(DIM);
assert_eq!(codec.name(), CodecName::Sq8);
}
#[test]
fn train_calibrates_without_error() {
let mut codec = Sq8Rerank::new(DIM);
let samples: Vec<Vec<f32>> = (0..20).map(|i| make_vec(i as f32 * 0.05)).collect();
let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
codec.train(&refs).expect("train must succeed");
let v = make_vec(0.5);
let enc = codec.encode(&v).expect("encode after train");
let prep = codec.prepare_query(&v).expect("prepare after train");
let dist = codec
.distance_prepared(&prep, &enc)
.expect("distance after train");
assert!(dist.is_finite());
}
#[test]
fn wrong_dim_encode_returns_error() {
let codec = Sq8Rerank::new(DIM);
let bad = vec![0.0f32; DIM + 1];
assert!(codec.encode(&bad).is_err());
}
}