use diskann::graph::{SearchOutputBuffer, search_output_buffer::BufferState};
#[derive(Debug)]
pub(crate) struct Buffer<'a, I>(Inner<'a, I>);
#[derive(Debug)]
enum Inner<'a, I> {
Slice { slice: &'a mut [I], written: usize },
Vec(&'a mut Vec<I>),
}
impl<'a, I> Buffer<'a, I> {
pub(crate) fn slice(slice: &'a mut [I]) -> Self {
Self(Inner::Slice { slice, written: 0 })
}
pub(crate) fn vector(v: &'a mut Vec<I>) -> Self {
v.clear();
Self(Inner::Vec(v))
}
pub(crate) fn current_len(&self) -> usize {
match &self.0 {
Inner::Slice { written, .. } => *written,
Inner::Vec(vec) => vec.len(),
}
}
}
impl<I, D> SearchOutputBuffer<I, D> for Buffer<'_, I> {
fn size_hint(&self) -> Option<usize> {
match &self.0 {
Inner::Slice { slice, written } => Some(slice.len() - written),
Inner::Vec(_) => None,
}
}
fn current_len(&self) -> usize {
<Buffer<I>>::current_len(self)
}
fn push(&mut self, id: I, _distance: D) -> BufferState {
match &mut self.0 {
Inner::Slice { slice, written } => match slice.get_mut(*written) {
Some(slot) => {
*slot = id;
*written += 1;
if *written == slice.len() {
BufferState::Full
} else {
BufferState::Available
}
}
None => BufferState::Full,
},
Inner::Vec(vec) => {
vec.push(id);
BufferState::Available
}
}
}
fn extend<Itr>(&mut self, itr: Itr) -> usize
where
Itr: IntoIterator<Item = (I, D)>,
{
match &mut self.0 {
Inner::Slice { slice, written } => match slice.get_mut(*written..) {
Some(left) => {
let count = std::iter::zip(left.iter_mut(), itr)
.map(|(dst, src)| {
*dst = src.0;
})
.count();
*written += count;
count
}
None => 0,
},
Inner::Vec(vec) => {
let before = vec.len();
vec.extend(itr.into_iter().map(|i| i.0));
vec.len() - before
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn size_hint<T>(buffer: &T) -> Option<usize>
where
T: SearchOutputBuffer<u32, f32>,
{
buffer.size_hint()
}
#[test]
fn test_slice_buffer_creation() {
let mut data = [0u32; 5];
let buffer = Buffer::slice(&mut data);
assert_eq!(buffer.current_len(), 0);
assert_eq!(size_hint(&buffer), Some(5));
}
#[test]
fn test_vec_buffer_creation() {
let mut vec = vec![1u32, 2, 3, 4];
let buffer = Buffer::vector(&mut vec);
assert_eq!(buffer.current_len(), 0);
assert_eq!(size_hint(&buffer), None);
assert!(vec.is_empty(), "vector should be cleared on construction");
}
#[test]
fn test_slice_buffer_set_single_element() {
let mut data = [0u32; 5];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(42, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 1);
assert_eq!(data, [42, 0, 0, 0, 0]);
}
#[test]
fn test_slice_buffer_set_updates_written() {
let mut data = [0u32; 5];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.current_len(), 0);
assert_eq!(size_hint(&buffer), Some(5));
assert_eq!(buffer.push(100, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 1);
assert_eq!(size_hint(&buffer), Some(4));
assert_eq!(buffer.push(200, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 2);
assert_eq!(size_hint(&buffer), Some(3));
assert_eq!(buffer.push(300, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 3);
assert_eq!(size_hint(&buffer), Some(2));
assert_eq!(buffer.push(400, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 4);
assert_eq!(size_hint(&buffer), Some(1));
assert_eq!(buffer.push(500, 0.0), BufferState::Full);
assert_eq!(buffer.current_len(), 5);
assert_eq!(size_hint(&buffer), Some(0));
assert_eq!(buffer.push(600, 0.0), BufferState::Full);
assert_eq!(buffer.current_len(), 5);
assert_eq!(size_hint(&buffer), Some(0));
assert_eq!(data, [100, 200, 300, 400, 500]);
}
#[test]
fn test_slice_buffer_empty() {
let mut data = [0u32; 0];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(42, 0.0), BufferState::Full);
}
#[test]
fn test_vec_buffer_push() {
let mut vec = vec![0u32, 0, 0, 0, 0];
let mut buffer = Buffer::vector(&mut vec);
assert_eq!(
size_hint(&buffer),
None,
"vector-type buffers have no upper bound"
);
assert_eq!(buffer.push(42, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 1);
assert_eq!(buffer.push(50, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 2);
assert_eq!(buffer.push(3, 0.0), BufferState::Available);
assert_eq!(buffer.current_len(), 3);
assert_eq!(&vec, &[42, 50, 3]);
}
#[test]
fn test_slice_buffer_set_from() {
let mut data = [0u32; 5];
let mut buffer = Buffer::slice(&mut data);
let items = vec![(10, 1.0), (20, 2.0), (30, 3.0)];
let count = buffer.extend(items);
assert_eq!(count, 3);
assert_eq!(buffer.current_len(), 3);
assert_eq!(data, [10, 20, 30, 0, 0]);
}
#[test]
fn test_slice_buffer_set_from_more_than_capacity() {
let mut data = [0u32; 3];
let mut buffer = Buffer::slice(&mut data);
let items = vec![(10, 1.0), (20, 2.0), (30, 3.0), (40, 4.0), (50, 5.0)];
let count = buffer.extend(items);
assert_eq!(count, 3);
assert_eq!(buffer.current_len(), 3);
assert_eq!(data, [10, 20, 30]);
}
#[test]
fn test_vec_buffer_extend() {
let mut vec = Vec::<u32>::new();
let mut buffer = Buffer::vector(&mut vec);
let items = vec![(100, 1.0), (200, 2.0)];
let count = buffer.extend(items);
assert_eq!(count, 2);
assert_eq!(buffer.current_len(), 2);
assert_eq!(&vec, &[100, 200]);
}
#[test]
fn test_vec_buffer_extend_cascades() {
let mut vec = vec![1u32, 2, 3, 4, 5];
let mut buffer = Buffer::vector(&mut vec);
let items = vec![(10, 1.0), (20, 2.0)];
assert_eq!(buffer.extend(items), 2);
let items = vec![(21, 1.0), (22, 2.0)];
assert_eq!(buffer.extend(items), 2);
assert_eq!(
buffer.extend::<[(u32, f32); 0]>([]),
0,
"empty iterator should add nothing"
);
assert_eq!(&vec, &[10, 20, 21, 22]);
}
#[test]
fn test_slice_push_and_extend_combinations() {
let mut data = [0u32; 5];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(1, 0.0), BufferState::Available);
assert_eq!(buffer.push(2, 0.0), BufferState::Available);
assert_eq!(buffer.extend(vec![(3, 0.0), (4, 0.0)]), 2);
assert_eq!(data, [1, 2, 3, 4, 0]);
let mut data = [0u32; 5];
let mut buffer = Buffer::slice(&mut data);
buffer.extend(vec![(10, 0.0), (20, 0.0)]);
assert_eq!(buffer.push(30, 0.0), BufferState::Available);
assert_eq!(buffer.push(40, 0.0), BufferState::Available);
assert_eq!(buffer.push(50, 0.0), BufferState::Full);
assert_eq!(data, [10, 20, 30, 40, 50]);
let mut data = [0u32; 6];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(1, 0.0), BufferState::Available);
assert_eq!(buffer.extend(vec![(2, 0.0), (3, 0.0)]), 2);
assert_eq!(buffer.push(4, 0.0), BufferState::Available);
assert_eq!(buffer.extend(vec![(5, 0.0), (6, 0.0), (7, 0.0)]), 2);
assert_eq!(data, [1, 2, 3, 4, 5, 6]);
}
#[test]
fn test_slice_push_and_extend_at_capacity() {
let mut data = [0u32; 2];
let mut buffer = Buffer::slice(&mut data);
buffer.extend(vec![(1, 0.0), (2, 0.0)]);
assert_eq!(buffer.push(99, 0.0), BufferState::Full);
assert_eq!(data, [1, 2]);
let mut data = [0u32; 2];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(1, 0.0), BufferState::Available);
assert_eq!(buffer.push(2, 0.0), BufferState::Full);
assert_eq!(buffer.extend(vec![(99, 0.0)]), 0);
assert_eq!(data, [1, 2]);
let mut data = [0u32; 4];
let mut buffer = Buffer::slice(&mut data);
assert_eq!(buffer.push(1, 0.0), BufferState::Available);
assert_eq!(buffer.push(2, 0.0), BufferState::Available);
assert_eq!(
buffer.extend(vec![(3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)]),
2
);
assert_eq!(data, [1, 2, 3, 4]);
}
#[test]
fn test_vec_push_and_extend_combinations() {
let mut vec = Vec::<u32>::new();
let mut buffer = Buffer::vector(&mut vec);
assert_eq!(buffer.push(1, 0.0), BufferState::Available);
assert_eq!(buffer.extend(vec![(2, 0.0), (3, 0.0)]), 2);
assert_eq!(buffer.push(4, 0.0), BufferState::Available);
assert_eq!(buffer.extend::<[(u32, f32); 0]>([]), 0); assert_eq!(buffer.extend(vec![(5, 0.0)]), 1);
assert_eq!(&vec, &[1, 2, 3, 4, 5]);
}
}