use crate::error::{Result, ZiporaError};
use std::fmt;
#[derive(Clone)]
pub struct UintVecMin0 {
data: Vec<u8>,
bits: usize,
mask: usize,
size: usize,
}
impl UintVecMin0 {
pub fn new(num: usize, max_val: usize) -> Self {
let bits = Self::compute_uintbits(max_val);
let mut vec = Self {
data: Vec::new(),
bits: 0,
mask: 0,
size: 0,
};
vec.resize_with_uintbits(num, bits);
vec
}
pub fn new_empty() -> Self {
Self {
data: Vec::new(),
bits: 0,
mask: 0,
size: 0,
}
}
#[inline]
pub fn compute_uintbits(value: usize) -> usize {
if value == 0 {
0
} else {
64 - value.leading_zeros() as usize
}
}
#[inline]
pub fn compute_mem_size(bits: usize, num: usize) -> usize {
assert!(bits <= 64, "bits must be <= 64");
let using_size = (bits * num + 7) / 8;
let touch_size = using_size + std::mem::size_of::<u64>() - 1;
(touch_size + 15) & !15 }
#[inline]
pub fn compute_mem_size_by_max_val(max_val: usize, num: usize) -> usize {
let bits = Self::compute_uintbits(max_val);
Self::compute_mem_size(bits, num)
}
#[inline]
pub fn get(&self, idx: usize) -> usize {
assert!(idx < self.size, "Index {} out of bounds {}", idx, self.size);
assert!(self.bits <= 58, "Use BigUintVecMin0 for >58 bits");
self.fast_get_internal(idx)
}
#[inline]
pub unsafe fn get_unchecked(&self, idx: usize) -> usize {
debug_assert!(idx < self.size, "Index {} out of bounds {}", idx, self.size);
debug_assert!(self.bits <= 58);
self.fast_get_internal(idx)
}
#[inline]
pub fn get2(&self, idx: usize) -> [usize; 2] {
assert!(idx + 1 < self.size, "Index {} out of bounds for get2", idx);
assert!(self.bits <= 58, "Use BigUintVecMin0 for >58 bits");
[self.fast_get_internal(idx), self.fast_get_internal(idx + 1)]
}
#[inline]
pub unsafe fn get2_unchecked(&self, idx: usize) -> [usize; 2] {
debug_assert!(idx + 1 < self.size);
debug_assert!(self.bits <= 58);
[self.fast_get_internal(idx), self.fast_get_internal(idx + 1)]
}
#[inline]
pub fn fast_get(data: &[u8], bits: usize, mask: usize, idx: usize) -> Result<usize> {
assert!(bits <= 58, "fast_get requires bits <= 58");
let bit_idx = bits * idx;
let byte_idx = bit_idx / 8;
let required_size = byte_idx + std::mem::size_of::<usize>();
if required_size > data.len() {
return Err(ZiporaError::out_of_bounds(idx, data.len()));
}
let val = unsafe {
std::ptr::read_unaligned(data.as_ptr().add(byte_idx) as *const usize)
};
Ok((val >> (bit_idx % 8)) & mask)
}
#[inline]
fn fast_get_internal(&self, idx: usize) -> usize {
let bit_idx = self.bits * idx;
let byte_idx = bit_idx / 8;
let val = unsafe {
std::ptr::read_unaligned(self.data.as_ptr().add(byte_idx) as *const usize)
};
(val >> (bit_idx % 8)) & self.mask
}
#[inline]
pub fn set(&mut self, idx: usize, val: usize) {
assert!(idx < self.size, "Index {} out of bounds {}", idx, self.size);
assert!(val <= self.mask, "Value {} exceeds max {}", val, self.mask);
assert!(self.bits <= 64, "Bits must be <= 64");
self.set_wire(idx, val);
}
fn set_wire(&mut self, idx: usize, val: usize) {
let bits = self.bits;
let bit_idx = bits * idx;
self.set_uint_bits(bit_idx, bits, val);
}
fn set_uint_bits(&mut self, bit_pos: usize, bits: usize, val: usize) {
if bits == 0 {
return;
}
let byte_idx = bit_pos / 8;
let bit_offset = bit_pos % 8;
let end_bit = bit_offset + bits;
if end_bit <= 64 {
let mask = if bits == 64 {
!0usize
} else {
(1usize << bits) - 1
};
let shifted_val = val << bit_offset;
let shifted_mask = mask << bit_offset;
unsafe {
let ptr = self.data.as_mut_ptr().add(byte_idx) as *mut usize;
let current = std::ptr::read_unaligned(ptr);
let new_val = (current & !shifted_mask) | shifted_val;
std::ptr::write_unaligned(ptr, new_val);
}
} else {
let mut remaining_bits = bits;
let mut remaining_val = val;
let mut curr_byte = byte_idx;
let mut curr_bit_offset = bit_offset;
while remaining_bits > 0 {
let bits_in_byte = (8 - curr_bit_offset).min(remaining_bits);
let byte_mask = ((1u8 << bits_in_byte) - 1) << curr_bit_offset;
let byte_val = ((remaining_val & ((1 << bits_in_byte) - 1)) as u8) << curr_bit_offset;
self.data[curr_byte] = (self.data[curr_byte] & !byte_mask) | byte_val;
remaining_val >>= bits_in_byte;
remaining_bits -= bits_in_byte;
curr_byte += 1;
curr_bit_offset = 0;
}
}
}
pub fn build_from_usize(src: &[usize]) -> (Self, usize) {
if src.is_empty() {
return (Self::new_empty(), 0);
}
let &min_val = src.iter().min().expect("non-empty input");
let &max_val = src.iter().max().expect("non-empty input");
let wire_max = max_val - min_val;
let mut vec = Self::new(src.len(), wire_max);
for (i, &val) in src.iter().enumerate() {
vec.set(i, val - min_val);
}
(vec, min_val)
}
pub fn build_from_i32(src: &[i32]) -> (Self, i32) {
if src.is_empty() {
return (Self::new_empty(), 0);
}
let &min_val = src.iter().min().expect("non-empty input");
let &max_val = src.iter().max().expect("non-empty input");
let wire_max = (max_val - min_val) as usize;
let mut vec = Self::new(src.len(), wire_max);
for (i, &val) in src.iter().enumerate() {
vec.set(i, (val - min_val) as usize);
}
(vec, min_val)
}
pub fn build_from_u32(src: &[u32]) -> (Self, u32) {
if src.is_empty() {
return (Self::new_empty(), 0);
}
let &min_val = src.iter().min().expect("non-empty input");
let &max_val = src.iter().max().expect("non-empty input");
let wire_max = (max_val - min_val) as usize;
let mut vec = Self::new(src.len(), wire_max);
for (i, &val) in src.iter().enumerate() {
vec.set(i, (val - min_val) as usize);
}
(vec, min_val)
}
pub fn push_back(&mut self, val: usize) {
if Self::compute_mem_size(self.bits, self.size + 1) <= self.data.len() && val <= self.mask {
self.set_wire(self.size, val);
self.size += 1;
} else {
self.push_back_slow_path(val);
}
}
fn push_back_slow_path(&mut self, val: usize) {
let new_bits = Self::compute_uintbits(val.max(self.mask));
if new_bits > self.bits {
let old_size = self.size;
let mut new_vec = Self::new(old_size + 1, val);
for i in 0..old_size {
new_vec.set(i, self.get(i));
}
new_vec.set(old_size, val);
*self = new_vec;
} else {
self.resize(self.size + 1);
self.set(self.size - 1, val);
}
}
#[inline]
pub fn back(&self) -> usize {
assert!(self.size > 0, "Vector is empty");
self.get(self.size - 1)
}
pub fn clear(&mut self) {
self.data.clear();
self.bits = 0;
self.mask = 0;
self.size = 0;
}
pub fn resize(&mut self, new_size: usize) {
let new_mem_size = Self::compute_mem_size(self.bits, new_size);
if new_mem_size > self.data.len() {
self.data.resize(new_mem_size, 0);
}
self.size = new_size;
}
pub fn resize_with_uintbits(&mut self, num: usize, bits: usize) {
assert!(bits <= 64, "Bits must be <= 64");
self.bits = bits;
self.mask = if bits == 0 { 0 } else { (1usize << bits) - 1 };
self.size = num;
let mem_size = Self::compute_mem_size(bits, num);
self.data.resize(mem_size, 0);
}
pub fn resize_with_wire_max_val(&mut self, num: usize, max_val: usize) {
let bits = Self::compute_uintbits(max_val);
self.resize_with_uintbits(num, bits);
}
pub fn shrink_to_fit(&mut self) {
let needed_size = Self::compute_mem_size(self.bits, self.size);
self.data.truncate(needed_size);
self.data.shrink_to_fit();
}
pub unsafe fn risk_set_data(&mut self, data: *mut u8, num: usize, bits: usize) {
assert!(bits <= 58, "bits={} is too large (max_allowed=58)", bits);
let mem_size = Self::compute_mem_size(bits, num);
unsafe {
self.data = Vec::from_raw_parts(data, mem_size, mem_size);
}
self.bits = bits;
self.mask = if bits == 0 { 0 } else { (1usize << bits) - 1 };
self.size = num;
}
#[inline]
pub fn data(&self) -> &[u8] {
&self.data
}
#[inline]
pub fn uintbits(&self) -> usize {
self.bits
}
#[inline]
pub fn uintmask(&self) -> usize {
self.mask
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
#[inline]
pub fn mem_size(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
impl Default for UintVecMin0 {
fn default() -> Self {
Self::new_empty()
}
}
impl std::ops::Index<usize> for UintVecMin0 {
type Output = usize;
fn index(&self, idx: usize) -> &Self::Output {
panic!("Use get() method instead of indexing for UintVecMin0");
}
}
impl fmt::Debug for UintVecMin0 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UintVecMin0")
.field("size", &self.size)
.field("bits", &self.bits)
.field("mask", &format_args!("{:#x}", self.mask))
.field("mem_size", &self.data.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_uintbits() {
assert_eq!(UintVecMin0::compute_uintbits(0), 0);
assert_eq!(UintVecMin0::compute_uintbits(1), 1);
assert_eq!(UintVecMin0::compute_uintbits(2), 2);
assert_eq!(UintVecMin0::compute_uintbits(3), 2);
assert_eq!(UintVecMin0::compute_uintbits(7), 3);
assert_eq!(UintVecMin0::compute_uintbits(8), 4);
assert_eq!(UintVecMin0::compute_uintbits(15), 4);
assert_eq!(UintVecMin0::compute_uintbits(255), 8);
assert_eq!(UintVecMin0::compute_uintbits(256), 9);
assert_eq!(UintVecMin0::compute_uintbits(65535), 16);
}
#[test]
fn test_compute_mem_size() {
assert_eq!(UintVecMin0::compute_mem_size(0, 100), 16);
assert!(UintVecMin0::compute_mem_size(1, 100) >= 13);
let size_8bit = UintVecMin0::compute_mem_size(8, 100);
assert!(size_8bit >= 100);
assert_eq!(size_8bit % 16, 0); }
#[test]
fn test_new_and_basic_ops() {
let vec = UintVecMin0::new(10, 255);
assert_eq!(vec.size(), 10);
assert_eq!(vec.uintbits(), 8);
assert_eq!(vec.uintmask(), 255);
}
#[test]
fn test_set_and_get() {
let mut vec = UintVecMin0::new(100, 255);
vec.set(0, 42);
vec.set(50, 128);
vec.set(99, 255);
assert_eq!(vec.get(0), 42);
assert_eq!(vec.get(50), 128);
assert_eq!(vec.get(99), 255);
}
#[test]
fn test_round_trip_various_bit_widths() {
let mut vec1 = UintVecMin0::new(64, 1);
for i in 0..64 {
vec1.set(i, i % 2);
}
for i in 0..64 {
assert_eq!(vec1.get(i), i % 2);
}
let mut vec4 = UintVecMin0::new(100, 15);
for i in 0..100 {
vec4.set(i, i % 16);
}
for i in 0..100 {
assert_eq!(vec4.get(i), i % 16);
}
let mut vec16 = UintVecMin0::new(1000, 65535);
for i in 0..1000 {
vec16.set(i, i % 65536);
}
for i in 0..1000 {
assert_eq!(vec16.get(i), i % 65536);
}
}
#[test]
fn test_edge_case_zero_bits() {
let vec = UintVecMin0::new(100, 0);
assert_eq!(vec.uintbits(), 0);
assert_eq!(vec.uintmask(), 0);
for i in 0..100 {
assert_eq!(vec.get(i), 0);
}
}
#[test]
fn test_edge_case_one_bit() {
let mut vec = UintVecMin0::new(128, 1);
assert_eq!(vec.uintbits(), 1);
for i in 0..128 {
vec.set(i, i % 2);
}
for i in 0..128 {
assert_eq!(vec.get(i), i % 2);
}
}
#[test]
fn test_58_bits_max_fast_path() {
let max_val = (1usize << 58) - 1;
let mut vec = UintVecMin0::new(10, max_val);
assert_eq!(vec.uintbits(), 58);
vec.set(0, max_val);
vec.set(5, max_val / 2);
assert_eq!(vec.get(0), max_val);
assert_eq!(vec.get(5), max_val / 2);
}
#[test]
fn test_get2() {
let mut vec = UintVecMin0::new(100, 255);
vec.set(10, 42);
vec.set(11, 43);
let vals = vec.get2(10);
assert_eq!(vals, [42, 43]);
}
#[test]
fn test_build_from_usize() {
let data = vec![100, 105, 103, 108, 101];
let (vec, min_val) = UintVecMin0::build_from_usize(&data);
assert_eq!(min_val, 100);
assert_eq!(vec.size(), 5);
assert_eq!(vec.uintbits(), 4);
assert_eq!(vec.get(0), 0); assert_eq!(vec.get(1), 5); assert_eq!(vec.get(2), 3); assert_eq!(vec.get(3), 8); assert_eq!(vec.get(4), 1); }
#[test]
fn test_push_back_fast_path() {
let mut vec = UintVecMin0::new(10, 255);
vec.resize(0);
for i in 0..10 {
vec.push_back(i * 10);
}
assert_eq!(vec.size(), 10);
for i in 0..10 {
assert_eq!(vec.get(i), i * 10);
}
}
#[test]
fn test_push_back_slow_path_capacity() {
let mut vec = UintVecMin0::new(2, 10);
vec.resize(0);
for i in 0..5 {
vec.push_back(i);
}
assert_eq!(vec.size(), 5);
for i in 0..5 {
assert_eq!(vec.get(i), i);
}
}
#[test]
fn test_push_back_slow_path_bit_expansion() {
let mut vec = UintVecMin0::new(10, 15);
vec.resize(0);
for i in 0..5 {
vec.push_back(i);
}
assert_eq!(vec.uintbits(), 4);
vec.push_back(255);
assert_eq!(vec.uintbits(), 8);
for i in 0..5 {
assert_eq!(vec.get(i), i);
}
assert_eq!(vec.get(5), 255);
}
#[test]
fn test_back() {
let mut vec = UintVecMin0::new(10, 255);
vec.set(9, 123);
assert_eq!(vec.back(), 123);
}
#[test]
fn test_clear() {
let mut vec = UintVecMin0::new(100, 255);
for i in 0..100 {
vec.set(i, i);
}
vec.clear();
assert_eq!(vec.size(), 0);
assert_eq!(vec.uintbits(), 0);
assert!(vec.is_empty());
}
#[test]
fn test_resize() {
let mut vec = UintVecMin0::new(10, 255);
for i in 0..10 {
vec.set(i, i);
}
vec.resize(20);
assert_eq!(vec.size(), 20);
for i in 0..10 {
assert_eq!(vec.get(i), i);
}
}
#[test]
fn test_shrink_to_fit() {
let mut vec = UintVecMin0::new(1000, 255);
vec.resize(10);
let before = vec.mem_size();
vec.shrink_to_fit();
let after = vec.mem_size();
assert!(after <= before);
assert_eq!(vec.size(), 10);
}
#[test]
fn test_memory_efficiency() {
let vec = UintVecMin0::new(1000, 255);
let mem = vec.mem_size();
assert!(mem < 1000 * std::mem::size_of::<usize>() / 4);
assert!(mem >= 1000); }
#[test]
#[should_panic(expected = "out of bounds")]
fn test_get_out_of_bounds() {
let vec = UintVecMin0::new(10, 255);
vec.get(10);
}
#[test]
#[should_panic(expected = "exceeds max")]
fn test_set_value_too_large() {
let mut vec = UintVecMin0::new(10, 255);
vec.set(0, 256);
}
#[test]
#[should_panic(expected = "Vector is empty")]
fn test_back_empty() {
let vec = UintVecMin0::new_empty();
vec.back();
}
#[test]
fn test_fast_get_static() {
let mut vec = UintVecMin0::new(100, 255);
for i in 0..100 {
vec.set(i, i);
}
for i in 0..100 {
let val = UintVecMin0::fast_get(vec.data(), vec.uintbits(), vec.uintmask(), i)
.expect("fast_get should succeed for valid index");
assert_eq!(val, i);
}
}
}