use crate::hnsw::graph::HNSWIndex;
use crate::RetrieveError;
pub struct HNSWSq8Index {
index: HNSWIndex,
codes: Vec<u8>,
mins: Vec<f32>,
steps: Vec<f32>,
inv_scales: Vec<f32>,
built: bool,
}
#[derive(Debug, Clone)]
pub struct SQ8MemoryReport {
pub vectors_bytes: usize,
pub codes_bytes: usize,
pub total_bytes: usize,
pub n: usize,
pub dim: usize,
}
impl std::fmt::Display for SQ8MemoryReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mb = |b: usize| b as f64 / (1024.0 * 1024.0);
write!(
f,
"SQ8 memory: {:.1} MB total ({:.1} MB vectors + {:.1} MB codes), \
{:.1} bytes/vector (n={}, d={})",
mb(self.total_bytes),
mb(self.vectors_bytes),
mb(self.codes_bytes),
self.total_bytes as f64 / self.n.max(1) as f64,
self.n,
self.dim,
)
}
}
impl HNSWSq8Index {
pub fn new(dimension: usize, m: usize, m_max: usize) -> Result<Self, RetrieveError> {
let index = HNSWIndex::new(dimension, m, m_max)?;
Ok(Self {
index,
codes: Vec::new(),
mins: Vec::new(),
steps: Vec::new(),
inv_scales: Vec::new(),
built: false,
})
}
pub fn with_params(
dimension: usize,
params: crate::hnsw::HNSWParams,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::with_params(dimension, params)?;
Ok(Self {
index,
codes: Vec::new(),
mins: Vec::new(),
steps: Vec::new(),
inv_scales: Vec::new(),
built: false,
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
self.index.add_slice(doc_id, vector)
}
pub fn add_batch(&mut self, ids: &[u32], vectors: &[f32]) -> Result<(), RetrieveError> {
self.index.add_batch(ids, vectors)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
self.index.build()?;
self.quantize_vectors()?;
self.built = true;
Ok(())
}
#[cfg(feature = "parallel")]
pub fn build_parallel(&mut self, batch_size: usize) -> Result<(), RetrieveError> {
self.index.build_parallel(batch_size)?;
self.quantize_vectors()?;
self.built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_ready(query)?;
let candidates = self.search_quantized(query, ef)?;
let mut output: Vec<(u32, f32)> = candidates
.into_iter()
.take(k)
.map(|(id, dist)| (self.index.doc_ids[id as usize], dist))
.collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(output)
}
pub fn search_reranked(
&self,
query: &[f32],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_ready(query)?;
let pool = rerank_pool.max(k);
let candidates = self.search_quantized(query, ef.max(pool))?;
let dist_fn = self.index.dist_fn();
let mut reranked: Vec<(u32, f32)> = candidates
.into_iter()
.take(pool)
.map(|(internal_id, _approx)| {
let vec = self.index.get_vector(internal_id as usize);
let exact = dist_fn(query, vec);
(self.index.doc_ids[internal_id as usize], exact)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
Ok(reranked)
}
#[cfg(feature = "parallel")]
pub fn search_reranked_batch(
&self,
queries: &[Vec<f32>],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<Vec<(u32, f32)>>, RetrieveError> {
use rayon::prelude::*;
queries
.par_iter()
.map(|q| self.search_reranked(q, k, ef, rerank_pool))
.collect()
}
pub fn len(&self) -> usize {
self.index.num_vectors
}
pub fn is_empty(&self) -> bool {
self.index.num_vectors == 0
}
pub fn inner(&self) -> &HNSWIndex {
&self.index
}
pub fn code_memory(&self) -> usize {
self.codes.len()
}
pub fn memory_report(&self) -> SQ8MemoryReport {
let n = self.index.num_vectors;
let dim = self.index.dimension;
let vectors_bytes = n * dim * 4;
let codes_bytes = self.codes.len();
SQ8MemoryReport {
vectors_bytes,
codes_bytes,
total_bytes: vectors_bytes + codes_bytes,
n,
dim,
}
}
fn check_ready(&self, query: &[f32]) -> Result<(), RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.index.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.index.dimension,
});
}
if self.index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
Ok(())
}
fn quantize_vectors(&mut self) -> Result<(), RetrieveError> {
let dim = self.index.dimension;
let n = self.index.num_vectors;
let vectors = self.index.raw_vectors();
let mut mins = vec![f32::INFINITY; dim];
let mut maxs = vec![f32::NEG_INFINITY; dim];
for i in 0..n {
let v = &vectors[i * dim..(i + 1) * dim];
for (d, &val) in v.iter().enumerate() {
if val < mins[d] {
mins[d] = val;
}
if val > maxs[d] {
maxs[d] = val;
}
}
}
let mut steps = vec![0.0f32; dim];
let mut inv_scales = vec![0.0f32; dim];
for d in 0..dim {
let range = maxs[d] - mins[d];
if range > 1e-10 {
steps[d] = range / 255.0;
inv_scales[d] = 255.0 / range;
}
}
let mut codes = vec![0u8; n * dim];
for i in 0..n {
let v = &vectors[i * dim..(i + 1) * dim];
let c = &mut codes[i * dim..(i + 1) * dim];
for d in 0..dim {
let q = ((v[d] - mins[d]) * inv_scales[d] + 0.5) as i32;
c[d] = q.clamp(0, 255) as u8;
}
}
self.codes = codes;
self.mins = mins;
self.steps = steps;
self.inv_scales = inv_scales;
Ok(())
}
#[inline]
fn approx_dist(query: &[f32], code: &[u8], mins: &[f32], steps: &[f32]) -> f32 {
debug_assert_eq!(query.len(), code.len());
let mut sum = 0.0f32;
let chunks = query.len() / 4;
let remainder = query.len() % 4;
for i in 0..chunks {
let base = i * 4;
let d0 = query[base] - (mins[base] + code[base] as f32 * steps[base]);
let d1 = query[base + 1] - (mins[base + 1] + code[base + 1] as f32 * steps[base + 1]);
let d2 = query[base + 2] - (mins[base + 2] + code[base + 2] as f32 * steps[base + 2]);
let d3 = query[base + 3] - (mins[base + 3] + code[base + 3] as f32 * steps[base + 3]);
sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
}
let base = chunks * 4;
for i in 0..remainder {
let d = query[base + i] - (mins[base + i] + code[base + i] as f32 * steps[base + i]);
sum += d * d;
}
sum
}
fn search_quantized(&self, query: &[f32], ef: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
let codes = &self.codes;
let dim = self.index.dimension;
let mins = &self.mins;
let steps = &self.steps;
let (entry_point, entry_layer) = self.index.entry_point().unwrap_or((0, 0));
let mut current = entry_point;
let code_slice = &codes[current as usize * dim..(current as usize + 1) * dim];
let mut current_dist = Self::approx_dist(query, code_slice, mins, steps);
for layer_idx in (1..=entry_layer).rev() {
if layer_idx >= self.index.layers.len() {
continue;
}
let layer = &self.index.layers[layer_idx];
let mut changed = true;
while changed {
changed = false;
let neighbors = layer.get_neighbors(current);
for &neighbor_id in neighbors.iter() {
let ncode =
&codes[neighbor_id as usize * dim..(neighbor_id as usize + 1) * dim];
let dist = Self::approx_dist(query, ncode, mins, steps);
if dist < current_dist {
current_dist = dist;
current = neighbor_id;
changed = true;
}
}
}
}
if self.index.layers.is_empty() {
return Ok(Vec::new());
}
let base_layer = &self.index.layers[0];
let dist_fn = |_q: &[f32], node_id: u32| -> f32 {
let offset = node_id as usize * dim;
let ncode = &codes[offset..offset + dim];
Self::approx_dist(query, ncode, mins, steps)
};
Ok(crate::hnsw::search::greedy_search_layer_custom(
query,
current,
base_layer,
&self.index.vectors,
self.index.dimension,
ef,
&dist_fn,
))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use rand::prelude::*;
fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| (0..dim).map(|_| rng.random::<f32>() * 2.0 - 1.0).collect())
.collect()
}
fn random_normalized(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let v: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
crate::distance::normalize(&v)
})
.collect()
}
fn l2_params() -> crate::hnsw::HNSWParams {
crate::hnsw::HNSWParams {
metric: crate::distance::DistanceMetric::L2,
..Default::default()
}
}
#[test]
fn sq8_encode_decode_roundtrip() {
let dim = 128;
let n = 100;
let vecs = random_vectors(n, dim, 42);
let mut index = HNSWSq8Index::with_params(dim, l2_params()).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let vectors = index.index.raw_vectors();
for i in 0..n {
let v = &vectors[i * dim..(i + 1) * dim];
let c = &index.codes[i * dim..(i + 1) * dim];
for d in 0..dim {
let decoded = index.mins[d] + c[d] as f32 * index.steps[d];
let max_err = index.steps[d]; assert!(
(decoded - v[d]).abs() <= max_err + 1e-6,
"vec {i} dim {d}: decoded={decoded} vs original={}, err={}, max_err={max_err}",
v[d],
(decoded - v[d]).abs(),
);
}
}
}
#[test]
fn sq8_self_retrieval() {
let dim = 64;
let n = 200;
let vecs = random_vectors(n, dim, 99);
let mut index = HNSWSq8Index::with_params(dim, l2_params()).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let query = &vecs[0];
let results = index.search(query, 5, 64).unwrap();
assert!(
results.iter().any(|(id, _)| *id == 0),
"Query vector should be in its own top-5 SQ8 results"
);
}
#[test]
fn sq8_reranked_recall() {
let dim = 32;
let n = 500;
let k = 10;
let vecs = random_vectors(n, dim, 42);
let mut index = HNSWSq8Index::with_params(dim, l2_params()).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let query = &vecs[0];
let mut gt: Vec<(u32, f32)> = vecs
.iter()
.enumerate()
.map(|(i, v)| (i as u32, crate::distance::l2_distance(query, v)))
.collect();
gt.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt_ids: std::collections::HashSet<u32> = gt.iter().take(k).map(|(id, _)| *id).collect();
let results = index.search_reranked(query, k, 64, 50).unwrap();
let result_ids: std::collections::HashSet<u32> =
results.iter().map(|(id, _)| *id).collect();
let recall = gt_ids.intersection(&result_ids).count() as f32 / k as f32;
assert!(
recall >= 0.70,
"SQ8 reranked recall@{k} = {recall:.3}, expected >= 0.70"
);
}
#[test]
fn sq8_with_l2_params() {
use crate::distance::DistanceMetric;
use crate::hnsw::HNSWParams;
let dim = 32;
let n = 200;
let vecs = random_vectors(n, dim, 55);
let params = HNSWParams {
metric: DistanceMetric::L2,
..Default::default()
};
let mut index = HNSWSq8Index::with_params(dim, params).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let results = index.search_reranked(&vecs[0], 5, 64, 50).unwrap();
assert_eq!(results[0].0, 0);
assert!(results[0].1 < 1e-3, "self-distance should be near 0");
}
#[test]
fn sq8_compression_ratio() {
let dim = 128;
let n = 1000;
let vecs = random_vectors(n, dim, 42);
let mut index = HNSWSq8Index::with_params(dim, l2_params()).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let float_bytes = n * dim * 4;
let code_bytes = index.code_memory();
let ratio = float_bytes as f64 / code_bytes as f64;
assert!(
ratio > 3.8 && ratio < 4.2,
"expected ~4x compression, got {ratio:.1}x"
);
}
#[test]
fn sq8_approx_dist_correlates_with_exact() {
let dim = 128;
let n = 100;
let vecs = random_vectors(n, dim, 77);
let mut index = HNSWSq8Index::with_params(dim, l2_params()).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let query = &vecs[0];
let vectors = index.index.raw_vectors();
let mut exact_dists: Vec<(usize, f32)> = (0..n)
.map(|i| {
let v = &vectors[i * dim..(i + 1) * dim];
(i, crate::distance::l2_distance(query, v))
})
.collect();
let mut approx_dists: Vec<(usize, f32)> = (0..n)
.map(|i| {
let code = &index.codes[i * dim..(i + 1) * dim];
(
i,
HNSWSq8Index::approx_dist(query, code, &index.mins, &index.steps),
)
})
.collect();
exact_dists.sort_by(|a, b| a.1.total_cmp(&b.1));
approx_dists.sort_by(|a, b| a.1.total_cmp(&b.1));
let exact_top10: std::collections::HashSet<usize> =
exact_dists.iter().take(10).map(|(i, _)| *i).collect();
let approx_top20: std::collections::HashSet<usize> =
approx_dists.iter().take(20).map(|(i, _)| *i).collect();
let overlap = exact_top10.intersection(&approx_top20).count();
assert!(
overlap >= 8,
"SQ8 approx should preserve ranking: {overlap}/10 of exact top-10 in approx top-20"
);
}
#[test]
fn sq8_cosine_metric() {
let dim = 32;
let n = 200;
let vecs = random_normalized(n, dim, 42);
let mut index = HNSWSq8Index::new(dim, 16, 32).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let results = index.search_reranked(&vecs[0], 5, 64, 50).unwrap();
assert_eq!(results[0].0, 0, "self should be closest");
}
}