use std::{
fmt::Debug,
future::Future,
io::{Read, Write},
num::NonZeroUsize,
str::FromStr,
sync::Arc,
};
use serde::{Deserialize, Serialize};
use bf_tree::{BfTree, Config};
use diskann::{
ANNError, ANNResult, default_post_processor,
graph::{
AdjacencyList, DiskANNIndex, SearchOutputBuffer,
glue::{
self, Batch, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy,
MultiInsertStrategy, PruneStrategy, SearchExt, SearchStrategy,
},
workingset::{self, map},
},
neighbor::Neighbor,
provider::{
Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext,
DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut,
NoopGuard, SetElement,
},
utils::{IntoUsize, VectorRepr},
};
use diskann_utils::{future::AsyncFriendly, views::MatrixView};
use diskann_vector::{DistanceFunction, distance::Metric};
use crate::model::{
graph::provider::async_::{
TableDeleteProviderAsync,
bf_tree::{
neighbor_provider::NeighborProvider, quant_vector_provider::QuantVectorProvider,
vector_provider::VectorProvider,
},
common::{CreateDeleteProvider, FullPrecision, Hybrid, NoDeletes, NoStore, Panics},
distances,
postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy},
},
pq::{self, FixedChunkPQTable, NUM_PQ_CENTROIDS},
};
use crate::storage::{LoadWith, PQStorage, SaveWith};
use crate::storage::{StorageReadProvider, StorageWriteProvider};
pub struct BfTreeProvider<T, Q = QuantVectorProvider, D = NoDeletes>
where
T: VectorRepr,
{
pub(super) quant_vectors: Q,
pub(super) full_vectors: VectorProvider<T>,
pub(crate) neighbor_provider: NeighborProvider<u32>,
pub(super) deleted: D,
pub(super) max_fp_vecs_per_fill: usize,
pub(super) metric: Metric,
pub(crate) graph_params: Option<GraphParams>,
}
#[derive(Debug, Clone)]
pub struct BfTreeProviderParameters {
pub max_points: usize,
pub num_start_points: NonZeroUsize,
pub dim: usize,
pub metric: Metric,
pub max_fp_vecs_per_fill: Option<usize>,
pub max_degree: u32,
pub vector_provider_config: Config,
pub quant_vector_provider_config: Config,
pub neighbor_list_provider_config: Config,
pub graph_params: Option<GraphParams>,
}
pub type Index<T, D = NoDeletes> = Arc<DiskANNIndex<BfTreeProvider<T, NoStore, D>>>;
pub type QuantIndex<T, Q, D = NoDeletes> = Arc<DiskANNIndex<BfTreeProvider<T, Q, D>>>;
impl<T, Q, D> BfTreeProvider<T, Q, D>
where
T: VectorRepr,
{
pub fn new_empty<TQ, TD>(
params: BfTreeProviderParameters,
quant_precursor: TQ,
delete_precursor: TD,
) -> ANNResult<Self>
where
TQ: CreateQuantProvider<Target = Q>,
TD: CreateDeleteProvider<Target = D>,
{
let num_start_points = params.num_start_points.get();
Ok(Self {
quant_vectors: quant_precursor.create(
params.max_points,
num_start_points,
params.metric,
params.quant_vector_provider_config,
)?,
full_vectors: VectorProvider::new_with_config(
params.max_points,
params.dim,
num_start_points,
params.vector_provider_config,
)?,
neighbor_provider: NeighborProvider::new_with_config(
params.max_degree,
params.neighbor_list_provider_config,
)?,
deleted: delete_precursor.create(params.max_points + num_start_points),
max_fp_vecs_per_fill: params.max_fp_vecs_per_fill.unwrap_or(usize::MAX),
metric: params.metric,
graph_params: params.graph_params,
})
}
pub fn new<TQ, TD>(
params: BfTreeProviderParameters,
start_points: MatrixView<'_, T>,
quant_precursor: TQ,
delete_precursor: TD,
) -> ANNResult<Self>
where
Self: StartPoint<T>,
TQ: CreateQuantProvider<Target = Q>,
TD: CreateDeleteProvider<Target = D>,
{
if start_points.nrows() != params.num_start_points.get() {
return Err(ANNError::log_async_index_error(format!(
"start_points matrix has {} rows, but params.num_start_points is {}",
start_points.nrows(),
params.num_start_points.get(),
)));
}
let provider = Self::new_empty(params.clone(), quant_precursor, delete_precursor)?;
provider.set_start_points(Hidden(()), start_points)?;
{
for i in 0..params.max_points {
let vector_id = i as u32;
provider.neighbor_provider.set_neighbors(vector_id, &[])?;
}
}
Ok(provider)
}
pub fn starting_points(&self) -> ANNResult<Vec<u32>> {
Ok(self.full_vectors.starting_points()?)
}
pub fn iter(&self) -> std::ops::Range<u32> {
0..(self.full_vectors.total() as u32)
}
pub fn num_start_points(&self) -> usize {
self.full_vectors.num_start_points
}
pub fn max_points(&self) -> usize {
self.full_vectors.max_vectors
}
pub fn dim(&self) -> usize {
self.full_vectors.dim()
}
pub fn metric(&self) -> Metric {
self.metric
}
pub fn max_degree(&self) -> u32 {
self.neighbor_provider.max_degree()
}
}
impl<T, Q> BfTreeProvider<T, Q, TableDeleteProviderAsync>
where
T: VectorRepr,
{
pub fn clear_delete_set(&self) {
self.deleted.clear();
}
}
impl<T, D> BfTreeProvider<T, QuantVectorProvider, D>
where
T: VectorRepr,
{
pub fn counts_for_get_vector(&self) -> (usize, usize) {
(
self.full_vectors.num_get_calls.get(),
self.quant_vectors.num_get_calls.get(),
)
}
}
impl<T, D> BfTreeProvider<T, NoStore, D>
where
T: VectorRepr,
{
pub fn counts_for_get_vector(&self) -> (usize, usize) {
(self.full_vectors.num_get_calls.get(), 0)
}
}
impl<T, Q, D> IntoIterator for &BfTreeProvider<T, Q, D>
where
T: VectorRepr,
{
type Item = u32;
type IntoIter = std::ops::Range<u32>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub trait CreateQuantProvider {
type Target;
fn create(
self,
max_points: usize,
frozen_points: usize,
metric: Metric,
bf_tree_config: Config,
) -> ANNResult<Self::Target>;
}
impl CreateQuantProvider for NoStore {
type Target = NoStore;
fn create(
self,
_max_points: usize,
_frozen_points: usize,
_metric: Metric,
_bf_tree_config: Config,
) -> ANNResult<Self::Target> {
Ok(self)
}
}
impl CreateQuantProvider for FixedChunkPQTable {
type Target = QuantVectorProvider;
fn create(
self,
max_points: usize,
frozen_points: usize,
metric: Metric,
bf_tree_config: Config,
) -> ANNResult<Self::Target> {
QuantVectorProvider::new_with_config(
metric,
max_points,
frozen_points,
self,
bf_tree_config,
)
}
}
impl<T, Q, D> BfTreeProvider<T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
pub fn neighbors(&self) -> &NeighborProvider<u32> {
&self.neighbor_provider
}
}
impl<T, Q, D> DataProvider for BfTreeProvider<T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Context = DefaultContext;
type InternalId = u32;
type ExternalId = u32;
type Error = ANNError;
type Guard = NoopGuard<u32>;
fn to_internal_id(
&self,
_context: &DefaultContext,
gid: &Self::ExternalId,
) -> Result<Self::InternalId, Self::Error> {
Ok(*gid)
}
fn to_external_id(
&self,
_context: &DefaultContext,
id: Self::InternalId,
) -> Result<Self::ExternalId, Self::Error> {
Ok(id)
}
}
impl<T, Q, D> HasId for BfTreeProvider<T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Id = u32;
}
impl<'a, T, Q, D> DelegateNeighbor<'a> for BfTreeProvider<T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Delegate = &'a NeighborProvider<u32>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
self.neighbors()
}
}
impl<T, Q> Delete for BfTreeProvider<T, Q, TableDeleteProviderAsync>
where
Q: AsyncFriendly,
T: VectorRepr,
{
fn release(
&self,
_: &DefaultContext,
id: Self::InternalId,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
if let Err(e) = self.neighbor_provider.delete_vector(id) {
return std::future::ready(Err(e));
}
self.deleted.undelete(id.into_usize());
let res = self
.neighbor_provider
.set_neighbors(id, &[])
.map_err(|err| err.context(format!("resetting neighbors for undeleted id {}", id)));
std::future::ready(res)
}
#[inline]
fn delete(
&self,
_context: &DefaultContext,
gid: &Self::ExternalId,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.deleted.delete(gid.into_usize());
std::future::ready(Ok(()))
}
#[inline]
fn status_by_external_id(
&self,
context: &DefaultContext,
gid: &Self::ExternalId,
) -> impl Future<Output = Result<ElementStatus, Self::Error>> + Send {
self.status_by_internal_id(context, *gid)
}
#[inline]
fn status_by_internal_id(
&self,
_context: &DefaultContext,
id: Self::InternalId,
) -> impl Future<Output = Result<ElementStatus, Self::Error>> + Send {
let status = if self.deleted.is_deleted(id.into_usize()) {
ElementStatus::Deleted
} else {
ElementStatus::Valid
};
std::future::ready(Ok(status))
}
}
impl NeighborAccessor for &NeighborProvider<u32> {
fn get_neighbors(
self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> impl Future<Output = ANNResult<Self>> + Send {
std::future::ready(self.get_neighbors(id, neighbors).map(|_| self))
}
}
impl NeighborAccessorMut for &NeighborProvider<u32> {
fn set_neighbors(
self,
vector_id: u32,
neighbors: &[u32],
) -> impl Future<Output = ANNResult<Self>> + Send {
std::future::ready(self.set_neighbors(vector_id, neighbors).map(|_| self))
}
fn append_vector(
self,
vector_id: u32,
new_neighbor_ids: &[u32],
) -> impl Future<Output = ANNResult<Self>> + Send {
std::future::ready(
self.append_vector(vector_id, new_neighbor_ids)
.map(|_| self),
)
}
}
impl<T, D> SetElement<&[T]> for BfTreeProvider<T, QuantVectorProvider, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type SetError = ANNError;
fn set_element(
&self,
_context: &Self::Context,
id: &u32,
element: &[T],
) -> impl Future<Output = Result<Self::Guard, Self::SetError>> + Send {
if let Err(err) = self.quant_vectors.set_vector_sync(id.into_usize(), element) {
return std::future::ready(Err(err));
}
if let Err(err) = self.full_vectors.set_vector_sync(id.into_usize(), element) {
return std::future::ready(Err(err));
}
std::future::ready(Ok(NoopGuard::new(*id)))
}
}
impl<T, D> SetElement<&[T]> for BfTreeProvider<T, NoStore, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type SetError = ANNError;
fn set_element(
&self,
_context: &Self::Context,
id: &u32,
element: &[T],
) -> impl Future<Output = Result<Self::Guard, Self::SetError>> + Send {
if let Err(err) = self.full_vectors.set_vector_sync(id.into_usize(), element) {
return std::future::ready(Err(err));
}
std::future::ready(Ok(NoopGuard::new(*id)))
}
}
pub struct Hidden(());
pub trait StartPoint<T> {
#[doc(hidden)]
fn set_start_points(&self, hidden: Hidden, start_points: MatrixView<'_, T>) -> ANNResult<()>;
}
impl<T, D> StartPoint<T> for BfTreeProvider<T, QuantVectorProvider, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
fn set_start_points(&self, _hidden: Hidden, start_points: MatrixView<'_, T>) -> ANNResult<()> {
let start_point_ids = self.full_vectors.starting_points()?;
if start_points.nrows() != start_point_ids.len() {
return Err(ANNError::log_async_index_error(format!(
"expected start_points to contain `{}` rows, instead it has {}",
start_point_ids.len(),
start_points.nrows(),
)));
}
for (id, v) in std::iter::zip(start_point_ids, start_points.row_iter()) {
self.full_vectors.set_vector_sync(id.into_usize(), v)?;
self.quant_vectors.set_vector_sync(id.into_usize(), v)?;
self.neighbor_provider.set_neighbors(id, &[])?;
}
Ok(())
}
}
impl<T, D> StartPoint<T> for BfTreeProvider<T, NoStore, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
fn set_start_points(&self, _hidden: Hidden, start_points: MatrixView<'_, T>) -> ANNResult<()> {
let start_point_ids = self.full_vectors.starting_points()?;
if start_points.nrows() != start_point_ids.len() {
return Err(ANNError::log_async_index_error(format!(
"expected start_points to contain `{}` rows, instead it has {}",
start_point_ids.len(),
start_points.nrows(),
)));
}
for (id, v) in std::iter::zip(start_point_ids, start_points.row_iter()) {
self.full_vectors.set_vector_sync(id.into_usize(), v)?;
self.neighbor_provider.set_neighbors(id, &[])?;
}
Ok(())
}
}
pub struct FullAccessor<'a, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
provider: &'a BfTreeProvider<T, Q, D>,
element: Box<[T]>,
}
impl<T, Q, D> HasId for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Id = u32;
}
impl<T, Q, D> SearchExt for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<u32>>> {
std::future::ready(self.provider.starting_points())
}
}
impl<'a, T, Q, D> FullAccessor<'a, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
pub(crate) fn new(provider: &'a BfTreeProvider<T, Q, D>) -> Self {
Self {
provider,
element: (0..provider.full_vectors.dim())
.map(|_| T::default())
.collect(),
}
}
}
impl<'a, T, Q, D> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Delegate = &'a NeighborProvider<u32>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
self.provider.neighbors()
}
}
impl<T, Q, D> Accessor for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type Element<'a>
= &'a [T]
where
Self: 'a;
type ElementRef<'a> = &'a [T];
type GetError = Panics;
#[inline(always)]
fn get_element(
&mut self,
id: Self::Id,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
#[allow(clippy::expect_used)]
self.provider
.full_vectors
.get_vector_into(id.into_usize(), &mut self.element)
.expect("Full vector provider failed to retrieve element");
std::future::ready(Ok(&*self.element))
}
}
impl<T, Q, D> BuildDistanceComputer for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type DistanceComputerError = Panics;
type DistanceComputer = T::Distance;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError> {
Ok(T::distance(
self.provider.metric,
Some(self.provider.full_vectors.dim()),
))
}
}
impl<T, Q, D> BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type QueryComputerError = Panics;
type QueryComputer = T::QueryDistance;
fn build_query_computer(
&self,
from: &[T],
) -> Result<Self::QueryComputer, Self::QueryComputerError> {
Ok(T::query_distance(from, self.provider.metric))
}
}
impl<T, Q, D> ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
}
impl<'a, T, Q, D> AsDeletionCheck for FullAccessor<'a, T, Q, D>
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
{
type Checker = D;
fn as_deletion_check(&self) -> &D {
&self.provider.deleted
}
}
pub struct QuantAccessor<'a, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
provider: &'a BfTreeProvider<T, QuantVectorProvider, D>,
element: Box<[u8]>,
}
impl<T, D> HasId for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Id = u32;
}
impl<T, D> SearchExt for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<u32>>> {
std::future::ready(self.provider.starting_points())
}
}
impl<'a, T, D> QuantAccessor<'a, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
pub(crate) fn new(provider: &'a BfTreeProvider<T, QuantVectorProvider, D>) -> Self {
Self {
provider,
element: (0..provider.quant_vectors.pq_chunks())
.map(|_| u8::default())
.collect(),
}
}
}
impl<'a, T, D> DelegateNeighbor<'a> for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Delegate = &'a NeighborProvider<u32>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
self.provider.neighbors()
}
}
impl<T, D> Accessor for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Element<'a>
= &'a [u8]
where
Self: 'a;
type ElementRef<'a> = &'a [u8];
type GetError = ANNError;
fn get_element(
&mut self,
id: Self::Id,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
let v = self
.provider
.quant_vectors
.get_vector_into(id.into_usize(), &mut self.element)
.map(|_: ()| &*self.element);
std::future::ready(v)
}
fn on_elements_unordered<Itr, F>(
&mut self,
itr: Itr,
mut f: F,
) -> impl Future<Output = Result<(), Self::GetError>> + Send
where
Self: Sync,
Itr: Iterator<Item = Self::Id> + Send,
F: Send + FnMut(Self::ElementRef<'_>, Self::Id),
{
for i in itr {
match self
.provider
.quant_vectors
.get_vector_into(i.into_usize(), &mut self.element)
{
Ok(()) => f(&self.element, i),
Err(e) => {
return std::future::ready(Err(e));
}
}
}
std::future::ready(Ok(()))
}
}
impl<T, D> BuildQueryComputer<&[T]> for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type QueryComputerError = ANNError;
type QueryComputer = pq::distance::QueryComputer<Arc<FixedChunkPQTable>>;
fn build_query_computer(
&self,
from: &[T],
) -> Result<Self::QueryComputer, Self::QueryComputerError> {
self.provider.quant_vectors.query_computer(from)
}
}
impl<T, D> ExpandBeam<&[T]> for QuantAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
}
impl<'a, T, D> AsDeletionCheck for QuantAccessor<'a, T, D>
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
type Checker = D;
fn as_deletion_check(&self) -> &D {
&self.provider.deleted
}
}
pub struct HybridAccessor<'a, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
provider: &'a BfTreeProvider<T, QuantVectorProvider, D>,
}
impl<'a, T, D> HybridAccessor<'a, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
fn new(provider: &'a BfTreeProvider<T, QuantVectorProvider, D>) -> Self {
Self { provider }
}
}
impl<T, D> HasId for HybridAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Id = u32;
}
impl<'a, T, D> DelegateNeighbor<'a> for HybridAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Delegate = &'a NeighborProvider<u32>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
self.provider.neighbors()
}
}
impl<T, D> Accessor for HybridAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Element<'a>
= distances::pq::Hybrid<Vec<T>, Vec<u8>>
where
Self: 'a;
type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>;
type GetError = Panics;
fn get_element(
&mut self,
id: Self::Id,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
#[allow(clippy::expect_used)]
std::future::ready(Ok(distances::pq::Hybrid::Full(
self.provider
.full_vectors
.get_vector_sync(id.into_usize())
.expect("Full vector provider failed to retrieve element"),
)))
}
}
impl<T, D> BuildDistanceComputer for HybridAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type DistanceComputerError = ANNError;
type DistanceComputer = distances::pq::HybridComputer<T>;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError> {
let metric = self.provider.quant_vectors.metric();
Ok(distances::pq::HybridComputer::new(
self.provider.quant_vectors.distance_computer(),
T::distance(metric, Some(self.provider.full_vectors.dim())),
))
}
}
impl<T, D> workingset::Fill<distances::pq::HybridMap<T, u8>> for HybridAccessor<'_, T, D>
where
T: VectorRepr,
D: AsyncFriendly,
{
type Error = ANNError;
type View<'a>
= distances::pq::View<'a, T, u8>
where
Self: 'a;
async fn fill<'a, Itr>(
&'a mut self,
state: &'a mut distances::pq::HybridMap<T, u8>,
itr: Itr,
) -> Result<Self::View<'a>, Self::Error>
where
Itr: ExactSizeIterator<Item = Self::Id> + Clone + Send + Sync,
Self: 'a,
{
let map = state.get_mut();
map.prepare(itr.clone());
let threshold = self.provider.max_fp_vecs_per_fill;
itr.enumerate().try_for_each(|(i, id)| -> ANNResult<()> {
match map.entry(id) {
workingset::map::Entry::Seeded(_) => {}
workingset::map::Entry::Occupied(occupied) => {
if i < threshold && !occupied.get().is_full() {
*occupied.into_mut() = distances::pq::Hybrid::Full(
self.provider
.full_vectors
.get_vector_sync(id.into_usize())?,
);
}
}
workingset::map::Entry::Vacant(vacant) => {
let element = if i < threshold {
let vec = self
.provider
.full_vectors
.get_vector_sync(id.into_usize())?;
distances::pq::Hybrid::Full(vec)
} else {
let vec = self
.provider
.quant_vectors
.get_vector_sync(id.into_usize())?;
distances::pq::Hybrid::Quant(vec)
};
vacant.insert(element);
}
}
Ok(())
})?;
Ok(map.view())
}
}
impl<T, Q, D> SearchStrategy<BfTreeProvider<T, Q, D>, &[T]> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
{
type QueryComputer = T::QueryDistance;
type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>;
type SearchAccessorError = Panics;
fn search_accessor<'a>(
&'a self,
provider: &'a BfTreeProvider<T, Q, D>,
_context: &'a DefaultContext,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(FullAccessor::new(provider))
}
}
impl<T, Q, D> DefaultPostProcessor<BfTreeProvider<T, Q, D>, &[T]> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
{
default_post_processor!(glue::Pipeline<glue::FilterStartPoints, RemoveDeletedIdsAndCopy>);
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Rerank;
impl<'a, T, D> glue::SearchPostProcess<QuantAccessor<'a, T, D>, &[T]> for Rerank
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
type Error = Panics;
fn post_process<I, B>(
&self,
accessor: &mut QuantAccessor<'a, T, D>,
query: &[T],
_computer: &pq::distance::QueryComputer<Arc<FixedChunkPQTable>>,
candidates: I,
output: &mut B,
) -> impl Future<Output = Result<usize, Self::Error>> + Send
where
I: Iterator<Item = Neighbor<u32>>,
B: SearchOutputBuffer<u32> + ?Sized,
{
let provider = &accessor.provider;
let checker = accessor.as_deletion_check();
let f = T::distance(provider.metric, Some(provider.full_vectors.dim()));
let mut reranked: Vec<(u32, f32)> = candidates
.filter_map(|n| {
if checker.deletion_check(n.id) {
None
} else {
#[allow(clippy::expect_used)]
let vec = provider
.full_vectors
.get_vector_sync(n.id.into_usize())
.expect("Full vector provider failed to retrieve element");
Some((n.id, f.evaluate_similarity(query, &vec)))
}
})
.collect();
reranked
.sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
std::future::ready(Ok(output.extend(reranked)))
}
}
impl<T, D> SearchStrategy<BfTreeProvider<T, QuantVectorProvider, D>, &[T]> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
type QueryComputer = pq::distance::QueryComputer<Arc<FixedChunkPQTable>>;
type SearchAccessor<'a> = QuantAccessor<'a, T, D>;
type SearchAccessorError = Panics;
fn search_accessor<'a>(
&'a self,
provider: &'a BfTreeProvider<T, QuantVectorProvider, D>,
_context: &'a DefaultContext,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(QuantAccessor::new(provider))
}
}
impl<T, D> DefaultPostProcessor<BfTreeProvider<T, QuantVectorProvider, D>, &[T]> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
default_post_processor!(glue::Pipeline<glue::FilterStartPoints, Rerank>);
}
impl<T, Q, D> PruneStrategy<BfTreeProvider<T, Q, D>> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly,
{
type WorkingSet = map::Map<u32, Box<[T]>, map::Ref<[T]>>;
type DistanceComputer<'a> = T::Distance;
type PruneAccessor<'a> = FullAccessor<'a, T, Q, D>;
type PruneAccessorError = diskann::error::Infallible;
fn prune_accessor<'a>(
&'a self,
provider: &'a BfTreeProvider<T, Q, D>,
_context: &'a DefaultContext,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
Ok(FullAccessor::new(provider))
}
fn create_working_set(&self, capacity: usize) -> Self::WorkingSet {
map::Builder::new(map::Capacity::Default).build(capacity)
}
}
impl<T, D> PruneStrategy<BfTreeProvider<T, QuantVectorProvider, D>> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly,
{
type WorkingSet = distances::pq::HybridMap<T, u8>;
type DistanceComputer<'a> = distances::pq::HybridComputer<T>;
type PruneAccessor<'a> = HybridAccessor<'a, T, D>;
type PruneAccessorError = diskann::error::Infallible;
fn prune_accessor<'a>(
&'a self,
provider: &'a BfTreeProvider<T, QuantVectorProvider, D>,
_context: &'a DefaultContext,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
Ok(HybridAccessor::new(provider))
}
fn create_working_set(&self, capacity: usize) -> Self::WorkingSet {
distances::pq::HybridMap::with_capacity(capacity)
}
}
impl<T, Q, D> InsertStrategy<BfTreeProvider<T, Q, D>, &[T]> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
{
type PruneStrategy = Self;
fn prune_strategy(&self) -> Self::PruneStrategy {
*self
}
}
impl<T, D> InsertStrategy<BfTreeProvider<T, QuantVectorProvider, D>, &[T]> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
type PruneStrategy = Self;
fn prune_strategy(&self) -> Self::PruneStrategy {
*self
}
}
impl<T, Q, D, B> MultiInsertStrategy<BfTreeProvider<T, Q, D>, B> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
B: for<'a> Batch<Element<'a> = &'a [T]> + Debug,
{
type Seed = map::Builder<u32, map::Ref<[T]>>;
type WorkingSet = map::Map<u32, Box<[T]>, map::Ref<[T]>>;
type FinishError = diskann::error::Infallible;
type InsertStrategy = Self;
fn insert_strategy(&self) -> Self::InsertStrategy {
*self
}
fn finish<Itr>(
&self,
_provider: &BfTreeProvider<T, Q, D>,
_ctx: &DefaultContext,
batch: &std::sync::Arc<B>,
ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = u32> + Send,
{
let overlay = map::Overlay::from_batch(batch.clone(), ids);
let builder = map::Builder::new(map::Capacity::Default).with_overlay(overlay);
std::future::ready(Ok(builder))
}
}
impl<T, D, B> MultiInsertStrategy<BfTreeProvider<T, QuantVectorProvider, D>, B> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
B: for<'a> Batch<Element<'a> = &'a [T]> + Debug,
{
type Seed = distances::pq::Overlay<T, u8>;
type WorkingSet = distances::pq::HybridMap<T, u8>;
type FinishError = diskann::error::Infallible;
type InsertStrategy = Self;
fn insert_strategy(&self) -> Self::InsertStrategy {
*self
}
fn finish<Itr>(
&self,
_provider: &BfTreeProvider<T, QuantVectorProvider, D>,
_ctx: &DefaultContext,
batch: &std::sync::Arc<B>,
ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = u32> + Send,
{
let overlay = Self::Seed::from_batch(batch.clone(), ids);
std::future::ready(Ok(overlay))
}
}
impl<T, Q, D> InplaceDeleteStrategy<BfTreeProvider<T, Q, D>> for FullPrecision
where
T: VectorRepr,
Q: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
{
type DeleteElementError = Panics;
type DeleteElement<'a> = &'a [T];
type DeleteElementGuard = Box<[T]>;
type PruneStrategy = Self;
type DeleteSearchAccessor<'a> = FullAccessor<'a, T, Q, D>;
type SearchPostProcessor = RemoveDeletedIdsAndCopy;
type SearchStrategy = Self;
fn search_strategy(&self) -> Self::SearchStrategy {
Self
}
fn prune_strategy(&self) -> Self::PruneStrategy {
Self
}
fn search_post_processor(&self) -> Self::SearchPostProcessor {
RemoveDeletedIdsAndCopy
}
async fn get_delete_element<'a>(
&'a self,
provider: &'a BfTreeProvider<T, Q, D>,
_context: &'a DefaultContext,
id: u32,
) -> Result<Self::DeleteElementGuard, Self::DeleteElementError> {
#[allow(clippy::expect_used)]
let elt = provider
.full_vectors
.get_vector_sync(id.into_usize())
.expect("Failed to get delete element")
.into();
Ok(elt)
}
}
impl<T, D> InplaceDeleteStrategy<BfTreeProvider<T, QuantVectorProvider, D>> for Hybrid
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
{
type DeleteElementError = Panics;
type DeleteElement<'a> = &'a [T];
type DeleteElementGuard = Box<[T]>;
type PruneStrategy = Self;
type DeleteSearchAccessor<'a> = QuantAccessor<'a, T, D>;
type SearchPostProcessor = Rerank;
type SearchStrategy = Self;
fn search_strategy(&self) -> Self::SearchStrategy {
*self
}
fn prune_strategy(&self) -> Self::PruneStrategy {
*self
}
fn search_post_processor(&self) -> Self::SearchPostProcessor {
Rerank
}
async fn get_delete_element<'a>(
&'a self,
provider: &'a BfTreeProvider<T, QuantVectorProvider, D>,
_context: &'a DefaultContext,
id: u32,
) -> Result<Self::DeleteElementGuard, Self::DeleteElementError> {
#[allow(clippy::expect_used)]
let elt = provider
.full_vectors
.get_vector_sync(id.into_usize())
.expect("Failed to get delete element")
.into();
Ok(elt)
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct BfTreeParams {
pub bytes: usize,
pub max_record_size: usize,
pub leaf_page_size: usize,
}
impl BfTreeParams {
pub fn to_config(&self, path: &std::path::Path, is_memory: bool) -> Config {
let mut config = Config::new(path, self.bytes);
config.cb_max_record_size(self.max_record_size);
config.leaf_page_size(self.leaf_page_size);
if is_memory {
config.storage_backend(bf_tree::StorageBackend::Memory);
} else {
config.storage_backend(bf_tree::StorageBackend::Std);
}
config
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct QuantParams {
pub num_pq_bytes: usize,
pub max_fp_vecs_per_fill: usize,
pub params_quant: BfTreeParams,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SavedParams {
pub max_points: usize,
pub frozen_points: NonZeroUsize,
pub dim: usize,
pub metric: String,
pub max_degree: u32,
pub prefix: String,
pub params_vector: BfTreeParams,
pub params_neighbor: BfTreeParams,
pub quant_params: Option<QuantParams>,
pub graph_params: Option<GraphParams>,
pub is_memory: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum VectorDtype {
F32,
F16,
U8,
I8,
}
pub trait AsVectorDtype {
const DATA_TYPE: VectorDtype;
}
impl AsVectorDtype for f32 {
const DATA_TYPE: VectorDtype = VectorDtype::F32;
}
impl AsVectorDtype for half::f16 {
const DATA_TYPE: VectorDtype = VectorDtype::F16;
}
impl AsVectorDtype for i8 {
const DATA_TYPE: VectorDtype = VectorDtype::I8;
}
impl AsVectorDtype for u8 {
const DATA_TYPE: VectorDtype = VectorDtype::U8;
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GraphParams {
pub l_build: usize,
pub alpha: f32,
pub backedge_ratio: f32,
pub vector_dtype: VectorDtype,
}
pub struct BfTreePaths;
impl BfTreePaths {
pub fn params_json(prefix: &str) -> String {
format!("{}_params.json", prefix)
}
pub fn vectors_bftree(prefix: &str) -> std::path::PathBuf {
std::path::PathBuf::from(format!("{}_vectors.bftree", prefix))
}
pub fn neighbors_bftree(prefix: &str) -> std::path::PathBuf {
std::path::PathBuf::from(format!("{}_neighbors.bftree", prefix))
}
pub fn quant_bftree(prefix: &str) -> std::path::PathBuf {
std::path::PathBuf::from(format!("{}_quant.bftree", prefix))
}
pub fn delete_bin(prefix: &str) -> String {
format!("{}_delete.bin", prefix)
}
pub fn pq_pivots_bin(prefix: &str) -> String {
format!("{}_pq_pivots.bin", prefix)
}
}
async fn copy_snapshot_if_needed(
snapshot_path: std::path::PathBuf,
target_path: std::path::PathBuf,
) -> ANNResult<()> {
if snapshot_path != target_path {
tokio::task::spawn_blocking(move || {
std::fs::copy(&snapshot_path, &target_path).map_err(|e| {
ANNError::log_index_error(format!(
"Failed to copy snapshot from {:?} to {:?}: {}",
snapshot_path, target_path, e
))
})
})
.await
.map_err(|e| ANNError::log_index_error(format!("Blocking copy task failed: {}", e)))??;
}
Ok(())
}
async fn save_bftree(tree: &BfTree, target_path: std::path::PathBuf) -> ANNResult<()> {
if tree.config().is_memory_backend() {
tree.snapshot_memory_to_disk(&target_path);
} else {
let snapshot_path = tree.snapshot();
copy_snapshot_if_needed(snapshot_path, target_path).await?;
}
Ok(())
}
fn load_bftree(
params: &BfTreeParams,
snapshot_path: std::path::PathBuf,
is_memory: bool,
) -> Result<BfTree, ANNError> {
let config = params.to_config(&snapshot_path, is_memory);
if is_memory {
BfTree::new_from_snapshot_disk_to_memory(snapshot_path, config)
.map_err(|e| ANNError::from(super::ConfigError(e)))
} else {
BfTree::new_from_snapshot(config, None).map_err(|e| ANNError::from(super::ConfigError(e)))
}
}
impl<T> SaveWith<String> for BfTreeProvider<T, NoStore, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(&self, storage: &P, prefix: &String) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let saved_params = SavedParams {
max_points: self.max_points(),
frozen_points: NonZeroUsize::new(self.num_start_points())
.ok_or_else(|| ANNError::log_index_error("num_start_points is zero"))?,
dim: self.dim(),
metric: self.metric().as_str().to_string(),
max_degree: self.max_degree(),
prefix: prefix.clone(),
params_vector: BfTreeParams {
bytes: self.full_vectors.config().get_cb_size_byte(),
max_record_size: self.full_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.full_vectors.config().get_leaf_page_size(),
},
params_neighbor: BfTreeParams {
bytes: self.neighbor_provider.config().get_cb_size_byte(),
max_record_size: self.neighbor_provider.config().get_cb_max_record_size(),
leaf_page_size: self.neighbor_provider.config().get_leaf_page_size(),
},
quant_params: None, graph_params: self.graph_params.clone(),
is_memory: self.full_vectors.config().is_memory_backend(),
};
debug_assert_eq!(
self.full_vectors.config().is_memory_backend(),
self.neighbor_provider.config().is_memory_backend(),
"Vector and neighbor stores have mismatched storage backends"
);
{
let params_filename = BfTreePaths::params_json(&saved_params.prefix);
let params_json = serde_json::to_string(&saved_params).map_err(|e| {
ANNError::log_index_error(format!("Failed to serialize params: {}", e))
})?;
let mut params_writer = storage.create_for_write(¶ms_filename)?;
params_writer.write_all(params_json.as_bytes())?;
}
save_bftree(
self.full_vectors.bftree(),
BfTreePaths::vectors_bftree(&saved_params.prefix),
)
.await?;
save_bftree(
self.neighbor_provider.bftree(),
BfTreePaths::neighbors_bftree(&saved_params.prefix),
)
.await?;
{
let filename = BfTreePaths::delete_bin(&saved_params.prefix);
let bitmap_bytes = self.deleted.to_bytes();
let mut writer = storage.create_for_write(&filename)?;
writer.write_all(&bitmap_bytes)?;
}
Ok(0)
}
}
impl<T> LoadWith<String> for BfTreeProvider<T, NoStore, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Error = ANNError;
async fn load_with<P>(storage: &P, prefix: &String) -> Result<Self, Self::Error>
where
P: StorageReadProvider,
{
let saved_params: SavedParams = {
let params_filename = BfTreePaths::params_json(prefix);
let mut params_reader = storage.open_reader(¶ms_filename)?;
let mut params_json = String::new();
params_reader.read_to_string(&mut params_json)?;
serde_json::from_str(¶ms_json).map_err(|e| {
ANNError::log_index_error(format!("Failed to deserialize params: {}", e))
})?
};
let metric = Metric::from_str(&saved_params.metric)
.map_err(|e| ANNError::log_index_error(format!("Failed to parse metric: {}", e)))?;
let vector_index = load_bftree(
&saved_params.params_vector,
BfTreePaths::vectors_bftree(&saved_params.prefix),
saved_params.is_memory,
)?;
let full_vectors = VectorProvider::<T>::new_from_bftree(
saved_params.max_points,
saved_params.dim,
saved_params.frozen_points.get(),
vector_index,
);
let adjacency_list_index = load_bftree(
&saved_params.params_neighbor,
BfTreePaths::neighbors_bftree(&saved_params.prefix),
saved_params.is_memory,
)?;
let neighbor_provider =
NeighborProvider::<u32>::new_from_bftree(saved_params.max_degree, adjacency_list_index);
let total_points = saved_params.max_points + saved_params.frozen_points.get();
let filename = BfTreePaths::delete_bin(&saved_params.prefix);
let deleted = if storage.exists(&filename) {
let mut reader = storage.open_reader(&filename)?;
let mut bitmap_bytes = Vec::new();
reader.read_to_end(&mut bitmap_bytes)?;
TableDeleteProviderAsync::from_bytes(&bitmap_bytes, total_points)
.map_err(|e| ANNError::log_index_error(e))?
} else {
TableDeleteProviderAsync::new(total_points)
};
Ok(Self {
quant_vectors: NoStore,
full_vectors,
neighbor_provider,
deleted,
max_fp_vecs_per_fill: 0,
metric,
graph_params: saved_params.graph_params,
})
}
}
impl<T> SaveWith<String> for BfTreeProvider<T, QuantVectorProvider, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(&self, storage: &P, prefix: &String) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let saved_params = SavedParams {
max_points: self.max_points(),
frozen_points: NonZeroUsize::new(self.num_start_points())
.ok_or_else(|| ANNError::log_index_error("num_start_points is zero"))?,
dim: self.dim(),
metric: self.metric().as_str().to_string(),
max_degree: self.max_degree(),
prefix: prefix.clone(),
params_vector: BfTreeParams {
bytes: self.full_vectors.config().get_cb_size_byte(),
max_record_size: self.full_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.full_vectors.config().get_leaf_page_size(),
},
params_neighbor: BfTreeParams {
bytes: self.neighbor_provider.config().get_cb_size_byte(),
max_record_size: self.neighbor_provider.config().get_cb_max_record_size(),
leaf_page_size: self.neighbor_provider.config().get_leaf_page_size(),
},
quant_params: Some(QuantParams {
num_pq_bytes: self.quant_vectors.pq_chunks(),
max_fp_vecs_per_fill: self.max_fp_vecs_per_fill,
params_quant: BfTreeParams {
bytes: self.quant_vectors.config().get_cb_size_byte(),
max_record_size: self.quant_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.quant_vectors.config().get_leaf_page_size(),
},
}),
graph_params: self.graph_params.clone(),
is_memory: self.full_vectors.config().is_memory_backend(),
};
debug_assert_eq!(
self.full_vectors.config().is_memory_backend(),
self.neighbor_provider.config().is_memory_backend(),
"Vector and neighbor stores have mismatched storage backends"
);
debug_assert_eq!(
self.full_vectors.config().is_memory_backend(),
self.quant_vectors.config().is_memory_backend(),
"Vector and quant stores have mismatched storage backends"
);
{
let params_filename = BfTreePaths::params_json(&saved_params.prefix);
let params_json = serde_json::to_string(&saved_params).map_err(|e| {
ANNError::log_index_error(format!("Failed to serialize params: {}", e))
})?;
let mut params_writer = storage.create_for_write(¶ms_filename)?;
params_writer.write_all(params_json.as_bytes())?;
}
save_bftree(
self.full_vectors.bftree(),
BfTreePaths::vectors_bftree(&saved_params.prefix),
)
.await?;
save_bftree(
self.neighbor_provider.bftree(),
BfTreePaths::neighbors_bftree(&saved_params.prefix),
)
.await?;
save_bftree(
self.quant_vectors.bftree(),
BfTreePaths::quant_bftree(&saved_params.prefix),
)
.await?;
let filename = BfTreePaths::pq_pivots_bin(&saved_params.prefix);
let pq_storage = PQStorage::new(&filename, "", None);
let pq_table = &self.quant_vectors.pq_chunk_table;
pq_storage.write_pivot_data(
pq_table.get_pq_table(),
pq_table.get_centroids(),
pq_table.get_chunk_offsets(),
NUM_PQ_CENTROIDS,
pq_table.get_dim(),
storage,
)?;
{
let filename = BfTreePaths::delete_bin(&saved_params.prefix);
let bitmap_bytes = self.deleted.to_bytes();
let mut writer = storage.create_for_write(&filename)?;
writer.write_all(&bitmap_bytes)?;
}
Ok(0)
}
}
impl<T> LoadWith<String> for BfTreeProvider<T, QuantVectorProvider, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Error = ANNError;
async fn load_with<P>(storage: &P, prefix: &String) -> Result<Self, Self::Error>
where
P: StorageReadProvider,
{
let saved_params: SavedParams = {
let params_filename = BfTreePaths::params_json(prefix);
let mut params_reader = storage.open_reader(¶ms_filename)?;
let mut params_json = String::new();
params_reader.read_to_string(&mut params_json)?;
serde_json::from_str(¶ms_json).map_err(|e| {
ANNError::log_index_error(format!("Failed to deserialize params: {}", e))
})?
};
let quant_params = saved_params.quant_params.ok_or_else(|| {
ANNError::log_index_error("Missing quant_params in saved params for quantized provider")
})?;
let metric = Metric::from_str(&saved_params.metric)
.map_err(|e| ANNError::log_index_error(format!("Failed to parse metric: {}", e)))?;
let vector_index = load_bftree(
&saved_params.params_vector,
BfTreePaths::vectors_bftree(&saved_params.prefix),
saved_params.is_memory,
)?;
let full_vectors = VectorProvider::<T>::new_from_bftree(
saved_params.max_points,
saved_params.dim,
saved_params.frozen_points.get(),
vector_index,
);
let adjacency_list_index = load_bftree(
&saved_params.params_neighbor,
BfTreePaths::neighbors_bftree(&saved_params.prefix),
saved_params.is_memory,
)?;
let neighbor_provider =
NeighborProvider::<u32>::new_from_bftree(saved_params.max_degree, adjacency_list_index);
let filename = BfTreePaths::pq_pivots_bin(&saved_params.prefix);
let pq_storage = PQStorage::new(&filename, "", None);
let pq_table =
pq_storage.load_pq_pivots_bin(&filename, quant_params.num_pq_bytes, storage)?;
let quant_vector_index = load_bftree(
&quant_params.params_quant,
BfTreePaths::quant_bftree(&saved_params.prefix),
saved_params.is_memory,
)?;
let quant_vectors = QuantVectorProvider::new_from_bftree(
metric,
saved_params.max_points,
saved_params.frozen_points.get(),
pq_table.clone(),
quant_vector_index,
);
let total_points = saved_params.max_points + saved_params.frozen_points.get();
let filename = BfTreePaths::delete_bin(&saved_params.prefix);
let deleted = if storage.exists(&filename) {
let mut reader = storage.open_reader(&filename)?;
let mut bitmap_bytes = Vec::new();
reader.read_to_end(&mut bitmap_bytes)?;
TableDeleteProviderAsync::from_bytes(&bitmap_bytes, total_points)
.map_err(|e| ANNError::log_index_error(e))?
} else {
TableDeleteProviderAsync::new(total_points)
};
Ok(Self {
quant_vectors,
full_vectors,
neighbor_provider,
deleted,
max_fp_vecs_per_fill: quant_params.max_fp_vecs_per_fill,
metric,
graph_params: saved_params.graph_params,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::graph::provider::async_::common::TableBasedDeletes;
use crate::storage::file_storage_provider::FileStorageProvider;
#[tokio::test]
async fn test_data_provider_and_delete_interface() {
let ctx = &DefaultContext;
let provider = BfTreeProvider::new_empty(
BfTreeProviderParameters {
max_points: 10,
num_start_points: NonZeroUsize::new(2).unwrap(),
dim: 5,
metric: Metric::L2,
max_fp_vecs_per_fill: None,
max_degree: 64,
vector_provider_config: Config::default(),
quant_vector_provider_config: Config::default(),
neighbor_list_provider_config: Config::default(),
graph_params: None,
},
NoStore,
TableBasedDeletes,
)
.unwrap();
assert_eq!((&provider).into_iter(), 0..(10 + 2));
let iter = provider.iter();
for i in iter.clone() {
assert_eq!(provider.to_external_id(ctx, i).unwrap(), i);
assert_eq!(provider.to_internal_id(ctx, &i).unwrap(), i);
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
provider.delete(ctx, &i).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Deleted
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Deleted
);
}
for i in iter.clone() {
provider
.neighbor_provider
.set_neighbors(i, &[1, 2])
.unwrap();
provider.release(ctx, i).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
let mut neighbors = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut neighbors)
.unwrap();
assert!(neighbors.to_vec().is_empty());
provider.delete(ctx, &i).await.unwrap();
}
provider.clear_delete_set();
for i in iter.clone() {
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
}
assert!(
provider
.set_element(ctx, &100, &[1.0, 2.0, 3.0, 4.0])
.await
.is_err()
);
}
#[tokio::test]
async fn test_empty_neighbor_list() {
let num_points = 100u32;
let ctx = &DefaultContext;
let provider = BfTreeProvider::<f32, _, _>::new_empty(
BfTreeProviderParameters {
max_points: num_points as usize,
num_start_points: NonZeroUsize::new(2).unwrap(),
dim: 3,
metric: Metric::L2,
max_fp_vecs_per_fill: None,
max_degree: 64,
vector_provider_config: Config::default(),
quant_vector_provider_config: Config::default(),
neighbor_list_provider_config: Config::default(),
graph_params: None,
},
NoStore,
TableBasedDeletes,
)
.unwrap();
let neighbor_accessor = &mut provider.neighbors();
for i in 0..num_points {
let vector = vec![i as f32, (i + 1) as f32, (i + 2) as f32];
provider.set_element(ctx, &i, &vector).await.unwrap();
let mut out = AdjacencyList::new();
assert!(neighbor_accessor.get_neighbors(i, &mut out).await.is_err());
neighbor_accessor.set_neighbors(i, &[]).await.unwrap();
neighbor_accessor.get_neighbors(i, &mut out).await.unwrap();
assert!(out.is_empty());
}
for i in 0..num_points {
let mut out = AdjacencyList::new();
let neighbors = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100];
neighbor_accessor
.set_neighbors(i, &neighbors)
.await
.unwrap();
neighbor_accessor.get_neighbors(i, &mut out).await.unwrap();
assert_eq!(&*out, &[10, 20, 30, 40, 50, 60, 70, 80, 90, 100]);
neighbor_accessor.set_neighbors(i, &[]).await.unwrap();
neighbor_accessor.get_neighbors(i, &mut out).await.unwrap();
assert!(out.is_empty());
}
let mut out = AdjacencyList::from_iter_untrusted([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]);
assert!(
neighbor_accessor
.get_neighbors(200, &mut out)
.await
.is_err()
);
assert!(out.is_empty());
}
use tempfile::tempdir;
#[tokio::test]
async fn test_bf_tree_provider_save_load_no_quant() {
let num_points = 50usize;
let dim = 4usize;
let max_degree = 32u32;
let num_start_points = NonZeroUsize::new(2).unwrap();
let ctx = &DefaultContext;
let temp_dir = tempdir().unwrap();
let temp_path = temp_dir.path();
let prefix = temp_path
.join("test_bf_tree_provider")
.to_string_lossy()
.to_string();
let vector_path = BfTreePaths::vectors_bftree(&prefix);
let neighbor_path = BfTreePaths::neighbors_bftree(&prefix);
let bytes_vector = 1024 * 1024;
let mut vector_config = Config::new(&vector_path, bytes_vector);
vector_config.leaf_page_size(8192);
vector_config.cb_max_record_size(1024);
vector_config.storage_backend(bf_tree::StorageBackend::Std);
let bytes_neighbor = 1024 * 1024;
let mut neighbor_config = Config::new(&neighbor_path, bytes_neighbor);
neighbor_config.storage_backend(bf_tree::StorageBackend::Std);
let params = BfTreeProviderParameters {
max_points: num_points,
num_start_points,
dim,
metric: Metric::L2,
max_fp_vecs_per_fill: None,
max_degree,
vector_provider_config: vector_config.clone(),
quant_vector_provider_config: Config::default(),
neighbor_list_provider_config: neighbor_config.clone(),
graph_params: None,
};
let provider = BfTreeProvider::<f32, NoStore, TableDeleteProviderAsync>::new_empty(
params.clone(),
NoStore,
TableBasedDeletes,
)
.unwrap();
for i in 0..num_points {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.1).collect();
provider
.set_element(ctx, &(i as u32), &vector)
.await
.unwrap();
}
let neighbor_accessor = &mut provider.neighbors();
for i in 0..num_points as u32 {
let neighbors: Vec<u32> = (0..std::cmp::min(i, max_degree))
.map(|j| (i + j) % num_points as u32)
.collect();
neighbor_accessor
.set_neighbors(i, &neighbors)
.await
.unwrap();
}
let deleted_ids = vec![5u32, 10u32, 15u32, 20u32, 25u32];
for id in &deleted_ids {
provider.delete(ctx, id).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, *id).await.unwrap(),
ElementStatus::Deleted
);
}
assert_eq!(vector_config.get_leaf_page_size(), 8192);
assert_eq!(vector_config.get_cb_max_record_size(), 1024);
let storage = FileStorageProvider;
let save_dir = tempdir().unwrap();
let save_prefix = save_dir
.path()
.join("saved_bf_tree_provider")
.to_string_lossy()
.to_string();
provider.save_with(&storage, &save_prefix).await.unwrap();
let loaded_provider = BfTreeProvider::<f32, NoStore, TableDeleteProviderAsync>::load_with(
&storage,
&save_prefix,
)
.await
.unwrap();
for i in 0..num_points as u32 {
let original = provider.full_vectors.get_vector_sync(i as usize).unwrap();
let loaded = loaded_provider
.full_vectors
.get_vector_sync(i as usize)
.unwrap();
assert_eq!(original, loaded, "Vector mismatch at index {}", i);
}
for i in 0..num_points as u32 {
let mut original_list = AdjacencyList::new();
let mut loaded_list = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut original_list)
.unwrap();
loaded_provider
.neighbor_provider
.get_neighbors(i, &mut loaded_list)
.unwrap();
assert_eq!(
&*original_list, &*loaded_list,
"Neighbor list mismatch at index {}",
i
);
}
for id in &deleted_ids {
assert_eq!(
loaded_provider
.status_by_internal_id(ctx, *id)
.await
.unwrap(),
ElementStatus::Deleted,
"Deletion status not preserved for id {}",
id
);
}
for i in 0..num_points as u32 {
if !deleted_ids.contains(&i) {
assert_eq!(
loaded_provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid,
"Non-deleted vector {} incorrectly marked as deleted",
i
);
}
}
}
#[tokio::test]
async fn test_bf_tree_provider_save_load_quant() {
let num_points = 50usize;
let dim = 8usize;
let max_degree = 32u32;
let num_start_points = NonZeroUsize::new(2).unwrap();
let ctx = &DefaultContext;
let temp_dir = tempdir().unwrap();
let temp_path = temp_dir.path();
let prefix = temp_path
.join("test_bf_tree_provider_quant")
.to_string_lossy()
.to_string();
let vector_path = BfTreePaths::vectors_bftree(&prefix);
let neighbor_path = BfTreePaths::neighbors_bftree(&prefix);
let quant_path = BfTreePaths::quant_bftree(&prefix);
let bytes_vector = 1024 * 1024;
let mut vector_config = Config::new(&vector_path, bytes_vector);
vector_config.storage_backend(bf_tree::StorageBackend::Std);
let bytes_neighbor = 1024 * 1024;
let mut neighbor_config = Config::new(&neighbor_path, bytes_neighbor);
neighbor_config.storage_backend(bf_tree::StorageBackend::Std);
let bytes_quant = 1024 * 1024;
let mut quant_config = Config::new(&quant_path, bytes_quant);
quant_config.storage_backend(bf_tree::StorageBackend::Std);
let pq_table = FixedChunkPQTable::new(
dim,
vec![0.0; dim * 256].into_boxed_slice(),
vec![0.0; dim].into_boxed_slice(),
Box::new([0, 4, dim]),
)
.unwrap();
let params = BfTreeProviderParameters {
max_points: num_points,
num_start_points,
dim,
metric: Metric::L2,
max_fp_vecs_per_fill: Some(10),
max_degree,
vector_provider_config: vector_config.clone(),
quant_vector_provider_config: quant_config.clone(),
neighbor_list_provider_config: neighbor_config.clone(),
graph_params: None,
};
let provider =
BfTreeProvider::<f32, QuantVectorProvider, TableDeleteProviderAsync>::new_empty(
params.clone(),
pq_table.clone(),
TableBasedDeletes,
)
.unwrap();
for i in 0..num_points {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.1).collect();
provider
.set_element(ctx, &(i as u32), &vector)
.await
.unwrap();
}
let neighbor_accessor = &mut provider.neighbors();
for i in 0..num_points as u32 {
let neighbors: Vec<u32> = (0..std::cmp::min(i, max_degree))
.map(|j| (i + j) % num_points as u32)
.collect();
neighbor_accessor
.set_neighbors(i, &neighbors)
.await
.unwrap();
}
let deleted_ids = vec![3u32, 8u32, 15u32, 22u32, 30u32];
for id in &deleted_ids {
provider.delete(ctx, id).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, *id).await.unwrap(),
ElementStatus::Deleted
);
}
let storage = FileStorageProvider;
let save_dir = tempdir().unwrap();
let save_prefix = save_dir
.path()
.join("saved_bf_tree_provider_quant")
.to_string_lossy()
.to_string();
provider.save_with(&storage, &save_prefix).await.unwrap();
let loaded_provider =
BfTreeProvider::<f32, QuantVectorProvider, TableDeleteProviderAsync>::load_with(
&storage,
&save_prefix,
)
.await
.unwrap();
let original_pq = &provider.quant_vectors.pq_chunk_table;
let loaded_pq = &loaded_provider.quant_vectors.pq_chunk_table;
assert_eq!(
original_pq.get_dim(),
loaded_pq.get_dim(),
"PQ table dim mismatch"
);
assert_eq!(
original_pq.get_num_chunks(),
loaded_pq.get_num_chunks(),
"PQ table num_chunks mismatch"
);
assert_eq!(
original_pq.get_num_centers(),
loaded_pq.get_num_centers(),
"PQ table num_centers mismatch"
);
assert_eq!(
original_pq.get_pq_table(),
loaded_pq.get_pq_table(),
"PQ table data mismatch"
);
assert_eq!(
original_pq.get_centroids(),
loaded_pq.get_centroids(),
"PQ table centroids mismatch"
);
assert_eq!(
original_pq.get_chunk_offsets(),
loaded_pq.get_chunk_offsets(),
"PQ table chunk_offsets mismatch"
);
for i in 0..num_points as u32 {
let original = provider.full_vectors.get_vector_sync(i as usize).unwrap();
let loaded = loaded_provider
.full_vectors
.get_vector_sync(i as usize)
.unwrap();
assert_eq!(original, loaded, "Vector mismatch at index {}", i);
}
for i in 0..num_points as u32 {
let original = provider.quant_vectors.get_vector_sync(i as usize).unwrap();
let loaded = loaded_provider
.quant_vectors
.get_vector_sync(i as usize)
.unwrap();
assert_eq!(original, loaded, "Quant vector mismatch at index {}", i);
}
for i in 0..num_points as u32 {
let mut original_list = AdjacencyList::new();
let mut loaded_list = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut original_list)
.unwrap();
loaded_provider
.neighbor_provider
.get_neighbors(i, &mut loaded_list)
.unwrap();
assert_eq!(
&*original_list, &*loaded_list,
"Neighbor list mismatch at index {}",
i
);
}
for id in &deleted_ids {
assert_eq!(
loaded_provider
.status_by_internal_id(ctx, *id)
.await
.unwrap(),
ElementStatus::Deleted,
"Deletion status not preserved for id {}",
id
);
}
for i in 0..num_points as u32 {
if !deleted_ids.contains(&i) {
assert_eq!(
loaded_provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid,
"Non-deleted vector {} incorrectly marked as deleted",
i
);
}
}
}
#[tokio::test]
async fn test_bf_tree_provider_memory_save_load_no_quant() {
let num_points = 20usize;
let dim = 4usize;
let max_degree = 16u32;
let num_start_points = NonZeroUsize::new(1).unwrap();
let ctx = &DefaultContext;
let provider = BfTreeProvider::<f32, NoStore, TableDeleteProviderAsync>::new_empty(
BfTreeProviderParameters {
max_points: num_points,
num_start_points,
dim,
metric: Metric::L2,
max_fp_vecs_per_fill: None,
max_degree,
vector_provider_config: Config::default(),
quant_vector_provider_config: Config::default(),
neighbor_list_provider_config: Config::default(),
graph_params: None,
},
NoStore,
TableBasedDeletes,
)
.unwrap();
for i in 0..num_points {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.1).collect();
provider
.set_element(ctx, &(i as u32), &vector)
.await
.unwrap();
}
let neighbor_accessor = &mut provider.neighbors();
for i in 0..num_points as u32 {
let neighbors: Vec<u32> = (0..std::cmp::min(i, max_degree))
.map(|j| (i + j) % num_points as u32)
.collect();
neighbor_accessor
.set_neighbors(i, &neighbors)
.await
.unwrap();
}
provider.delete(ctx, &3u32).await.unwrap();
provider.delete(ctx, &7u32).await.unwrap();
let save_dir = tempdir().unwrap();
let save_prefix = save_dir
.path()
.join("mem_no_quant")
.to_string_lossy()
.to_string();
let storage = FileStorageProvider;
provider.save_with(&storage, &save_prefix).await.unwrap();
let loaded = BfTreeProvider::<f32, NoStore, TableDeleteProviderAsync>::load_with(
&storage,
&save_prefix,
)
.await
.unwrap();
for i in 0..num_points as u32 {
assert_eq!(
provider.full_vectors.get_vector_sync(i as usize).unwrap(),
loaded.full_vectors.get_vector_sync(i as usize).unwrap(),
"Vector mismatch at {}",
i
);
}
for i in 0..num_points as u32 {
let mut orig = AdjacencyList::new();
let mut load = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut orig)
.unwrap();
loaded
.neighbor_provider
.get_neighbors(i, &mut load)
.unwrap();
assert_eq!(&*orig, &*load, "Neighbor mismatch at {}", i);
}
assert_eq!(
loaded.status_by_internal_id(ctx, 3).await.unwrap(),
ElementStatus::Deleted
);
assert_eq!(
loaded.status_by_internal_id(ctx, 7).await.unwrap(),
ElementStatus::Deleted
);
assert_eq!(
loaded.status_by_internal_id(ctx, 0).await.unwrap(),
ElementStatus::Valid
);
}
#[tokio::test]
async fn test_bf_tree_provider_memory_save_load_quant() {
let num_points = 20usize;
let dim = 8usize;
let max_degree = 16u32;
let num_start_points = NonZeroUsize::new(1).unwrap();
let ctx = &DefaultContext;
let pq_table = FixedChunkPQTable::new(
dim,
vec![0.0; dim * 256].into_boxed_slice(),
vec![0.0; dim].into_boxed_slice(),
Box::new([0, 4, dim]),
)
.unwrap();
let provider =
BfTreeProvider::<f32, QuantVectorProvider, TableDeleteProviderAsync>::new_empty(
BfTreeProviderParameters {
max_points: num_points,
num_start_points,
dim,
metric: Metric::L2,
max_fp_vecs_per_fill: Some(5),
max_degree,
vector_provider_config: Config::default(),
quant_vector_provider_config: Config::default(),
neighbor_list_provider_config: Config::default(),
graph_params: None,
},
pq_table,
TableBasedDeletes,
)
.unwrap();
for i in 0..num_points {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.1).collect();
provider
.set_element(ctx, &(i as u32), &vector)
.await
.unwrap();
}
let neighbor_accessor = &mut provider.neighbors();
for i in 0..num_points as u32 {
let neighbors: Vec<u32> = (0..std::cmp::min(i, max_degree))
.map(|j| (i + j) % num_points as u32)
.collect();
neighbor_accessor
.set_neighbors(i, &neighbors)
.await
.unwrap();
}
provider.delete(ctx, &2u32).await.unwrap();
let save_dir = tempdir().unwrap();
let save_prefix = save_dir
.path()
.join("mem_quant")
.to_string_lossy()
.to_string();
let storage = FileStorageProvider;
provider.save_with(&storage, &save_prefix).await.unwrap();
let loaded =
BfTreeProvider::<f32, QuantVectorProvider, TableDeleteProviderAsync>::load_with(
&storage,
&save_prefix,
)
.await
.unwrap();
for i in 0..num_points as u32 {
assert_eq!(
provider.full_vectors.get_vector_sync(i as usize).unwrap(),
loaded.full_vectors.get_vector_sync(i as usize).unwrap(),
"Vector mismatch at {}",
i
);
}
for i in 0..num_points as u32 {
assert_eq!(
provider.quant_vectors.get_vector_sync(i as usize).unwrap(),
loaded.quant_vectors.get_vector_sync(i as usize).unwrap(),
"Quant vector mismatch at {}",
i
);
}
for i in 0..num_points as u32 {
let mut orig = AdjacencyList::new();
let mut load = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut orig)
.unwrap();
loaded
.neighbor_provider
.get_neighbors(i, &mut load)
.unwrap();
assert_eq!(&*orig, &*load, "Neighbor mismatch at {}", i);
}
assert_eq!(
loaded.status_by_internal_id(ctx, 2).await.unwrap(),
ElementStatus::Deleted
);
assert_eq!(
loaded.status_by_internal_id(ctx, 0).await.unwrap(),
ElementStatus::Valid
);
}
}