use std::sync::RwLock;
use crate::storage::{StorageReadProvider, StorageWriteProvider};
use diskann::{
ANNError, ANNResult,
graph::AdjacencyList,
provider::HasId,
utils::{TryIntoVectorId, VectorId},
};
use tracing::trace;
use super::common::{AlignedMemoryVectorStore, TestCallCount};
use crate::storage::{
self, AsyncIndexMetadata, AsyncQuantLoadContext, DiskGraphOnly, LoadWith, SaveWith,
};
pub struct SimpleNeighborProviderAsync<I: VectorId> {
graph: AlignedMemoryVectorStore<I>,
locks: Vec<RwLock<()>>,
num_start_points: usize,
pub num_get_calls: TestCallCount,
}
impl<I: VectorId> SimpleNeighborProviderAsync<I> {
pub fn new(
max_points: usize,
num_start_points: usize,
max_degree: u32,
graph_slack_factor: f32,
) -> Self {
let size = max_points + num_start_points;
let graph = AlignedMemoryVectorStore::with_capacity(
size,
(max_degree as f32 * graph_slack_factor) as usize + 1,
);
let locks = (0..size).map(|_| RwLock::new(())).collect::<Vec<_>>();
Self {
graph,
locks,
num_start_points,
num_get_calls: TestCallCount::default(),
}
}
unsafe fn get_slice(&self, index: usize) -> &[I] {
let s = unsafe { self.graph.get_slice(index) };
let len = s[self.graph.dim() - 1].into_usize();
&s[0..len]
}
pub fn set_neighbors_sync(&self, id: usize, neighbors: &[I]) -> ANNResult<()> {
assert!(
neighbors.len() < self.graph.dim(),
"neighbors ({}) exceeded max adjacency list size ({})",
neighbors.len(),
self.graph.dim() - 1,
);
#[allow(clippy::unwrap_used)]
let _guard = self.locks[id].write().unwrap();
let list = unsafe { self.graph.get_mut_slice(id) };
list[0..neighbors.len()].copy_from_slice(neighbors);
#[allow(clippy::unwrap_used)]
{
list[self.graph.dim() - 1] = neighbors.len().try_into_vector_id().unwrap();
}
Ok(())
}
pub fn get_neighbors_sync(&self, id: usize, neighbors: &mut AdjacencyList<I>) -> ANNResult<()> {
#[cfg(test)]
self.num_get_calls.increment();
#[allow(clippy::unwrap_used)]
let _guard = self.locks[id].read().unwrap();
let list = unsafe { self.get_slice(id) };
neighbors.overwrite_trusted(list);
Ok(())
}
pub fn append_vector_sync(&self, id: usize, new_neighbor_ids: &[I]) -> ANNResult<()> {
#[allow(clippy::unwrap_used)]
let _guard = self.locks[id].write().unwrap();
let list_raw = unsafe { self.graph.get_mut_slice(id) };
let len = list_raw[self.graph.dim() - 1].into_usize();
let mut new_len = len;
let mut list = &mut list_raw[0..len];
for new_neighbor_id in new_neighbor_ids {
if I::contains_simd(list, *new_neighbor_id) {
trace!("append_vector: new neighbor already exists");
continue;
}
if new_len < self.graph.dim() - 1 {
list_raw[new_len] = *new_neighbor_id;
new_len += 1;
list = &mut list_raw[0..new_len];
} else {
trace!("append_vector: some new neighbors discarded; adjacency list full");
break;
}
}
#[allow(clippy::unwrap_used)]
{
list_raw[self.graph.dim() - 1] = new_len.try_into_vector_id().unwrap();
}
Ok(())
}
}
impl<I: VectorId> HasId for SimpleNeighborProviderAsync<I> {
type Id = I;
}
impl SimpleNeighborProviderAsync<u32> {
pub fn load_direct<P>(provider: &P, path: &str) -> ANNResult<Self>
where
P: StorageReadProvider,
{
storage::bin::load_graph(
provider,
path,
|num_points, max_degree, num_start_points| {
let max_points = num_points.checked_sub(num_start_points).ok_or_else(|| {
ANNError::log_index_error(format_args!(
"expected {} start points but the on-disk dataset only has {} total points",
num_start_points, num_points,
))
})?;
Ok(Self::new(
max_points,
num_start_points,
max_degree as u32,
1.0,
))
},
)
}
pub fn save_direct<P>(&self, provider: &P, start_point: u32, path: &str) -> ANNResult<usize>
where
P: StorageWriteProvider,
{
storage::bin::save_graph(self, provider, start_point, path)
}
}
impl SaveWith<(u32, AsyncIndexMetadata)> for SimpleNeighborProviderAsync<u32> {
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(
&self,
provider: &P,
(start_point, metadata): &(u32, AsyncIndexMetadata),
) -> ANNResult<usize>
where
P: StorageWriteProvider,
{
self.save_direct(provider, *start_point, metadata.prefix())
}
}
impl SaveWith<(u32, u32, DiskGraphOnly)> for SimpleNeighborProviderAsync<u32> {
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(
&self,
provider: &P,
(imem_start_point, actual_start_point, metadata): &(u32, u32, DiskGraphOnly),
) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let graph = DiskAdaptor {
provider: self,
inmem_start_point: *imem_start_point,
actual_start_point: *actual_start_point,
};
storage::bin::save_graph(&graph, provider, *actual_start_point, metadata.prefix())
}
}
impl LoadWith<AsyncIndexMetadata> for SimpleNeighborProviderAsync<u32> {
type Error = ANNError;
async fn load_with<P>(provider: &P, metadata: &AsyncIndexMetadata) -> ANNResult<Self>
where
P: StorageReadProvider,
{
Self::load_direct(provider, metadata.prefix())
}
}
impl LoadWith<AsyncQuantLoadContext> for SimpleNeighborProviderAsync<u32> {
type Error = ANNError;
async fn load_with<P>(provider: &P, ctx: &AsyncQuantLoadContext) -> ANNResult<Self>
where
P: StorageReadProvider,
{
Self::load_with(provider, &ctx.metadata).await
}
}
impl storage::bin::SetAdjacencyList for SimpleNeighborProviderAsync<u32> {
type Item = u32;
fn set_adjacency_list(&mut self, i: usize, element: &[u32]) -> ANNResult<()> {
self.set_neighbors_sync(i, element)?;
Ok(())
}
}
impl storage::bin::GetAdjacencyList for SimpleNeighborProviderAsync<u32> {
type Element = u32;
type Item<'a> = AdjacencyList<u32>;
fn get_adjacency_list(&self, i: usize) -> ANNResult<Self::Item<'_>> {
let mut list = AdjacencyList::new();
self.get_neighbors_sync(i, &mut list)?;
Ok(list)
}
fn total(&self) -> usize {
self.locks.len()
}
fn additional_points(&self) -> u64 {
self.num_start_points as u64
}
fn max_degree(&self) -> Option<u32> {
Some((self.graph.dim() - 1) as u32)
}
}
struct DiskAdaptor<'a> {
provider: &'a SimpleNeighborProviderAsync<u32>,
inmem_start_point: u32,
actual_start_point: u32,
}
impl storage::bin::GetAdjacencyList for DiskAdaptor<'_> {
type Element = u32;
type Item<'item>
= Vec<u32>
where
Self: 'item;
fn get_adjacency_list(&self, i: usize) -> ANNResult<Self::Item<'_>> {
let mut list = AdjacencyList::new();
self.provider.get_neighbors_sync(i, &mut list)?;
let mut list: Vec<_> = list.into();
for i in list.iter_mut() {
if *i == self.inmem_start_point {
*i = self.actual_start_point;
}
}
Ok(list)
}
fn total(&self) -> usize {
self.provider.locks.len() - self.provider.num_start_points
}
fn additional_points(&self) -> u64 {
0
}
fn max_degree(&self) -> Option<u32> {
None
}
}
#[cfg(test)]
mod tests {
use crate::storage::VirtualStorageProvider;
use super::*;
#[test]
fn test_neighbor_provider() {
let neighbor_provider = SimpleNeighborProviderAsync::<u32>::new(10, 1, 5, 1.0);
let adj_list = vec![1, 2, 3];
neighbor_provider.set_neighbors_sync(1, &adj_list).unwrap();
let mut result = AdjacencyList::new();
neighbor_provider
.get_neighbors_sync(1, &mut result)
.unwrap();
assert_eq!(&adj_list, &*result);
let new_adj_list = AdjacencyList::from_iter_untrusted([4, 5, 6]);
neighbor_provider
.set_neighbors_sync(1, &new_adj_list)
.unwrap();
neighbor_provider
.get_neighbors_sync(1, &mut result)
.unwrap();
assert_eq!(new_adj_list, result);
}
#[tokio::test]
async fn test_save_load() {
let max_degree = 5;
let max_points = 8;
let additional_points = 2;
let provider =
SimpleNeighborProviderAsync::<u32>::new(max_points, additional_points, max_degree, 1.0);
let storage = VirtualStorageProvider::new_memory();
for i in 0..max_points + additional_points {
let neighbors: Vec<u32> = (1..4).map(|j| i as u32 + j).collect();
provider.set_neighbors_sync(i, &neighbors).unwrap();
}
let prefix = AsyncIndexMetadata::new("/resumable_test");
let start_point = 0;
let result = provider
.save_with(&storage, &(start_point, prefix.clone()))
.await;
assert!(result.is_ok(), "Failed to save with resumable context");
let expected_path = prefix.prefix();
assert!(
storage.exists(expected_path),
"Resumable graph file was not created"
);
let receiver =
SimpleNeighborProviderAsync::<u32>::load_direct(&storage, prefix.prefix()).unwrap();
for i in 0..max_points + additional_points {
let mut result = AdjacencyList::new();
let mut loaded_result = AdjacencyList::new();
provider.get_neighbors_sync(i, &mut result).unwrap();
receiver.get_neighbors_sync(i, &mut loaded_result).unwrap();
assert_eq!(
result, loaded_result,
"Adjacency list for node {} doesn't match after loading",
i
);
}
}
}