pub(crate) type PathSegmentInner = [u8; 33];
const BIT_MASK: [u8; 8] = [128, 64, 32, 16, 8, 4, 2, 1];
#[derive(PartialEq)]
pub enum Direction {
Left,
Right,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Path<T>(pub T);
#[derive(Clone, Copy, Debug)]
pub struct PathSegment<T>(pub T);
pub trait PathUtils {
fn direction(&self, index: usize) -> Direction;
fn split_point<S: BitLength + PathUtils>(
&self,
segment_start: usize,
segment: S,
) -> Option<usize>;
}
impl<T: AsMut<[u8]>> PathSegment<T> {
pub fn copy<A: BitLength + PathUtils>(&mut self, src: A, start: usize, end: usize) {
if start == end {
return;
}
assert!(start < end, "start {} must be less than end {}", start, end);
let bit_len = end - start;
self.set_len(bit_len);
let (src, src_start, start_bit) = (src.inner(), start / 8, start % 8);
let (dst_slice, dst_end_idx, dst_end_bit) =
(self.as_mut_inner(), (bit_len - 1) / 8, bit_len % 8);
if start_bit == 0 {
dst_slice[..dst_end_idx + 1]
.copy_from_slice(&src[src_start..src_start + dst_end_idx + 1]);
} else {
for (i, j) in (src_start..src_start + dst_end_idx + 1).zip(0..) {
dst_slice[j] = src[i] << start_bit;
if i + 1 < src.len() {
dst_slice[j] |= src[i + 1] >> (8 - start_bit);
}
}
}
if dst_end_bit != 0 {
dst_slice[dst_end_idx] &= 0xFF << (8 - dst_end_bit);
}
}
pub fn extend<A: BitLength + PathUtils>(&mut self, other: A) {
let inner = other.inner();
let mut remaining_bits = other.bit_len();
let mut index = 0;
while remaining_bits >= 8 {
self.extend_from_byte(inner[index], 8);
remaining_bits -= 8;
index += 1;
}
if remaining_bits > 0 {
self.extend_from_byte(inner[index], remaining_bits as u8);
}
}
pub fn extend_from_byte(&mut self, bits: u8, len: u8) {
assert!(len <= 8, "invalid bit length");
let raw = self.as_mut();
let current_bit_len = raw[0];
let new_bit_len = current_bit_len + len;
raw[0] = new_bit_len;
let inner = &mut raw[1..];
let unfilled_bits = 8 - (current_bit_len % 8);
let mut index = (current_bit_len / 8) as usize;
if len <= unfilled_bits {
inner[index] |= bits >> (8 - len) << (unfilled_bits - len);
} else {
inner[index] |= bits >> (8 - unfilled_bits);
index += 1;
assert!(inner.len() > index, "could not fit all bits");
inner[index] = bits << unfilled_bits;
}
let last_byte_bits = new_bit_len % 8;
if last_byte_bits != 0 {
inner[index] &= 0xFF << (8 - last_byte_bits);
}
}
#[inline(always)]
pub fn set_len(&mut self, len: usize) {
assert!(len <= 255, "PathSegment length must be <= 255");
self.0.as_mut()[0] = len as u8;
}
#[inline(always)]
pub fn as_mut_inner(&mut self) -> &mut [u8] {
&mut self.0.as_mut()[1..]
}
}
impl<T: BitLength + AsRef<[u8]>> PathUtils for T {
#[inline(always)]
fn direction(&self, index: usize) -> Direction {
if self.inner()[index / 8] & BIT_MASK[index % 8] != 0 {
return Direction::Right;
}
Direction::Left
}
fn split_point<S: BitLength + PathUtils>(&self, start: usize, b: S) -> Option<usize> {
let max_bit_len = core::cmp::min(self.bit_len(), b.bit_len());
let (src_start_byte, src_start_bit, seg_end_byte) =
(start / 8, start % 8, max_bit_len.div_ceil(8));
let mut count = 0;
if src_start_bit == 0 {
let (a, b) = (&self.inner()[src_start_byte..], &b.inner()[..seg_end_byte]);
for (a_byte, b_byte) in a.iter().zip(b.iter()) {
if *a_byte != *b_byte {
count += (a_byte ^ b_byte).leading_zeros();
break;
}
count += 8;
}
} else {
let (a, b) = (&self.inner()[src_start_byte..], &b.inner()[..seg_end_byte]);
for (i, b_byte) in b.iter().enumerate() {
let mut a_byte = a[i] << src_start_bit;
if i < a.len() {
a_byte |= a[i + 1] >> (8 - src_start_bit);
}
if a_byte != *b_byte {
count += (a_byte ^ b_byte).leading_zeros();
break;
}
count += 8;
}
}
let count = core::cmp::min(count as usize, max_bit_len);
if count == max_bit_len {
None
} else {
Some(count)
}
}
}
impl PathSegment<[u8; 33]> {
#[inline(always)]
pub fn from_path<A: BitLength + PathUtils>(src: A, from: usize, to: usize) -> Self {
let mut a = PathSegment([0; 33]);
a.copy(src, from, to);
a
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Path<T> {
#[inline(always)]
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<T: AsMut<[u8]>> AsMut<[u8]> for Path<T> {
#[inline(always)]
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for PathSegment<T> {
#[inline(always)]
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<T: AsMut<[u8]>> AsMut<[u8]> for PathSegment<T> {
#[inline(always)]
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}
impl<T: AsRef<[u8]>> BitLength for Path<T> {
#[inline(always)]
fn bit_len(&self) -> usize {
256
}
#[inline(always)]
fn inner(&self) -> &[u8] {
self.0.as_ref()
}
#[inline(always)]
fn as_bytes(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<T: AsRef<[u8]>> BitLength for PathSegment<T> {
#[inline(always)]
fn bit_len(&self) -> usize {
self.0.as_ref()[0] as usize
}
#[inline(always)]
fn inner(&self) -> &[u8] {
&self.0.as_ref()[1..]
}
#[inline(always)]
fn as_bytes(&self) -> &[u8] {
let byte_len = self.bit_len().div_ceil(8);
&self.0.as_ref()[..(byte_len + 1)]
}
}
#[cfg(feature = "std")]
impl<T: PartialOrd> PartialOrd for Path<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
#[cfg(feature = "std")]
impl<T: Ord> Ord for Path<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
pub trait BitLength {
fn bit_len(&self) -> usize;
fn inner(&self) -> &[u8];
fn as_bytes(&self) -> &[u8];
}
#[cfg(test)]
mod tests {
use crate::path::{BitLength, Direction, PathSegment, PathUtils};
use core::fmt::Display;
impl<T: AsRef<[u8]>> Display for PathSegment<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.bit_len() {
if self.direction(i) == Direction::Right {
write!(f, "1")?;
} else {
write!(f, "0")?;
}
}
Ok(())
}
}
#[test]
fn test_extend() {
let mut parent = PathSegment([0u8; 33]);
parent.set_len(5);
let mut child = PathSegment([0u8; 33]);
child.set_len(10);
child.0[1] = 0b1111_1010;
parent.extend(child);
assert_eq!(parent.to_string(), "000001111101000");
}
#[test]
fn test_extend_from_byte() {
let mut segment = PathSegment([0u8; 33]);
segment.set_len(2);
let inner = segment.as_mut_inner();
inner[0] = 0b1100_0000;
segment.extend_from_byte(0b1000_1000, 3);
assert_eq!(segment.to_string(), "11100");
segment.extend_from_byte(0b1111_1111, 8);
assert_eq!(segment.to_string(), "1110011111111");
segment.extend_from_byte(0b0011_1111, 2);
assert_eq!(segment.to_string(), "111001111111100");
segment.extend_from_byte(0b1111_1111, 8);
assert_eq!(segment.to_string(), "11100111111110011111111");
segment.extend_from_byte(0b0000_1111, 4);
assert_eq!(segment.to_string(), "111001111111100111111110000");
segment.set_len(segment.bit_len() + 2);
assert_eq!(
segment.to_string(),
"11100111111110011111111000000",
"trailing bits must be cleared"
);
}
}