use diskann_utils::views::{self, Matrix};
use crate::recall;
#[derive(Debug)]
pub struct ResultIds<I> {
inner: ResultIdsInner<I>,
}
impl<I> ResultIds<I> {
pub fn as_rows(&self) -> &dyn recall::Rows<I> {
self.inner.as_rows()
}
pub(crate) fn new(inner: ResultIdsInner<I>) -> Self {
Self { inner }
}
}
#[derive(Debug)]
pub(crate) struct Bounded<I> {
ids: Matrix<I>,
lengths: Vec<usize>,
}
impl<I> Bounded<I> {
pub(crate) fn new(ids: Matrix<I>, lengths: Vec<usize>) -> Self {
assert_eq!(
ids.nrows(),
lengths.len(),
"an internal invariant was not upheld",
);
Self { ids, lengths }
}
pub(crate) fn len(&self) -> usize {
self.lengths.len()
}
pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = &[I]> {
std::iter::zip(self.ids.row_iter(), self.lengths.iter()).map(|(row, len)| {
match row.get(..*len) {
Some(v) => v,
None => row,
}
})
}
}
impl<I> recall::Rows<I> for Bounded<I> {
fn nrows(&self) -> usize {
self.len()
}
fn row(&self, index: usize) -> &[I] {
let length = self.lengths[index];
let row = self.ids.row(index);
match row.get(..length) {
Some(v) => v,
None => row,
}
}
fn ncols(&self) -> Option<usize> {
None
}
}
#[derive(Debug)]
pub(crate) enum ResultIdsInner<I> {
Fixed(Bounded<I>),
Dynamic(Vec<Vec<I>>),
}
impl<I> ResultIdsInner<I> {
pub(crate) fn as_rows(&self) -> &dyn recall::Rows<I> {
match self {
Self::Fixed(bounded) => bounded,
Self::Dynamic(ids) => ids,
}
}
}
#[derive(Debug, Default)]
pub(crate) enum IdAggregator<I> {
#[default]
Empty,
Fixed {
matrices: Vec<Bounded<I>>,
len: usize,
num_ids: usize,
},
Dynamic(Vec<ResultIdsInner<I>>),
}
impl<I> IdAggregator<I>
where
I: Clone + Default,
{
pub(crate) fn new() -> Self {
Self::Empty
}
pub(crate) fn push(&mut self, ids: ResultIdsInner<I>) {
*self = match std::mem::take(self) {
Self::Empty => match ids {
ResultIdsInner::Fixed(bounded) => {
let len = bounded.ids.nrows();
let num_ids = bounded.ids.ncols();
Self::Fixed {
matrices: vec![bounded],
len,
num_ids,
}
}
ResultIdsInner::Dynamic(ids) => Self::Dynamic(vec![ResultIdsInner::Dynamic(ids)]),
},
Self::Fixed {
mut matrices,
len,
num_ids,
} => match ids {
ResultIdsInner::Fixed(bounded) => {
if bounded.ids.ncols() == num_ids {
let len = len + bounded.len();
matrices.push(bounded);
Self::Fixed {
matrices,
len,
num_ids,
}
} else {
let mut dynamic: Vec<_> =
matrices.into_iter().map(ResultIdsInner::Fixed).collect();
dynamic.push(ResultIdsInner::Fixed(bounded));
Self::Dynamic(dynamic)
}
}
ResultIdsInner::Dynamic(ids) => {
let mut dynamic: Vec<_> =
matrices.into_iter().map(ResultIdsInner::Fixed).collect();
dynamic.push(ResultIdsInner::Dynamic(ids));
Self::Dynamic(dynamic)
}
},
Self::Dynamic(mut dynamic) => {
dynamic.push(ids);
Self::Dynamic(dynamic)
}
};
}
pub(crate) fn finish(self) -> ResultIds<I> {
match self {
Self::Empty => ResultIds::new(ResultIdsInner::Dynamic(Vec::new())),
Self::Fixed {
matrices,
len,
num_ids,
} => {
let mut dst = Matrix::new(views::Init(|| I::default()), len, num_ids);
let mut lengths = Vec::with_capacity(len);
let mut output_row = 0;
for bounded in matrices {
for row in bounded.ids.row_iter() {
dst.row_mut(output_row).clone_from_slice(row);
output_row += 1;
}
lengths.extend_from_slice(&bounded.lengths);
}
ResultIds::new(ResultIdsInner::Fixed(Bounded::new(dst, lengths)))
}
Self::Dynamic(all) => {
let mut dst = Vec::<Vec<I>>::new();
for ids in all {
match ids {
ResultIdsInner::Fixed(bounded) => {
bounded.iter().for_each(|row| dst.push(row.into()));
}
ResultIdsInner::Dynamic(dynamic) => {
dynamic.into_iter().for_each(|i| dst.push(i));
}
}
}
ResultIds::new(ResultIdsInner::Dynamic(dst))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::recall::Rows;
fn make_bounded(data: Vec<Vec<u32>>) -> Bounded<u32> {
let nrows = data.len();
let ncols = data.iter().map(|v| v.len()).max().unwrap_or(0);
let mut matrix = Matrix::new(0u32, nrows, ncols);
let mut lengths = Vec::with_capacity(nrows);
for (row, row_data) in std::iter::zip(matrix.row_iter_mut(), data.iter()) {
let len = std::iter::zip(row.iter_mut(), row_data.iter())
.map(|(dst, src)| {
*dst = *src;
})
.count();
lengths.push(len);
}
Bounded::new(matrix, lengths)
}
#[test]
fn test_bounded_new_valid() {
let matrix = Matrix::new(0u32, 3, 5);
let lengths = vec![2, 3, 1];
let bounded = Bounded::new(matrix, lengths);
assert_eq!(bounded.len(), 3);
}
#[test]
fn test_bounded_length_clamping() {
let matrix = Matrix::new(0u32, 3, 3);
let lengths = vec![2, 3, 5]; let bounded = Bounded::new(matrix, lengths);
assert_eq!(bounded.len(), 3);
assert_eq!(bounded.row(0), &[0, 0]);
assert_eq!(bounded.row(1), &[0, 0, 0]);
assert_eq!(bounded.row(2), &[0, 0, 0]);
let rows: Vec<&[u32]> = bounded.iter().collect();
assert_eq!(rows[0], &[0, 0]);
assert_eq!(rows[1], &[0, 0, 0]);
assert_eq!(rows[2], &[0, 0, 0]); }
#[test]
#[should_panic(expected = "an internal invariant was not upheld")]
fn test_bounded_new_mismatched_lengths() {
let matrix = Matrix::new(0u32, 3, 5);
let lengths = vec![2, 3]; Bounded::new(matrix, lengths);
}
#[test]
fn test_bounded() {
let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5], vec![6]]);
assert_eq!(bounded.len(), 3);
assert_eq!(bounded.nrows(), 3);
assert_eq!(bounded.row(0), &[1, 2]);
assert_eq!(bounded.row(1), &[3, 4, 5]);
assert_eq!(bounded.row(2), &[6]);
assert_eq!(bounded.ncols(), None);
let rows: Vec<&[u32]> = bounded.iter().collect();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], &[1, 2]);
assert_eq!(rows[1], &[3, 4, 5]);
assert_eq!(rows[2], &[6]);
}
#[test]
fn test_result_ids_inner_fixed() {
let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5]]);
let inner = ResultIdsInner::Fixed(bounded);
let rows = inner.as_rows();
assert_eq!(rows.nrows(), 2);
assert_eq!(rows.row(0), &[1, 2]);
assert_eq!(rows.row(1), &[3, 4, 5]);
}
#[test]
fn test_result_ids_inner_dynamic() {
let vecs = vec![vec![1, 2, 3], vec![4], vec![5, 6]];
let inner = ResultIdsInner::Dynamic(vecs);
let rows = inner.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2, 3]);
assert_eq!(rows.row(1), &[4]);
assert_eq!(rows.row(2), &[5, 6]);
}
#[test]
fn test_result_ids_wrapper() {
let bounded = make_bounded(vec![vec![10], vec![20, 30]]);
let result = ResultIds::new(ResultIdsInner::Fixed(bounded));
let rows = result.as_rows();
assert_eq!(rows.nrows(), 2);
assert_eq!(rows.row(0), &[10]);
assert_eq!(rows.row(1), &[20, 30]);
}
#[test]
fn test_aggregator_empty_finish() {
let aggregator = IdAggregator::<u32>::new();
let result = aggregator.finish();
let rows = result.as_rows();
assert_eq!(rows.nrows(), 0);
assert_eq!(rows.ncols(), None);
}
#[test]
fn test_aggregator_empty_to_fixed() {
let mut aggregator = IdAggregator::new();
let bounded = make_bounded(vec![vec![1, 2], vec![3], vec![4, 5]]);
aggregator.push(ResultIdsInner::Fixed(bounded));
match aggregator {
IdAggregator::Fixed { len, num_ids, .. } => {
assert_eq!(len, 3);
assert_eq!(num_ids, 2);
}
_ => panic!("Expected Fixed state"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2]);
assert_eq!(rows.row(1), &[3]);
assert_eq!(rows.row(2), &[4, 5]);
}
#[test]
fn test_aggregator_empty_to_dynamic() {
let mut aggregator = IdAggregator::new();
let vecs = vec![vec![1, 2, 3], vec![4]];
aggregator.push(ResultIdsInner::Dynamic(vecs));
match aggregator {
IdAggregator::Dynamic(ref inner) => {
assert_eq!(inner.len(), 1);
}
_ => panic!("Expected Dynamic state"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 2);
assert_eq!(rows.row(0), &[1, 2, 3]);
assert_eq!(rows.row(1), &[4]);
}
#[test]
fn test_aggregator_fixed_stays_fixed_same_size() {
let mut aggregator = IdAggregator::new();
let bounded1 = make_bounded(vec![vec![1, 2, 3], vec![4, 5]]);
aggregator.push(ResultIdsInner::Fixed(bounded1));
let bounded2 = make_bounded(vec![vec![6, 7, 8]]);
aggregator.push(ResultIdsInner::Fixed(bounded2));
match &aggregator {
IdAggregator::Fixed {
len,
num_ids,
matrices,
} => {
assert_eq!(*len, 3); assert_eq!(*num_ids, 3);
assert_eq!(matrices.len(), 2);
}
_ => panic!("Expected Fixed state"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2, 3]);
assert_eq!(rows.row(1), &[4, 5]);
assert_eq!(rows.row(2), &[6, 7, 8]);
}
#[test]
fn test_aggregator_fixed_to_dynamic_different_sizes() {
let mut aggregator = IdAggregator::new();
let bounded1 = make_bounded(vec![vec![1, 2], vec![3, 4]]);
aggregator.push(ResultIdsInner::Fixed(bounded1));
let bounded2 = make_bounded(vec![vec![5, 6, 7]]);
aggregator.push(ResultIdsInner::Fixed(bounded2));
match aggregator {
IdAggregator::Dynamic(ref inner) => {
assert_eq!(inner.len(), 2);
}
_ => panic!("Expected Dynamic state after size mismatch"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2]);
assert_eq!(rows.row(1), &[3, 4]);
assert_eq!(rows.row(2), &[5, 6, 7]);
}
#[test]
fn test_aggregator_fixed_to_dynamic_incoming_dynamic() {
let mut aggregator = IdAggregator::new();
let bounded = make_bounded(vec![vec![1, 2], vec![3, 4]]);
aggregator.push(ResultIdsInner::Fixed(bounded));
let vecs = vec![vec![5, 6, 7]];
aggregator.push(ResultIdsInner::Dynamic(vecs));
match aggregator {
IdAggregator::Dynamic(ref inner) => {
assert_eq!(inner.len(), 2);
}
_ => panic!("Expected Dynamic state"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2]);
assert_eq!(rows.row(1), &[3, 4]);
assert_eq!(rows.row(2), &[5, 6, 7]);
}
#[test]
fn test_aggregator_dynamic_stays_dynamic() {
let mut aggregator = IdAggregator::new();
let vecs1 = vec![vec![1, 2]];
aggregator.push(ResultIdsInner::Dynamic(vecs1));
let vecs2 = vec![vec![3, 4, 5]];
aggregator.push(ResultIdsInner::Dynamic(vecs2));
let bounded = make_bounded(vec![vec![6, 7]]);
aggregator.push(ResultIdsInner::Fixed(bounded));
match aggregator {
IdAggregator::Dynamic(ref inner) => {
assert_eq!(inner.len(), 3);
}
_ => panic!("Expected Dynamic state"),
}
let finished = aggregator.finish();
let rows = finished.as_rows();
assert_eq!(rows.nrows(), 3);
assert_eq!(rows.row(0), &[1, 2]);
assert_eq!(rows.row(1), &[3, 4, 5]);
assert_eq!(rows.row(2), &[6, 7]);
}
}