use ahash::AHashMap;
use ordered_float::OrderedFloat;
use rayon::prelude::*;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use crate::error::{Error, Result};
use super::traits::{DistanceType, IndexConfig, SearchResult, VectorIndex};
#[derive(Debug)]
pub struct FlatIndex {
config: IndexConfig,
vectors: AHashMap<String, Vec<f32>>,
parallel: bool,
parallel_threshold: usize,
}
impl FlatIndex {
#[must_use]
pub fn new(config: IndexConfig) -> Self {
Self {
config,
vectors: AHashMap::new(),
parallel: true,
parallel_threshold: 1000,
}
}
#[must_use]
pub fn with_capacity(config: IndexConfig, capacity: usize) -> Self {
Self {
config,
vectors: AHashMap::with_capacity(capacity),
parallel: true,
parallel_threshold: 1000,
}
}
#[must_use]
pub const fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
#[must_use]
pub const fn with_parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.distance_type {
DistanceType::L2 => Self::l2_distance(a, b),
DistanceType::InnerProduct => Self::inner_product(a, b),
DistanceType::Cosine => Self::cosine_similarity(a, b),
}
}
#[inline]
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[inline]
fn inner_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn normalize(vector: &mut [f32]) {
let norm: f32 = vector.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
if norm > 0.0 {
for x in vector.iter_mut() {
*x /= norm;
}
}
}
fn search_sequential(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
let mut heap: BinaryHeap<(Reverse<OrderedFloat<f32>>, String)> = BinaryHeap::new();
let is_similarity = matches!(
self.config.distance_type,
DistanceType::InnerProduct | DistanceType::Cosine
);
for (id, vector) in &self.vectors {
let dist = self.compute_distance(query, vector);
let key = if is_similarity {
Reverse(OrderedFloat(-dist))
} else {
Reverse(OrderedFloat(dist))
};
if heap.len() < k {
heap.push((key, id.clone()));
} else if let Some((top_key, _)) = heap.peek() {
if key > *top_key {
heap.pop();
heap.push((key, id.clone()));
}
}
}
let mut results: Vec<_> = heap
.into_iter()
.map(|(Reverse(OrderedFloat(dist)), id)| {
let actual_dist = if is_similarity { -dist } else { dist };
SearchResult::new(id, actual_dist, self.config.distance_type)
})
.collect();
if is_similarity {
results.sort_by(|a, b| b.distance.partial_cmp(&a.distance).unwrap());
} else {
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
}
results
}
fn search_parallel(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
let is_similarity = matches!(
self.config.distance_type,
DistanceType::InnerProduct | DistanceType::Cosine
);
let mut distances: Vec<_> = self.vectors
.par_iter()
.map(|(id, vector)| {
let dist = self.compute_distance(query, vector);
(id.clone(), dist)
})
.collect();
if is_similarity {
distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
} else {
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
}
distances
.into_iter()
.take(k)
.map(|(id, dist)| SearchResult::new(id, dist, self.config.distance_type))
.collect()
}
}
impl VectorIndex for FlatIndex {
fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::InvalidQuery {
reason: format!(
"Dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
),
});
}
let mut vec = vector.to_vec();
if self.config.normalize {
Self::normalize(&mut vec);
}
self.vectors.insert(id, vec);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.config.dimension {
return Err(Error::InvalidQuery {
reason: format!(
"Query dimension mismatch: expected {}, got {}",
self.config.dimension,
query.len()
),
});
}
if self.vectors.is_empty() {
return Ok(vec![]);
}
let k = k.min(self.vectors.len());
let query = if self.config.normalize {
let mut q = query.to_vec();
Self::normalize(&mut q);
q
} else {
query.to_vec()
};
let results = if self.parallel && self.vectors.len() > self.parallel_threshold {
self.search_parallel(&query, k)
} else {
self.search_sequential(&query, k)
};
Ok(results)
}
fn remove(&mut self, id: &str) -> Result<bool> {
Ok(self.vectors.remove(id).is_some())
}
fn contains(&self, id: &str) -> bool {
self.vectors.contains_key(id)
}
fn len(&self) -> usize {
self.vectors.len()
}
fn dimension(&self) -> usize {
self.config.dimension
}
fn distance_type(&self) -> DistanceType {
self.config.distance_type
}
fn clear(&mut self) {
self.vectors.clear();
}
fn memory_usage(&self) -> usize {
let per_entry = self.config.dimension * 4 + 32 + 48;
self.vectors.len() * per_entry
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_index() -> FlatIndex {
let config = IndexConfig::new(4);
let mut index = FlatIndex::new(config).with_parallel(false);
index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
index.add("d".to_string(), &[0.5, 0.5, 0.0, 0.0]).unwrap();
index
}
#[test]
fn test_add_and_len() {
let index = create_test_index();
assert_eq!(index.len(), 4);
assert!(index.contains("a"));
assert!(!index.contains("z"));
}
#[test]
fn test_search_l2() {
let index = create_test_index();
let results = index.search(&[0.9, 0.1, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "a"); }
#[test]
fn test_search_cosine() {
let config = IndexConfig::new(4).with_distance(DistanceType::Cosine);
let mut index = FlatIndex::new(config).with_parallel(false);
index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add("b".to_string(), &[1.0, 1.0, 0.0, 0.0]).unwrap();
index.add("c".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results[0].id, "a"); assert!((results[0].distance - 1.0).abs() < 1e-6); }
#[test]
fn test_remove() {
let mut index = create_test_index();
assert!(index.remove("a").unwrap());
assert!(!index.contains("a"));
assert_eq!(index.len(), 3);
assert!(!index.remove("z").unwrap()); }
#[test]
fn test_dimension_mismatch() {
let mut index = create_test_index();
let result = index.add("e".to_string(), &[1.0, 2.0]); assert!(result.is_err());
}
#[test]
fn test_empty_search() {
let config = IndexConfig::new(4);
let index = FlatIndex::new(config);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_clear() {
let mut index = create_test_index();
assert_eq!(index.len(), 4);
index.clear();
assert!(index.is_empty());
}
}