use crate::scann::partitioning::KMeans;
use crate::scann::quantization::AnisotropicQuantizer;
use crate::scann::reranking;
use crate::RetrieveError;
#[derive(Debug)]
pub struct SCANNIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
params: SCANNParams,
built: bool,
partitions: Vec<Partition>,
pub(crate) partition_centroids: Vec<Vec<f32>>,
quantizer: Option<AnisotropicQuantizer>,
}
#[derive(Clone, Debug)]
pub struct SCANNParams {
pub num_partitions: usize,
pub nprobe: usize,
pub num_reorder: usize,
pub num_codebooks: usize,
pub codebook_size: usize,
pub seed: u64,
}
impl Default for SCANNParams {
fn default() -> Self {
Self {
num_partitions: 256,
nprobe: 20,
num_reorder: 100,
num_codebooks: 16,
codebook_size: 256,
seed: 42,
}
}
}
#[derive(Clone, Debug)]
struct Partition {
vector_indices: Vec<u32>,
codes: Vec<u8>,
}
impl SCANNIndex {
pub fn new(dimension: usize, params: SCANNParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
params,
built: false,
partitions: Vec::new(),
partition_centroids: Vec::new(),
quantizer: None,
})
}
pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
self.add_slice(_doc_id, &vector)
}
pub fn add_slice(&mut self, _doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"index already built".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(vector);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let mut kmeans =
KMeans::new(self.dimension, self.params.num_partitions)?.with_seed(self.params.seed);
kmeans.fit(&self.vectors, self.num_vectors)?;
self.partition_centroids = kmeans.centroids().to_vec();
let assignments = kmeans.assign_clusters(&self.vectors, self.num_vectors);
let mut residuals = Vec::with_capacity(self.vectors.len());
self.partitions = vec![
Partition {
vector_indices: Vec::new(),
codes: Vec::new()
};
self.params.num_partitions
];
for (i, &partition_idx) in assignments.iter().enumerate() {
self.partitions[partition_idx].vector_indices.push(i as u32);
let vec = self.get_vector(i);
let centroid = &self.partition_centroids[partition_idx];
for (x, c) in vec.iter().zip(centroid.iter()) {
residuals.push(x - c);
}
}
let mut quantizer = AnisotropicQuantizer::new(
self.dimension,
self.params.num_codebooks,
self.params.codebook_size,
self.params.seed,
)?;
quantizer.fit_residuals(&residuals, self.num_vectors)?;
for p_idx in 0..self.params.num_partitions {
let centroid = &self.partition_centroids[p_idx];
let vec_indices: Vec<u32> = self.partitions[p_idx].vector_indices.clone();
let mut all_codes = Vec::with_capacity(vec_indices.len() * self.params.num_codebooks);
for vec_idx in vec_indices {
let vec = self.get_vector(vec_idx as usize);
let residual: Vec<f32> = vec
.iter()
.zip(centroid.iter())
.map(|(x, c)| x - c)
.collect();
let codes = quantizer.quantize(&residual);
all_codes.extend(codes);
}
self.partitions[p_idx].codes = all_codes;
}
self.quantizer = Some(quantizer);
self.built = true;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
let quantizer = self
.quantizer
.as_ref()
.ok_or(RetrieveError::InvalidParameter(
"quantizer not initialized".into(),
))?;
let mut partition_scores: Vec<(usize, f32)> = self
.partition_centroids
.iter()
.enumerate()
.map(|(idx, c)| (idx, crate::simd::dot(query, c)))
.collect();
partition_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
let num_probe = self.params.nprobe.min(self.params.num_partitions);
let lut = quantizer.build_lut(query);
let mut candidates = Vec::new();
for (p_idx, center_score) in partition_scores.iter().take(num_probe) {
let partition = &self.partitions[*p_idx];
let num_vectors = partition.vector_indices.len();
let m = self.params.num_codebooks;
for i in 0..num_vectors {
let mut residual_score = 0.0;
let code_start = i * m;
let codes = &partition.codes[code_start..code_start + m];
for (subspace_idx, &code) in codes.iter().enumerate() {
residual_score += lut[subspace_idx][code as usize];
}
let approx_score = center_score + residual_score;
candidates.push((partition.vector_indices[i], approx_score));
}
}
candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
let top_candidates: Vec<(u32, f32)> = candidates
.into_iter()
.take(self.params.num_reorder.max(k))
.collect();
let reranked = reranking::rerank(query, &top_candidates, &self.vectors, self.dimension, k);
Ok(reranked)
}
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::RetrieveError;
#[test]
fn test_create_index() {
let params = SCANNParams {
num_partitions: 2,
nprobe: 2,
num_reorder: 10,
num_codebooks: 2,
codebook_size: 256,
seed: 42,
};
let index = SCANNIndex::new(4, params);
assert!(index.is_ok());
let index = index.unwrap();
assert_eq!(index.dimension, 4);
assert_eq!(index.num_vectors, 0);
}
#[test]
fn test_add_and_search() {
let params = SCANNParams {
num_partitions: 2,
nprobe: 2,
num_reorder: 10,
num_codebooks: 2,
codebook_size: 256,
seed: 42,
};
let mut index = SCANNIndex::new(4, params).unwrap();
for i in 0..20u32 {
let v = vec![i as f32, (i as f32) * 0.5, 1.0, 0.0];
index.add(i, v).unwrap();
}
index.build().unwrap();
let query = vec![0.0, 0.0, 1.0, 0.0];
let results = index.search(&query, 3).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
}
#[test]
fn test_zero_dimension_error() {
let result = SCANNIndex::new(0, SCANNParams::default());
assert!(result.is_err());
match result.unwrap_err() {
RetrieveError::InvalidParameter(_) => {}
other => panic!("Expected InvalidParameter, got {:?}", other),
}
}
}