use std::collections::HashMap;
use std::path::Path;
use crate::io;
use crate::TurboQuantIndex;
pub struct IdMapIndex {
inner: TurboQuantIndex,
slot_to_id: Vec<u64>,
id_to_slot: HashMap<u64, usize>,
}
impl IdMapIndex {
pub fn new(dim: usize, bit_width: usize) -> Self {
Self {
inner: TurboQuantIndex::new(dim, bit_width),
slot_to_id: Vec::new(),
id_to_slot: HashMap::new(),
}
}
pub fn new_lazy(bit_width: usize) -> Self {
Self {
inner: TurboQuantIndex::new_lazy(bit_width),
slot_to_id: Vec::new(),
id_to_slot: HashMap::new(),
}
}
pub fn add_with_ids(&mut self, vectors: &[f32], ids: &[u64]) {
let dim = self.inner.dim_opt().expect(
"IdMapIndex dim is not set; use add_with_ids_2d(vectors, dim, ids) \
on the first add or construct with IdMapIndex::new(dim, bit_width)",
);
self.add_with_ids_2d(vectors, dim, ids);
}
pub fn add_with_ids_2d(&mut self, vectors: &[f32], dim: usize, ids: &[u64]) {
let n = vectors.len() / dim;
assert_eq!(
vectors.len(),
n * dim,
"vector buffer length {} not a multiple of dim {}",
vectors.len(),
dim,
);
assert_eq!(
ids.len(),
n,
"expected {n} ids, got {}",
ids.len(),
);
self.id_to_slot.reserve(n);
self.slot_to_id.reserve(n);
let base_slot = self.inner.len();
for (i, &id) in ids.iter().enumerate() {
let slot = base_slot + i;
if self.id_to_slot.insert(id, slot).is_some() {
panic!("id {id} already present in index");
}
}
self.slot_to_id.extend_from_slice(ids);
self.inner.add_2d(vectors, dim);
}
pub fn remove(&mut self, id: u64) -> bool {
let Some(slot) = self.id_to_slot.remove(&id) else {
return false;
};
let last = self.slot_to_id.len() - 1;
let moved_from = self.inner.swap_remove(slot);
debug_assert_eq!(moved_from, last);
if slot != last {
let moved_id = self.slot_to_id[last];
self.slot_to_id[slot] = moved_id;
self.id_to_slot.insert(moved_id, slot);
}
self.slot_to_id.pop();
true
}
pub fn search(&self, queries: &[f32], k: usize) -> (Vec<f32>, Vec<u64>) {
self.search_with_allowlist(queries, k, None)
}
pub fn search_with_allowlist(
&self,
queries: &[f32],
k: usize,
allowlist: Option<&[u64]>,
) -> (Vec<f32>, Vec<u64>) {
let mask_buf: Option<Vec<bool>> = allowlist.map(|ids| {
assert!(!ids.is_empty(), "allowlist is empty");
let mut mask = vec![false; self.inner.len()];
for &id in ids {
let slot = match self.id_to_slot.get(&id) {
Some(&s) => s,
None => panic!("id {id} in allowlist is not present in index"),
};
mask[slot] = true;
}
mask
});
let res = self
.inner
.search_with_mask(queries, k, mask_buf.as_deref());
let mut ids = Vec::with_capacity(res.indices.len());
for &slot in &res.indices {
let id = self.slot_to_id[slot as usize];
ids.push(id);
}
(res.scores, ids)
}
pub fn contains(&self, id: u64) -> bool {
self.id_to_slot.contains_key(&id)
}
pub fn len(&self) -> usize {
self.slot_to_id.len()
}
pub fn is_empty(&self) -> bool {
self.slot_to_id.is_empty()
}
pub fn dim(&self) -> usize {
self.inner.dim()
}
pub fn dim_opt(&self) -> Option<usize> {
self.inner.dim_opt()
}
pub fn bit_width(&self) -> usize {
self.inner.bit_width()
}
pub fn prepare(&self) {
self.inner.prepare();
}
pub fn write(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
io::write_id_map(
path,
self.inner.bit_width(),
self.inner.dim_opt().unwrap_or(0),
self.inner.len(),
self.inner.packed_codes(),
self.inner.scales(),
&self.slot_to_id,
)
}
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
let (bit_width, dim, n_vectors, packed_codes, scales, slot_to_id) =
io::load_id_map(path)?;
let dim_opt = if dim == 0 { None } else { Some(dim) };
let inner = TurboQuantIndex::from_parts(dim_opt, bit_width, n_vectors, packed_codes, scales);
let id_to_slot: HashMap<u64, usize> = slot_to_id
.iter()
.enumerate()
.map(|(slot, &id)| (id, slot))
.collect();
if id_to_slot.len() != slot_to_id.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"duplicate ids in .tvim file",
));
}
Ok(Self {
inner,
slot_to_id,
id_to_slot,
})
}
}