use std::{ops::Range, sync::Arc};
use diskann::utils::IntoUsize;
use diskann_benchmark_core::{
recall::Rows,
streaming::{self, executors},
};
use diskann_benchmark_runner::{timed, utils::MicroSeconds};
use diskann_utils::views::{Matrix, MatrixView};
use crate::utils::streaming::TagSlotManager;
pub(crate) trait ManagedStream<T> {
type Output;
fn search(
&self,
queries: Arc<Matrix<T>>,
groundtruth: &dyn Rows<u32>,
) -> anyhow::Result<Self::Output>;
fn insert(&self, data: MatrixView<'_, T>, slots: &[u32]) -> anyhow::Result<Self::Output>;
fn replace(&self, data: MatrixView<'_, T>, slots: &[u32]) -> anyhow::Result<Self::Output>;
fn delete(&self, slots: &[u32]) -> anyhow::Result<Self::Output>;
fn maintain(&self) -> anyhow::Result<Self::Output>;
}
pub(crate) struct Managed<T, O> {
book_keeping: TagSlotManager,
translated: Vec<Vec<u32>>,
threshold: f32,
stream: Box<dyn ManagedStream<T, Output = O>>,
}
impl<T, O> Managed<T, O> {
pub(crate) fn new(
max: usize,
threshold: f32,
stream: impl ManagedStream<T, Output = O> + 'static,
) -> Self {
Self {
book_keeping: TagSlotManager::new(max),
translated: Vec::new(),
threshold,
stream: Box::new(stream),
}
}
}
impl<T, O> streaming::Stream<executors::bigann::DataArgs<T, u32>> for Managed<T, O>
where
T: 'static,
O: 'static,
{
type Output = Stats<O>;
fn search(
&mut self,
(queries, groundtruth): (Arc<Matrix<T>>, &dyn Rows<u32>),
) -> anyhow::Result<Self::Output> {
let (overhead, _): (_, ()) = timed! {
self.translated.resize(groundtruth.nrows(), Vec::new());
for (i, translated) in self.translated.iter_mut().enumerate() {
translated.clear();
for tag in groundtruth.row(i) {
if let Some(slot_id) = self.book_keeping.tag_to_slot.get(&tag.into_usize()) {
translated.push(*slot_id)
} else {
anyhow::bail!("Tag {} not found in tag-to-slot mapping", tag);
}
}
};
};
self.stream
.search(queries, &self.translated)
.map(|r| Stats::new(overhead, r))
}
fn insert(
&mut self,
(data, tags): (MatrixView<'_, T>, Range<usize>),
) -> anyhow::Result<Self::Output> {
let (overhead_get, slots) = timed!(self.book_keeping.get_n_empty_slots(tags.len())?);
let output = self.stream.insert(data, &slots)?;
let (overhead_assign, _) = timed!(self.book_keeping.assign_slots_to_tags(tags, slots)?);
Ok(Stats::new(overhead_get + overhead_assign, output))
}
fn replace(
&mut self,
(data, tags): (MatrixView<'_, T>, Range<usize>),
) -> anyhow::Result<Self::Output> {
let (overhead, slots) = timed!(self.book_keeping.find_slots_by_tags(tags)?);
self.stream
.replace(data, &slots)
.map(|r| Stats::new(overhead, r))
}
fn delete(&mut self, tags: Range<usize>) -> anyhow::Result<Self::Output> {
let (overhead_slots, slots) = timed!(self.book_keeping.find_slots_by_tags(tags.clone())?);
let output = self.stream.delete(&slots)?;
let (overhead_mark, _) = timed!(self.book_keeping.mark_tags_deleted(tags)?);
Ok(Stats::new(overhead_slots + overhead_mark, output))
}
fn maintain(&mut self, _: ()) -> anyhow::Result<Self::Output> {
let output = self.stream.maintain()?;
let (overhead, _) = timed!(self.book_keeping.consolidate());
Ok(Stats::new(overhead, output))
}
fn needs_maintenance(&mut self) -> bool {
let num_active = self.book_keeping.num_active();
let threshold = (num_active as f32 * self.threshold) as usize;
self.book_keeping.num_deleted() > threshold
}
}
#[derive(Debug, serde::Serialize)]
pub(crate) struct Stats<T> {
pub(crate) manager_overhead: MicroSeconds,
#[serde(flatten)]
pub(crate) inner: T,
}
impl<T> Stats<T> {
fn new(manager_overhead: MicroSeconds, inner: T) -> Self {
Self {
manager_overhead,
inner,
}
}
pub(crate) fn inner(&self) -> &T {
&self.inner
}
}
impl<T> std::fmt::Display for Stats<T>
where
T: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"Manager Overhead: {}s",
self.manager_overhead.as_seconds()
)?;
self.inner.fmt(f)
}
}