use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::backend::Backend;
use crate::error::Result;
use crate::payload::Payload;
use crate::record::{Record, RecordId};
use crate::vector::{DistanceMetric, Vector};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SearchResult {
pub id: RecordId,
pub score: f32,
pub payload: Option<Payload>,
}
pub(crate) fn flat_search<F>(
backend: &Backend,
query: &Vector,
k: usize,
metric: DistanceMetric,
filter: F,
) -> Result<Vec<SearchResult>>
where
F: Fn(&Record) -> bool,
{
if k == 0 {
return Ok(Vec::new());
}
backend.with_records(|records| {
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
for record in records.values() {
let score = metric.distance(query, record.vector())?;
if !filter(record) {
continue;
}
let candidate = HeapEntry {
score,
id: record.id(),
};
if heap.len() < k {
heap.push(candidate);
} else if let Some(worst) = heap.peek() {
if compare_score(candidate.score, worst.score) == Ordering::Less {
let _ = heap.pop();
heap.push(candidate);
}
}
}
let mut results: Vec<SearchResult> = heap
.into_iter()
.map(|entry| {
let payload = records.get(&entry.id).and_then(|r| r.payload().cloned());
SearchResult {
id: entry.id,
score: entry.score,
payload,
}
})
.collect();
results.sort_by(|a, b| match compare_score(a.score, b.score) {
Ordering::Equal => a.id.cmp(&b.id),
other => other,
});
Ok(results)
})
}
#[derive(Debug, Clone, Copy)]
struct HeapEntry {
score: f32,
id: RecordId,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
compare_score(self.score, other.score) == Ordering::Equal && self.id == other.id
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
match compare_score(self.score, other.score) {
Ordering::Equal => self.id.cmp(&other.id),
other_ord => other_ord,
}
}
}
#[inline]
fn compare_score(a: f32, b: f32) -> Ordering {
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
(false, false) => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::MemoryStore;
fn build_backend(records: &[(u64, Vec<f32>)]) -> Backend {
let store = MemoryStore::new();
for (id, components) in records {
let v = Vector::new(components.clone()).expect("finite");
store
.upsert(Record::new(RecordId::new(*id), v))
.expect("upsert ok");
}
Backend::Memory(store)
}
fn always_accept(_record: &Record) -> bool {
true
}
#[test]
fn k_zero_returns_empty() {
let backend = build_backend(&[(1, vec![1.0, 0.0])]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 0, DistanceMetric::L2, always_accept).unwrap();
assert!(out.is_empty());
}
#[test]
fn empty_store_returns_empty() {
let backend = Backend::Memory(MemoryStore::new());
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 5, DistanceMetric::L2, always_accept).unwrap();
assert!(out.is_empty());
}
#[test]
fn returns_at_most_k_results() {
let backend = build_backend(&[
(1, vec![1.0, 0.0]),
(2, vec![0.0, 1.0]),
(3, vec![0.5, 0.5]),
(4, vec![1.0, 1.0]),
]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 2, DistanceMetric::L2, always_accept).unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn k_larger_than_store_returns_all() {
let backend = build_backend(&[(1, vec![1.0, 0.0]), (2, vec![0.0, 1.0])]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 100, DistanceMetric::L2, always_accept).unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn results_are_sorted_ascending_by_score() {
let backend = build_backend(&[
(1, vec![10.0, 0.0]),
(2, vec![1.0, 0.0]),
(3, vec![5.0, 0.0]),
]);
let q = Vector::new(vec![0.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 3, DistanceMetric::L2, always_accept).unwrap();
assert_eq!(out.len(), 3);
assert_eq!(out[0].id, RecordId::new(2));
assert_eq!(out[1].id, RecordId::new(3));
assert_eq!(out[2].id, RecordId::new(1));
assert!(out[0].score <= out[1].score);
assert!(out[1].score <= out[2].score);
}
#[test]
fn filter_excludes_records_before_admission() {
let backend = build_backend(&[
(1, vec![1.0, 0.0]),
(2, vec![0.99, 0.0]),
(3, vec![0.5, 0.0]),
]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 2, DistanceMetric::L2, |r| r.id().get() != 1).unwrap();
assert_eq!(out.len(), 2);
assert!(out.iter().all(|r| r.id != RecordId::new(1)));
}
#[test]
fn dimension_mismatch_returns_error() {
let backend = build_backend(&[(1, vec![1.0, 0.0])]);
let q = Vector::new(vec![1.0, 0.0, 0.0]).unwrap();
let err = flat_search(&backend, &q, 1, DistanceMetric::L2, always_accept).unwrap_err();
assert!(matches!(
err,
crate::Error::DimensionMismatch { left: 3, right: 2 }
));
}
#[test]
fn cosine_against_zero_vector_yields_nan_at_tail() {
let backend = build_backend(&[
(1, vec![1.0, 0.0]),
(2, vec![0.0, 0.0]), (3, vec![0.5, 0.5]),
]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 3, DistanceMetric::Cosine, always_accept).unwrap();
assert_eq!(out.len(), 3);
assert_eq!(out[2].id, RecordId::new(2));
assert!(out[2].score.is_nan());
assert!(out[0].score.is_finite());
assert!(out[1].score.is_finite());
assert!(out[0].score <= out[1].score);
}
#[test]
fn tie_break_by_id_is_deterministic() {
let backend = build_backend(&[
(10, vec![1.0, 0.0]),
(5, vec![1.0, 0.0]),
(20, vec![1.0, 0.0]),
]);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 3, DistanceMetric::L2, always_accept).unwrap();
let ids: Vec<u64> = out.iter().map(|r| r.id.get()).collect();
assert_eq!(ids, vec![5, 10, 20]);
}
#[test]
fn payload_is_attached_to_results() {
let store = MemoryStore::new();
let mut payload = crate::Payload::new();
payload.insert("kind", "doc");
store
.upsert(Record::with_payload(
RecordId::new(1),
Vector::new(vec![1.0, 0.0]).unwrap(),
payload,
))
.unwrap();
let backend = Backend::Memory(store);
let q = Vector::new(vec![1.0, 0.0]).unwrap();
let out = flat_search(&backend, &q, 1, DistanceMetric::L2, always_accept).unwrap();
let hit = &out[0];
let attached = hit.payload.as_ref().expect("payload attached");
assert!(attached.contains_key("kind"));
}
}