#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::must_use_candidate)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::doc_markdown)]
use super::graph::GraphError;
use serde::{Deserialize, Serialize};
use std::vec::Vec;
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NeighborPool {
#[serde(with = "serde_bytes")]
pub(crate) buffer: Vec<u8>,
buckets: Vec<Vec<u32>>,
}
pub struct NeighborIter<'a> {
data: &'a [u8],
cursor: usize,
count: u32,
prev: u32,
current_idx: u32,
}
impl NeighborIter<'_> {
fn empty() -> Self {
Self {
data: &[],
cursor: 0,
count: 0,
prev: 0,
current_idx: 0,
}
}
}
impl Iterator for NeighborIter<'_> {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
if self.current_idx >= self.count || self.cursor >= self.data.len() {
return None;
}
let (delta, bytes) = NeighborPool::vbyte_decode(self.data, self.cursor);
self.cursor += bytes;
let val = self.prev.wrapping_add(delta);
self.prev = val;
self.current_idx += 1;
Some(val)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.current_idx) as usize;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for NeighborIter<'_> {}
impl NeighborPool {
const GRANULARITY: usize = 16;
const MAX_BUCKET_IDX: usize = 31;
#[must_use]
pub fn new() -> Self {
let mut buckets = Vec::with_capacity(Self::MAX_BUCKET_IDX + 1);
for _ in 0..=Self::MAX_BUCKET_IDX {
buckets.push(Vec::new());
}
Self {
buffer: Vec::new(),
buckets,
}
}
pub fn alloc(&mut self, size: usize) -> Result<(u32, u16), GraphError> {
if size > u16::MAX as usize {
return Err(GraphError::NeighborError);
}
let remainder = size % Self::GRANULARITY;
let pad = if remainder == 0 {
0
} else {
Self::GRANULARITY - remainder
};
let target_cap = size + pad;
let target_cap = if target_cap == 0 {
Self::GRANULARITY
} else {
target_cap
};
let bucket_idx = (target_cap / Self::GRANULARITY).saturating_sub(1);
if bucket_idx <= Self::MAX_BUCKET_IDX {
if let Some(offset) = self.buckets[bucket_idx].pop() {
return Ok((offset, target_cap as u16));
}
}
let current_len = self.buffer.len();
if current_len
.checked_add(target_cap)
.ok_or(GraphError::CapacityExceeded)?
> u32::MAX as usize
{
return Err(GraphError::CapacityExceeded);
}
let offset = u32::try_from(current_len).map_err(|_| GraphError::CapacityExceeded)?;
let capacity = u16::try_from(target_cap).map_err(|_| GraphError::NeighborError)?;
self.buffer.resize(current_len + target_cap, 0);
Ok((offset, capacity))
}
pub fn free(&mut self, offset: u32, capacity: u16) {
if capacity == 0 {
return;
}
let cap = capacity as usize;
if cap % Self::GRANULARITY == 0 {
let bucket_idx = (cap / Self::GRANULARITY).saturating_sub(1);
if bucket_idx <= Self::MAX_BUCKET_IDX {
self.buckets[bucket_idx].push(offset);
}
}
}
pub fn encode_neighbors(neighbors: &[u32]) -> Vec<u8> {
let mut buf = Vec::new();
Self::encode_neighbors_to_buf(neighbors, &mut buf);
buf
}
pub fn encode_neighbors_to_buf(neighbors: &[u32], buf: &mut Vec<u8>) {
if neighbors.is_empty() {
Self::vbyte_encode(0, buf);
return;
}
let mut sorted = neighbors.to_vec();
sorted.sort_unstable();
Self::vbyte_encode(u32::try_from(sorted.len()).unwrap_or(u32::MAX), buf);
let mut prev = 0u32;
for &curr in &sorted {
let delta = curr.wrapping_sub(prev);
Self::vbyte_encode(delta, buf);
prev = curr;
}
}
pub fn decode_neighbors(data: &[u8]) -> Vec<u32> {
let mut buf = Vec::new();
Self::decode_neighbors_to_buf(data, &mut buf);
buf
}
pub fn decode_neighbors_to_buf(data: &[u8], buf: &mut Vec<u32>) {
if data.is_empty() {
return;
}
let _ = Self::decode_one_list_to_buf(data, 0, buf);
}
pub fn decode_layer(data: &[u8], target_level: u8) -> Vec<u32> {
Self::iter_layer(data, target_level).collect()
}
pub fn decode_layer_to_buf(data: &[u8], target_level: u8, buf: &mut Vec<u32>) {
buf.extend(Self::iter_layer(data, target_level));
}
pub fn iter_layer(data: &[u8], target_level: u8) -> NeighborIter<'_> {
if data.is_empty() {
return NeighborIter::empty();
}
let mut cursor = 0;
let mut current_level = 0;
while cursor < data.len() {
if current_level == target_level {
let (count, bytes) = Self::vbyte_decode(data, cursor);
return NeighborIter {
data,
cursor: cursor + bytes,
count,
prev: 0,
current_idx: 0,
};
}
let (count, bytes) = Self::vbyte_decode(data, cursor);
cursor += bytes;
for _ in 0..count {
if cursor >= data.len() {
break;
}
let (_, b) = Self::vbyte_decode(data, cursor);
cursor += b;
}
current_level += 1;
}
NeighborIter::empty()
}
fn decode_one_list_to_buf(data: &[u8], mut cursor: usize, buf: &mut Vec<u32>) -> usize {
let start_cursor = cursor;
if cursor >= data.len() {
return 0;
}
let (count, bytes_read) = Self::vbyte_decode(data, cursor);
cursor += bytes_read;
if count == 0 {
return cursor - start_cursor;
}
buf.reserve(count as usize);
let mut prev = 0u32;
for _ in 0..count {
if cursor >= data.len() {
break; }
let (delta, bytes_read) = Self::vbyte_decode(data, cursor);
cursor += bytes_read;
let curr = prev.wrapping_add(delta);
buf.push(curr);
prev = curr;
}
cursor - start_cursor
}
fn vbyte_encode(mut val: u32, buf: &mut Vec<u8>) {
loop {
if val < 128 {
buf.push(val as u8);
break;
}
buf.push((val as u8 & 0x7F) | 0x80);
val >>= 7;
}
}
pub fn memory_usage(&self) -> usize {
let buffer_size = self.buffer.capacity();
let buckets_size = self
.buckets
.iter()
.map(|b| b.capacity() * std::mem::size_of::<u32>())
.sum::<usize>();
let buckets_overhead = self.buckets.capacity() * std::mem::size_of::<Vec<u32>>();
std::mem::size_of::<Self>() + buffer_size + buckets_size + buckets_overhead
}
fn vbyte_decode(data: &[u8], start: usize) -> (u32, usize) {
let mut val = 0u32;
let mut shift = 0;
let mut bytes_read = 0;
for byte in data.iter().skip(start) {
bytes_read += 1;
val |= u32::from(byte & 0x7F) << shift;
if byte & 0x80 == 0 {
return (val, bytes_read);
}
shift += 7;
if shift >= 35 {
return (val, bytes_read);
}
}
(val, bytes_read)
}
}
impl Default for NeighborPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vbyte_roundtrip() {
let original = vec![10, 100, 1000, 10000, 100_000, 1_000_000];
let encoded = NeighborPool::encode_neighbors(&original);
let decoded = NeighborPool::decode_neighbors(&encoded);
assert_eq!(original, decoded);
}
#[test]
fn test_vbyte_unsorted() {
let original = vec![50, 10, 30, 20];
let expected = vec![10, 20, 30, 50];
let encoded = NeighborPool::encode_neighbors(&original);
let decoded = NeighborPool::decode_neighbors(&encoded);
assert_eq!(decoded, expected);
}
#[test]
fn test_alloc_free_recycle() {
let mut pool = NeighborPool::new();
let (offset1, cap1) = pool.alloc(30).unwrap();
assert_eq!(offset1, 0);
assert_eq!(cap1, 32); assert_eq!(pool.buffer.len(), 32);
pool.free(offset1, cap1);
let (offset2, cap2) = pool.alloc(10).unwrap();
assert_eq!(offset2, 32); assert_eq!(cap2, 16);
let (offset3, cap3) = pool.alloc(32).unwrap();
assert_eq!(offset3, 0); assert_eq!(cap3, 32);
let (offset4, cap4) = pool.alloc(600).unwrap();
assert!(offset4 > 0);
assert_eq!(cap4, 608); }
#[test]
fn test_alloc_too_large() {
let mut pool = NeighborPool::new();
let res = pool.alloc(100_000); assert!(matches!(res, Err(GraphError::NeighborError)));
}
#[test]
fn test_decode_layer() {
let l0 = NeighborPool::encode_neighbors(&[1, 2]);
let l1 = NeighborPool::encode_neighbors(&[3]);
let l2 = NeighborPool::encode_neighbors(&[4, 5, 6]);
let mut blob = Vec::new();
blob.extend_from_slice(&l0);
blob.extend_from_slice(&l1);
blob.extend_from_slice(&l2);
assert_eq!(NeighborPool::decode_layer(&blob, 0), vec![1, 2]);
assert_eq!(NeighborPool::decode_layer(&blob, 1), vec![3]);
assert_eq!(NeighborPool::decode_layer(&blob, 2), vec![4, 5, 6]);
assert_eq!(NeighborPool::decode_layer(&blob, 3), Vec::<u32>::new()); }
#[test]
fn test_iter_layer() {
let l0 = NeighborPool::encode_neighbors(&[1, 2]);
let l1 = NeighborPool::encode_neighbors(&[3]);
let l2 = NeighborPool::encode_neighbors(&[4, 5, 6]);
let mut blob = Vec::new();
blob.extend_from_slice(&l0);
blob.extend_from_slice(&l1);
blob.extend_from_slice(&l2);
let iter0: Vec<u32> = NeighborPool::iter_layer(&blob, 0).collect();
assert_eq!(iter0, vec![1, 2]);
let iter1: Vec<u32> = NeighborPool::iter_layer(&blob, 1).collect();
assert_eq!(iter1, vec![3]);
let iter2: Vec<u32> = NeighborPool::iter_layer(&blob, 2).collect();
assert_eq!(iter2, vec![4, 5, 6]);
}
}