pub mod codec;
pub mod embedder;
#[cfg(test)]
pub mod falsification;
pub mod index;
pub mod search;
pub mod types;
pub use codec::{ResidualCodec, ResidualCodecBuilder};
pub use embedder::{MockMultiVectorEmbedder, MultiVectorEmbedder};
pub use index::WarpIndex;
pub use search::{exact_maxsim, CandidateScorer, CentroidSelector, ScoreMerger};
pub use types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig};
#[cfg(test)]
mod tests {
use super::*;
use crate::{Chunk, DocumentId};
#[test]
fn test_full_pipeline() {
let embedder = MockMultiVectorEmbedder::new(32, 128);
let config = WarpIndexConfig::new(2, 4, 32).with_kmeans_iterations(5);
let mut index = WarpIndex::new(config);
let training_texts = [
"machine learning algorithms are powerful tools for data science",
"deep neural networks have revolutionized computer vision tasks",
"natural language processing enables machines to understand text",
"computer vision systems detect objects in images and video",
"reinforcement learning agents learn through trial and error",
"transformer architectures power modern large language models",
"attention mechanisms allow models to focus on relevant inputs",
"gradient descent optimization updates neural network parameters",
];
let training_embeddings: Vec<_> =
training_texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
index.train(&training_embeddings).unwrap();
for text in training_texts.iter() {
let chunk = Chunk::new(DocumentId::new(), text.to_string(), 0, text.len());
let embedding = embedder.embed_tokens(text).unwrap();
index.insert(chunk, embedding).unwrap();
}
index.build().unwrap();
let query_text = "neural network learning";
let query_embedding = embedder.embed_tokens(query_text).unwrap();
let search_config = WarpSearchConfig::with_k(3);
let results = index.search(&query_embedding, &search_config).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
for i in 1..results.len() {
assert!(results[i - 1].1 >= results[i].1);
}
for (chunk_id, _score) in &results {
let chunk = index.get_chunk(chunk_id);
assert!(chunk.is_some());
}
}
#[test]
fn test_exact_maxsim_calculation() {
let query = MultiVectorEmbedding::new(
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
2,
4,
);
let doc = MultiVectorEmbedding::new(
vec![
0.5, 0.5, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
3,
4,
);
let score = exact_maxsim(&query, &doc);
assert!((score - 1.5).abs() < 1e-6);
}
#[test]
fn test_compression_preserves_ordering() {
let embedder = MockMultiVectorEmbedder::new(32, 128);
let query = embedder.embed_tokens("machine learning").unwrap();
let doc_relevant = embedder.embed_tokens("machine learning algorithms").unwrap();
let doc_partial = embedder.embed_tokens("learning systems").unwrap();
let doc_irrelevant = embedder.embed_tokens("cooking recipes").unwrap();
let exact_relevant = exact_maxsim(&query, &doc_relevant);
let _exact_partial = exact_maxsim(&query, &doc_partial);
let exact_irrelevant = exact_maxsim(&query, &doc_irrelevant);
assert!(
exact_relevant > exact_irrelevant,
"Relevant doc should score higher: {} vs {}",
exact_relevant,
exact_irrelevant
);
}
#[test]
fn test_search_nprobe_variations() {
let embedder = MockMultiVectorEmbedder::new(16, 64);
let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
let mut index = WarpIndex::new(config);
let texts: Vec<String> = (0..50).map(|i| format!("document number {}", i)).collect();
let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
index.train(&embeddings).unwrap();
for (i, text) in texts.iter().enumerate() {
let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
index.insert(chunk, embeddings[i].clone()).unwrap();
}
index.build().unwrap();
let query = embedder.embed_tokens("document number").unwrap();
for nprobe in [1, 2, 4, 8] {
let config = WarpSearchConfig::with_k(5).nprobe(nprobe);
let results = index.search(&query, &config).unwrap();
assert!(results.len() <= 5, "nprobe={}: got {} results", nprobe, results.len());
}
}
#[test]
fn test_memory_efficiency() {
let embedder = MockMultiVectorEmbedder::new(128, 512);
let config = WarpIndexConfig::new(2, 8, 128).with_kmeans_iterations(5);
let mut index = WarpIndex::new(config);
let texts: Vec<String> = (0..50)
.map(|i| {
format!("document number {} contains important information about topic {}", i, i)
})
.collect();
let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
index.train(&embeddings).unwrap();
for (i, text) in texts.iter().enumerate() {
let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
index.insert(chunk, embeddings[i].clone()).unwrap();
}
index.build().unwrap();
let memory = index.memory_usage();
let num_tokens = index.num_tokens();
let theoretical_min = num_tokens * 32;
let overhead_factor = 3.0;
assert!(
memory < (theoretical_min as f64 * overhead_factor) as usize,
"Memory {} too high for {} tokens (theoretical min {})",
memory,
num_tokens,
theoretical_min
);
}
}