use std::{ops::Range, path::Path, sync::Arc};
use diskann_utils::views::{Matrix, MatrixView};
use crate::{recall, streaming};
use super::{Args, Delete, Insert, Replace, Search};
type LoadGroundtruth<I> = dyn FnMut(&Path) -> anyhow::Result<Box<dyn recall::Rows<I>>>;
pub struct WithData<T, I, Inner> {
inner: Inner,
dataset: Matrix<T>,
queries: Arc<Matrix<T>>,
load_groundtruth: Box<LoadGroundtruth<I>>,
}
impl<T, I, Inner> WithData<T, I, Inner> {
pub fn new(
inner: Inner,
dataset: Matrix<T>,
queries: Arc<Matrix<T>>,
load_groundtruth: impl FnMut(&Path) -> anyhow::Result<Box<dyn recall::Rows<I>>> + 'static,
) -> Self {
Self {
inner,
dataset,
queries,
load_groundtruth: Box::new(load_groundtruth),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DataArgs<T, I> {
_marker: std::marker::PhantomData<(T, I)>,
}
impl<T, I> streaming::Arguments for DataArgs<T, I>
where
T: 'static,
I: 'static,
{
type Search<'a> = (Arc<Matrix<T>>, &'a dyn recall::Rows<I>);
type Insert<'a> = (MatrixView<'a, T>, Range<usize>);
type Delete<'a> = Range<usize>;
type Replace<'a> = (MatrixView<'a, T>, Range<usize>);
type Maintain<'a> = ();
}
impl<T, I, Inner> streaming::Stream<Args> for WithData<T, I, Inner>
where
Inner: streaming::Stream<DataArgs<T, I>>,
T: 'static,
I: 'static,
{
type Output = Inner::Output;
fn search(&mut self, args: Search<'_>) -> anyhow::Result<Self::Output> {
let groundtruth = (self.load_groundtruth)(args.groundtruth)?;
self.inner.search((self.queries.clone(), &*groundtruth))
}
fn insert(&mut self, args: Insert) -> anyhow::Result<Self::Output> {
let data = self.dataset.subview(args.offsets).unwrap();
self.inner.insert((data, args.ids))
}
fn replace(&mut self, args: Replace) -> anyhow::Result<Self::Output> {
let data = self.dataset.subview(args.offsets).unwrap();
self.inner.replace((data, args.ids))
}
fn delete(&mut self, args: Delete) -> anyhow::Result<Self::Output> {
self.inner.delete(args.ids)
}
fn maintain(&mut self, _args: ()) -> anyhow::Result<Self::Output> {
self.inner.maintain(())
}
fn needs_maintenance(&mut self) -> bool {
self.inner.needs_maintenance()
}
}