use crate::fixed::atomic::AtomicFixedVec;
use crate::fixed::traits::Storable;
use crate::fixed::{BitWidth, Error};
use num_traits::ToPrimitive;
use std::marker::PhantomData;
use std::sync::atomic::AtomicU64;
#[derive(Debug, Clone)]
pub struct AtomicFixedVecBuilder<T: Storable<u64>> {
bit_width_strategy: BitWidth,
_phantom: PhantomData<T>,
}
impl<T> Default for AtomicFixedVecBuilder<T>
where
T: Storable<u64>,
{
fn default() -> Self {
Self {
bit_width_strategy: BitWidth::default(),
_phantom: PhantomData,
}
}
}
impl<T> AtomicFixedVecBuilder<T>
where
T: Storable<u64> + Copy + ToPrimitive,
{
pub fn new() -> Self {
Self::default()
}
pub fn bit_width(mut self, strategy: BitWidth) -> Self {
self.bit_width_strategy = strategy;
self
}
pub fn build(self, input: &[T]) -> Result<AtomicFixedVec<T>, Error> {
let bits_per_word = u64::BITS as usize;
let final_bit_width = match self.bit_width_strategy {
BitWidth::Explicit(n) => n,
_ => {
let max_val: u64 = input
.iter()
.map(|&val| T::into_word(val))
.max()
.unwrap_or(0);
let min_bits = if max_val == 0 {
1
} else {
bits_per_word - max_val.leading_zeros() as usize
};
match self.bit_width_strategy {
BitWidth::Minimal => min_bits,
BitWidth::PowerOfTwo => min_bits.next_power_of_two().min(bits_per_word),
BitWidth::Explicit(_) => unreachable!(),
}
}
};
if !input.is_empty() && final_bit_width == 0 {
return Err(Error::InvalidParameters(
"bit_width cannot be zero for a non-empty vector".to_string(),
));
}
let mut atomic_vec = AtomicFixedVec::new(final_bit_width, input.len())?;
if input.is_empty() {
return Ok(atomic_vec);
}
let limit = if final_bit_width < bits_per_word {
1u64 << final_bit_width
} else {
u64::MAX
};
for (i, &value_t) in input.iter().enumerate() {
let value_w = T::into_word(value_t);
if final_bit_width < bits_per_word && value_w >= limit {
return Err(Error::ValueTooLarge {
value: value_w as u128,
index: i,
bit_width: final_bit_width,
});
}
unsafe {
set_unchecked_non_atomic(
&mut atomic_vec.storage,
i,
value_w,
final_bit_width,
atomic_vec.mask,
);
}
}
Ok(atomic_vec)
}
}
unsafe fn set_unchecked_non_atomic(
limbs: &mut [AtomicU64],
index: usize,
value: u64,
bit_width: usize,
mask: u64,
) {
let bits_per_word = u64::BITS as usize;
let bit_pos = index * bit_width;
let word_index = bit_pos / bits_per_word;
let bit_offset = bit_pos % bits_per_word;
if bit_offset + bit_width <= bits_per_word {
let word = unsafe { limbs.get_unchecked_mut(word_index).get_mut() };
*word &= !(mask << bit_offset);
*word |= value << bit_offset;
} else {
let low_word_ptr = unsafe { limbs.as_mut_ptr().add(word_index) };
let high_word_ptr = unsafe { limbs.as_mut_ptr().add(word_index + 1) };
let low_word = unsafe { (*low_word_ptr).get_mut() };
let high_word = unsafe { (*high_word_ptr).get_mut() };
*low_word &= !(u64::MAX << bit_offset);
*low_word |= value << bit_offset;
let bits_in_high = (bit_offset + bit_width) - bits_per_word;
let high_mask = (1u64 << bits_in_high).wrapping_sub(1);
*high_word &= !high_mask;
*high_word |= value >> (bits_per_word - bit_offset);
}
}