use std::hash::{Hash, Hasher};
use iqdb_types::{DistanceMetric, Filter, SearchParams};
#[derive(Clone, Debug)]
pub(crate) struct ResultKey {
query: Box<[f32]>,
k: usize,
ef: Option<usize>,
metric: DistanceMetric,
filter: Option<Filter>,
}
impl ResultKey {
pub(crate) fn new(query: &[f32], params: &SearchParams) -> Self {
Self {
query: Box::from(query),
k: params.k,
ef: params.ef,
metric: params.metric,
filter: params.filter.clone(),
}
}
}
impl PartialEq for ResultKey {
fn eq(&self, other: &Self) -> bool {
self.k == other.k
&& self.ef == other.ef
&& self.metric == other.metric
&& self.query.len() == other.query.len()
&& self
.query
.iter()
.zip(other.query.iter())
.all(|(a, b)| a.to_bits() == b.to_bits())
&& self.filter == other.filter
}
}
impl Eq for ResultKey {}
impl Hash for ResultKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.k.hash(state);
self.ef.hash(state);
self.metric.hash(state);
for component in self.query.iter() {
state.write_u32(component.to_bits());
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::collections::hash_map::DefaultHasher;
use iqdb_types::Value;
use super::*;
fn hash_of(key: &ResultKey) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
#[test]
fn identical_searches_are_equal_and_hash_equal() {
let params = SearchParams::new(5, DistanceMetric::Cosine);
let a = ResultKey::new(&[1.0, 2.0, 3.0], ¶ms);
let b = ResultKey::new(&[1.0, 2.0, 3.0], ¶ms);
assert_eq!(a, b);
assert_eq!(hash_of(&a), hash_of(&b));
}
#[test]
fn different_query_differs() {
let params = SearchParams::new(5, DistanceMetric::Cosine);
let a = ResultKey::new(&[1.0, 2.0, 3.0], ¶ms);
let b = ResultKey::new(&[1.0, 2.0, 3.5], ¶ms);
assert_ne!(a, b);
}
#[test]
fn different_k_differs() {
let a = ResultKey::new(&[1.0], &SearchParams::new(5, DistanceMetric::Cosine));
let b = ResultKey::new(&[1.0], &SearchParams::new(6, DistanceMetric::Cosine));
assert_ne!(a, b);
}
#[test]
fn different_metric_differs() {
let a = ResultKey::new(&[1.0], &SearchParams::new(5, DistanceMetric::Cosine));
let b = ResultKey::new(&[1.0], &SearchParams::new(5, DistanceMetric::Euclidean));
assert_ne!(a, b);
}
#[test]
fn different_filter_differs() {
let with_filter = SearchParams {
filter: Some(Filter::eq("k", Value::Bool(true))),
..SearchParams::new(5, DistanceMetric::Cosine)
};
let a = ResultKey::new(&[1.0], &SearchParams::new(5, DistanceMetric::Cosine));
let b = ResultKey::new(&[1.0], &with_filter);
assert_ne!(a, b);
}
#[test]
fn different_length_differs() {
let params = SearchParams::new(5, DistanceMetric::Cosine);
let a = ResultKey::new(&[1.0, 2.0], ¶ms);
let b = ResultKey::new(&[1.0, 2.0, 3.0], ¶ms);
assert_ne!(a, b);
}
#[test]
fn negative_zero_distinct_from_zero() {
let params = SearchParams::new(1, DistanceMetric::Cosine);
let a = ResultKey::new(&[0.0], ¶ms);
let b = ResultKey::new(&[-0.0], ¶ms);
assert_ne!(a, b);
}
}