use std::collections::BinaryHeap;
use crate::distance::distance;
use nodedb_types::vector_distance::DistanceMetric;
#[derive(Debug, Clone)]
pub struct MatryoshkaSpec {
pub truncation_dims: Vec<u32>,
}
impl MatryoshkaSpec {
pub fn new(mut truncation_dims: Vec<u32>) -> Self {
truncation_dims.sort_unstable();
truncation_dims.dedup();
Self { truncation_dims }
}
pub fn pick(&self, requested: Option<u32>) -> u32 {
let Some(req) = requested else {
return *self.truncation_dims.last().copied().get_or_insert(0);
};
self.truncation_dims
.iter()
.rev()
.find(|&&d| d <= req)
.copied()
.unwrap_or_else(|| self.truncation_dims.first().copied().unwrap_or(req))
}
pub fn is_valid(&self, dim: u32) -> bool {
self.truncation_dims.contains(&dim)
}
}
#[inline]
pub fn truncate(v: &[f32], dim: usize) -> &[f32] {
&v[..dim.min(v.len())]
}
pub struct MatryoshkaSearchOptions {
pub coarse_dim: u32,
pub full_dim: u32,
pub oversample: u8,
pub k: usize,
}
#[derive(PartialEq)]
struct HeapEntry {
dist: f32,
id: u32,
vec_idx: usize,
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist
.partial_cmp(&other.dist)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
pub fn matryoshka_search<'a, I>(
candidates: I,
query: &[f32],
options: &MatryoshkaSearchOptions,
metric: DistanceMetric,
) -> Vec<(u32, f32)>
where
I: Iterator<Item = (u32, &'a [f32])>,
{
let coarse = options.coarse_dim as usize;
let full = options.full_dim as usize;
let pool_size = (options.oversample as usize).max(1) * options.k.max(1);
let query_coarse = truncate(query, coarse);
let mut coarse_heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pool_size + 1);
let mut survivor_vecs: Vec<Vec<f32>> = Vec::with_capacity(pool_size);
for (id, vec) in candidates {
let vec_coarse = truncate(vec, coarse);
let d = distance(query_coarse, vec_coarse, metric);
let should_insert = coarse_heap.len() < pool_size
|| coarse_heap
.peek()
.map(|worst| d < worst.dist)
.unwrap_or(true);
if should_insert {
let vec_idx = survivor_vecs.len();
survivor_vecs.push(vec[..full.min(vec.len())].to_vec());
coarse_heap.push(HeapEntry {
dist: d,
id,
vec_idx,
});
if coarse_heap.len() > pool_size {
coarse_heap.pop();
}
}
}
let query_full = truncate(query, full);
let mut reranked: Vec<(u32, f32)> = coarse_heap
.into_iter()
.map(|entry| {
let full_vec = &survivor_vecs[entry.vec_idx];
let d_full = distance(query_full, full_vec.as_slice(), metric);
(entry.id, d_full)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
reranked.truncate(options.k);
reranked
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pick_largest_leq_requested() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert_eq!(spec.pick(Some(300)), 256);
}
#[test]
fn pick_exact_match() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert_eq!(spec.pick(Some(512)), 512);
}
#[test]
fn pick_none_returns_full_dim() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert_eq!(spec.pick(None), 1024);
}
#[test]
fn pick_smaller_than_all_returns_smallest() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert_eq!(spec.pick(Some(10)), 256);
}
#[test]
fn is_valid_known_dim() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert!(spec.is_valid(512));
}
#[test]
fn is_valid_unknown_dim() {
let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
assert!(!spec.is_valid(100));
}
#[test]
fn truncate_clips_to_requested_dim() {
let v: Vec<f32> = (0..1536).map(|i| i as f32).collect();
let t = truncate(&v, 256);
assert_eq!(t.len(), 256);
assert_eq!(t[0], 0.0);
assert_eq!(t[255], 255.0);
}
#[test]
fn truncate_does_not_exceed_vec_len() {
let v = vec![1.0f32; 10];
let t = truncate(&v, 9999);
assert_eq!(t.len(), 10);
}
fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32 * 0.01).sin())
.collect()
})
.collect()
}
#[test]
fn search_returns_k_results() {
let vecs = make_vecs(100, 128);
let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
let candidates = vecs
.iter()
.enumerate()
.map(|(i, v)| (i as u32, v.as_slice()));
let opts = MatryoshkaSearchOptions {
coarse_dim: 64,
full_dim: 128,
oversample: 3,
k: 10,
};
let results = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
assert_eq!(results.len(), 10, "expected exactly k=10 results");
}
#[test]
fn coarse_equal_to_full_matches_direct_search() {
let vecs = make_vecs(100, 128);
let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
let mut direct: Vec<(u32, f32)> = vecs
.iter()
.enumerate()
.map(|(i, v)| (i as u32, distance(&query, v.as_slice(), DistanceMetric::L2)))
.collect();
direct.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
direct.truncate(10);
let candidates = vecs
.iter()
.enumerate()
.map(|(i, v)| (i as u32, v.as_slice()));
let opts = MatryoshkaSearchOptions {
coarse_dim: 128,
full_dim: 128,
oversample: 1,
k: 10,
};
let mrl = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
let direct_ids: Vec<u32> = direct.iter().map(|(id, _)| *id).collect();
let mrl_ids: Vec<u32> = mrl.iter().map(|(id, _)| *id).collect();
assert_eq!(
direct_ids, mrl_ids,
"coarse==full should equal direct search"
);
}
}