use parking_lot::RwLock;
use rustc_hash::FxHashMap;
pub(crate) const NUM_SHARDS: usize = 16;
#[derive(Debug, Default)]
struct VectorShard {
vectors: FxHashMap<usize, Vec<f32>>,
}
#[derive(Debug)]
pub struct ShardedVectors {
shards: [RwLock<VectorShard>; NUM_SHARDS],
#[allow(dead_code)]
dimension: usize,
}
impl Default for ShardedVectors {
fn default() -> Self {
Self::new(0)
}
}
#[allow(dead_code)] impl ShardedVectors {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self {
shards: std::array::from_fn(|_| RwLock::new(VectorShard::default())),
dimension,
}
}
#[inline]
pub(crate) const fn shard_index(idx: usize) -> usize {
idx % NUM_SHARDS
}
pub fn insert(&self, idx: usize, vector: &[f32]) {
let shard_idx = Self::shard_index(idx);
let mut shard = self.shards[shard_idx].write();
shard.vectors.insert(idx, vector.to_vec());
}
pub fn insert_batch(&self, vectors: impl IntoIterator<Item = (usize, Vec<f32>)>) {
let mut by_shard: [Vec<(usize, Vec<f32>)>; NUM_SHARDS] =
std::array::from_fn(|_| Vec::new());
for (idx, vec) in vectors {
let shard_idx = Self::shard_index(idx);
by_shard[shard_idx].push((idx, vec));
}
for (shard_idx, batch) in by_shard.into_iter().enumerate() {
if !batch.is_empty() {
let mut shard = self.shards[shard_idx].write();
for (idx, vec) in batch {
shard.vectors.insert(idx, vec);
}
}
}
}
#[must_use]
pub fn get(&self, idx: usize) -> Option<Vec<f32>> {
let shard_idx = Self::shard_index(idx);
let shard = self.shards[shard_idx].read();
shard.vectors.get(&idx).cloned()
}
#[must_use]
#[allow(dead_code)] pub fn contains(&self, idx: usize) -> bool {
let shard_idx = Self::shard_index(idx);
let shard = self.shards[shard_idx].read();
shard.vectors.contains_key(&idx)
}
#[allow(dead_code)] pub fn with_vector<F, R>(&self, idx: usize, f: F) -> Option<R>
where
F: FnOnce(&[f32]) -> R,
{
let shard_idx = Self::shard_index(idx);
let shard = self.shards[shard_idx].read();
shard.vectors.get(&idx).map(|v| f(v))
}
#[allow(dead_code)] pub fn remove(&self, idx: usize) -> Option<Vec<f32>> {
let shard_idx = Self::shard_index(idx);
let mut shard = self.shards[shard_idx].write();
shard.vectors.remove(&idx)
}
#[must_use]
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().vectors.len()).sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.read().vectors.is_empty())
}
pub fn clear(&self) {
for shard in &self.shards {
shard.write().vectors.clear();
}
}
#[allow(dead_code)] pub fn iter_all(&self) -> Vec<(usize, Vec<f32>)> {
self.collect_for_parallel()
}
#[allow(dead_code)] pub fn for_each_parallel<F>(&self, mut f: F)
where
F: FnMut(usize, &[f32]),
{
for shard in &self.shards {
let guard = shard.read();
for (idx, vec) in &guard.vectors {
f(*idx, vec);
}
}
}
#[must_use]
pub fn collect_for_parallel(&self) -> Vec<(usize, Vec<f32>)> {
let total_len = self.len();
let mut result = Vec::with_capacity(total_len);
self.drain_shards_into(&mut result);
result
}
pub fn collect_into(&self, buffer: &mut Vec<(usize, Vec<f32>)>) {
buffer.clear();
let total_len = self.len();
buffer.reserve(total_len.saturating_sub(buffer.capacity()));
self.drain_shards_into(buffer);
}
fn drain_shards_into(&self, buffer: &mut Vec<(usize, Vec<f32>)>) {
for shard in &self.shards {
let guard = shard.read();
for (idx, vec) in &guard.vectors {
buffer.push((*idx, vec.clone()));
}
}
}
#[must_use]
#[allow(dead_code)] pub fn snapshot_indices(&self) -> Vec<usize> {
let mut indices = Vec::with_capacity(self.len());
for shard in &self.shards {
let guard = shard.read();
for idx in guard.vectors.keys() {
indices.push(*idx);
}
}
indices
}
}