use crate::hnsw::graph::HNSWIndex;
use crate::RetrieveError;
pub struct HNSWSq4Index {
index: HNSWIndex,
codes: Vec<Vec<u8>>,
mins: Vec<f32>,
steps: Vec<f32>,
inv_scales: Vec<f32>,
built: bool,
}
impl HNSWSq4Index {
pub fn new(dimension: usize, m: usize, m_max: usize) -> Result<Self, RetrieveError> {
let index = HNSWIndex::new(dimension, m, m_max)?;
Ok(Self {
index,
codes: Vec::new(),
mins: Vec::new(),
steps: Vec::new(),
inv_scales: Vec::new(),
built: false,
})
}
pub fn with_params(
dimension: usize,
params: crate::hnsw::HNSWParams,
) -> Result<Self, RetrieveError> {
let index = HNSWIndex::with_params(dimension, params)?;
Ok(Self {
index,
codes: Vec::new(),
mins: Vec::new(),
steps: Vec::new(),
inv_scales: Vec::new(),
built: false,
})
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
self.index.add_slice(doc_id, vector)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
self.index.build()?;
self.quantize_vectors()?;
self.built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_ready(query)?;
let candidates = self.search_quantized(query, ef)?;
let mut output: Vec<(u32, f32)> = candidates
.into_iter()
.take(k)
.map(|(id, dist)| (self.index.doc_ids[id as usize], dist))
.collect();
output.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(output)
}
pub fn search_reranked(
&self,
query: &[f32],
k: usize,
ef: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
self.check_ready(query)?;
let pool = rerank_pool.max(k);
let candidates = self.search_quantized(query, ef.max(pool))?;
let dist_fn = self.index.dist_fn();
let mut reranked: Vec<(u32, f32)> = candidates
.into_iter()
.take(pool)
.map(|(internal_id, _approx)| {
let vec = self.index.get_vector(internal_id as usize);
let exact = dist_fn(query, vec);
(self.index.doc_ids[internal_id as usize], exact)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
reranked.truncate(k);
Ok(reranked)
}
pub fn len(&self) -> usize {
self.index.num_vectors
}
pub fn is_empty(&self) -> bool {
self.index.num_vectors == 0
}
pub fn inner(&self) -> &HNSWIndex {
&self.index
}
fn check_ready(&self, query: &[f32]) -> Result<(), RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.index.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.index.dimension,
});
}
if self.index.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
Ok(())
}
fn quantize_vectors(&mut self) -> Result<(), RetrieveError> {
let dim = self.index.dimension;
let n = self.index.num_vectors;
let vectors = self.index.raw_vectors();
let mut mins = vec![f32::INFINITY; dim];
let mut maxs = vec![f32::NEG_INFINITY; dim];
for i in 0..n {
let v = &vectors[i * dim..(i + 1) * dim];
for (d, &val) in v.iter().enumerate() {
if val < mins[d] {
mins[d] = val;
}
if val > maxs[d] {
maxs[d] = val;
}
}
}
let mut steps = vec![0.0f32; dim];
let mut inv_scales = vec![0.0f32; dim];
for d in 0..dim {
let range = maxs[d] - mins[d];
if range > 1e-10 {
steps[d] = range / 15.0;
inv_scales[d] = 15.0 / range;
}
}
let code_len = dim.div_ceil(2);
let mut codes = Vec::with_capacity(n);
let mut buf = vec![0u8; code_len];
for i in 0..n {
let v = &vectors[i * dim..(i + 1) * dim];
crate::sq4::pack_vector(v, &mins, &inv_scales, &mut buf);
codes.push(buf.clone());
}
self.codes = codes;
self.mins = mins;
self.steps = steps;
self.inv_scales = inv_scales;
Ok(())
}
#[inline]
#[allow(clippy::needless_range_loop)]
fn precompute_table(&self, query: &[f32]) -> Vec<f32> {
let dim = self.index.dimension;
let mut table = vec![0.0f32; dim * 16];
for d in 0..dim {
let q = query[d];
let min = self.mins[d];
let step = self.steps[d];
let base = d * 16;
for code in 0..16u32 {
let decoded = min + code as f32 * step;
let diff = q - decoded;
table[base + code as usize] = diff * diff;
}
}
table
}
#[inline]
fn approx_dist_table(table: &[f32], code: &[u8], dim: usize) -> f32 {
let mut sum = 0.0f32;
let pairs = dim / 2;
for p in 0..pairs {
let byte = code[p];
let lo = (byte & 0x0F) as usize;
let hi = (byte >> 4) as usize;
sum += table[2 * p * 16 + lo] + table[(2 * p + 1) * 16 + hi];
}
if dim % 2 == 1 {
let lo = (code[pairs] & 0x0F) as usize;
sum += table[(dim - 1) * 16 + lo];
}
sum
}
fn search_quantized(&self, query: &[f32], ef: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
let table = self.precompute_table(query);
let codes = &self.codes;
let dim = self.index.dimension;
let (entry_point, entry_layer) = self.index.entry_point().unwrap_or((0, 0));
let mut current = entry_point;
let mut current_dist = Self::approx_dist_table(&table, &codes[current as usize], dim);
for layer_idx in (1..=entry_layer).rev() {
if layer_idx >= self.index.layers.len() {
continue;
}
let layer = &self.index.layers[layer_idx];
let mut changed = true;
while changed {
changed = false;
for &neighbor_id in layer.get_neighbors(current).iter() {
let dist = Self::approx_dist_table(&table, &codes[neighbor_id as usize], dim);
if dist < current_dist {
current_dist = dist;
current = neighbor_id;
changed = true;
}
}
}
}
if self.index.layers.is_empty() {
return Ok(Vec::new());
}
let base_layer = &self.index.layers[0];
let dist_fn = |_q: &[f32], node_id: u32| -> f32 {
Self::approx_dist_table(&table, &codes[node_id as usize], dim)
};
Ok(crate::hnsw::search::greedy_search_layer_custom(
query,
current,
base_layer,
&self.index.vectors,
self.index.dimension,
ef,
&dist_fn,
))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use rand::prelude::*;
fn random_normalized(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let v: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
crate::distance::normalize(&v)
})
.collect()
}
#[test]
fn sq4u_search_reranked_recall() {
let dim = 32;
let n = 300;
let k = 10;
let vecs = random_normalized(n, dim, 42);
let mut index = HNSWSq4Index::new(dim, 16, 32).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let query = &vecs[0];
let mut gt: Vec<(u32, f32)> = vecs
.iter()
.enumerate()
.map(|(i, v)| {
(
i as u32,
crate::distance::cosine_distance_normalized(query, v),
)
})
.collect();
gt.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt_ids: std::collections::HashSet<u32> = gt.iter().take(k).map(|(id, _)| *id).collect();
let results = index.search_reranked(query, k, 64, 50).unwrap();
let result_ids: std::collections::HashSet<u32> =
results.iter().map(|(id, _)| *id).collect();
let recall = gt_ids.intersection(&result_ids).count() as f32 / k as f32;
assert!(
recall >= 0.50,
"SQ4U reranked recall@{k} = {recall:.3}, expected >= 0.50"
);
}
#[test]
fn sq4u_approx_closer_than_random() {
let dim = 64;
let n = 100;
let vecs = random_normalized(n, dim, 99);
let mut index = HNSWSq4Index::new(dim, 16, 32).unwrap();
for (i, v) in vecs.iter().enumerate() {
index.add_slice(i as u32, v).unwrap();
}
index.build().unwrap();
let query = &vecs[0];
let results = index.search(query, 5, 64).unwrap();
assert!(
results.iter().any(|(id, _)| *id == 0),
"Query vector should be in its own top-5 SQ4U results"
);
}
}