use anyhow::{anyhow, Result};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScaNNConfig {
pub num_leaves: usize,
pub num_leaves_to_search: usize,
pub quantization_bits: u8,
pub rerank_k: usize,
pub dimensions_per_block: usize,
}
impl Default for ScaNNConfig {
fn default() -> Self {
Self {
num_leaves: 1000,
num_leaves_to_search: 100,
quantization_bits: 4,
rerank_k: 100,
dimensions_per_block: 2,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
struct TreeNode {
center: Vec<f32>,
vectors: Vec<(String, Vec<u8>, Vec<f32>)>,
}
impl TreeNode {
fn new(center: Vec<f32>) -> Self {
Self {
center,
vectors: Vec::new(),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
struct AnisotropicQuantizer {
dimension_weights: Vec<f32>,
bits: u8,
levels: usize,
min_vals: Vec<f32>,
max_vals: Vec<f32>,
}
impl AnisotropicQuantizer {
fn train(vectors: &[Vec<f32>], bits: u8, dimensions_per_block: usize) -> Result<Self> {
if vectors.is_empty() {
return Err(anyhow!("Need training vectors"));
}
let dimension = vectors[0].len();
let levels = 1 << bits;
let mut dimension_weights = vec![0.0; dimension];
let mut means = vec![0.0; dimension];
for vector in vectors {
for (i, &val) in vector.iter().enumerate() {
means[i] += val;
}
}
for mean in &mut means {
*mean /= vectors.len() as f32;
}
for vector in vectors {
for (i, &val) in vector.iter().enumerate() {
let diff = val - means[i];
dimension_weights[i] += diff * diff;
}
}
let total_weight: f32 = dimension_weights.iter().sum();
if total_weight > 0.0 {
for weight in &mut dimension_weights {
*weight = (*weight / total_weight).sqrt();
}
} else {
dimension_weights.fill(1.0);
}
let mut min_vals = vec![f32::INFINITY; dimension];
let mut max_vals = vec![f32::NEG_INFINITY; dimension];
for vector in vectors {
for (i, &val) in vector.iter().enumerate() {
let weighted_val = val * dimension_weights[i];
min_vals[i] = min_vals[i].min(weighted_val);
max_vals[i] = max_vals[i].max(weighted_val);
}
}
Ok(Self {
dimension_weights,
bits,
levels,
min_vals,
max_vals,
})
}
fn quantize(&self, vector: &[f32]) -> Vec<u8> {
let mut quantized = Vec::with_capacity(vector.len());
for (i, &val) in vector.iter().enumerate() {
let weighted_val = val * self.dimension_weights[i];
let range = self.max_vals[i] - self.min_vals[i];
let normalized = if range > 0.0 {
((weighted_val - self.min_vals[i]) / range).clamp(0.0, 1.0)
} else {
0.5
};
let quantized_val = (normalized * (self.levels - 1) as f32).round() as u8;
quantized.push(quantized_val);
}
quantized
}
fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
let mut vector = Vec::with_capacity(quantized.len());
for (i, &q) in quantized.iter().enumerate() {
let normalized = q as f32 / (self.levels - 1) as f32;
let range = self.max_vals[i] - self.min_vals[i];
let weighted_val = normalized * range + self.min_vals[i];
let original_val = if self.dimension_weights[i] > 0.0 {
weighted_val / self.dimension_weights[i]
} else {
weighted_val
};
vector.push(original_val);
}
vector
}
fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
let reconstructed = self.dequantize(quantized);
euclidean_distance(query, &reconstructed)
}
}
pub struct ScaNNIndex {
dimension: usize,
config: ScaNNConfig,
tree: Vec<TreeNode>,
quantizer: Option<AnisotropicQuantizer>,
is_trained: bool,
}
impl ScaNNIndex {
pub fn new(dimension: usize, config: ScaNNConfig) -> Result<Self> {
if config.quantization_bits != 2
&& config.quantization_bits != 4
&& config.quantization_bits != 8
{
return Err(anyhow!("Quantization bits must be 2, 4, or 8"));
}
Ok(Self {
dimension,
config,
tree: Vec::new(),
quantizer: None,
is_trained: false,
})
}
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
if training_vectors.is_empty() {
return Err(anyhow!("Training data is empty"));
}
if training_vectors[0].len() != self.dimension {
return Err(anyhow!(
"Training vector dimension {} doesn't match index dimension {}",
training_vectors[0].len(),
self.dimension
));
}
tracing::info!(
"Training ScaNN index on {} vectors, dimension {}",
training_vectors.len(),
self.dimension
);
tracing::info!("Learning tree with {} leaves", self.config.num_leaves);
self.tree = self.learn_tree(training_vectors)?;
tracing::info!(
"Training anisotropic quantizer ({} bits)",
self.config.quantization_bits
);
self.quantizer = Some(AnisotropicQuantizer::train(
training_vectors,
self.config.quantization_bits,
self.config.dimensions_per_block,
)?);
self.is_trained = true;
tracing::info!("ScaNN training complete");
Ok(())
}
fn learn_tree(&self, vectors: &[Vec<f32>]) -> Result<Vec<TreeNode>> {
use crate::vectors::KMeans;
let kmeans = KMeans::new(self.config.num_leaves).with_max_iterations(25);
let (centers, _) = kmeans.fit(vectors)?;
let tree: Vec<TreeNode> = centers.into_iter().map(TreeNode::new).collect();
Ok(tree)
}
pub fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
if !self.is_trained {
return Err(anyhow!("Index must be trained before adding vectors"));
}
if vector.len() != self.dimension {
return Err(anyhow!(
"Vector dimension {} doesn't match index dimension {}",
vector.len(),
self.dimension
));
}
let leaf_idx = self.find_nearest_leaf(&vector);
let quantizer = self.quantizer.as_ref().unwrap();
let quantized = quantizer.quantize(&vector);
self.tree[leaf_idx].vectors.push((id, quantized, vector));
Ok(())
}
pub fn add_batch(&mut self, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
if !self.is_trained {
return Err(anyhow!("Index must be trained before adding vectors"));
}
let quantizer = self.quantizer.as_ref().unwrap();
let processed: Vec<(usize, String, Vec<u8>, Vec<f32>)> = vectors
.par_iter()
.map(|(id, vector)| {
if vector.len() != self.dimension {
return Err(anyhow!(
"Vector dimension {} doesn't match index dimension {}",
vector.len(),
self.dimension
));
}
let leaf_idx = self.find_nearest_leaf(vector);
let quantized = quantizer.quantize(vector);
Ok((leaf_idx, id.clone(), quantized, vector.clone()))
})
.collect::<Result<Vec<_>>>()?;
for (leaf_idx, id, quantized, vector) in processed {
self.tree[leaf_idx].vectors.push((id, quantized, vector));
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
if !self.is_trained {
return Err(anyhow!("Index must be trained before searching"));
}
if query.len() != self.dimension {
return Err(anyhow!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
self.dimension
));
}
let quantizer = self.quantizer.as_ref().unwrap();
let leaves_to_search = self.find_nearest_leaves(query, self.config.num_leaves_to_search);
let mut candidates: Vec<(String, f32, Vec<f32>)> = Vec::new();
for &leaf_idx in &leaves_to_search {
for (id, quantized, original) in &self.tree[leaf_idx].vectors {
let distance = quantizer.asymmetric_distance(query, quantized);
candidates.push((id.clone(), distance, original.clone()));
}
}
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
candidates.truncate(self.config.rerank_k);
let mut results: Vec<(String, f32)> = candidates
.into_iter()
.map(|(id, _, original)| {
let exact_distance = euclidean_distance(query, &original);
(id, exact_distance)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
Ok(results)
}
fn find_nearest_leaf(&self, vector: &[f32]) -> usize {
let mut min_distance = f32::INFINITY;
let mut nearest_leaf = 0;
for (idx, node) in self.tree.iter().enumerate() {
let distance = euclidean_distance(vector, &node.center);
if distance < min_distance {
min_distance = distance;
nearest_leaf = idx;
}
}
nearest_leaf
}
fn find_nearest_leaves(&self, query: &[f32], k: usize) -> Vec<usize> {
let mut distances: Vec<(usize, f32)> = self
.tree
.iter()
.enumerate()
.map(|(idx, node)| (idx, euclidean_distance(query, &node.center)))
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
distances.into_iter().take(k).map(|(idx, _)| idx).collect()
}
pub fn stats(&self) -> ScaNNStats {
let total_vectors: usize = self.tree.iter().map(|node| node.vectors.len()).sum();
let avg_vectors_per_leaf = if !self.tree.is_empty() {
total_vectors as f32 / self.tree.len() as f32
} else {
0.0
};
let max_vectors_per_leaf = self
.tree
.iter()
.map(|node| node.vectors.len())
.max()
.unwrap_or(0);
let non_empty_leaves = self
.tree
.iter()
.filter(|node| !node.vectors.is_empty())
.count();
let bytes_per_vector = if self.config.quantization_bits == 8 {
self.dimension
} else if self.config.quantization_bits == 4 {
(self.dimension + 1) / 2
} else {
(self.dimension + 3) / 4
};
let quantized_memory = total_vectors * bytes_per_vector;
let original_memory = total_vectors * self.dimension * 4; let total_memory = quantized_memory + original_memory;
let compression_ratio = if total_memory > 0 {
(total_vectors * self.dimension * 4) as f32 / quantized_memory as f32
} else {
0.0
};
ScaNNStats {
num_vectors: total_vectors,
num_leaves: self.tree.len(),
avg_vectors_per_leaf,
max_vectors_per_leaf,
non_empty_leaves,
memory_bytes: total_memory,
compression_ratio,
is_trained: self.is_trained,
}
}
pub fn len(&self) -> usize {
self.tree.iter().map(|node| node.vectors.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_trained(&self) -> bool {
self.is_trained
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn config(&self) -> &ScaNNConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct ScaNNStats {
pub num_vectors: usize,
pub num_leaves: usize,
pub avg_vectors_per_leaf: f32,
pub max_vectors_per_leaf: usize,
pub non_empty_leaves: usize,
pub memory_bytes: usize,
pub compression_ratio: f32,
pub is_trained: bool,
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..n)
.map(|_| (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect())
.collect()
}
#[test]
fn test_scann_basic() {
let config = ScaNNConfig {
num_leaves: 10,
num_leaves_to_search: 5,
quantization_bits: 4,
rerank_k: 20,
dimensions_per_block: 2,
};
let mut index = ScaNNIndex::new(64, config).unwrap();
let training_vectors = generate_random_vectors(200, 64);
index.train(&training_vectors).unwrap();
assert!(index.is_trained());
for (i, vector) in training_vectors.iter().enumerate().take(100) {
index.add(format!("vec_{}", i), vector.clone()).unwrap();
}
assert_eq!(index.len(), 100);
let query = &training_vectors[0];
let results = index.search(query, 10).unwrap();
assert_eq!(results.len(), 10);
assert_eq!(results[0].0, "vec_0");
assert!(results[0].1 < 0.01);
}
#[test]
fn test_anisotropic_quantizer() {
let vectors = vec![
vec![0.1, 0.9, 0.2],
vec![0.2, 0.8, 0.3],
vec![0.15, 0.85, 0.25],
vec![0.12, 0.88, 0.22],
vec![0.18, 0.82, 0.28],
vec![0.11, 0.89, 0.21],
vec![0.19, 0.81, 0.29],
vec![0.13, 0.87, 0.23],
];
let quantizer = AnisotropicQuantizer::train(&vectors, 4, 2).unwrap();
let original = vec![0.15, 0.85, 0.25];
let quantized = quantizer.quantize(&original);
let reconstructed = quantizer.dequantize(&quantized);
for (orig, recon) in original.iter().zip(reconstructed.iter()) {
assert!(
(orig - recon).abs() < 0.5,
"orig: {}, recon: {}, diff: {}",
orig,
recon,
(orig - recon).abs()
);
}
for &q in &quantized {
assert!(q < 16); }
}
#[test]
fn test_scann_batch_add() {
let config = ScaNNConfig {
num_leaves: 20,
num_leaves_to_search: 10,
quantization_bits: 4,
rerank_k: 50,
dimensions_per_block: 2,
};
let mut index = ScaNNIndex::new(32, config).unwrap();
let training_vectors = generate_random_vectors(200, 32);
index.train(&training_vectors).unwrap();
let batch: Vec<(String, Vec<f32>)> = training_vectors
.iter()
.enumerate()
.take(100)
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
index.add_batch(batch).unwrap();
assert_eq!(index.len(), 100);
}
#[test]
fn test_scann_search_accuracy() {
let config = ScaNNConfig {
num_leaves: 20,
num_leaves_to_search: 10,
quantization_bits: 8,
rerank_k: 50,
dimensions_per_block: 2,
};
let mut index = ScaNNIndex::new(64, config).unwrap();
let training_vectors = generate_random_vectors(500, 64);
index.train(&training_vectors).unwrap();
for (i, vector) in training_vectors.iter().enumerate().take(300) {
index.add(format!("vec_{}", i), vector.clone()).unwrap();
}
let query = &training_vectors[5];
let results = index.search(query, 10).unwrap();
assert_eq!(results[0].0, "vec_5");
assert!(results[0].1 < 0.1);
}
#[test]
fn test_scann_stats() {
let config = ScaNNConfig {
num_leaves: 10,
num_leaves_to_search: 5,
quantization_bits: 4,
rerank_k: 20,
dimensions_per_block: 2,
};
let mut index = ScaNNIndex::new(32, config).unwrap();
let training_vectors = generate_random_vectors(200, 32);
index.train(&training_vectors).unwrap();
for (i, vector) in training_vectors.iter().enumerate().take(100) {
index.add(format!("vec_{}", i), vector.clone()).unwrap();
}
let stats = index.stats();
assert_eq!(stats.num_vectors, 100);
assert_eq!(stats.num_leaves, 10);
assert!(stats.is_trained);
assert!(stats.compression_ratio > 1.0);
assert!(stats.non_empty_leaves > 0);
}
#[test]
fn test_different_quantization_bits() {
for &bits in &[2, 4, 8] {
let config = ScaNNConfig {
num_leaves: 10,
num_leaves_to_search: 5,
quantization_bits: bits,
rerank_k: 20,
dimensions_per_block: 2,
};
let mut index = ScaNNIndex::new(32, config).unwrap();
let training_vectors = generate_random_vectors(100, 32);
index.train(&training_vectors).unwrap();
for (i, vector) in training_vectors.iter().enumerate().take(50) {
index.add(format!("vec_{}", i), vector.clone()).unwrap();
}
let query = &training_vectors[0];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, "vec_0");
}
}
}