use std::collections::HashMap;
use std::mem::size_of;
use std::sync::Arc;
use iqdb_filter::FilterEvaluator;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_types::{
DistanceMetric, Filter, Hit, IqdbError, Metadata, Result, SearchParams, VectorId,
};
use crate::topk;
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct FlatConfig;
#[derive(Debug)]
pub struct FlatIndex {
dim: usize,
metric: DistanceMetric,
vectors: Vec<Arc<[f32]>>,
ids: Vec<VectorId>,
metadata: Vec<Option<Metadata>>,
seqs: Vec<u64>,
next_seq: u64,
id_to_pos: HashMap<VectorId, usize>,
}
impl FlatIndex {
pub fn new_unconfigured(dim: usize, metric: DistanceMetric) -> Result<Self> {
if dim == 0 {
return Err(IqdbError::InvalidConfig {
reason: "FlatIndex dim must be greater than zero",
});
}
Ok(Self {
dim,
metric,
vectors: Vec::new(),
ids: Vec::new(),
metadata: Vec::new(),
seqs: Vec::new(),
next_seq: 0,
id_to_pos: HashMap::new(),
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[must_use]
pub fn len(&self) -> usize {
self.ids.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
fn check_dim(&self, vector_len: usize) -> Result<()> {
if vector_len != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: vector_len,
});
}
Ok(())
}
fn approximate_memory_bytes(&self) -> usize {
let arc_header_bytes = 2 * size_of::<usize>();
let vectors_bytes = self
.vectors
.iter()
.map(|arc| arc.len() * size_of::<f32>() + arc_header_bytes)
.sum::<usize>()
+ self.vectors.capacity() * size_of::<Arc<[f32]>>();
let ids_bytes = self.ids.capacity() * size_of::<VectorId>();
let metadata_bytes = self.metadata.capacity() * size_of::<Option<Metadata>>();
let seqs_bytes = self.seqs.capacity() * size_of::<u64>();
let id_to_pos_bytes =
self.id_to_pos.capacity() * (size_of::<VectorId>() + size_of::<usize>());
vectors_bytes + ids_bytes + metadata_bytes + seqs_bytes + id_to_pos_bytes
}
fn search_unfiltered(&self, query: &[f32], k: usize) -> Result<Vec<Hit>> {
let candidates: Vec<&[f32]> = self.vectors.iter().map(|arc| &arc[..]).collect();
let mut distances = vec![0.0_f32; candidates.len()];
compute_distances(self.metric, query, &candidates, &mut distances)?;
if matches!(self.metric, DistanceMetric::DotProduct) {
for value in distances.iter_mut() {
*value = -*value;
}
}
let chosen = topk::select_topk_indices(&distances, &self.seqs, k);
let mut hits = Vec::with_capacity(chosen.len());
for storage_idx in chosen {
hits.push(Hit {
id: self.ids[storage_idx].clone(),
distance: distances[storage_idx],
metadata: self.metadata[storage_idx].clone(),
});
}
Ok(hits)
}
fn search_filtered(&self, query: &[f32], k: usize, filter: &Filter) -> Result<Vec<Hit>> {
let evaluator = FilterEvaluator::new(filter.clone())?;
let accepted: Vec<usize> = (0..self.ids.len())
.filter(|&i| evaluator.evaluate(self.metadata[i].as_ref()))
.collect();
if accepted.is_empty() {
return Ok(Vec::new());
}
let candidates: Vec<&[f32]> = accepted.iter().map(|&i| &self.vectors[i][..]).collect();
let accepted_seqs: Vec<u64> = accepted.iter().map(|&i| self.seqs[i]).collect();
let mut distances = vec![0.0_f32; candidates.len()];
compute_distances(self.metric, query, &candidates, &mut distances)?;
if matches!(self.metric, DistanceMetric::DotProduct) {
for value in distances.iter_mut() {
*value = -*value;
}
}
let chosen = topk::select_topk_indices(&distances, &accepted_seqs, k);
let mut hits = Vec::with_capacity(chosen.len());
for candidate_idx in chosen {
let storage_idx = accepted[candidate_idx];
hits.push(Hit {
id: self.ids[storage_idx].clone(),
distance: distances[candidate_idx],
metadata: self.metadata[storage_idx].clone(),
});
}
Ok(hits)
}
}
impl IndexCore for FlatIndex {
fn insert(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
self.check_dim(vector.len())?;
if self.id_to_pos.contains_key(&id) {
return Err(IqdbError::Duplicate);
}
let pos = self.ids.len();
let seq = self.next_seq;
self.next_seq = self
.next_seq
.checked_add(1)
.ok_or(IqdbError::InvalidConfig {
reason: "FlatIndex insertion sequence counter overflowed u64",
})?;
self.vectors.push(vector);
self.ids.push(id.clone());
self.metadata.push(metadata);
self.seqs.push(seq);
let _prev = self.id_to_pos.insert(id, pos);
Ok(())
}
fn insert_batch(&mut self, items: Vec<(VectorId, Arc<[f32]>, Option<Metadata>)>) -> Result<()> {
let additional = items.len();
self.vectors.reserve(additional);
self.ids.reserve(additional);
self.metadata.reserve(additional);
self.seqs.reserve(additional);
self.id_to_pos.reserve(additional);
for (id, vector, metadata) in items {
self.insert(id, vector, metadata)?;
}
Ok(())
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
let pos = self.id_to_pos.remove(id).ok_or(IqdbError::NotFound)?;
let _v = self.vectors.swap_remove(pos);
let _i = self.ids.swap_remove(pos);
let _m = self.metadata.swap_remove(pos);
let _s = self.seqs.swap_remove(pos);
if pos < self.ids.len() {
let _prev = self.id_to_pos.insert(self.ids[pos].clone(), pos);
}
Ok(())
}
fn search(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
self.check_dim(query.len())?;
if params.metric != self.metric {
return Err(IqdbError::InvalidMetric);
}
if params.k == 0 || self.ids.is_empty() {
return Ok(Vec::new());
}
match ¶ms.filter {
None => self.search_unfiltered(query, params.k),
Some(filter) => self.search_filtered(query, params.k, filter),
}
}
fn len(&self) -> usize {
FlatIndex::len(self)
}
fn is_empty(&self) -> bool {
FlatIndex::is_empty(self)
}
fn dim(&self) -> usize {
FlatIndex::dim(self)
}
fn metric(&self) -> DistanceMetric {
FlatIndex::metric(self)
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.ids.len(),
memory_bytes: self.approximate_memory_bytes(),
disk_bytes: None,
index_type: "flat",
extra: None,
}
}
}
impl Index for FlatIndex {
type Config = FlatConfig;
fn new(dim: usize, metric: DistanceMetric, _config: Self::Config) -> Result<Self> {
Self::new_unconfigured(dim, metric)
}
}
fn compute_distances(
metric: DistanceMetric,
query: &[f32],
candidates: &[&[f32]],
out: &mut [f32],
) -> Result<()> {
#[cfg(feature = "parallel")]
{
crate::parallel::compute_distances(metric, query, candidates, out)
}
#[cfg(not(feature = "parallel"))]
{
iqdb_distance::compute_batch(metric, query, candidates, out)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn insert_stores_caller_arc_without_reallocating_payload() {
let mut idx = FlatIndex::new_unconfigured(3, DistanceMetric::Euclidean).unwrap();
let payload: Arc<[f32]> = Arc::from(&[1.0_f32, 2.0, 3.0][..]);
let caller_ptr = Arc::as_ptr(&payload);
idx.insert(VectorId::from(1u64), Arc::clone(&payload), None)
.unwrap();
let stored = &idx.vectors[0];
assert_eq!(
Arc::as_ptr(stored),
caller_ptr,
"FlatIndex MUST store the caller's Arc verbatim — no fresh \
allocation, no copy",
);
assert_eq!(Arc::strong_count(&payload), 2);
}
#[test]
fn delete_drops_the_stored_strong_ref() {
let mut idx = FlatIndex::new_unconfigured(2, DistanceMetric::Cosine).unwrap();
let payload: Arc<[f32]> = Arc::from(&[0.5_f32, -0.5][..]);
idx.insert(VectorId::from(9u64), Arc::clone(&payload), None)
.unwrap();
assert_eq!(Arc::strong_count(&payload), 2);
idx.delete(&VectorId::from(9u64)).unwrap();
assert_eq!(
Arc::strong_count(&payload),
1,
"delete drops the index's strong ref; only the caller's remains",
);
}
}