use std::iter::{Extend, FromIterator};
use std::mem;
use ff::PrimeField;
const LEN_USIZE: u32 = mem::size_of::<usize>() as u32;
const LOG2_LEN_USIZE: u32 = LEN_USIZE.ilog2();
#[derive(Debug, Clone)]
pub struct Mask {
len: u32,
set: Vec<usize>,
}
impl Mask {
pub const fn new() -> Self {
Self {
len: 0u32,
set: Vec::new(),
}
}
pub const fn len(&self) -> usize {
self.len as usize
}
pub fn push(&mut self, val: bool) {
let global_insert_pos = self.len;
self.len += 1;
if fast_rem::is_divisible(global_insert_pos, LEN_USIZE) {
self.make_room_for_new_value_slow();
}
if val {
let usize_insert_pos =
fast_rem::fast_mod(global_insert_pos, LEN_USIZE);
self.set_last_index_slow(usize_insert_pos);
}
}
#[cold]
fn make_room_for_new_value_slow(&mut self) {
self.set.push(0);
}
#[cold]
fn set_last_index_slow(&mut self, insert_pos: u32) {
*self.set.last_mut().unwrap() |= 1 << insert_pos;
}
pub fn is_set(&self, global_index: usize) -> bool {
let vec_index = global_index >> LOG2_LEN_USIZE;
let usize_index = fast_rem::fast_mod(global_index as u32, LEN_USIZE);
self.set[vec_index] & (1 << usize_index) != 0
}
pub fn set(&mut self, global_index: usize, to: bool) {
let vec_index = global_index >> LOG2_LEN_USIZE;
let usize_index = fast_rem::fast_mod(global_index as u32, LEN_USIZE);
self.set[vec_index] |= (to as usize) << usize_index;
}
}
impl FromIterator<bool> for Mask {
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = bool>,
{
let mut mask = Mask::new();
mask.extend(iter);
mask
}
}
impl Extend<bool> for Mask {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = bool>,
{
for value in iter {
self.push(value);
}
}
}
pub fn dense_assignment_from_masks<S: PrimeField>(
inputs: &[S],
aux: &[S],
input_mask: &Mask,
aux_mask: &Mask,
) -> Vec<S> {
let mut out = Vec::new();
for (i, s) in inputs.iter().enumerate() {
if i < input_mask.len() && input_mask.is_set(i) {
out.push(*s);
}
}
for (i, s) in aux.iter().enumerate() {
if i < aux_mask.len() && aux_mask.is_set(i) {
out.push(*s);
}
}
out
}
mod fast_rem {
#[inline]
pub const fn compute_c(d: u32) -> u64 {
(u64::MAX / (d as u64)) + 1
}
#[inline]
pub const fn fast_mod_with_c(n: u32, d: u32, c: u64) -> u32 {
let lowbits: u64 = c.wrapping_mul(n as u64);
let result: u128 = (lowbits as u128) * (d as u128);
(result >> 64) as u32
}
#[inline]
pub const fn fast_mod(n: u32, d: u32) -> u32 {
let c = compute_c(d);
fast_mod_with_c(n, d, c)
}
#[inline]
pub const fn is_divisible_with_c(n: u32, c: u64) -> bool {
let lowbits: u64 = c.wrapping_mul(n as u64);
lowbits <= c.wrapping_sub(1)
}
#[inline]
pub const fn is_divisible(n: u32, d: u32) -> bool {
let c = compute_c(d);
is_divisible_with_c(n, c)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_density_mask() {
let mut mask = Mask::new();
let decomposed_bits: u32 = 0b11111011010101010;
for i in 0..17 {
let value = ((1 << i) & decomposed_bits) != 0;
mask.push(value);
println!("value at {i} is {value}");
}
assert_eq!(mask.len, 17);
assert_eq!(mask.set.len(), 3);
for i in 0..17 {
let expected_value = ((1 << i) & decomposed_bits) != 0;
assert_eq!(
expected_value,
mask.is_set(i),
"failed at index {i}, current set has {:?}",
mask.set
);
}
}
}