use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorHit {
pub vector_id: u32,
pub distance: f32,
pub shard_id: u32,
pub doc_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardSearchResult {
pub shard_id: u32,
pub hits: Vec<VectorHit>,
pub success: bool,
pub error: Option<String>,
}
pub struct VectorMerger {
all_hits: Vec<VectorHit>,
responded: usize,
expected: usize,
}
impl VectorMerger {
pub fn new(expected_shards: usize) -> Self {
Self {
all_hits: Vec::new(),
responded: 0,
expected: expected_shards,
}
}
pub fn add_shard_result(&mut self, result: &ShardSearchResult) {
if result.success {
self.all_hits.extend_from_slice(&result.hits);
}
self.responded += 1;
}
pub fn all_responded(&self) -> bool {
self.responded >= self.expected
}
pub fn top_k(&mut self, top_k: usize) -> Vec<VectorHit> {
self.all_hits.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.shard_id.cmp(&b.shard_id))
});
self.all_hits.truncate(top_k);
self.all_hits.clone()
}
pub fn total_hits(&self) -> usize {
self.all_hits.len()
}
pub fn response_count(&self) -> usize {
self.responded
}
}
pub fn over_fetch_factor(has_filter: bool, estimated_pass_rate: f64) -> usize {
if !has_filter {
return 1;
}
if estimated_pass_rate > 0.5 { 2 } else { 3 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_two_shards() {
let mut merger = VectorMerger::new(2);
merger.add_shard_result(&ShardSearchResult {
shard_id: 0,
hits: vec![
VectorHit {
vector_id: 1,
distance: 0.1,
shard_id: 0,
doc_id: None,
},
VectorHit {
vector_id: 2,
distance: 0.3,
shard_id: 0,
doc_id: None,
},
VectorHit {
vector_id: 3,
distance: 0.5,
shard_id: 0,
doc_id: None,
},
],
success: true,
error: None,
});
merger.add_shard_result(&ShardSearchResult {
shard_id: 1,
hits: vec![
VectorHit {
vector_id: 10,
distance: 0.05,
shard_id: 1,
doc_id: None,
},
VectorHit {
vector_id: 11,
distance: 0.2,
shard_id: 1,
doc_id: None,
},
VectorHit {
vector_id: 12,
distance: 0.4,
shard_id: 1,
doc_id: None,
},
],
success: true,
error: None,
});
assert!(merger.all_responded());
let top3 = merger.top_k(3);
assert_eq!(top3.len(), 3);
assert_eq!(top3[0].vector_id, 10);
assert_eq!(top3[0].shard_id, 1);
assert!((top3[0].distance - 0.05).abs() < f32::EPSILON);
assert_eq!(top3[1].vector_id, 1);
assert_eq!(top3[2].vector_id, 11);
}
#[test]
fn merge_with_failed_shard() {
let mut merger = VectorMerger::new(2);
merger.add_shard_result(&ShardSearchResult {
shard_id: 0,
hits: vec![VectorHit {
vector_id: 1,
distance: 0.1,
shard_id: 0,
doc_id: None,
}],
success: true,
error: None,
});
merger.add_shard_result(&ShardSearchResult {
shard_id: 1,
hits: vec![],
success: false,
error: Some("timeout".into()),
});
assert!(merger.all_responded());
let top1 = merger.top_k(1);
assert_eq!(top1.len(), 1);
assert_eq!(top1[0].vector_id, 1);
}
#[test]
fn top_k_truncation() {
let mut merger = VectorMerger::new(1);
merger.add_shard_result(&ShardSearchResult {
shard_id: 0,
hits: (0..100)
.map(|i| VectorHit {
vector_id: i,
distance: i as f32 * 0.01,
shard_id: 0,
doc_id: None,
})
.collect(),
success: true,
error: None,
});
let top5 = merger.top_k(5);
assert_eq!(top5.len(), 5);
for (i, hit) in top5.iter().enumerate() {
assert_eq!(hit.vector_id, i as u32);
}
}
#[test]
fn over_fetch_no_filter() {
assert_eq!(over_fetch_factor(false, 1.0), 1);
}
#[test]
fn over_fetch_light_filter() {
assert_eq!(over_fetch_factor(true, 0.7), 2);
}
#[test]
fn over_fetch_heavy_filter() {
assert_eq!(over_fetch_factor(true, 0.2), 3);
}
#[test]
fn deterministic_tie_breaking() {
let mut merger = VectorMerger::new(2);
merger.add_shard_result(&ShardSearchResult {
shard_id: 0,
hits: vec![VectorHit {
vector_id: 1,
distance: 0.5,
shard_id: 0,
doc_id: None,
}],
success: true,
error: None,
});
merger.add_shard_result(&ShardSearchResult {
shard_id: 1,
hits: vec![VectorHit {
vector_id: 2,
distance: 0.5,
shard_id: 1,
doc_id: None,
}],
success: true,
error: None,
});
let top2 = merger.top_k(2);
assert_eq!(top2[0].shard_id, 0);
assert_eq!(top2[1].shard_id, 1);
}
}