use crate::RetrieveError;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone)]
pub struct MatryoshkaConfig {
pub full_dimension: usize,
pub supported_dimensions: Vec<usize>,
pub default_dimension: usize,
pub cascaded_search: bool,
pub cascade_expansion: usize,
}
impl Default for MatryoshkaConfig {
fn default() -> Self {
Self {
full_dimension: 768,
supported_dimensions: vec![64, 128, 256, 384, 512, 768],
default_dimension: 256,
cascaded_search: true,
cascade_expansion: 4,
}
}
}
impl MatryoshkaConfig {
pub fn openai_style() -> Self {
Self {
full_dimension: 3072,
supported_dimensions: vec![256, 512, 1024, 1536, 3072],
default_dimension: 1024,
cascaded_search: true,
cascade_expansion: 4,
}
}
pub fn sentence_transformers() -> Self {
Self {
full_dimension: 768,
supported_dimensions: vec![64, 128, 256, 384, 512, 768],
default_dimension: 256,
cascaded_search: true,
cascade_expansion: 4,
}
}
pub fn cohere_style() -> Self {
Self {
full_dimension: 1024,
supported_dimensions: vec![256, 512, 768, 1024],
default_dimension: 512,
cascaded_search: true,
cascade_expansion: 4,
}
}
pub fn find_dimension(&self, requested: usize) -> usize {
for &dim in &self.supported_dimensions {
if dim >= requested {
return dim;
}
}
self.full_dimension
}
pub fn coarse_dimension(&self) -> usize {
self.supported_dimensions
.first()
.copied()
.unwrap_or(self.full_dimension / 4)
}
}
#[derive(Debug, Clone)]
pub struct MatryoshkaEmbedding {
data: Vec<f32>,
#[allow(dead_code)]
config: MatryoshkaConfig,
}
impl MatryoshkaEmbedding {
pub fn new(data: Vec<f32>, config: MatryoshkaConfig) -> Result<Self, RetrieveError> {
if data.len() != config.full_dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: data.len(),
doc_dim: config.full_dimension,
});
}
Ok(Self { data, config })
}
#[inline]
pub fn at_dimension(&self, dim: usize) -> &[f32] {
let dim = dim.min(self.data.len());
&self.data[..dim]
}
#[inline]
pub fn full(&self) -> &[f32] {
&self.data
}
#[inline]
pub fn dimension(&self) -> usize {
self.data.len()
}
pub fn cosine_similarity_at(&self, other: &Self, dim: usize) -> f32 {
let a = self.at_dimension(dim);
let b = other.at_dimension(dim);
crate::simd::cosine(a, b)
}
pub fn l2_distance_at(&self, other: &Self, dim: usize) -> f32 {
let a = self.at_dimension(dim);
let b = other.at_dimension(dim);
crate::distance::l2_distance(a, b)
}
pub fn inner_product_at(&self, other: &Self, dim: usize) -> f32 {
let a = self.at_dimension(dim);
let b = other.at_dimension(dim);
crate::simd::dot(a, b)
}
}
#[derive(Debug)]
pub struct MatryoshkaIndex {
embeddings: Vec<MatryoshkaEmbedding>,
doc_ids: Vec<u32>,
config: MatryoshkaConfig,
stats: MatryoshkaStats,
}
#[derive(Debug, Default, Clone)]
pub struct MatryoshkaStats {
pub total_searches: u64,
pub cascaded_searches: u64,
pub avg_coarse_candidates: f64,
pub distance_computations_saved: u64,
}
#[derive(Debug, Clone)]
struct Candidate {
id: u32,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.distance.total_cmp(&other.distance).reverse()
}
}
impl MatryoshkaIndex {
pub fn new(config: MatryoshkaConfig) -> Self {
Self {
embeddings: Vec::new(),
doc_ids: Vec::new(),
config,
stats: MatryoshkaStats::default(),
}
}
pub fn add(&mut self, doc_id: u32, embedding: Vec<f32>) -> Result<(), RetrieveError> {
let emb = MatryoshkaEmbedding::new(embedding, self.config.clone())?;
self.embeddings.push(emb);
self.doc_ids.push(doc_id);
Ok(())
}
pub fn add_batch(&mut self, items: Vec<(u32, Vec<f32>)>) -> Result<(), RetrieveError> {
for (doc_id, embedding) in items {
self.add(doc_id, embedding)?;
}
Ok(())
}
pub fn search_at_dimension(
&self,
query: &[f32],
k: usize,
dimension: usize,
) -> Vec<(u32, f32)> {
let query_slice = &query[..dimension.min(query.len())];
let mut heap: BinaryHeap<Candidate> = BinaryHeap::with_capacity(k + 1);
for (idx, emb) in self.embeddings.iter().enumerate() {
let emb_slice = emb.at_dimension(dimension);
let dist = crate::distance::l2_distance(query_slice, emb_slice);
heap.push(Candidate {
id: self.doc_ids[idx],
distance: dist,
});
if heap.len() > k {
heap.pop();
}
}
let mut results: Vec<_> = heap.into_iter().map(|c| (c.id, c.distance)).collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results
}
pub fn search_cascaded(&mut self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
self.stats.total_searches += 1;
self.stats.cascaded_searches += 1;
let coarse_dim = self.config.coarse_dimension();
let fine_dim = self.config.default_dimension;
let coarse_k = k * self.config.cascade_expansion;
let coarse_results = self.search_at_dimension(query, coarse_k, coarse_dim);
self.stats.avg_coarse_candidates = (self.stats.avg_coarse_candidates
* (self.stats.cascaded_searches - 1) as f64
+ coarse_results.len() as f64)
/ self.stats.cascaded_searches as f64;
let query_fine = &query[..fine_dim.min(query.len())];
let mut refined: Vec<(u32, f32)> = coarse_results
.iter()
.filter_map(|(doc_id, _)| {
let idx = self.doc_ids.iter().position(|&id| id == *doc_id)?;
let emb = &self.embeddings[idx];
let dist = crate::distance::l2_distance(query_fine, emb.at_dimension(fine_dim));
Some((*doc_id, dist))
})
.collect();
refined.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
refined.truncate(k);
let full_computations = self.embeddings.len() as u64;
let actual_computations = self.embeddings.len() as u64 + coarse_results.len() as u64;
if full_computations > actual_computations {
self.stats.distance_computations_saved += full_computations - actual_computations;
}
refined
}
pub fn search(&mut self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
self.stats.total_searches += 1;
if self.config.cascaded_search && self.embeddings.len() > 1000 {
self.search_cascaded(query, k)
} else {
self.search_at_dimension(query, k, self.config.default_dimension)
}
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn stats(&self) -> &MatryoshkaStats {
&self.stats
}
}
#[derive(Debug)]
pub struct AdaptiveDimensionSelector {
dimensions: Vec<usize>,
accuracy_estimates: Vec<f32>,
latency_estimates: Vec<f32>,
target_accuracy: f32,
}
impl AdaptiveDimensionSelector {
pub fn new(dimensions: Vec<usize>) -> Self {
let _n = dimensions.len();
let accuracy_estimates: Vec<f32> = dimensions
.iter()
.map(|&d| (d as f32).sqrt() / (dimensions.last().copied().unwrap_or(1) as f32).sqrt())
.collect();
let latency_estimates: Vec<f32> = dimensions
.iter()
.map(|&d| d as f32 / dimensions.last().copied().unwrap_or(1) as f32)
.collect();
Self {
dimensions,
accuracy_estimates,
latency_estimates,
target_accuracy: 0.95,
}
}
pub fn with_target_accuracy(mut self, accuracy: f32) -> Self {
self.target_accuracy = accuracy;
self
}
pub fn calibrate(&mut self, dimension_idx: usize, observed_accuracy: f32) {
if dimension_idx < self.accuracy_estimates.len() {
self.accuracy_estimates[dimension_idx] =
0.9 * self.accuracy_estimates[dimension_idx] + 0.1 * observed_accuracy;
}
}
pub fn select_dimension(&self, speed_preference: f32) -> usize {
let mut best_dim = *self.dimensions.last().unwrap_or(&768);
let mut best_score = f32::NEG_INFINITY;
for (i, &dim) in self.dimensions.iter().enumerate() {
let accuracy = self.accuracy_estimates.get(i).copied().unwrap_or(0.5);
let latency = self.latency_estimates.get(i).copied().unwrap_or(1.0);
if accuracy < self.target_accuracy {
continue;
}
let score = (1.0 - speed_preference) * accuracy - speed_preference * latency;
if score > best_score {
best_score = score;
best_dim = dim;
}
}
best_dim
}
pub fn minimum_dimension(&self) -> usize {
for (i, &dim) in self.dimensions.iter().enumerate() {
if self.accuracy_estimates.get(i).copied().unwrap_or(0.0) >= self.target_accuracy {
return dim;
}
}
*self.dimensions.last().unwrap_or(&768)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn make_embedding(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| ((seed as f32 * 0.1 + i as f32) * 0.01).sin())
.collect()
}
#[test]
fn test_matryoshka_embedding_truncation() {
let config = MatryoshkaConfig::default();
let data = make_embedding(768, 42);
let emb = MatryoshkaEmbedding::new(data.clone(), config).unwrap();
assert_eq!(emb.at_dimension(64).len(), 64);
assert_eq!(emb.at_dimension(256).len(), 256);
assert_eq!(emb.at_dimension(768).len(), 768);
assert_eq!(emb.at_dimension(1000).len(), 768); }
#[test]
fn test_matryoshka_index_search() {
let config = MatryoshkaConfig::default();
let mut index = MatryoshkaIndex::new(config.clone());
for i in 0..100 {
let emb = make_embedding(768, i);
index.add(i, emb).unwrap();
}
let query = make_embedding(768, 50);
let results = index.search_at_dimension(&query, 5, 256);
assert_eq!(results.len(), 5);
for i in 1..results.len() {
assert!(
results[i - 1].1 <= results[i].1,
"Results should be sorted by distance"
);
}
}
#[test]
fn test_cascaded_search() {
let config = MatryoshkaConfig {
cascaded_search: true,
..Default::default()
};
let mut index = MatryoshkaIndex::new(config);
for i in 0..2000 {
let emb = make_embedding(768, i);
index.add(i, emb).unwrap();
}
let query = make_embedding(768, 1000);
let results = index.search_cascaded(&query, 10);
assert_eq!(results.len(), 10);
for i in 1..results.len() {
assert!(
results[i - 1].1 <= results[i].1,
"Results should be sorted by distance"
);
}
assert!(index.stats.cascaded_searches > 0);
}
#[test]
fn test_dimension_selector() {
let dims = vec![64, 128, 256, 512, 768];
let selector = AdaptiveDimensionSelector::new(dims).with_target_accuracy(0.9);
let fast_dim = selector.select_dimension(0.8);
let accurate_dim = selector.select_dimension(0.2);
assert!(fast_dim <= accurate_dim);
}
#[test]
fn test_config_presets() {
let openai = MatryoshkaConfig::openai_style();
assert_eq!(openai.full_dimension, 3072);
let cohere = MatryoshkaConfig::cohere_style();
assert_eq!(cohere.full_dimension, 1024);
let st = MatryoshkaConfig::sentence_transformers();
assert_eq!(st.full_dimension, 768);
}
}