use thiserror::Error;
use super::SortedNeighbors;
use crate::{
ANNError, ANNErrorKind, error, graph::AdjacencyList, neighbor::Neighbor, utils::VectorId,
};
#[derive(Debug, Clone, Copy)]
pub(crate) struct Options {
pub(in crate::graph) force_saturate: bool,
}
#[derive(Debug)]
pub(crate) struct Scratch<I>
where
I: VectorId,
{
pub(in crate::graph) pool: Vec<Neighbor<I>>,
pub(in crate::graph) occlude_factor: Vec<f32>,
pub(in crate::graph) last_checked: Vec<u16>,
pub(in crate::graph) neighbors: AdjacencyList<I>,
}
impl<I> Scratch<I>
where
I: VectorId,
{
pub(in crate::graph) fn new() -> Self {
Self {
pool: Vec::new(),
occlude_factor: Vec::new(),
neighbors: AdjacencyList::new(),
last_checked: Vec::new(),
}
}
pub(in crate::graph) fn as_context(&mut self, max_candidates: usize) -> Context<'_, I> {
Context {
pool: SortedNeighbors::new(&mut self.pool, max_candidates),
occlude_factor: &mut self.occlude_factor,
neighbors: &mut self.neighbors,
last_checked: &mut self.last_checked,
}
}
}
#[derive(Debug)]
pub(crate) struct Context<'ctx, I>
where
I: VectorId,
{
pub(in crate::graph) pool: SortedNeighbors<'ctx, I>,
pub(in crate::graph) occlude_factor: &'ctx mut Vec<f32>,
pub(in crate::graph) last_checked: &'ctx mut Vec<u16>,
pub(in crate::graph) neighbors: &'ctx mut AdjacencyList<I>,
}
#[derive(Debug, Clone, Copy, Error)]
#[error("retrieval of main vector id {} failed during prune aggregation", self.0)]
pub(crate) struct FailedVectorRetrieval<I>(I)
where
I: VectorId;
impl<I> error::TransientError<ANNError> for FailedVectorRetrieval<I>
where
I: VectorId,
{
fn acknowledge<D>(self, _why: D)
where
D: std::fmt::Display,
{
}
#[track_caller]
#[inline(never)]
fn escalate<D>(self, why: D) -> ANNError
where
D: std::fmt::Display,
{
ANNError::new(ANNErrorKind::IndexError, self).context(why.to_string())
}
}
#[derive(Debug)]
pub(crate) enum ListError<I>
where
I: VectorId,
{
FailedVectorRetrieval(FailedVectorRetrieval<I>),
Other(ANNError),
}
impl<I> ListError<I>
where
I: VectorId,
{
pub(in crate::graph) fn failed_retrieval(id: I) -> Self {
Self::FailedVectorRetrieval(FailedVectorRetrieval(id))
}
}
impl<I> From<ANNError> for ListError<I>
where
I: VectorId,
{
fn from(err: ANNError) -> Self {
Self::Other(err)
}
}
impl<I> error::ToRanked for ListError<I>
where
I: VectorId,
{
type Transient = FailedVectorRetrieval<I>;
type Error = ANNError;
fn to_ranked(self) -> error::RankedError<Self::Transient, Self::Error> {
match self {
Self::FailedVectorRetrieval(err) => error::RankedError::Transient(err),
Self::Other(err) => error::RankedError::Error(err),
}
}
fn from_transient(transient: Self::Transient) -> Self {
Self::FailedVectorRetrieval(transient)
}
fn from_error(error: Self::Error) -> Self {
Self::Other(error)
}
}