use crate::error::IndexResult;
use crate::hnsw::{HnswConfig, HnswIndex, Neighbor};
use crate::metric::Metric;
use crate::PointId;
use parking_lot::RwLock;
use std::sync::Arc;
pub struct ConcurrentHnsw<P, M>
where
M: Metric<Point = P>,
{
inner: Arc<RwLock<HnswIndex<P, M>>>,
}
impl<P, M> Clone for ConcurrentHnsw<P, M>
where
M: Metric<Point = P>,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<P, M> ConcurrentHnsw<P, M>
where
M: Metric<Point = P>,
{
pub fn new(config: HnswConfig, metric: M) -> IndexResult<Self> {
Ok(Self {
inner: Arc::new(RwLock::new(HnswIndex::new(config, metric)?)),
})
}
pub fn insert(&self, id: PointId, point: P) -> IndexResult<()> {
self.inner.write().insert(id, point)
}
pub fn search(&self, query: &P, k: usize) -> Vec<Neighbor> {
self.inner.read().search(query, k)
}
pub fn len(&self) -> usize {
self.inner.read().len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metric::L2;
use std::thread;
#[test]
fn concurrent_reads_during_inserts() {
let idx: ConcurrentHnsw<Vec<f32>, L2> =
ConcurrentHnsw::new(HnswConfig::default(), L2).unwrap();
for i in 0..50 {
idx.insert(i, vec![i as f32, 0.0]).unwrap();
}
let writer = {
let idx = idx.clone();
thread::spawn(move || {
for i in 50..150 {
idx.insert(i, vec![i as f32, 0.0]).unwrap();
}
})
};
let readers: Vec<_> = (0..4)
.map(|_| {
let idx = idx.clone();
thread::spawn(move || {
for _ in 0..200 {
let res = idx.search(&vec![25.0, 0.0], 5);
assert!(!res.is_empty());
}
})
})
.collect();
writer.join().unwrap();
for r in readers {
r.join().unwrap();
}
assert_eq!(idx.len(), 150);
}
}