use std::collections::HashMap;
use std::path::Path;
use crate::io;
use crate::{AddError, ConstructError, TurboQuantIndex};
pub struct IdMapIndex {
inner: TurboQuantIndex,
slot_to_id: Vec<u64>,
id_to_slot: HashMap<u64, usize>,
}
impl IdMapIndex {
pub fn new(dim: usize, bit_width: usize) -> Result<Self, ConstructError> {
Ok(Self {
inner: TurboQuantIndex::new(dim, bit_width)?,
slot_to_id: Vec::new(),
id_to_slot: HashMap::new(),
})
}
pub fn new_lazy(bit_width: usize) -> Result<Self, ConstructError> {
Ok(Self {
inner: TurboQuantIndex::new_lazy(bit_width)?,
slot_to_id: Vec::new(),
id_to_slot: HashMap::new(),
})
}
pub fn add_with_ids(&mut self, vectors: &[f32], ids: &[u64]) -> Result<(), AddError> {
let dim = self.inner.dim_opt().expect(
"IdMapIndex dim is not set; use add_with_ids_2d(vectors, dim, ids) \
on the first add or construct with IdMapIndex::new(dim, bit_width)",
);
self.add_with_ids_2d(vectors, dim, ids)
}
pub fn add_with_ids_2d(
&mut self,
vectors: &[f32],
dim: usize,
ids: &[u64],
) -> Result<(), AddError> {
if dim == 0 || vectors.len() % dim != 0 {
return Err(AddError::VectorBufferNotMultipleOfDim {
vectors_len: vectors.len(),
dim,
});
}
let n = vectors.len() / dim;
if ids.len() != n {
return Err(AddError::IdsCountMismatch {
expected: n,
got: ids.len(),
});
}
let mut seen_this_call: std::collections::HashSet<u64> =
std::collections::HashSet::with_capacity(n);
for &id in ids {
if self.id_to_slot.contains_key(&id) || !seen_this_call.insert(id) {
return Err(AddError::IdAlreadyPresent(id));
}
}
let base_slot = self.inner.len();
self.inner.add_2d(vectors, dim)?;
self.id_to_slot.reserve(n);
self.slot_to_id.reserve(n);
for (i, &id) in ids.iter().enumerate() {
self.id_to_slot.insert(id, base_slot + i);
}
self.slot_to_id.extend_from_slice(ids);
Ok(())
}
pub fn remove(&mut self, id: u64) -> bool {
let Some(slot) = self.id_to_slot.remove(&id) else {
return false;
};
let last = self.slot_to_id.len() - 1;
let moved_from = self.inner.swap_remove(slot);
debug_assert_eq!(moved_from, last);
if slot != last {
let moved_id = self.slot_to_id[last];
self.slot_to_id[slot] = moved_id;
self.id_to_slot.insert(moved_id, slot);
}
self.slot_to_id.pop();
true
}
pub fn search(&self, queries: &[f32], k: usize) -> (Vec<f32>, Vec<u64>) {
self.search_with_allowlist(queries, k, None)
}
pub fn search_with_allowlist(
&self,
queries: &[f32],
k: usize,
allowlist: Option<&[u64]>,
) -> (Vec<f32>, Vec<u64>) {
let mask_buf: Option<Vec<bool>> = allowlist.map(|ids| {
assert!(!ids.is_empty(), "allowlist is empty");
let mut mask = vec![false; self.inner.len()];
for &id in ids {
let slot = match self.id_to_slot.get(&id) {
Some(&s) => s,
None => panic!("id {id} in allowlist is not present in index"),
};
mask[slot] = true;
}
mask
});
let res = self
.inner
.search_with_mask(queries, k, mask_buf.as_deref());
let mut ids = Vec::with_capacity(res.indices.len());
for &slot in &res.indices {
let id = self.slot_to_id[slot as usize];
ids.push(id);
}
(res.scores, ids)
}
pub fn contains(&self, id: u64) -> bool {
self.id_to_slot.contains_key(&id)
}
pub fn len(&self) -> usize {
self.slot_to_id.len()
}
pub fn is_empty(&self) -> bool {
self.slot_to_id.is_empty()
}
pub fn dim(&self) -> usize {
self.inner.dim()
}
pub fn dim_opt(&self) -> Option<usize> {
self.inner.dim_opt()
}
pub fn bit_width(&self) -> usize {
self.inner.bit_width()
}
pub fn prepare(&self) {
self.inner.prepare();
}
pub fn write(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
io::write_id_map(
path,
self.inner.bit_width(),
self.inner.dim_opt().unwrap_or(0),
self.inner.len(),
self.inner.packed_codes(),
self.inner.scales(),
self.inner.tqplus_shift(),
self.inner.tqplus_scale(),
&self.slot_to_id,
)
}
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
let (bit_width, dim, n_vectors, packed_codes, scales, tqplus_shift, tqplus_scale, slot_to_id) =
io::load_id_map(path)?;
let dim_opt = if dim == 0 { None } else { Some(dim) };
let inner = TurboQuantIndex::from_parts(
dim_opt, bit_width, n_vectors, packed_codes, scales, tqplus_shift, tqplus_scale,
);
let id_to_slot: HashMap<u64, usize> = slot_to_id
.iter()
.enumerate()
.map(|(slot, &id)| (id, slot))
.collect();
if id_to_slot.len() != slot_to_id.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"duplicate ids in .tvim file",
));
}
Ok(Self {
inner,
slot_to_id,
id_to_slot,
})
}
}