use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::sparse::{sparse_dot_product, SparseId, SparseStorage, SparseVector};
#[derive(Clone, Debug)]
pub struct SparseSearchResult {
pub id: SparseId,
pub score: f32,
}
impl SparseSearchResult {
#[inline]
#[must_use]
pub fn new(id: SparseId, score: f32) -> Self {
Self { id, score }
}
}
#[derive(Clone, Debug)]
struct MinHeapEntry {
id: SparseId,
score: f32,
}
impl PartialEq for MinHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.score == other.score && self.id == other.id
}
}
impl Eq for MinHeapEntry {}
impl PartialOrd for MinHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
match other.score.partial_cmp(&self.score) {
Some(ord) => ord,
None => Ordering::Equal, }
}
}
pub struct SparseSearcher<'a> {
storage: &'a SparseStorage,
}
impl<'a> SparseSearcher<'a> {
#[inline]
#[must_use]
pub fn new(storage: &'a SparseStorage) -> Self {
Self { storage }
}
#[must_use]
pub fn search(&self, query: &SparseVector, k: usize) -> Vec<SparseSearchResult> {
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<MinHeapEntry> = BinaryHeap::with_capacity(k + 1);
for (id, vector) in self.storage {
let score = sparse_dot_product(query, &vector);
if score <= 0.0 {
continue;
}
if heap.len() < k {
heap.push(MinHeapEntry { id, score });
} else if let Some(min_entry) = heap.peek() {
if score > min_entry.score {
heap.pop();
heap.push(MinHeapEntry { id, score });
}
}
}
let mut results: Vec<SparseSearchResult> = heap
.into_iter()
.map(|entry| SparseSearchResult::new(entry.id, entry.score))
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results
}
#[must_use]
pub fn search_raw(&self, query: &SparseVector, k: usize) -> Vec<(SparseId, f32)> {
self.search(query, k)
.into_iter()
.map(|r| (r.id, r.score))
.collect()
}
#[must_use]
pub fn search_u64(&self, query: &SparseVector, k: usize) -> Vec<(u64, f32)> {
self.search(query, k)
.into_iter()
.map(|r| (r.id.as_u64(), r.score))
.collect()
}
#[inline]
#[must_use]
pub fn storage(&self) -> &SparseStorage {
self.storage
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_storage() -> SparseStorage {
let mut storage = SparseStorage::new();
let v0 = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
storage.insert(&v0).unwrap();
let v1 = SparseVector::new(vec![5, 10, 20], vec![0.5, 1.5, 2.0], 100).unwrap();
storage.insert(&v1).unwrap();
let v2 = SparseVector::new(vec![30, 40, 50], vec![1.0, 1.0, 1.0], 100).unwrap();
storage.insert(&v2).unwrap();
let v3 = SparseVector::new(vec![0], vec![5.0], 100).unwrap();
storage.insert(&v3).unwrap();
storage
}
#[test]
fn test_search_basic() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 10);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id.as_u64(), 3);
assert!((results[0].score - 5.0).abs() < 1e-6);
assert_eq!(results[1].id.as_u64(), 0);
assert!((results[1].score - 3.0).abs() < 1e-6);
assert_eq!(results[2].id.as_u64(), 1);
assert!((results[2].score - 0.5).abs() < 1e-6);
}
#[test]
fn test_search_empty_storage() {
let storage = SparseStorage::new();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_search_k_zero() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 0);
assert!(results.is_empty());
}
#[test]
fn test_search_k_larger_than_count() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 1000);
assert_eq!(results.len(), 3);
}
#[test]
fn test_search_k_equals_one() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id.as_u64(), 3); }
#[test]
fn test_search_skips_deleted() {
let mut storage = create_test_storage();
let query = SparseVector::new(vec![0], vec![1.0], 100).unwrap();
let searcher = SparseSearcher::new(&storage);
let results_before = searcher.search(&query, 10);
assert!(results_before.iter().any(|r| r.id.as_u64() == 3));
storage.delete(SparseId::new(3)).unwrap();
let searcher = SparseSearcher::new(&storage);
let results_after = searcher.search(&query, 10);
assert!(!results_after.iter().any(|r| r.id.as_u64() == 3));
}
#[test]
fn test_search_ordering_descending() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5, 10], vec![1.0, 1.0, 1.0], 100).unwrap();
let results = searcher.search(&query, 10);
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results not sorted: {} < {}",
results[i - 1].score,
results[i].score
);
}
}
#[test]
fn test_search_raw_format() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search_raw(&query, 10);
assert!(!results.is_empty());
assert_eq!(results[0].0.as_u64(), 3);
}
#[test]
fn test_search_u64_format() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 5], vec![1.0, 1.0], 100).unwrap();
let results = searcher.search_u64(&query, 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, 3u64);
}
#[test]
fn test_search_no_overlap_returns_empty() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![99], vec![1.0], 100).unwrap();
let results = searcher.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_search_score_correctness() {
let mut storage = SparseStorage::new();
let v0 = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0], 10).unwrap();
let v1 = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0], 10).unwrap();
storage.insert(&v0).unwrap();
storage.insert(&v1).unwrap();
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![1, 2], vec![1.0, 1.0], 10).unwrap();
let results = searcher.search(&query, 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id.as_u64(), 1); assert!((results[0].score - 9.0).abs() < 1e-6);
assert_eq!(results[1].id.as_u64(), 0);
assert!((results[1].score - 5.0).abs() < 1e-6);
}
#[test]
fn test_search_many_vectors() {
let mut storage = SparseStorage::new();
for i in 0..1000 {
let v = SparseVector::new(vec![0, u32::try_from(i + 1).unwrap()], vec![1.0, 1.0], 2000)
.unwrap();
storage.insert(&v).unwrap();
}
let searcher = SparseSearcher::new(&storage);
let query = SparseVector::new(vec![0, 1], vec![1.0, 2.0], 2000).unwrap();
let results = searcher.search(&query, 10);
assert_eq!(results.len(), 10);
assert_eq!(results[0].id.as_u64(), 0);
assert!((results[0].score - 3.0).abs() < 1e-6);
}
#[test]
fn test_storage_accessor() {
let storage = create_test_storage();
let searcher = SparseSearcher::new(&storage);
assert_eq!(searcher.storage().len(), 4);
}
}