use std::collections::HashMap;
use std::sync::Arc;
use iqdb_types::{Metadata, VectorId};
#[derive(Debug, Clone)]
pub(crate) struct Row {
pub(crate) id: VectorId,
pub(crate) vector: Arc<[f32]>,
pub(crate) meta: Option<Metadata>,
}
#[derive(Debug, Default)]
pub(crate) struct RowStore {
rows: Vec<Row>,
index: HashMap<VectorId, usize>,
}
impl RowStore {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn with_capacity(cap: usize) -> Self {
Self {
rows: Vec::with_capacity(cap),
index: HashMap::with_capacity(cap),
}
}
pub(crate) fn len(&self) -> usize {
self.rows.len()
}
pub(crate) fn is_empty(&self) -> bool {
self.rows.is_empty()
}
pub(crate) fn contains(&self, id: &VectorId) -> bool {
self.index.contains_key(id)
}
pub(crate) fn get(&self, id: &VectorId) -> Option<&Row> {
self.index.get(id).map(|&pos| &self.rows[pos])
}
pub(crate) fn upsert(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
meta: Option<Metadata>,
) -> bool {
if let Some(&pos) = self.index.get(&id) {
self.rows[pos] = Row { id, vector, meta };
false
} else {
let pos = self.rows.len();
let _ = self.index.insert(id.clone(), pos);
self.rows.push(Row { id, vector, meta });
true
}
}
pub(crate) fn remove(&mut self, id: &VectorId) -> bool {
let Some(pos) = self.index.remove(id) else {
return false;
};
let _removed = self.rows.swap_remove(pos);
if pos < self.rows.len() {
let moved_id = self.rows[pos].id.clone();
let _ = self.index.insert(moved_id, pos);
}
true
}
pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = &Row> {
self.rows.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn v(xs: &[f32]) -> Arc<[f32]> {
Arc::from(xs)
}
#[test]
fn upsert_insert_then_replace_keeps_len_and_position() {
let mut s = RowStore::new();
assert!(s.upsert(VectorId::from(1u64), v(&[1.0, 0.0]), None));
assert!(s.upsert(VectorId::from(2u64), v(&[0.0, 1.0]), None));
assert_eq!(s.len(), 2);
assert!(!s.upsert(VectorId::from(1u64), v(&[2.0, 2.0]), None));
assert_eq!(s.len(), 2);
let order: Vec<_> = s.iter().map(|r| r.id.clone()).collect();
assert_eq!(order, vec![VectorId::from(1u64), VectorId::from(2u64)]);
assert_eq!(
s.get(&VectorId::from(1u64)).unwrap().vector.as_ref(),
&[2.0, 2.0]
);
}
#[test]
fn remove_returns_false_when_absent_and_repairs_index() {
let mut s = RowStore::new();
assert!(s.upsert(VectorId::from(1u64), v(&[1.0]), None));
assert!(s.upsert(VectorId::from(2u64), v(&[2.0]), None));
assert!(s.upsert(VectorId::from(3u64), v(&[3.0]), None));
assert!(s.remove(&VectorId::from(2u64)));
assert!(!s.remove(&VectorId::from(2u64)));
assert_eq!(s.len(), 2);
assert_eq!(
s.get(&VectorId::from(3u64)).unwrap().vector.as_ref(),
&[3.0]
);
assert!(!s.contains(&VectorId::from(2u64)));
}
#[test]
fn get_absent_is_none() {
let s = RowStore::new();
assert!(s.get(&VectorId::from(9u64)).is_none());
assert!(s.is_empty());
}
}