use std::ops::{Deref, DerefMut, Index};
use diskann_vector::contains::ContainsSimd;
#[derive(Debug, Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct AdjacencyList<I> {
edges: Vec<I>,
}
impl<I> AdjacencyList<I> {
pub fn new() -> Self {
Self { edges: Vec::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
edges: Vec::with_capacity(capacity),
}
}
pub fn capacity(&self) -> usize {
self.edges.capacity()
}
pub fn last(&self, count: usize) -> Option<&[I]> {
self.len().checked_sub(count).map(|start| {
unsafe { self.edges.get_unchecked(start..) }
})
}
pub fn retain<F>(&mut self, f: F)
where
F: FnMut(&I) -> bool,
{
self.edges.retain(f)
}
pub fn clear(&mut self) {
self.edges.clear();
}
pub fn truncate(&mut self, len: usize) {
self.edges.truncate(len)
}
}
impl<I> AdjacencyList<I>
where
I: Copy + std::fmt::Debug,
{
pub fn push(&mut self, i: I) -> bool
where
I: ContainsSimd,
{
if !self.contains(i) {
self.edges.push(i);
true
} else {
false
}
}
pub fn extend_from_slice(&mut self, is: &[I]) -> usize
where
I: ContainsSimd,
{
is.iter().filter(|&i| self.push(*i)).count()
}
pub fn contains(&self, i: I) -> bool
where
I: ContainsSimd,
{
I::contains_simd(self, i)
}
pub fn sort(&mut self)
where
I: Ord,
{
self.edges.sort_unstable()
}
pub fn from_iter_unique<Itr>(itr: Itr) -> Self
where
Itr: UniqueIter<Item = I>,
{
Self {
edges: itr.collect(),
}
}
pub fn resize(&mut self, capacity: usize) -> ResizeGuard<'_, I>
where
I: Default + ContainsSimd,
{
self.edges.resize(capacity, I::default());
ResizeGuard(self)
}
pub fn from_iter_untrusted<Itr>(itr: Itr) -> Self
where
Itr: IntoIterator<Item = I>,
I: Ord,
{
let mut edges: Vec<_> = itr.into_iter().collect();
edges.sort_unstable();
edges.dedup();
Self { edges }
}
pub fn overwrite_trusted(&mut self, is: &[I])
where
I: Clone + ContainsSimd,
{
self.clear();
self.edges.extend_from_slice(is);
self.debug_check_uniqueness();
}
pub fn remap_trusted<F>(&mut self, f: F)
where
F: FnMut(&mut I),
I: ContainsSimd,
{
self.edges.iter_mut().for_each(f);
self.debug_check_uniqueness();
}
pub fn all_unique(&self) -> bool
where
I: ContainsSimd,
{
let mut other = AdjacencyList::new();
other.extend_from_slice(self) == self.len()
}
fn debug_check_uniqueness(&self)
where
I: ContainsSimd,
{
#[cfg(debug_assertions)]
#[allow(clippy::panic)]
if !self.all_unique() {
panic!("duplicate items detected: {:?}", self);
}
}
}
impl<I, Idx> Index<Idx> for AdjacencyList<I>
where
Idx: std::slice::SliceIndex<[I]>,
{
type Output = Idx::Output;
fn index(&self, index: Idx) -> &Self::Output {
&self.edges[index]
}
}
impl<I> Default for AdjacencyList<I> {
fn default() -> Self {
Self::new()
}
}
impl<I> Deref for AdjacencyList<I> {
type Target = [I];
fn deref(&self) -> &Self::Target {
&self.edges
}
}
impl<I> From<AdjacencyList<I>> for Vec<I> {
fn from(list: AdjacencyList<I>) -> Self {
list.edges
}
}
impl<I> From<I> for AdjacencyList<I> {
fn from(value: I) -> Self {
Self { edges: vec![value] }
}
}
#[derive(Debug)]
pub struct ResizeGuard<'a, I>(&'a mut AdjacencyList<I>)
where
I: Copy + ContainsSimd;
impl<'a, I> ResizeGuard<'a, I>
where
I: Copy + ContainsSimd + std::fmt::Debug,
{
pub fn finish(self, at_most: usize) {
self.0.truncate(at_most);
self.0.debug_check_uniqueness();
std::mem::forget(self);
}
}
impl<I> Deref for ResizeGuard<'_, I>
where
I: Copy + ContainsSimd,
{
type Target = [I];
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl<I> DerefMut for ResizeGuard<'_, I>
where
I: Copy + ContainsSimd,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0.edges
}
}
impl<'a, I> Drop for ResizeGuard<'a, I>
where
I: Copy + ContainsSimd,
{
fn drop(&mut self) {
self.0.clear()
}
}
pub trait UniqueIter: Iterator {}
impl<K> UniqueIter for std::collections::hash_set::Iter<'_, K> {}
impl<K> UniqueIter for std::collections::hash_set::IntoIter<K> {}
impl<T> UniqueIter for std::iter::Once<T> {}
impl<K> UniqueIter for hashbrown::hash_set::IntoIter<K> {}
impl UniqueIter for std::ops::Range<u32> {}
impl UniqueIter for std::ops::Range<u64> {}
impl<I> UniqueIter for std::iter::Copied<I>
where
I: UniqueIter,
std::iter::Copied<I>: Iterator,
{
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rand::{
SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
#[test]
fn test_new() {
let x = AdjacencyList::<u32>::new();
assert_eq!(x.len(), 0);
assert!(x.is_empty());
assert_eq!(x.capacity(), 0);
}
#[test]
fn test_with_capacity() {
for cap in [0, 1, 2, 5, 10, 100] {
let mut x = AdjacencyList::<u32>::with_capacity(cap);
assert_eq!(x.len(), 0);
assert!(x.is_empty());
assert!(
x.capacity() >= cap,
"got {}, expected at least {}",
x.capacity(),
cap
);
let ptr = x.as_ptr();
for i in 0..cap {
assert!(x.push(i.try_into().unwrap()));
}
assert_eq!(x.len(), cap);
assert_eq!(ptr, x.as_ptr());
}
}
#[test]
fn test_last() {
let x = AdjacencyList::<u32>::from_iter_unique(0..10);
for i in 0..=10 {
let last = x.last(i).unwrap();
let expected: Vec<_> = ((10 - i) as u32..10).collect();
assert_eq!(last, &*expected);
}
for i in 11..15 {
assert!(x.last(i).is_none(), "failed for length {}", i);
}
}
#[test]
fn test_retain() {
let mut x = AdjacencyList::<u32>::new();
x.retain(|_| false);
assert!(x.is_empty());
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.retain(|_| false);
assert!(x.is_empty());
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.retain(|_| true);
assert_eq!(x.len(), 10);
assert_eq!(&*x, &*((0..10).collect::<Vec<u32>>()));
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.retain(|i| i % 2 == 0);
assert_eq!(x.len(), 5);
assert_eq!(&*x, &[0, 2, 4, 6, 8]);
}
#[test]
fn test_clear() {
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
assert!(x.all_unique());
let cap = x.capacity();
assert_eq!(x.len(), 10);
x.clear();
assert!(x.is_empty());
assert_eq!(x.capacity(), cap, "capacity should remain unchanged");
}
#[test]
fn test_truncate() {
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
assert!(x.all_unique());
let ptr = x.as_ptr();
for i in 0..10 {
let len = 10 - i;
x.truncate(len);
assert_eq!(x.len(), len);
assert_eq!(ptr, x.as_ptr(), "truncating should not reallocate");
assert_eq!(&*x, &*((0..len as u32).collect::<Vec<_>>()));
}
}
#[test]
fn test_to_vec() {
let x = AdjacencyList::<u32>::from_iter_unique(0..10);
let ptr = x.as_ptr();
let y: Vec<u32> = x.into();
assert_eq!(&*y, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(y.as_ptr(), ptr);
}
#[test]
fn test_push_directed() {
let mut x = AdjacencyList::<u32>::new();
assert!(x.push(10));
assert_eq!(&*x, &[10]);
assert!(!x.push(10));
assert_eq!(&*x, &[10]);
assert!(x.push(0));
assert_eq!(&*x, &[10, 0]);
assert!(x.push(12));
assert_eq!(&*x, &[10, 0, 12]);
assert!(!x.push(0));
assert_eq!(&*x, &[10, 0, 12]);
assert!(!x.push(12));
assert_eq!(&*x, &[10, 0, 12]);
x.sort();
assert_eq!(&*x, &[0, 10, 12]);
}
fn test_push_fuzz_impl(domain: Uniform<u32>, ntrials: usize, rng: &mut StdRng) {
let mut set = HashSet::new();
let mut list = AdjacencyList::new();
for _ in 0..ntrials {
let v = domain.sample(rng);
let should_insert = set.insert(v);
let inserted = list.push(v);
assert_eq!(should_insert, inserted);
assert_eq!(set.len(), list.len());
if inserted {
assert_eq!(list[list.len() - 1], v);
}
}
}
#[test]
fn test_push_fuzz() {
let mut rng = StdRng::seed_from_u64(0x50e02da44abc56c3);
let domain = Uniform::new(0, 100).unwrap();
for _ in 0..10 {
test_push_fuzz_impl(domain, 200, &mut rng);
}
}
#[test]
fn test_extend_from_slice() {
let mut x = AdjacencyList::from_iter_untrusted([1, 2, 3, 4]);
assert!(x.contains(1));
assert!(!x.contains(5));
assert!(!x.contains(9));
assert_eq!(x.extend_from_slice(&[1, 5, 9]), 2);
assert_eq!(&*x, &[1, 2, 3, 4, 5, 9]);
fn some(y: &[u32]) -> Option<&[u32]> {
Some(y)
}
assert_eq!(x.last(0), some(&[]));
assert_eq!(x.last(1), some(&[9]));
assert_eq!(x.last(2), some(&[5, 9]));
assert_eq!(x.extend_from_slice(&[1, 10, 9, 10, 8]), 2);
assert_eq!(&*x, &[1, 2, 3, 4, 5, 9, 10, 8]);
assert_eq!(x.last(2), some(&[10, 8]));
assert_eq!(x.extend_from_slice(&[]), 0);
assert_eq!(&*x, &[1, 2, 3, 4, 5, 9, 10, 8]);
}
fn test_extend_from_slice_fuzz_impl(
domain: Uniform<u32>,
length_distribution: Uniform<usize>,
ntrials: usize,
rng: &mut StdRng,
) {
let mut set = HashSet::new();
let mut list = AdjacencyList::new();
for _ in 0..ntrials {
let len = length_distribution.sample(rng);
let to_insert: Vec<_> = (0..len).map(|_| domain.sample(rng)).collect();
let should_be_inserted: Vec<u32> = to_insert
.iter()
.copied()
.filter(|i| set.insert(*i))
.collect();
let num_inserted = list.extend_from_slice(&to_insert);
assert_eq!(num_inserted, should_be_inserted.len());
assert_eq!(list.last(num_inserted).unwrap(), &*should_be_inserted);
assert_eq!(set.len(), list.len());
}
}
#[test]
fn test_extend_from_slice_fuzz() {
let mut rng = StdRng::seed_from_u64(0x50e02da44abc56c3);
let domain = Uniform::new(0, 100).unwrap();
let length_distribution = Uniform::new(0, 10).unwrap();
for _ in 0..10 {
test_extend_from_slice_fuzz_impl(domain, length_distribution, 50, &mut rng);
}
}
#[test]
fn test_from_iter_untrusted() {
let x = AdjacencyList::<u32>::from_iter_untrusted([]);
assert!(x.is_empty());
let x = AdjacencyList::<u32>::from_iter_untrusted([1]);
assert_eq!(&*x, &[1]);
let x = AdjacencyList::<u32>::from_iter_untrusted([2, 1]);
assert_eq!(&*x, &[1, 2]);
let x = AdjacencyList::<u32>::from_iter_untrusted([1, 2, 1]);
assert_eq!(&*x, &[1, 2]);
}
#[test]
fn test_overwrite() {
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.overwrite_trusted(&[]);
assert!(x.is_empty());
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.overwrite_trusted(&[10, 2, 3, 4]);
assert_eq!(&*x, &[10, 2, 3, 4]);
let mut x = AdjacencyList::<u32>::new();
x.overwrite_trusted(&[4, 3, 10, 9]);
assert_eq!(&*x, &[4, 3, 10, 9]);
}
#[test]
fn test_remap() {
let mut x = AdjacencyList::<u32>::from_iter_unique(0..10);
x.remap_trusted(|i| *i += 1);
assert_eq!(&*x, &*((1..11).collect::<Vec<u32>>()));
}
#[test]
fn test_resize() {
let mut x = AdjacencyList::<u32>::new();
{
let mut guard = x.resize(4);
assert_eq!(guard.len(), 4);
guard[0] = 1;
guard[1] = 2;
guard[2] = 3;
guard.finish(3);
}
assert_eq!(&*x, &[1, 2, 3]);
{
let _guard = x.resize(10);
}
assert!(x.is_empty());
{
let mut guard = x.resize(3);
guard.copy_from_slice(&[3, 2, 1]);
guard.finish(10);
}
assert_eq!(&*x, &[3, 2, 1]);
{
let guard = x.resize(10);
guard.finish(0);
}
assert!(x.is_empty());
}
}