use crate::types::Key;
use crate::vector::{score, Metric};
use crate::Result;
#[derive(Debug)]
pub struct Candidate<'a> {
pub key: &'a Key,
pub vector: &'a [f32],
}
#[derive(Debug, Clone, PartialEq)]
pub struct ScoredItem {
pub key: Key,
pub score: f32,
}
pub fn search_flat<'a, F>(
query: &[f32],
metric: Metric,
top_k: usize,
candidates: impl IntoIterator<Item = Candidate<'a>>,
mut filter: Option<F>,
) -> Result<Vec<ScoredItem>>
where
F: FnMut(&Candidate<'a>) -> bool,
{
if top_k == 0 {
return Ok(Vec::new());
}
let mut results = Vec::new();
for cand in candidates {
if let Some(ref mut pred) = filter {
if !pred(&cand) {
continue;
}
}
let s = score(metric, query, cand.vector)?;
results.push(ScoredItem {
key: cand.key.clone(),
score: s,
});
}
results.sort_by(|a, b| b.score.total_cmp(&a.score).then_with(|| a.key.cmp(&b.key)));
if results.len() > top_k {
results.truncate(top_k);
}
Ok(results)
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
fn key(bytes: &[u8]) -> Key {
bytes.to_vec()
}
#[test]
fn respects_filter_before_scoring() {
let query = [1.0, 0.0];
let ka = key(b"a");
let kb = key(b"b");
let items = vec![
Candidate {
key: &ka,
vector: &[1.0, 0.0],
},
Candidate {
key: &kb,
vector: &[0.0, 1.0],
},
];
let res = search_flat(
&query,
Metric::Cosine,
10,
items,
Some(|c: &Candidate| c.key != b"b"),
)
.unwrap();
assert_eq!(res.len(), 1);
assert_eq!(res[0].key, b"a");
}
#[test]
fn orders_by_score_then_key() {
let query = [1.0, 0.0];
let kb = key(b"b");
let ka = key(b"a");
let items = vec![
Candidate {
key: &kb,
vector: &[1.0, 0.0],
},
Candidate {
key: &ka,
vector: &[1.0, 0.0],
},
];
let res = search_flat(
&query,
Metric::Cosine,
10,
items,
None::<fn(&Candidate) -> bool>,
)
.unwrap();
assert_eq!(res[0].key, b"a");
assert_eq!(res[1].key, b"b");
}
#[test]
fn switches_metric() {
let query = [1.0, 0.0];
let ka = key(b"a");
let kb = key(b"b");
let items = vec![
Candidate {
key: &ka,
vector: &[2.0, 0.0],
},
Candidate {
key: &kb,
vector: &[0.0, 2.0],
},
];
let res =
search_flat(&query, Metric::L2, 1, items, None::<fn(&Candidate) -> bool>).unwrap();
assert_eq!(res[0].key, b"a");
}
#[test]
fn enforces_dimension_match() {
let query = [1.0, 0.0];
let ka = key(b"a");
let items = vec![Candidate {
key: &ka,
vector: &[1.0, 0.0, 1.0],
}];
use crate::Error;
let err = search_flat(
&query,
Metric::Cosine,
1,
items,
None::<fn(&Candidate) -> bool>,
)
.unwrap_err();
assert!(matches!(err, Error::DimensionMismatch { .. }));
}
}