use std::collections::HashMap;
use std::sync::Arc;
use crate::core::{Id, Point};
use crate::core::proximity::Proximity;
use crate::ports::{Near, NearError, NearResult, SearchResult};
pub struct FlatIndex {
points: HashMap<Id, Point>,
dimensionality: usize,
proximity: Arc<dyn Proximity>,
higher_is_better: bool,
}
impl FlatIndex {
pub fn new(
dimensionality: usize,
proximity: Arc<dyn Proximity>,
higher_is_better: bool,
) -> Self {
Self {
points: HashMap::new(),
dimensionality,
proximity,
higher_is_better,
}
}
pub fn cosine(dimensionality: usize) -> Self {
use crate::core::proximity::Cosine;
Self::new(dimensionality, Arc::new(Cosine), true)
}
pub fn euclidean(dimensionality: usize) -> Self {
use crate::core::proximity::Euclidean;
Self::new(dimensionality, Arc::new(Euclidean), false)
}
fn sort_results(&self, results: &mut Vec<SearchResult>) {
if self.higher_is_better {
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
} else {
results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
}
}
}
impl Near for FlatIndex {
fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
if query.dimensionality() != self.dimensionality {
return Err(NearError::DimensionalityMismatch {
expected: self.dimensionality,
got: query.dimensionality(),
});
}
let mut results: Vec<SearchResult> = self
.points
.iter()
.map(|(id, point)| {
let score = self.proximity.proximity(query, point);
SearchResult::new(*id, score)
})
.collect();
self.sort_results(&mut results);
results.truncate(k);
Ok(results)
}
fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
if query.dimensionality() != self.dimensionality {
return Err(NearError::DimensionalityMismatch {
expected: self.dimensionality,
got: query.dimensionality(),
});
}
let mut results: Vec<SearchResult> = self
.points
.iter()
.filter_map(|(id, point)| {
let score = self.proximity.proximity(query, point);
let within = if self.higher_is_better {
score >= threshold
} else {
score <= threshold
};
if within {
Some(SearchResult::new(*id, score))
} else {
None
}
})
.collect();
self.sort_results(&mut results);
Ok(results)
}
fn add(&mut self, id: Id, point: &Point) -> NearResult<()> {
if point.dimensionality() != self.dimensionality {
return Err(NearError::DimensionalityMismatch {
expected: self.dimensionality,
got: point.dimensionality(),
});
}
self.points.insert(id, point.clone());
Ok(())
}
fn remove(&mut self, id: Id) -> NearResult<()> {
self.points.remove(&id);
Ok(())
}
fn rebuild(&mut self) -> NearResult<()> {
Ok(())
}
fn is_ready(&self) -> bool {
true }
fn len(&self) -> usize {
self.points.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_index() -> FlatIndex {
let mut index = FlatIndex::cosine(3);
let points = vec![
(Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])),
(Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])),
(Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])),
(Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()),
];
for (id, point) in points {
index.add(id, &point).unwrap();
}
index
}
#[test]
fn test_flat_index_near() {
let index = setup_index();
let query = Point::new(vec![1.0, 0.0, 0.0]);
let results = index.near(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, Id::from_bytes([1; 16]));
assert!((results[0].score - 1.0).abs() < 0.0001);
}
#[test]
fn test_flat_index_within_cosine() {
let index = setup_index();
let query = Point::new(vec![1.0, 0.0, 0.0]);
let results = index.within(&query, 0.5).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_flat_index_euclidean() {
let mut index = FlatIndex::euclidean(2);
index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap();
index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap();
index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap();
let query = Point::new(vec![0.0, 0.0]);
let results = index.near(&query, 2).unwrap();
assert_eq!(results[0].id, Id::from_bytes([1; 16]));
assert!((results[0].score - 0.0).abs() < 0.0001);
assert_eq!(results[1].id, Id::from_bytes([2; 16]));
assert!((results[1].score - 1.0).abs() < 0.0001);
}
#[test]
fn test_flat_index_add_remove() {
let mut index = FlatIndex::cosine(3);
let id = Id::from_bytes([1; 16]);
let point = Point::new(vec![1.0, 0.0, 0.0]);
index.add(id, &point).unwrap();
assert_eq!(index.len(), 1);
index.remove(id).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_flat_index_dimensionality_check() {
let mut index = FlatIndex::cosine(3);
let wrong_dims = Point::new(vec![1.0, 0.0]); let result = index.add(Id::now(), &wrong_dims);
match result {
Err(NearError::DimensionalityMismatch { expected, got }) => {
assert_eq!(expected, 3);
assert_eq!(got, 2);
}
_ => panic!("Expected DimensionalityMismatch error"),
}
}
#[test]
fn test_flat_index_ready() {
let index = FlatIndex::cosine(3);
assert!(index.is_ready());
}
}