use crate::distance::cosine_distance_normalized;
use crate::RetrieveError;
#[derive(Clone, Debug)]
pub struct SQ4Params {
pub rerank_factor: usize,
}
impl Default for SQ4Params {
fn default() -> Self {
Self { rerank_factor: 10 }
}
}
pub struct SQ4Index {
dimension: usize,
params: SQ4Params,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
codes: Vec<u8>,
mins: Vec<f32>,
inv_scales: Vec<f32>,
steps: Vec<f32>,
}
impl SQ4Index {
pub fn new(dimension: usize, params: SQ4Params) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
if params.rerank_factor == 0 {
return Err(RetrieveError::InvalidParameter(
"rerank_factor must be > 0".into(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
codes: Vec::new(),
mins: Vec::new(),
inv_scales: Vec::new(),
steps: Vec::new(),
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add after build".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(vector);
}
self.doc_ids.push(doc_id);
self.num_vectors += 1;
Ok(())
}
pub fn add_batch(&mut self, ids: &[u32], vectors: &[f32]) -> Result<(), RetrieveError> {
if vectors.len() != ids.len() * self.dimension {
return Err(RetrieveError::InvalidParameter(
"vectors.len() must equal ids.len() * dimension".into(),
));
}
for (i, &id) in ids.iter().enumerate() {
self.add_slice(id, &vectors[i * self.dimension..(i + 1) * self.dimension])?;
}
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let d = self.dimension;
let n = self.num_vectors;
let mut mins = vec![f32::INFINITY; d];
let mut maxs = vec![f32::NEG_INFINITY; d];
for i in 0..n {
let v = &self.vectors[i * d..(i + 1) * d];
for (j, &val) in v.iter().enumerate() {
if val < mins[j] {
mins[j] = val;
}
if val > maxs[j] {
maxs[j] = val;
}
}
}
let mut inv_scales = vec![0.0f32; d];
let mut steps = vec![0.0f32; d];
for j in 0..d {
let range = maxs[j] - mins[j];
if range > 1e-10 {
inv_scales[j] = 15.0 / range;
steps[j] = range / 15.0;
}
}
let bytes_per_vec = d.div_ceil(2);
let mut codes = vec![0u8; n * bytes_per_vec];
for i in 0..n {
let v = &self.vectors[i * d..(i + 1) * d];
let c = &mut codes[i * bytes_per_vec..(i + 1) * bytes_per_vec];
pack_vector(v, &mins, &inv_scales, c);
}
self.mins = mins;
self.inv_scales = inv_scales;
self.steps = steps;
self.codes = codes;
self.built = true;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.is_empty() {
return Err(RetrieveError::EmptyQuery);
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
if k == 0 {
return Ok(Vec::new());
}
let d = self.dimension;
let n = self.num_vectors;
let bytes_per_vec = d.div_ceil(2);
let norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_norm: Vec<f32> = if norm > 1e-10 {
query.iter().map(|x| x / norm).collect()
} else {
query.to_vec()
};
let candidates = k * self.params.rerank_factor;
let candidates = candidates.min(n);
let mut dists: Vec<(usize, f32)> = Vec::with_capacity(n);
for i in 0..n {
let code = &self.codes[i * bytes_per_vec..(i + 1) * bytes_per_vec];
let dist = asymmetric_l2_sq4(&query_norm, code, &self.mins, &self.steps, d);
dists.push((i, dist));
}
if candidates < n {
dists.select_nth_unstable_by(candidates, |a, b| a.1.total_cmp(&b.1));
dists.truncate(candidates);
}
let mut results: Vec<(u32, f32)> = dists
.iter()
.map(|&(idx, _)| {
let v = &self.vectors[idx * d..(idx + 1) * d];
let dist = cosine_distance_normalized(&query_norm, v);
(self.doc_ids[idx], dist)
})
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
Ok(results)
}
#[must_use]
pub fn num_vectors(&self) -> usize {
self.num_vectors
}
#[must_use]
pub fn code_memory(&self) -> usize {
self.codes.len()
}
}
pub(crate) fn pack_vector(v: &[f32], mins: &[f32], inv_scales: &[f32], out: &mut [u8]) {
let d = v.len();
let pairs = d / 2;
for p in 0..pairs {
let lo = quantize_4bit(v[2 * p], mins[2 * p], inv_scales[2 * p]);
let hi = quantize_4bit(v[2 * p + 1], mins[2 * p + 1], inv_scales[2 * p + 1]);
out[p] = lo | (hi << 4);
}
if !d.is_multiple_of(2) {
out[pairs] = quantize_4bit(v[d - 1], mins[d - 1], inv_scales[d - 1]);
}
}
#[inline]
pub(crate) fn quantize_4bit(val: f32, min: f32, inv_scale: f32) -> u8 {
let q = ((val - min) * inv_scale + 0.5) as i32; q.clamp(0, 15) as u8
}
fn asymmetric_l2_sq4(query: &[f32], code: &[u8], mins: &[f32], steps: &[f32], dim: usize) -> f32 {
let mut sum = 0.0f32;
let pairs = dim / 2;
for p in 0..pairs {
let byte = code[p];
let lo = (byte & 0x0F) as f32;
let hi = (byte >> 4) as f32;
let decoded_lo = mins[2 * p] + lo * steps[2 * p];
let decoded_hi = mins[2 * p + 1] + hi * steps[2 * p + 1];
let d0 = query[2 * p] - decoded_lo;
let d1 = query[2 * p + 1] - decoded_hi;
sum += d0 * d0 + d1 * d1;
}
if !dim.is_multiple_of(2) {
let lo = (code[pairs] & 0x0F) as f32;
let decoded = mins[dim - 1] + lo * steps[dim - 1];
let d = query[dim - 1] - decoded;
sum += d * d;
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_unpack_roundtrip() {
let d = 8;
let v: Vec<f32> = vec![0.0, 0.5, 0.25, 0.75, 1.0, 0.1, 0.9, 0.6];
let mins = vec![0.0f32; d];
let inv_scales = vec![15.0f32; d]; let steps = vec![1.0 / 15.0; d];
let mut packed = vec![0u8; d / 2];
pack_vector(&v, &mins, &inv_scales, &mut packed);
for p in 0..d / 2 {
let lo = (packed[p] & 0x0F) as f32;
let hi = (packed[p] >> 4) as f32;
let decoded_lo = mins[2 * p] + lo * steps[2 * p];
let decoded_hi = mins[2 * p + 1] + hi * steps[2 * p + 1];
assert!(
(decoded_lo - v[2 * p]).abs() < 0.07,
"dim {}: {} vs {}",
2 * p,
decoded_lo,
v[2 * p]
);
assert!(
(decoded_hi - v[2 * p + 1]).abs() < 0.07,
"dim {}: {} vs {}",
2 * p + 1,
decoded_hi,
v[2 * p + 1]
);
}
}
#[test]
fn odd_dimension() {
let d = 5;
let v: Vec<f32> = vec![0.0, 0.5, 0.25, 0.75, 1.0];
let mins = vec![0.0f32; d];
let inv_scales = vec![15.0f32; d];
let mut packed = vec![0u8; d.div_ceil(2)];
pack_vector(&v, &mins, &inv_scales, &mut packed);
assert_eq!(packed[2] & 0xF0, 0);
assert_eq!(packed[2] & 0x0F, 15); }
#[test]
fn build_and_search() {
let d = 32;
let n = 100;
let mut index = SQ4Index::new(d, SQ4Params::default()).unwrap();
for i in 0..n {
let v: Vec<f32> = (0..d).map(|j| ((i * d + j) as f32) * 0.01).collect();
index.add_slice(i as u32, &v).unwrap();
}
index.build().unwrap();
let query: Vec<f32> = (0..d).map(|j| (j as f32) * 0.01).collect();
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, 0);
}
#[test]
fn compression_ratio() {
let d = 128;
let n = 1000;
let mut index = SQ4Index::new(d, SQ4Params::default()).unwrap();
for i in 0..n {
let v: Vec<f32> = (0..d).map(|j| ((i + j) as f32) * 0.001).collect();
index.add_slice(i as u32, &v).unwrap();
}
index.build().unwrap();
let float_bytes = n * d * 4;
let code_bytes = index.code_memory();
let ratio = float_bytes as f64 / code_bytes as f64;
assert!(ratio > 7.5, "expected ~8x compression, got {ratio:.1}x");
}
}