use std::{
borrow::Cow,
collections::{HashMap, HashSet},
num::NonZeroUsize,
sync::Arc,
};
use dashmap::{DashMap, mapref::entry::Entry};
use diskann_utils::views::Matrix;
use diskann_vector::distance::Metric;
use thiserror::Error;
use crate::{
ANNError, ANNResult, default_post_processor,
error::{Infallible, RankedError, StandardError, ToRanked, TransientError, message},
graph::{AdjacencyList, SearchOutputBuffer, glue, test::synthetic, workingset},
internal::counter::{Counter, LocalCounter},
neighbor::Neighbor,
provider,
utils::VectorRepr,
};
#[derive(Debug)]
pub struct StartPoint {
id: u32,
vector: Vec<f32>,
}
impl StartPoint {
pub fn new(id: u32, vector: Vec<f32>) -> Self {
Self { id, vector }
}
pub fn id(&self) -> u32 {
self.id
}
pub fn vector(&self) -> &[f32] {
&self.vector
}
}
impl IntoIterator for StartPoint {
type Item = Self;
type IntoIter = std::iter::Once<Self>;
fn into_iter(self) -> Self::IntoIter {
std::iter::once(self)
}
}
#[derive(Debug, Clone)]
pub struct Config {
start_points: HashMap<u32, Vec<f32>>,
max_degree: NonZeroUsize,
dim: NonZeroUsize,
metric: Metric,
}
impl Config {
pub fn new<I>(metric: Metric, max_degree: usize, start_points: I) -> Result<Self, ConfigError>
where
I: IntoIterator<Item = StartPoint>,
{
let max_degree = match NonZeroUsize::new(max_degree) {
Some(max_degree) => max_degree,
None => return Err(ConfigError::MaxDegreeCannotBeZero),
};
let mut dim: Option<NonZeroUsize> = None;
let mut count = 0;
let start_points = start_points
.into_iter()
.map(|point| {
match dim {
None => {
dim = NonZeroUsize::new(point.vector.len());
}
Some(dim) => {
if dim.get() != point.vector.len() {
return Err(ConfigError::MismatchedDims);
}
}
}
count += 1;
Ok((point.id, point.vector))
})
.collect::<Result<HashMap<u32, Vec<f32>>, ConfigError>>()?;
if start_points.is_empty() {
return Err(ConfigError::NeedStartPoint);
}
if start_points.len() != count {
return Err(ConfigError::StartPointsNotUnique);
}
let dim = match dim {
None => return Err(ConfigError::DimCannotBeZero),
Some(dim) => dim,
};
Ok(Self {
start_points,
max_degree,
dim,
metric,
})
}
}
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("at least one start point must be specified")]
NeedStartPoint,
#[error("start points must be unique")]
StartPointsNotUnique,
#[error("not all start points have the same dimension")]
MismatchedDims,
#[error("start point dimension must be non-zero")]
DimCannotBeZero,
#[error("max degree must be non-zero")]
MaxDegreeCannotBeZero,
}
impl From<ConfigError> for ANNError {
#[track_caller]
fn from(err: ConfigError) -> Self {
ANNError::opaque(err)
}
}
#[derive(Debug)]
pub struct Provider {
terms: DashMap<u32, Term>,
config: Config,
pub(crate) get_vector: Counter,
pub(crate) set_vector: Counter,
pub(crate) get_neighbors: Counter,
pub(crate) set_neighbors: Counter,
pub(crate) append_neighbors: Counter,
}
impl Provider {
pub fn new(config: Config) -> Self {
let this = Self {
terms: DashMap::new(),
config,
get_vector: Counter::new(),
set_vector: Counter::new(),
get_neighbors: Counter::new(),
set_neighbors: Counter::new(),
append_neighbors: Counter::new(),
};
for (id, value) in this.config.start_points.iter() {
this.terms.insert(
*id,
Term {
data: Vector::Valid(value.clone()),
neighbors: AdjacencyList::new(),
},
);
}
this
}
pub fn new_from<I, T>(config: Config, start_points: I, points: T) -> ANNResult<Self>
where
I: IntoIterator<Item = (u32, AdjacencyList<u32>)>,
T: IntoIterator<Item = (u32, Vec<f32>, AdjacencyList<u32>)>,
{
let this = Self::new(config);
let max_degree = this.config.max_degree.get();
for (id, neighbors) in start_points {
if neighbors.len() > max_degree {
return Err(message!(
"start point {} has neighbors with length {} when max degree is {}",
id,
neighbors.len(),
max_degree
));
}
if let Some(mut term) = this.terms.get_mut(&id) {
term.neighbors = neighbors;
} else {
return Err(message!("id {} is not a valid start point", id));
}
}
for (id, data, neighbors) in points {
if this.is_start_point(id) {
return Err(message!(
"cannot assign start point {} through a regular point",
id
));
}
if neighbors.len() > max_degree {
return Err(message!(
"point {} has neighbors with length {} when max degree is {}",
id,
neighbors.len(),
max_degree
));
}
if data.len() != this.dim() {
return Err(message!(
"data for id {} has length {} but the provider is expecting dim {}",
id,
data.len(),
this.dim(),
));
}
let term = Term {
data: Vector::Valid(data),
neighbors,
};
this.terms.insert(id, term);
}
this.is_consistent()?;
Ok(this)
}
pub fn grid(grid: synthetic::Grid, size: usize) -> ANNResult<Self> {
let max_degree: usize = (grid.dim() * 2).into();
let start_id = u32::MAX;
let setup = grid.setup(size, start_id);
let provider_config = Config::new(
Metric::L2,
max_degree,
StartPoint::new(setup.start_id(), setup.start_point()),
)?;
Self::new_from(provider_config, setup.start_neighbors(), setup.setup())
}
pub fn dim(&self) -> usize {
self.config.dim.get()
}
pub fn max_degree(&self) -> usize {
self.config.max_degree.get()
}
pub fn distance_metric(&self) -> Metric {
self.config.metric
}
fn is_start_point(&self, id: u32) -> bool {
self.config.start_points.contains_key(&id)
}
pub fn all_internal_ids(&self) -> HashSet<u32> {
self.terms
.iter()
.map(|ref_multi| *ref_multi.key())
.collect()
}
pub fn is_consistent(&self) -> ANNResult<()> {
let all = self.all_internal_ids();
for ref_multi in self.terms.iter() {
let id = ref_multi.key();
let term = ref_multi.value();
for neighbor in term.neighbors.iter() {
if !all.contains(neighbor) {
return Err(message!(
"term with id {} has neighbors {:?} \
but neighbor {} is not in the provider",
id,
term.neighbors,
neighbor,
));
}
}
}
Ok(())
}
fn is_deleted(&self, id: u32) -> Result<bool, InvalidId> {
if let Some(term) = self.terms.get(&id) {
Ok(term.is_deleted())
} else {
Err(InvalidId::Internal(id))
}
}
pub fn metrics(&self) -> Metrics {
Metrics {
get_vector: self.get_vector.value(),
set_vector: self.set_vector.value(),
get_neighbors: self.get_neighbors.value(),
set_neighbors: self.set_neighbors.value(),
append_neighbors: self.append_neighbors.value(),
}
}
pub fn neighbors(&self) -> NeighborAccessor<'_> {
NeighborAccessor::new(self)
}
pub fn dump_neighbors(&self, sort: bool) -> Vec<(u32, AdjacencyList<u32>)> {
let mut neighbors: Vec<_> = self
.terms
.iter()
.map(|ref_multi| {
let mut neighbors = ref_multi.value().neighbors.clone();
if sort {
neighbors.sort();
}
(*ref_multi.key(), neighbors)
})
.collect();
if sort {
neighbors.sort_unstable_by_key(|(id, _)| *id);
}
neighbors
}
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(serde::Serialize, serde::Deserialize))]
pub struct Metrics {
pub get_vector: usize,
pub set_vector: usize,
pub get_neighbors: usize,
pub set_neighbors: usize,
pub append_neighbors: usize,
}
#[cfg(test)]
crate::test::cmp::verbose_eq!(Metrics {
get_vector,
set_vector,
get_neighbors,
set_neighbors,
append_neighbors
});
#[derive(Debug)]
struct Term {
neighbors: AdjacencyList<u32>,
data: Vector,
}
impl Term {
fn mark_deleted(&mut self) {
self.data.mark_deleted()
}
fn is_deleted(&self) -> bool {
self.data.is_deleted()
}
}
#[derive(Debug)]
enum Vector {
Valid(Vec<f32>),
Deleted(Vec<f32>),
}
impl Vector {
fn mark_deleted(&mut self) {
*self = match self.take() {
Self::Valid(v) => Self::Deleted(v),
Self::Deleted(v) => Self::Deleted(v),
}
}
fn take(&mut self) -> Self {
match self {
Self::Valid(v) => Self::Valid(std::mem::take(v)),
Self::Deleted(v) => Self::Deleted(std::mem::take(v)),
}
}
fn is_deleted(&self) -> bool {
matches!(self, Self::Deleted(_))
}
}
impl std::ops::Deref for Vector {
type Target = [f32];
fn deref(&self) -> &[f32] {
match self {
Self::Valid(v) => v,
Self::Deleted(v) => v,
}
}
}
#[derive(Debug)]
pub struct Context(Arc<ContextInner>);
impl Context {
pub fn new() -> Self {
let inner = ContextInner {
spawns: Counter::new(),
clones: Counter::new(),
};
Self(Arc::new(inner))
}
pub fn spawns(&self) -> usize {
self.0.spawns.value()
}
pub fn clones(&self) -> usize {
self.0.clones.value()
}
pub fn metrics(&self) -> ContextMetrics {
ContextMetrics {
spawns: self.spawns(),
clones: self.clones(),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(serde::Serialize, serde::Deserialize))]
pub struct ContextMetrics {
pub spawns: usize,
pub clones: usize,
}
#[cfg(test)]
crate::test::cmp::verbose_eq!(ContextMetrics { spawns, clones });
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl provider::ExecutionContext for Context {
fn wrap_spawn<F, T>(&self, f: F) -> impl Future<Output = T> + Send + 'static
where
F: Future<Output = T> + Send + 'static,
{
self.0.spawns.increment();
f
}
}
impl Clone for Context {
fn clone(&self) -> Self {
self.0.clones.increment();
Self(self.0.clone())
}
}
#[derive(Debug)]
struct ContextInner {
spawns: Counter,
clones: Counter,
}
#[derive(Debug, Clone, Copy, Error)]
pub enum InvalidId {
#[error("internal id {0} is not initialized")]
Internal(u32),
#[error("external id {0} is not initialized")]
External(u32),
#[error("cannot delete start point {0}")]
IsStartPoint(u32),
}
crate::always_escalate!(InvalidId);
impl From<InvalidId> for ANNError {
#[track_caller]
fn from(err: InvalidId) -> ANNError {
ANNError::opaque(err)
}
}
impl IntoIterator for &Provider {
type Item = u32;
type IntoIter = std::vec::IntoIter<u32>;
fn into_iter(self) -> Self::IntoIter {
self.terms
.iter()
.map(|entry| *entry.key())
.filter(|id| !self.config.start_points.contains_key(id))
.collect::<Vec<_>>()
.into_iter()
}
}
impl provider::DataProvider for Provider {
type Context = Context;
type InternalId = u32;
type ExternalId = u32;
type Error = InvalidId;
type Guard = provider::NoopGuard<u32>;
fn to_internal_id(&self, _context: &Context, gid: &u32) -> Result<u32, InvalidId> {
let valid = self.terms.contains_key(gid);
if valid {
Ok(*gid)
} else {
Err(InvalidId::External(*gid))
}
}
fn to_external_id(&self, _context: &Context, id: u32) -> Result<u32, InvalidId> {
let valid = self.terms.contains_key(&id);
if valid {
Ok(id)
} else {
Err(InvalidId::Internal(id))
}
}
}
impl provider::Delete for Provider {
async fn delete(
&self,
_context: &Self::Context,
gid: &Self::ExternalId,
) -> Result<(), Self::Error> {
if self.is_start_point(*gid) {
return Err(InvalidId::IsStartPoint(*gid));
}
match self.terms.entry(*gid) {
Entry::Occupied(mut occupied) => {
occupied.get_mut().mark_deleted();
Ok(())
}
Entry::Vacant(_) => Err(InvalidId::External(*gid)),
}
}
async fn release(
&self,
_context: &Self::Context,
id: Self::InternalId,
) -> Result<(), Self::Error> {
if self.is_start_point(id) {
return Err(InvalidId::IsStartPoint(id));
}
if self.terms.remove(&id).is_none() {
Err(InvalidId::Internal(id))
} else {
Ok(())
}
}
async fn status_by_internal_id(
&self,
_context: &Context,
id: u32,
) -> Result<provider::ElementStatus, Self::Error> {
if self.is_deleted(id)? {
Ok(provider::ElementStatus::Deleted)
} else {
Ok(provider::ElementStatus::Valid)
}
}
fn status_by_external_id(
&self,
context: &Context,
gid: &u32,
) -> impl Future<Output = Result<provider::ElementStatus, Self::Error>> + Send {
self.status_by_internal_id(context, *gid)
}
}
impl provider::SetElement<&[f32]> for Provider {
type SetError = ANNError;
async fn set_element(
&self,
_context: &Context,
id: &Self::ExternalId,
element: &[f32],
) -> Result<Self::Guard, Self::SetError> {
#[derive(Debug, Clone, Copy, Error)]
enum SetError {
#[error("vector id {0} is already assigned")]
AlreadyAssigned(u32),
#[error("wrong dim - got {0}, expected {1}")]
WrongDim(usize, usize),
}
crate::always_escalate!(SetError);
impl From<SetError> for ANNError {
#[track_caller]
fn from(err: SetError) -> Self {
Self::new(crate::ANNErrorKind::IndexError, err)
}
}
if element.len() != self.dim() {
return Err(SetError::WrongDim(element.len(), self.dim()).into());
}
match self.terms.entry(*id) {
Entry::Occupied(_) => Err(SetError::AlreadyAssigned(*id).into()),
Entry::Vacant(term) => {
term.insert(Term {
neighbors: AdjacencyList::new(),
data: Vector::Valid(element.into()),
});
self.set_vector.increment();
Ok(provider::NoopGuard::new(*id))
}
}
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Attempt to access an invalid id: {0}")]
pub struct AccessedInvalidId(u32);
crate::always_escalate!(AccessedInvalidId);
impl From<AccessedInvalidId> for ANNError {
#[track_caller]
fn from(err: AccessedInvalidId) -> Self {
Self::opaque(err)
}
}
#[derive(Debug)]
pub struct TransientAccessError {
id: u32,
handled: bool,
}
impl TransientAccessError {
fn new(id: u32) -> Self {
Self { id, handled: false }
}
}
impl std::fmt::Display for TransientAccessError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "transient error accessing id: {}", self.id)
}
}
impl std::error::Error for TransientAccessError {}
impl Drop for TransientAccessError {
fn drop(&mut self) {
assert!(
self.handled,
"dropping an unhandled transient error for id {}!",
self.id
);
}
}
impl TransientError<AccessedInvalidId> for TransientAccessError {
fn acknowledge<D>(mut self, _why: D)
where
D: std::fmt::Display,
{
self.handled = true;
}
fn escalate<D>(mut self, _why: D) -> AccessedInvalidId
where
D: std::fmt::Display,
{
self.handled = true;
AccessedInvalidId(self.id)
}
}
#[derive(Debug)]
pub enum AccessError {
InvalidId(AccessedInvalidId),
Transient(TransientAccessError),
}
impl std::fmt::Display for AccessError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidId(e) => e.fmt(f),
Self::Transient(e) => e.fmt(f),
}
}
}
impl std::error::Error for AccessError {}
impl ToRanked for AccessError {
type Transient = TransientAccessError;
type Error = AccessedInvalidId;
fn to_ranked(self) -> RankedError<TransientAccessError, AccessedInvalidId> {
match self {
Self::InvalidId(e) => RankedError::Error(e),
Self::Transient(e) => RankedError::Transient(e),
}
}
fn from_transient(transient: TransientAccessError) -> Self {
Self::Transient(transient)
}
fn from_error(error: AccessedInvalidId) -> Self {
Self::InvalidId(error)
}
}
impl provider::DefaultAccessor for Provider {
type Accessor<'a> = NeighborAccessor<'a>;
fn default_accessor(&self) -> Self::Accessor<'_> {
NeighborAccessor::new(self)
}
}
#[derive(Debug, Clone, Copy)]
pub struct NeighborAccessor<'a> {
provider: &'a Provider,
}
impl<'a> NeighborAccessor<'a> {
pub fn new(provider: &'a Provider) -> Self {
Self { provider }
}
}
impl provider::HasId for NeighborAccessor<'_> {
type Id = u32;
}
impl provider::NeighborAccessor for NeighborAccessor<'_> {
async fn get_neighbors(
self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> ANNResult<Self> {
match self.provider.terms.get(&id) {
Some(v) => {
self.provider.get_neighbors.increment();
neighbors.overwrite_trusted(&v.neighbors);
Ok(self)
}
None => Err(ANNError::opaque(AccessedInvalidId(id))),
}
}
}
impl provider::NeighborAccessorMut for NeighborAccessor<'_> {
async fn set_neighbors(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
if neighbors.len() > self.provider.max_degree() {
return Err(message!(
"trying to assign neighbors with length {} when max degree is {}",
neighbors.len(),
self.provider.max_degree()
));
}
match self.provider.terms.get_mut(&id) {
Some(mut term) => {
term.neighbors.clear();
term.neighbors.extend_from_slice(neighbors);
self.provider.set_neighbors.increment();
if term.neighbors.len() != neighbors.len() {
Err(message!("duplicate neighbors detected"))
} else {
Ok(self)
}
}
None => Err(ANNError::opaque(AccessedInvalidId(id))),
}
}
async fn append_vector(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
match self.provider.terms.get_mut(&id) {
Some(mut term) => {
if let Some(estimate) = term.neighbors.len().checked_add(neighbors.len()) {
if estimate > self.provider.max_degree() {
return Err(message!(
"append neighbors to {} will exceed the max degree",
id
));
}
} else {
return Err(message!("the number of neighbors is way too high"));
}
let added = term.neighbors.extend_from_slice(neighbors);
self.provider.append_neighbors.increment();
if added != neighbors.len() {
Err(message!("duplicate ids in append-vector"))
} else {
Ok(self)
}
}
None => Err(ANNError::opaque(AccessedInvalidId(id))),
}
}
}
#[derive(Debug)]
pub struct Accessor<'a> {
provider: &'a Provider,
buffer: Box<[f32]>,
get_vector: LocalCounter<'a>,
transient_ids: Option<Cow<'a, HashSet<u32>>>,
}
impl<'a> Accessor<'a> {
pub fn provider(&self) -> &'a Provider {
self.provider
}
pub fn new(provider: &'a Provider) -> Self {
Self::new_inner(provider, None)
}
pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet<u32>>) -> Self {
Self::new_inner(provider, Some(transient_ids))
}
fn new_inner(provider: &'a Provider, transient_ids: Option<Cow<'a, HashSet<u32>>>) -> Self {
let buffer = (0..provider.dim()).map(|_| 0.0).collect();
Self {
provider,
buffer,
get_vector: provider.get_vector.local(),
transient_ids,
}
}
}
impl provider::HasId for Accessor<'_> {
type Id = u32;
}
impl provider::Accessor for Accessor<'_> {
type Element<'a>
= &'a [f32]
where
Self: 'a;
type ElementRef<'a> = &'a [f32];
type GetError = AccessError;
async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> {
match self.provider.terms.get(&id) {
Some(term) => {
if let Some(transient) = &self.transient_ids
&& transient.contains(&id)
{
return Err(AccessError::Transient(TransientAccessError::new(id)));
}
self.get_vector.increment();
self.buffer.copy_from_slice(&term.data);
Ok(&*self.buffer)
}
None => Err(AccessError::InvalidId(AccessedInvalidId(id))),
}
}
}
impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> {
type Delegate = NeighborAccessor<'a>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
NeighborAccessor::new(self.provider)
}
}
impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> {
type QueryComputerError = Infallible;
type QueryComputer = <f32 as VectorRepr>::QueryDistance;
fn build_query_computer(
&self,
from: &[f32],
) -> Result<Self::QueryComputer, Self::QueryComputerError> {
Ok(f32::query_distance(from, self.provider.config.metric))
}
}
impl provider::BuildDistanceComputer for Accessor<'_> {
type DistanceComputerError = Infallible;
type DistanceComputer = <f32 as VectorRepr>::Distance;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError> {
Ok(f32::distance(
self.provider.distance_metric(),
Some(self.provider.dim()),
))
}
}
impl glue::SearchExt for Accessor<'_> {
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<u32>>> + Send {
futures_util::future::ok(self.provider.config.start_points.keys().copied().collect())
}
}
impl glue::ExpandBeam<&[f32]> for Accessor<'_> {}
impl glue::IdIterator<std::vec::IntoIter<u32>> for Accessor<'_> {
async fn id_iterator(&mut self) -> Result<std::vec::IntoIter<u32>, ANNError> {
let ids: Vec<u32> = self.provider.terms.iter().map(|r| *r.key()).collect();
Ok(ids.into_iter())
}
}
impl provider::CacheableAccessor for Accessor<'_> {
type Map = diskann_utils::lifetime::Slice<f32>;
fn from_cached<'a>(element: &'a [f32]) -> &'a [f32]
where
Self: 'a,
{
element
}
fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32]
where
Self: 'a + 'b,
{
element
}
}
#[derive(Debug, Clone)]
pub struct Strategy {
working_set_reuse: bool,
transient_ids: Option<Arc<HashSet<u32>>>,
}
impl Strategy {
pub fn new() -> Self {
Self {
working_set_reuse: true,
transient_ids: None,
}
}
pub fn with_options(working_set_reuse: bool) -> Self {
Self {
working_set_reuse,
transient_ids: None,
}
}
pub fn with_transient(
working_set_reuse: bool,
transient_ids: impl IntoIterator<Item = u32>,
) -> Self {
Self {
working_set_reuse,
transient_ids: Some(Arc::new(transient_ids.into_iter().collect())),
}
}
}
impl Default for Strategy {
fn default() -> Self {
Self::new()
}
}
impl glue::SearchStrategy<Provider, &[f32]> for Strategy {
type QueryComputer = <f32 as VectorRepr>::QueryDistance;
type SearchAccessorError = Infallible;
type SearchAccessor<'a> = Accessor<'a>;
fn search_accessor<'a>(
&'a self,
provider: &'a Provider,
_context: &'a Context,
) -> Result<Accessor<'a>, Infallible> {
Ok(Accessor::new(provider))
}
}
impl glue::DefaultPostProcessor<Provider, &[f32]> for Strategy {
default_post_processor!(glue::Pipeline<glue::FilterStartPoints, glue::CopyIds>);
}
impl glue::PruneStrategy<Provider> for Strategy {
type WorkingSet = workingset::Map<u32, Box<[f32]>, workingset::map::Ref<[f32]>>;
type DistanceComputer<'a> = <f32 as VectorRepr>::Distance;
type PruneAccessor<'a> = Accessor<'a>;
type PruneAccessorError = Infallible;
fn create_working_set(&self, capacity: usize) -> Self::WorkingSet {
let cap = if self.working_set_reuse {
workingset::map::Capacity::Default
} else {
workingset::map::Capacity::None
};
workingset::map::Builder::new(cap).build(capacity)
}
fn prune_accessor<'a>(
&'a self,
provider: &'a Provider,
_context: &'a Context,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
match &self.transient_ids {
Some(ids) => Ok(Accessor::flaky(provider, Cow::Borrowed(ids))),
None => Ok(Accessor::new(provider)),
}
}
}
impl glue::InsertStrategy<Provider, &[f32]> for Strategy {
type PruneStrategy = Self;
fn prune_strategy(&self) -> Self::PruneStrategy {
self.clone()
}
fn insert_search_accessor<'a>(
&'a self,
provider: &'a Provider,
_context: &'a Context,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(Accessor::new(provider))
}
}
impl glue::MultiInsertStrategy<Provider, Matrix<f32>> for Strategy {
type WorkingSet = workingset::Map<u32, Box<[f32]>, workingset::map::Ref<[f32]>>;
type Seed = workingset::map::Builder<u32, workingset::map::Ref<[f32]>>;
type FinishError = Infallible;
type InsertStrategy = Self;
fn insert_strategy(&self) -> Self::InsertStrategy {
self.clone()
}
fn finish<Itr>(
&self,
_provider: &Provider,
_ctx: &Context,
batch: &Arc<Matrix<f32>>,
ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = u32> + Send,
{
use workingset::map::{Builder, Capacity, Overlay};
let capacity = if self.working_set_reuse {
Capacity::Default
} else {
Capacity::None
};
std::future::ready(Ok(
Builder::new(capacity).with_overlay(Overlay::from_batch(batch.clone(), ids))
))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FilterDeleted;
impl<'a, 'b, O> glue::SearchPostProcessStep<Accessor<'a>, &'b [f32], O> for FilterDeleted {
type Error<NextError>
= NextError
where
NextError: StandardError;
type NextAccessor = Accessor<'a>;
fn post_process_step<I, B, Next>(
&self,
next: &Next,
accessor: &mut Accessor<'a>,
query: &'b [f32],
computer: &<f32 as VectorRepr>::QueryDistance,
candidates: I,
output: &mut B,
) -> impl std::future::Future<Output = Result<usize, Self::Error<Next::Error>>> + Send
where
I: Iterator<Item = Neighbor<u32>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized,
Next: glue::SearchPostProcess<Self::NextAccessor, &'b [f32], O> + Sync,
{
let provider = accessor.provider;
next.post_process(
accessor,
query,
computer,
candidates.filter(|n| !provider.is_deleted(n.id).unwrap_or(true)),
output,
)
}
}
impl glue::InplaceDeleteStrategy<Provider> for Strategy {
type DeleteElement<'a> = &'a [f32];
type DeleteElementGuard = Box<[f32]>;
type DeleteElementError = AccessedInvalidId;
type PruneStrategy = Self;
type DeleteSearchAccessor<'a> = Accessor<'a>;
type SearchStrategy = Self;
type SearchPostProcessor = glue::Pipeline<FilterDeleted, glue::CopyIds>;
fn prune_strategy(&self) -> Self::PruneStrategy {
self.clone()
}
fn search_strategy(&self) -> Self::SearchStrategy {
self.clone()
}
fn search_post_processor(&self) -> Self::SearchPostProcessor {
glue::Pipeline::new(FilterDeleted, glue::CopyIds)
}
async fn get_delete_element<'a>(
&'a self,
provider: &'a Provider,
_context: &'a <Provider as provider::DataProvider>::Context,
id: <Provider as provider::DataProvider>::InternalId,
) -> Result<Self::DeleteElementGuard, Self::DeleteElementError> {
provider
.terms
.get(&id)
.map(|v| (*v.data).into())
.ok_or(AccessedInvalidId(id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{assert_message_contains, tokio::current_thread_runtime};
#[test]
fn test_start_point() {
let start_point = StartPoint::new(42, vec![1.0, 2.0, 3.0]);
assert_eq!(start_point.id(), 42);
assert_eq!(start_point.vector(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_config_new() {
let metric = Metric::L2;
{
let start_points = [
StartPoint::new(0, vec![1.0, 2.0, 3.0]),
StartPoint::new(1, vec![4.0, 5.0, 6.0]),
];
let config = Config::new(metric, 10, start_points).unwrap();
assert_eq!(config.max_degree.get(), 10);
assert_eq!(config.dim.get(), 3);
assert_eq!(config.metric, metric);
assert_eq!(config.start_points.len(), 2);
}
{
let err = Config::new(metric, 0, StartPoint::new(0, vec![1.0, 2.0])).unwrap_err();
assert!(matches!(err, ConfigError::MaxDegreeCannotBeZero));
let msg = err.to_string();
let ann: ANNError = err.into();
assert!(
ann.to_string().contains(&msg),
"ANNError message \"{}\" does not contain original error: \"{}\"",
ann,
msg,
);
}
{
let err = Config::new(metric, 10, []).unwrap_err();
assert!(matches!(err, ConfigError::NeedStartPoint));
}
{
let start_points = [
StartPoint::new(0, vec![1.0, 2.0, 3.0]),
StartPoint::new(1, vec![4.0, 5.0]), ];
let err = Config::new(metric, 10, start_points).unwrap_err();
assert!(matches!(err, ConfigError::MismatchedDims));
}
{
let err = Config::new(metric, 10, StartPoint::new(0, vec![])).unwrap_err();
assert!(matches!(err, ConfigError::DimCannotBeZero));
}
{
let start_points = [
StartPoint::new(0, vec![1.0, 2.0]),
StartPoint::new(0, vec![3.0, 4.0]), ];
let err = Config::new(metric, 10, start_points).unwrap_err();
assert!(matches!(err, ConfigError::StartPointsNotUnique));
}
}
#[test]
fn test_vector() {
let vector = vec![1.0, 2.0, 3.0];
let ptr = vector.as_ptr();
let mut v = Vector::Valid(vector);
assert!(!v.is_deleted());
assert_eq!(v.as_ptr(), ptr);
v.mark_deleted();
assert!(v.is_deleted());
assert_eq!(v.as_ptr(), ptr);
v.mark_deleted();
assert!(v.is_deleted());
assert_eq!(v.as_ptr(), ptr);
}
#[test]
fn test_term() {
let vector = vec![1.0, 2.0, 3.0];
let ptr = vector.as_ptr();
let mut t = Term {
neighbors: AdjacencyList::new(),
data: Vector::Valid(vector),
};
assert!(!t.is_deleted());
assert_eq!(t.data.as_ptr(), ptr);
t.mark_deleted();
assert!(t.is_deleted());
assert_eq!(t.data.as_ptr(), ptr);
t.mark_deleted();
assert!(t.is_deleted());
assert_eq!(t.data.as_ptr(), ptr);
}
#[test]
fn test_context() {
use provider::ExecutionContext;
let context = Context::default();
let ContextMetrics { spawns, clones } = context.metrics();
assert_eq!(spawns, 0);
assert_eq!(clones, 0);
{
let c0 = context.clone();
let _c1 = c0.clone();
}
let ContextMetrics { spawns, clones } = context.metrics();
assert_eq!(spawns, 0);
assert_eq!(clones, 2);
let rt = current_thread_runtime();
let v = rt.block_on(context.clone().wrap_spawn(async { 2usize }));
assert_eq!(v, 2);
let ContextMetrics { spawns, clones } = context.metrics();
assert_eq!(spawns, 1);
assert_eq!(clones, 3);
}
#[test]
fn test_provider_new_from() {
{
let config = Config::new(
Metric::L2,
3,
[
StartPoint::new(0, vec![1.0, 0.0]),
StartPoint::new(1, vec![0.0, 1.0]),
],
)
.unwrap();
let start_points = [(0, AdjacencyList::from_iter_untrusted([1, 2]))];
let points = [(
2,
vec![1.0, 1.0],
AdjacencyList::from_iter_untrusted([0, 1]),
)];
let provider = Provider::new_from(config, start_points, points).unwrap();
assert_eq!(provider.dim(), 2);
assert_eq!(provider.max_degree(), 3);
}
{
let config = Config::new(Metric::L2, 5, [StartPoint::new(0, vec![1.0])]).unwrap();
let provider = Provider::new_from(config, [], []).unwrap();
assert_eq!(provider.dim(), 1);
}
{
let config = Config::new(Metric::L2, 2, [StartPoint::new(0, vec![1.0])]).unwrap();
let start_points = [(0, AdjacencyList::from_iter_untrusted([1, 2, 3]))];
let err = Provider::new_from(config, start_points, []).unwrap_err();
assert_message_contains!(err.to_string(), "max degree");
}
{
let config = Config::new(Metric::L2, 5, [StartPoint::new(0, vec![1.0])]).unwrap();
let start_points = [(999, AdjacencyList::new())]; let err = Provider::new_from(config, start_points, []).unwrap_err();
assert_message_contains!(err.to_string(), "not a valid start point");
}
{
let config = Config::new(Metric::L2, 2, [StartPoint::new(0, vec![1.0])]).unwrap();
let points = [(1, vec![2.0], AdjacencyList::from_iter_untrusted([0, 2, 3]))];
let err = Provider::new_from(config, [], points).unwrap_err();
assert_message_contains!(err.to_string(), "max degree");
}
{
let config = Config::new(Metric::L2, 5, [StartPoint::new(0, vec![1.0])]).unwrap();
let points = [(0, vec![2.0], AdjacencyList::new())]; let err = Provider::new_from(config, [], points).unwrap_err();
assert_message_contains!(err.to_string(), "cannot assign start point");
}
{
let config = Config::new(Metric::L2, 5, [StartPoint::new(0, vec![1.0, 2.0])]).unwrap();
let points = [(1, vec![3.0], AdjacencyList::new())]; let err = Provider::new_from(config, [], points).unwrap_err();
assert_message_contains!(err.to_string(), "expecting dim");
}
{
let config = Config::new(Metric::L2, 5, [StartPoint::new(0, vec![1.0])]).unwrap();
let points = [(
1,
vec![2.0],
AdjacencyList::from_iter_unique(std::iter::once(999)),
)]; let err = Provider::new_from(config, [], points).unwrap_err();
assert_message_contains!(err.to_string(), "not in the provider");
}
}
fn create_test_provider() -> Provider {
let config = Config::new(
Metric::Cosine,
4,
[
StartPoint::new(0, vec![1.0, 0.0]),
StartPoint::new(1, vec![0.0, 1.0]),
],
)
.unwrap();
let start_points = [
(0, AdjacencyList::from_iter_untrusted([1, 2, 3])),
(1, AdjacencyList::from_iter_untrusted([0, 3])),
];
let points = [
(
2,
vec![0.5, 0.5],
AdjacencyList::from_iter_untrusted([0, 3]),
),
(
3,
vec![-1.0, 1.0],
AdjacencyList::from_iter_untrusted([0, 1, 2]),
),
];
let provider = Provider::new_from(config, start_points, points).unwrap();
assert_eq!(provider.dim(), 2);
assert_eq!(provider.max_degree(), 4);
assert_eq!(provider.distance_metric(), Metric::Cosine);
provider
}
#[test]
fn id_conversion() {
use provider::DataProvider;
let provider = create_test_provider();
let context = Context::default();
for i in 0u32..3u32 {
let internal = provider.to_internal_id(&context, &i).unwrap();
assert_eq!(internal, i);
let external = provider.to_external_id(&context, i).unwrap();
assert_eq!(external, i);
}
let err = provider.to_internal_id(&context, &5).unwrap_err();
let message = err.to_string();
assert_eq!(
message, "external id 5 is not initialized",
"got {}",
message
);
let err = provider.to_external_id(&context, 5).unwrap_err();
let message = err.to_string();
assert_eq!(
message, "internal id 5 is not initialized",
"got {}",
message
);
}
#[test]
fn test_set_element() {
use provider::{Accessor, Guard, SetElement};
let provider = create_test_provider();
let rt = current_thread_runtime();
let context = Context::new();
let mut accessor = super::Accessor::new(&provider);
let id = 5;
assert!(rt.block_on(accessor.get_element(5)).is_err());
{
let v = vec![1.0f32; provider.dim() + 1];
let err = rt
.block_on(provider.set_element(&context, &id, &v))
.unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "wrong dim");
assert!(rt.block_on(accessor.get_element(id)).is_err());
}
{
let v = vec![1.0f32; provider.dim()];
let guard = rt
.block_on(provider.set_element(&context, &id, &v))
.unwrap();
rt.block_on(guard.complete());
let element = rt.block_on(accessor.get_element(id)).unwrap();
assert_eq!(v, element);
}
{
let v = vec![1.0f32; provider.dim()];
let err = rt
.block_on(provider.set_element(&context, &id, &v))
.unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "vector id 5 is already assigned");
}
}
#[test]
fn test_neighbor_accessor() {
use provider::{DefaultAccessor, NeighborAccessor};
let provider = create_test_provider();
let accessor = provider.default_accessor();
let mut v = AdjacencyList::new();
let rt = current_thread_runtime();
rt.block_on(accessor.get_neighbors(0, &mut v)).unwrap();
assert_eq!(&*v, &[1, 2, 3]);
rt.block_on(accessor.get_neighbors(1, &mut v)).unwrap();
assert_eq!(&*v, &[0, 3]);
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[0, 3]);
rt.block_on(accessor.get_neighbors(3, &mut v)).unwrap();
assert_eq!(&*v, &[0, 1, 2]);
let err = rt.block_on(accessor.get_neighbors(4, &mut v)).unwrap_err();
assert_message_contains!(err.to_string(), "Attempt to access an invalid id");
}
#[test]
fn test_set_neighbors() {
use provider::{DefaultAccessor, NeighborAccessor, NeighborAccessorMut};
let provider = create_test_provider();
let accessor = provider.default_accessor();
let mut v = AdjacencyList::new();
let rt = current_thread_runtime();
rt.block_on(accessor.set_neighbors(2, &[])).unwrap();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert!(v.is_empty());
assert_eq!(provider.set_neighbors.value(), 1);
rt.block_on(accessor.set_neighbors(2, &[1, 3])).unwrap();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1, 3]);
assert_eq!(provider.set_neighbors.value(), 2);
{
assert_eq!(
provider.max_degree(),
4,
"if this changes - update this mini test"
);
let err = rt
.block_on(accessor.set_neighbors(2, &[1, 2, 3, 4, 5]))
.unwrap_err();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1, 3], "original neighbors should be unchanged");
let msg = err.to_string();
assert_message_contains!(msg, "trying to assign neighbors with length 5");
assert_eq!(
provider.set_neighbors.value(),
2,
"number of successful `set_neighbors` should not change"
);
}
{
let err = rt
.block_on(accessor.set_neighbors(2, &[1, 2, 3, 2]))
.unwrap_err();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(
&*v,
&[1, 2, 3],
"final neighbors should still be deduplicated"
);
let msg = err.to_string();
assert_message_contains!(msg, "duplicate neighbors detected");
assert_eq!(
provider.set_neighbors.value(),
3,
"number of successful `set_neighbors` should change"
);
}
{
let err = rt
.block_on(accessor.set_neighbors(10, &[1, 2]))
.unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "access an invalid id");
}
}
#[test]
fn test_append_vector() {
use provider::{DefaultAccessor, NeighborAccessor, NeighborAccessorMut};
let provider = create_test_provider();
let accessor = provider.default_accessor();
let mut v = AdjacencyList::new();
let rt = current_thread_runtime();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[0, 3]);
rt.block_on(accessor.append_vector(2, &[1])).unwrap();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[0, 3, 1]);
assert_eq!(provider.append_neighbors.value(), 1);
{
rt.block_on(accessor.set_neighbors(2, &[])).unwrap();
rt.block_on(accessor.append_vector(2, &[1, 3, 4])).unwrap();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1, 3, 4]);
assert_eq!(provider.append_neighbors.value(), 2);
}
{
let err = rt.block_on(accessor.append_vector(2, &[1])).unwrap_err();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1, 3, 4]);
let msg = err.to_string();
assert_message_contains!(msg, "duplicate ids in append-vector");
assert_eq!(
provider.append_neighbors.value(),
3,
"number of append calls should still increase",
);
}
{
rt.block_on(accessor.set_neighbors(2, &[])).unwrap();
let err = rt
.block_on(accessor.append_vector(2, &[1, 1, 1]))
.unwrap_err();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1]);
let msg = err.to_string();
assert_message_contains!(msg, "duplicate ids in append-vector");
assert_eq!(
provider.append_neighbors.value(),
4,
"number of append calls should still increase",
);
}
{
let err = rt
.block_on(accessor.append_vector(2, &[2, 3, 4, 5]))
.unwrap_err();
rt.block_on(accessor.get_neighbors(2, &mut v)).unwrap();
assert_eq!(&*v, &[1]);
let msg = err.to_string();
assert_message_contains!(msg, "will exceed the max degree");
assert_eq!(provider.append_neighbors.value(), 4);
}
{
let err = rt
.block_on(accessor.append_vector(10, &[1, 2]))
.unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "access an invalid id");
}
}
#[test]
fn test_delete() {
use provider::Delete;
let provider = create_test_provider();
let rt = current_thread_runtime();
let ids = [0, 1, 2, 3];
let invalid_id = 5;
{
let mut check: Vec<_> = provider.all_internal_ids().into_iter().collect();
check.sort();
assert_eq!(&*check, &ids);
assert!(provider.is_start_point(0));
assert!(provider.is_start_point(1));
assert!(!provider.is_start_point(2));
assert!(!provider.is_start_point(3));
}
let context = Context::new();
for i in ids {
let is_deleted = provider.is_deleted(i).unwrap();
assert!(!is_deleted);
let status = rt
.block_on(provider.status_by_internal_id(&context, i))
.unwrap();
assert_eq!(status, provider::ElementStatus::Valid);
let status = rt
.block_on(provider.status_by_external_id(&context, &i))
.unwrap();
assert_eq!(status, provider::ElementStatus::Valid);
}
{
let err = provider.is_deleted(invalid_id).unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
let err = rt
.block_on(provider.status_by_internal_id(&context, invalid_id))
.unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
let err = rt
.block_on(provider.status_by_external_id(&context, &invalid_id))
.unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
}
{
let id = 3;
rt.block_on(provider.delete(&context, &id)).unwrap();
let is_deleted = provider.is_deleted(id).unwrap();
assert!(is_deleted);
let status = rt
.block_on(provider.status_by_internal_id(&context, id))
.unwrap();
assert_eq!(status, provider::ElementStatus::Deleted);
let status = rt
.block_on(provider.status_by_external_id(&context, &id))
.unwrap();
assert_eq!(status, provider::ElementStatus::Deleted);
}
{
let id = 3;
rt.block_on(provider.release(&context, id)).unwrap();
let err = provider.is_deleted(id).unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
let err = rt
.block_on(provider.status_by_internal_id(&context, id))
.unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
let err = rt
.block_on(provider.status_by_external_id(&context, &id))
.unwrap_err();
assert_message_contains!(err.to_string(), "not initialized");
}
}
#[test]
fn test_start_points_cannot_be_deleted() {
use provider::Delete;
let provider = create_test_provider();
let rt = current_thread_runtime();
assert!(provider.is_start_point(0));
assert!(provider.is_start_point(1));
let context = Context::new();
let err = rt.block_on(provider.delete(&context, &0)).unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "cannot delete start point");
assert!(!provider.is_deleted(0).unwrap());
let err = rt.block_on(provider.release(&context, 0)).unwrap_err();
let msg = err.to_string();
assert_message_contains!(msg, "cannot delete start point");
assert!(!provider.is_deleted(0).unwrap());
}
}