use std::{fmt::Debug, sync::Arc};
use futures_util::FutureExt;
use diskann::{
ANNResult,
error::{self as core_error, IntoANNResult, StandardError},
graph::{
AdjacencyList, SearchOutputBuffer,
glue::{
self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy,
Pipeline, PruneStrategy, SearchExt, SearchPostProcessStep, SearchStrategy,
},
workingset,
},
neighbor::Neighbor,
provider::{
Accessor, AsNeighbor, BuildDistanceComputer, BuildQueryComputer, CacheableAccessor,
DataProvider, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor,
NeighborAccessorMut, SetElement,
},
};
use diskann_utils::{
WithLifetime,
future::{AssertSend, AsyncFriendly, SendFuture},
};
use thiserror::Error;
#[derive(Debug)]
pub struct Missing<'a, C, I> {
cache: &'a mut C,
key: I,
}
impl<'a, C, I> Missing<'a, C, I>
where
I: Clone,
{
pub fn set<E>(self, element: &E::Of<'_>) -> Result<(), C::Error>
where
C: ElementCache<I, E>,
E: WithLifetime,
{
self.cache.set_cached(self.key, element)
}
}
#[derive(Debug)]
pub enum MaybeCached<'a, C, I, E>
where
E: WithLifetime,
C: ElementCache<I, E>,
I: Clone,
{
Present(E::Of<'a>),
Missing(Missing<'a, C, I>),
}
pub trait ElementCache<I, E>: Send + Sync + Sized
where
E: WithLifetime,
I: Clone,
{
type Error: StandardError;
fn get_cached(&mut self, key: I) -> Result<Option<E::Of<'_>>, Self::Error>;
fn set_cached(&mut self, key: I, v: &E::Of<'_>) -> Result<(), Self::Error>;
fn try_get(&mut self, key: I) -> Result<MaybeCached<'_, Self, I, E>, Self::Error> {
use polonius_the_crab as ptc;
type Output<E> = ptc::ForLt!(<E as WithLifetime>::Of<'_>);
let result_or_cache =
ptc::polonius::<_, Result<(), Self::Error>, Output<E>>(self, |cache| {
match cache.get_cached(key.clone()) {
Ok(Some(element)) => ptc::PoloniusResult::Borrowing(element),
Ok(None) => ptc::PoloniusResult::Owned(Ok(())),
Err(err) => ptc::PoloniusResult::Owned(Err(err)),
}
});
match result_or_cache {
ptc::PoloniusResult::Borrowing(v) => Ok(MaybeCached::Present(v)),
ptc::PoloniusResult::Owned {
value,
input_borrow: cache, } => {
value?;
Ok(MaybeCached::Missing(Missing { cache, key }))
}
}
}
}
pub fn get_or_insert<'a, A, C>(
accessor: &'a mut A,
cache: &'a mut C,
id: A::Id,
) -> impl SendFuture<Result<A::Element<'a>, CachingError<A::GetError, C::Error>>>
where
A: CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
async move {
match cache.try_get(id).map_err(CachingError::Cache)? {
MaybeCached::Present(element) => Ok(A::from_cached(element)),
MaybeCached::Missing(missing) => {
let element = accessor
.get_element(id)
.await
.map_err(CachingError::Inner)?;
missing
.set(A::as_cached(&element))
.map_err(CachingError::Cache)?;
Ok(element)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[must_use = "NeighborStatus must be observed and acted on"]
pub enum NeighborStatus {
Hit,
Miss,
Uncacheable,
}
pub trait NeighborCache<I>: Send + Sync {
type Error: StandardError;
fn try_get_neighbors(
&mut self,
id: I,
neighbors: &mut AdjacencyList<I>,
) -> Result<NeighborStatus, Self::Error>;
fn set_neighbors(&mut self, id: I, neighbors: &[I]) -> Result<(), Self::Error>;
fn invalidate_neighbors(&mut self, id: I);
}
pub trait Evict<I> {
fn evict(&self, id: I);
}
pub trait AsCacheAccessorFor<'a, A>
where
A: CacheableAccessor,
{
type Accessor: ElementCache<A::Id, A::Map>;
type Error: StandardError;
fn as_cache_accessor_for(
&'a self,
accessor: A,
) -> Result<CachingAccessor<A, Self::Accessor>, Self::Error>;
}
pub trait CachedFill<C, State>: CacheableAccessor
where
C: ElementCache<Self::Id, Self::Map>,
Self: workingset::Fill<State>,
{
fn cached_fill<'a, Itr>(
&'a mut self,
cache: &'a mut C,
state: &'a mut State,
itr: Itr,
) -> impl SendFuture<Result<Self::View<'a>, CachingError<Self::Error, C::Error>>>
where
Itr: ExactSizeIterator<Item = Self::Id> + Clone + Send + Sync;
}
pub struct CachingProvider<T, C> {
provider: T,
cache: C,
}
impl<T, C> CachingProvider<T, C> {
pub fn new(provider: T, cache: C) -> Self {
Self { provider, cache }
}
pub fn inner(&self) -> &T {
&self.provider
}
pub fn cache(&self) -> &C {
&self.cache
}
}
#[derive(Debug)]
pub struct CachingAccessor<A, C> {
inner: A,
cache: C,
}
impl<A, C> CachingAccessor<A, C> {
pub fn new(inner: A, cache: C) -> Self {
Self { inner, cache }
}
pub fn inner(&self) -> &A {
&self.inner
}
pub fn cache(&self) -> &C {
&self.cache
}
}
#[derive(Debug, Clone, Copy)]
pub struct Cached<S> {
strategy: S,
}
impl<S> Cached<S> {
pub fn new(strategy: S) -> Self {
Self { strategy }
}
}
impl<T, U> workingset::AsWorkingSet<Cached<T>> for Cached<U>
where
U: workingset::AsWorkingSet<T>,
{
fn as_working_set(&self, capacity: usize) -> Cached<T> {
Cached::new(self.strategy.as_working_set(capacity))
}
}
#[derive(Debug, Error)]
pub enum CachingError<E, C> {
#[error("encountered error from backing provider")]
Inner(#[source] E),
#[error("encountered error while accessing cache")]
Cache(#[source] C),
}
#[cfg(test)]
impl<E, C> CachingError<E, C>
where
E: Debug,
C: Debug,
{
fn expect_inner(self) -> E {
match self {
Self::Inner(e) => e,
Self::Cache(c) => panic!("expected an `Inner` error but got a `Cache`: {:?}", c),
}
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct Transient<T>(T);
impl<E, C> From<CachingError<E, C>> for diskann::ANNError
where
E: Into<diskann::ANNError>,
C: StandardError,
{
#[track_caller]
fn from(err: CachingError<E, C>) -> Self {
match err {
CachingError::Inner(inner) => inner.into(),
CachingError::Cache(err) => err.into(),
}
}
}
impl<E, C, T> core_error::TransientError<CachingError<E, C>> for Transient<T>
where
T: core_error::TransientError<E>,
{
#[track_caller]
fn acknowledge<D>(self, why: D)
where
D: std::fmt::Display,
{
self.0.acknowledge(why)
}
#[track_caller]
fn escalate<D>(self, why: D) -> CachingError<E, C>
where
D: std::fmt::Display,
{
CachingError::Inner(self.0.escalate(why))
}
#[track_caller]
fn acknowledge_with<F, D>(self, why: F)
where
F: FnOnce() -> D,
D: std::fmt::Display,
{
self.0.acknowledge_with(why)
}
#[track_caller]
fn escalate_with<F, D>(self, why: F) -> CachingError<E, C>
where
F: FnOnce() -> D,
D: std::fmt::Display,
{
CachingError::Inner(self.0.escalate_with(why))
}
}
impl<E, C> core_error::ToRanked for CachingError<E, C>
where
E: core_error::ToRanked,
C: StandardError,
{
type Error = CachingError<E::Error, C>;
type Transient = Transient<E::Transient>;
fn to_ranked(self) -> core_error::RankedError<Self::Transient, Self::Error> {
use core_error::RankedError;
match self {
Self::Inner(err) => match err.to_ranked() {
RankedError::Transient(v) => core_error::RankedError::Transient(Transient(v)),
RankedError::Error(v) => core_error::RankedError::Error(CachingError::Inner(v)),
},
Self::Cache(err) => core_error::RankedError::Error(CachingError::Cache(err)),
}
}
fn from_transient(transient: Self::Transient) -> Self {
Self::Inner(E::from_transient(transient.0))
}
fn from_error(error: Self::Error) -> Self {
match error {
CachingError::Inner(err) => Self::Inner(E::from_error(err)),
CachingError::Cache(err) => Self::Cache(err),
}
}
}
impl<T, C> DataProvider for CachingProvider<T, C>
where
T: DataProvider,
C: AsyncFriendly,
{
type Context = T::Context;
type Error = T::Error;
type ExternalId = T::ExternalId;
type InternalId = T::InternalId;
type Guard = T::Guard;
fn to_external_id(
&self,
context: &Self::Context,
id: Self::InternalId,
) -> Result<Self::ExternalId, Self::Error> {
self.provider.to_external_id(context, id)
}
fn to_internal_id(
&self,
context: &Self::Context,
gid: &Self::ExternalId,
) -> Result<Self::InternalId, Self::Error> {
self.provider.to_internal_id(context, gid)
}
}
impl<DP, C> Delete for CachingProvider<DP, C>
where
DP: DataProvider + Delete,
C: Evict<DP::InternalId> + AsyncFriendly,
{
fn delete(
&self,
context: &DP::Context,
gid: &DP::ExternalId,
) -> impl Future<Output = Result<(), DP::Error>> + Send {
self.provider.delete(context, gid)
}
fn release(
&self,
context: &DP::Context,
id: DP::InternalId,
) -> impl Future<Output = Result<(), DP::Error>> + Send {
self.cache.evict(id);
self.provider.release(context, id)
}
fn status_by_internal_id(
&self,
context: &DP::Context,
id: DP::InternalId,
) -> impl Future<Output = Result<ElementStatus, DP::Error>> + Send {
self.provider.status_by_internal_id(context, id)
}
fn status_by_external_id(
&self,
context: &DP::Context,
gid: &DP::ExternalId,
) -> impl Future<Output = Result<ElementStatus, DP::Error>> + Send {
self.provider.status_by_external_id(context, gid)
}
}
impl<DP, C, T> SetElement<T> for CachingProvider<DP, C>
where
DP: SetElement<T>,
T: Send + Sync,
C: AsyncFriendly + Evict<DP::InternalId>,
{
type SetError = DP::SetError;
async fn set_element(
&self,
context: &Self::Context,
id: &Self::ExternalId,
element: T,
) -> Result<Self::Guard, Self::SetError> {
use diskann::provider::Guard;
let guard = self.provider.set_element(context, id, element).await?;
self.cache.evict(guard.id());
Ok(guard)
}
}
impl<A, C> HasId for CachingAccessor<A, C>
where
A: HasId,
{
type Id = A::Id;
}
impl<A, C> NeighborAccessor for CachingAccessor<A, &mut C>
where
A: NeighborAccessor,
C: NeighborCache<A::Id>,
{
async fn get_neighbors(
mut self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> ANNResult<Self> {
let status = self
.cache
.try_get_neighbors(id, neighbors)
.into_ann_result()?;
if status != NeighborStatus::Hit {
self.inner = self.inner.get_neighbors(id, neighbors).await?;
if status != NeighborStatus::Uncacheable {
self.cache.set_neighbors(id, neighbors).into_ann_result()?;
}
}
Ok(self)
}
}
impl<A, C> NeighborAccessorMut for CachingAccessor<A, &mut C>
where
A: NeighborAccessorMut,
C: NeighborCache<A::Id>,
{
async fn set_neighbors(mut self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
self.inner = self.inner.set_neighbors(id, neighbors).await?;
self.cache.invalidate_neighbors(id);
Ok(self)
}
async fn append_vector(mut self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
self.inner = self.inner.append_vector(id, neighbors).await?;
self.cache.invalidate_neighbors(id);
Ok(self)
}
}
impl<'a, A, C> DelegateNeighbor<'a> for CachingAccessor<A, C>
where
A: DelegateNeighbor<'a>,
C: NeighborCache<Self::Id>,
{
type Delegate = CachingAccessor<A::Delegate, &'a mut C>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
CachingAccessor::new(self.inner.delegate_neighbor(), &mut self.cache)
}
}
impl<A, C> Accessor for CachingAccessor<A, C>
where
A: CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
type Element<'a>
= A::Element<'a>
where
Self: 'a;
type ElementRef<'a> = A::ElementRef<'a>;
type GetError = CachingError<A::GetError, C::Error>;
async fn get_element(&mut self, id: Self::Id) -> Result<A::Element<'_>, Self::GetError> {
get_or_insert(&mut self.inner, &mut self.cache, id)
.send()
.await
}
}
impl<A, C> BuildDistanceComputer for CachingAccessor<A, C>
where
A: BuildDistanceComputer + CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
type DistanceComputerError = A::DistanceComputerError;
type DistanceComputer = A::DistanceComputer;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError> {
self.inner.build_distance_computer()
}
}
impl<T, A, C> BuildQueryComputer<T> for CachingAccessor<A, C>
where
A: BuildQueryComputer<T> + CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
type QueryComputerError = A::QueryComputerError;
type QueryComputer = A::QueryComputer;
fn build_query_computer(
&self,
from: T,
) -> Result<Self::QueryComputer, Self::QueryComputerError> {
self.inner.build_query_computer(from)
}
}
impl<A, C, State> workingset::Fill<Cached<State>> for CachingAccessor<A, C>
where
A: workingset::Fill<State>,
A: CacheableAccessor + CachedFill<C, State>,
C: ElementCache<A::Id, A::Map>,
{
type Error = CachingError<A::Error, C::Error>;
type View<'a>
= A::View<'a>
where
Self: 'a,
State: 'a;
fn fill<'a, Itr>(
&'a mut self,
state: &'a mut Cached<State>,
itr: Itr,
) -> impl SendFuture<Result<Self::View<'a>, Self::Error>>
where
Itr: ExactSizeIterator<Item = Self::Id> + Clone + Send + Sync,
Self: 'a,
{
self.inner
.cached_fill(&mut self.cache, &mut state.strategy, itr)
}
}
impl<A, C, T> ExpandBeam<T> for CachingAccessor<A, C>
where
A: BuildQueryComputer<T> + CacheableAccessor + AsNeighbor,
C: ElementCache<A::Id, A::Map> + NeighborCache<A::Id>,
{
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Unwrap;
impl<A, C, T> SearchPostProcessStep<CachingAccessor<A, C>, T> for Unwrap
where
A: BuildQueryComputer<T> + CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
type Error<NextError>
= NextError
where
NextError: StandardError;
type NextAccessor = A;
fn post_process_step<I, B, Next>(
&self,
next: &Next,
accessor: &mut CachingAccessor<A, C>,
query: T,
computer: &<A as BuildQueryComputer<T>>::QueryComputer,
candidates: I,
output: &mut B,
) -> impl Future<Output = Result<usize, Self::Error<Next::Error>>> + Send
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<A::Id> + Send + ?Sized,
Next: glue::SearchPostProcess<Self::NextAccessor, T, A::Id> + Sync,
{
next.post_process(&mut accessor.inner, query, computer, candidates, output)
}
}
impl<A, C> SearchExt for CachingAccessor<A, C>
where
A: SearchExt + CacheableAccessor,
C: ElementCache<A::Id, A::Map>,
{
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<Self::Id>>> + Send {
self.inner.starting_points()
}
fn terminate_early(&mut self) -> bool {
self.inner.terminate_early()
}
fn is_not_start_point(
&self,
) -> impl Future<Output = ANNResult<impl Fn(Self::Id) -> bool + Send + Sync + 'static>> + Send
{
self.inner.is_not_start_point()
}
}
type SearchAccessor<'a, S, DP, T> = <S as SearchStrategy<DP, T>>::SearchAccessor<'a>;
type PruneAccessor<'a, S, DP> = <S as PruneStrategy<DP>>::PruneAccessor<'a>;
impl<DP, C, T, S, E> SearchStrategy<CachingProvider<DP, C>, T> for Cached<S>
where
DP: DataProvider,
S: for<'a> SearchStrategy<DP, T, SearchAccessor<'a>: CacheableAccessor>,
C: for<'a> AsCacheAccessorFor<
'a,
SearchAccessor<'a, S, DP, T>,
Accessor: NeighborCache<DP::InternalId>,
Error = E,
> + AsyncFriendly,
E: StandardError,
{
type QueryComputer = S::QueryComputer;
type SearchAccessor<'a> = CachingAccessor<
SearchAccessor<'a, S, DP, T>,
<C as AsCacheAccessorFor<'a, SearchAccessor<'a, S, DP, T>>>::Accessor,
>;
type SearchAccessorError = CachingError<S::SearchAccessorError, E>;
fn search_accessor<'a>(
&'a self,
provider: &'a CachingProvider<DP, C>,
context: &'a DP::Context,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
let inner = self
.strategy
.search_accessor(&provider.provider, context)
.map_err(CachingError::Inner)?;
provider
.cache
.as_cache_accessor_for(inner)
.map_err(CachingError::Cache)
}
}
impl<DP, C, T, S, E> glue::DefaultPostProcessor<CachingProvider<DP, C>, T> for Cached<S>
where
DP: DataProvider,
S: glue::DefaultPostProcessor<DP, T>
+ for<'a> SearchStrategy<DP, T, SearchAccessor<'a>: CacheableAccessor>,
C: for<'a> AsCacheAccessorFor<
'a,
SearchAccessor<'a, S, DP, T>,
Accessor: NeighborCache<DP::InternalId>,
Error = E,
> + AsyncFriendly,
E: StandardError,
{
type Processor = Pipeline<Unwrap, S::Processor>;
fn default_post_processor(&self) -> Self::Processor {
Pipeline::new(Unwrap, self.strategy.default_post_processor())
}
}
impl<DP, C, S, E> PruneStrategy<CachingProvider<DP, C>> for Cached<S>
where
DP: DataProvider,
S: for<'a> PruneStrategy<DP, PruneAccessor<'a>: CacheableAccessor>,
C: for<'a> AsCacheAccessorFor<
'a,
PruneAccessor<'a, S, DP>,
Accessor: NeighborCache<DP::InternalId>,
Error = E,
> + AsyncFriendly,
for<'a> S::PruneAccessor<'a>: CachedFill<<C as AsCacheAccessorFor<'a, PruneAccessor<'a, S, DP>>>::Accessor, S::WorkingSet>,
E: StandardError,
{
type WorkingSet = Cached<S::WorkingSet>;
type DistanceComputer = S::DistanceComputer;
type PruneAccessor<'a> = CachingAccessor<
PruneAccessor<'a, S, DP>,
<C as AsCacheAccessorFor<'a, PruneAccessor<'a, S, DP>>>::Accessor,
>;
type PruneAccessorError = CachingError<S::PruneAccessorError, E>;
fn prune_accessor<'a>(
&'a self,
provider: &'a CachingProvider<DP, C>,
context: &'a DP::Context,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
let inner = self
.strategy
.prune_accessor(&provider.provider, context)
.map_err(CachingError::Inner)?;
provider
.cache
.as_cache_accessor_for(inner)
.map_err(CachingError::Cache)
}
fn create_working_set(&self, capacity: usize) -> Self::WorkingSet {
Cached::new(self.strategy.create_working_set(capacity))
}
}
impl<DP, C, T, S> InsertStrategy<CachingProvider<DP, C>, T> for Cached<S>
where
DP: DataProvider,
S: InsertStrategy<DP, T>,
Cached<S>: SearchStrategy<CachingProvider<DP, C>, T>,
Cached<S::PruneStrategy>: PruneStrategy<CachingProvider<DP, C>>,
C: AsyncFriendly,
{
type PruneStrategy = Cached<S::PruneStrategy>;
fn prune_strategy(&self) -> Self::PruneStrategy {
Cached {
strategy: self.strategy.prune_strategy(),
}
}
}
impl<DP, C, S, E> InplaceDeleteStrategy<CachingProvider<DP, C>> for Cached<S>
where
DP: DataProvider,
S: InplaceDeleteStrategy<DP>,
for<'a> S::DeleteSearchAccessor<'a>: CacheableAccessor,
Cached<S::PruneStrategy>: PruneStrategy<CachingProvider<DP, C>>,
for<'a> Cached<S::SearchStrategy>: SearchStrategy<
CachingProvider<DP, C>,
S::DeleteElement<'a>,
SearchAccessor<'a> = CachingAccessor<
S::DeleteSearchAccessor<'a>,
<C as AsCacheAccessorFor<'a, S::DeleteSearchAccessor<'a>>>::Accessor,
>,
>,
C: for<'a> AsCacheAccessorFor<
'a,
S::DeleteSearchAccessor<'a>,
Accessor: NeighborCache<DP::InternalId>,
Error = E,
> + AsyncFriendly,
E: StandardError,
{
type DeleteElement<'a> = S::DeleteElement<'a>;
type DeleteElementGuard = S::DeleteElementGuard;
type DeleteElementError = S::DeleteElementError;
type PruneStrategy = Cached<S::PruneStrategy>;
type DeleteSearchAccessor<'a> = CachingAccessor<
S::DeleteSearchAccessor<'a>,
<C as AsCacheAccessorFor<'a, S::DeleteSearchAccessor<'a>>>::Accessor,
>;
type SearchStrategy = Cached<S::SearchStrategy>;
type SearchPostProcessor = Pipeline<Unwrap, S::SearchPostProcessor>;
fn prune_strategy(&self) -> Self::PruneStrategy {
Cached {
strategy: self.strategy.prune_strategy(),
}
}
fn search_strategy(&self) -> Self::SearchStrategy {
Cached {
strategy: self.strategy.search_strategy(),
}
}
fn search_post_processor(&self) -> Self::SearchPostProcessor {
Pipeline::new(Unwrap, self.strategy.search_post_processor())
}
fn get_delete_element<'a>(
&'a self,
provider: &'a CachingProvider<DP, C>,
context: &'a DP::Context,
id: DP::InternalId,
) -> impl Future<Output = Result<Self::DeleteElementGuard, Self::DeleteElementError>> + Send
{
self.strategy
.get_delete_element(&provider.provider, context, id)
}
}
impl<DP, C, S, B> MultiInsertStrategy<CachingProvider<DP, C>, B> for Cached<S>
where
DP: DataProvider,
B: Batch,
S: MultiInsertStrategy<DP, B>,
Cached<S::InsertStrategy>: for<'a> InsertStrategy<
CachingProvider<DP, C>,
B::Element<'a>,
PruneStrategy: PruneStrategy<
CachingProvider<DP, C>,
WorkingSet = Cached<S::WorkingSet>,
>,
>,
C: AsyncFriendly,
{
type Seed = Cached<S::Seed>;
type WorkingSet = Cached<S::WorkingSet>;
type FinishError = S::FinishError;
type InsertStrategy = Cached<S::InsertStrategy>;
fn insert_strategy(&self) -> Self::InsertStrategy {
Cached {
strategy: self.strategy.insert_strategy(),
}
}
fn finish<Itr>(
&self,
provider: &CachingProvider<DP, C>,
context: &DP::Context,
batch: &Arc<B>,
ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = DP::InternalId> + Send,
{
self.strategy
.finish(provider.inner(), context, batch, ids)
.map(|r| r.map(Cached::new))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
fmt::Display,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use diskann::{
ANNError,
error::{RankedError, ToRanked, TransientError},
};
#[derive(Debug, Default)]
struct Counters {
acknowledge: AtomicUsize,
acknowledge_with: AtomicUsize,
escalate: AtomicUsize,
escalate_with: AtomicUsize,
}
#[derive(Debug)]
struct TransientErr {
counters: Arc<Counters>,
token: usize,
}
impl TransientErr {
fn new(counters: &Arc<Counters>, token: usize) -> Self {
Self {
counters: counters.clone(),
token,
}
}
}
#[derive(Debug, Error)]
#[error("super critical error: {0}")]
struct Critical(usize);
impl From<Critical> for ANNError {
fn from(err: Critical) -> Self {
ANNError::opaque(err)
}
}
#[derive(Debug)]
enum Generic {
Transient(TransientErr),
Critical(Critical),
}
impl TransientError<Critical> for TransientErr {
fn acknowledge<D>(self, _why: D)
where
D: Display,
{
self.counters.acknowledge.fetch_add(1, Ordering::Relaxed);
}
fn acknowledge_with<F, D>(self, _why: F)
where
F: FnOnce() -> D,
D: Display,
{
self.counters
.acknowledge_with
.fetch_add(1, Ordering::Relaxed);
}
fn escalate<D>(self, _why: D) -> Critical
where
D: Display,
{
self.counters.escalate.fetch_add(1, Ordering::Relaxed);
Critical(self.token)
}
fn escalate_with<F, D>(self, _why: F) -> Critical
where
F: FnOnce() -> D,
D: Display,
{
self.counters.escalate_with.fetch_add(1, Ordering::Relaxed);
Critical(self.token)
}
}
impl ToRanked for Generic {
type Transient = TransientErr;
type Error = Critical;
fn to_ranked(self) -> RankedError<TransientErr, Critical> {
match self {
Self::Transient(e) => RankedError::Transient(e),
Self::Critical(e) => RankedError::Error(e),
}
}
fn from_transient(transient: TransientErr) -> Self {
Self::Transient(transient)
}
fn from_error(error: Critical) -> Self {
Self::Critical(error)
}
}
#[derive(Debug, Error)]
#[error("always a critical error")]
struct AlwaysCritical;
impl From<AlwaysCritical> for ANNError {
fn from(err: AlwaysCritical) -> Self {
ANNError::opaque(err)
}
}
#[test]
fn test_caching_error() {
type TestError = CachingError<Critical, AlwaysCritical>;
let err = CachingError::<Generic, AlwaysCritical>::Cache(AlwaysCritical);
assert!(matches!(
err.to_ranked(),
RankedError::Error(CachingError::Cache(AlwaysCritical))
));
let counters = Arc::new(Counters::default());
let make_transient = || Transient(TransientErr::new(&counters, 10));
<_ as TransientError<TestError>>::acknowledge(make_transient(), "");
assert_eq!(counters.acknowledge.load(Ordering::Relaxed), 1);
<_ as TransientError<TestError>>::acknowledge_with(make_transient(), || "");
assert_eq!(counters.acknowledge_with.load(Ordering::Relaxed), 1);
let err = <_ as TransientError<TestError>>::escalate(make_transient(), "").expect_inner();
assert_eq!(counters.escalate.load(Ordering::Relaxed), 1);
assert_eq!(err.0, 10);
let err =
<_ as TransientError<TestError>>::escalate_with(make_transient(), || "").expect_inner();
assert_eq!(counters.escalate.load(Ordering::Relaxed), 1);
assert_eq!(err.0, 10);
}
#[test]
fn test_caching_error_to_ranked() {
type Top = CachingError<Generic, AlwaysCritical>;
type Crit = CachingError<Critical, AlwaysCritical>;
let err = Top::Cache(AlwaysCritical);
assert!(
matches!(
err.to_ranked(),
RankedError::Error(CachingError::<Critical, AlwaysCritical>::Cache(
AlwaysCritical
))
),
"cache errors are always critical"
);
assert!(
matches!(Top::from_error(Crit::Cache(AlwaysCritical)), Top::Cache(_)),
"reassembling from Cache should preserve Cache"
);
let counters = Arc::new(Counters::default());
let err = Top::Inner(Generic::Transient(TransientErr::new(&counters, 5)));
assert!(
matches!(
err.to_ranked(),
RankedError::Transient(Transient(TransientErr { .. }))
),
"transient inner errors are transient"
);
assert!(
matches!(
Top::from_transient(Transient(TransientErr::new(&counters, 5))),
Top::Inner(Generic::Transient(_)),
),
"transient errors are still tranient",
);
let err = Top::Inner(Generic::Critical(Critical(2)));
assert!(
matches!(
err.to_ranked(),
RankedError::Error(CachingError::<Critical, AlwaysCritical>::Inner(_))
),
"critical errors are critical"
);
assert!(
matches!(
Top::from_error(Crit::Inner(Critical(2))),
Top::Inner(Generic::Critical(_)),
),
"critical errors are still critical",
);
}
}