use super::bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask};
use crate::base::{
proof::ProofError,
scalar::{Scalar, ScalarExt},
standard_serializations::limbs::{deserialize_to_limbs, serialize_limbs},
};
use ark_std::iterable::Iterable;
use bit_iter::BitIter;
use bnum::types::U256;
use core::{
convert::Into,
ops::{Shl, Shr},
};
use itertools::Itertools;
#[cfg(feature = "rayon")]
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct BitDistribution {
#[serde(
serialize_with = "serialize_limbs",
deserialize_with = "deserialize_to_limbs"
)]
pub(crate) vary_mask: [u64; 4],
#[serde(
serialize_with = "serialize_limbs",
deserialize_with = "deserialize_to_limbs"
)]
pub(crate) leading_bit_mask: [u64; 4],
}
#[derive(Debug)]
pub enum BitDistributionError {
NoLeadBit,
Verification,
}
impl From<BitDistributionError> for ProofError {
fn from(err: BitDistributionError) -> Self {
match err {
BitDistributionError::NoLeadBit => {
panic!("No lead bit available despite variable lead bit.")
}
BitDistributionError::Verification => ProofError::VerificationError {
error: "invalid bit_decomposition",
},
}
}
}
impl BitDistribution {
pub fn new<S: Scalar, T: Into<S> + Clone + Sync>(data: &[T]) -> Self {
#[cfg(feature = "rayon")]
let (sign_mask, inverse_sign_mask) = data
.par_iter()
.map(|item| {
let scalar: S = item.clone().into();
let bit_mask = make_bit_mask(scalar);
let adjusted_mask = if is_bit_mask_negative_representation(bit_mask) {
bit_mask ^ U256::MAX.shr(1)
} else {
bit_mask
};
(adjusted_mask, !adjusted_mask)
})
.reduce(
|| (U256::MAX, U256::MAX),
|(acc_sign, acc_inverse), (mask, inverse_mask)| {
(acc_sign & mask, acc_inverse & inverse_mask)
},
);
let bit_masks = data.iter().cloned().map(Into::<S>::into).map(make_bit_mask);
#[cfg(not(feature = "rayon"))]
let (sign_mask, inverse_sign_mask) = bit_masks
.clone()
.map(|bit_mask| {
let adjusted_mask = if is_bit_mask_negative_representation(bit_mask) {
bit_mask ^ U256::MAX.shr(1)
} else {
bit_mask
};
(adjusted_mask, !adjusted_mask)
})
.reduce(|acc, item| (acc.0 & item.0, acc.1 & item.1))
.unwrap_or((U256::MAX, U256::MAX));
let vary_mask_bit = U256::from(
!bit_masks
.map(is_bit_mask_negative_representation)
.all_equal(),
) << 255;
let vary_mask: U256 = !(sign_mask | inverse_sign_mask) | vary_mask_bit;
Self {
leading_bit_mask: sign_mask.into(),
vary_mask: vary_mask.into(),
}
}
pub fn vary_mask(&self) -> U256 {
U256::from(self.vary_mask)
}
pub fn leading_bit_mask(&self) -> U256 {
U256::from(self.leading_bit_mask) | (U256::ONE.shl(255))
}
pub fn leading_bit_inverse_mask(&self) -> U256 {
(!self.vary_mask() ^ self.leading_bit_mask()) & U256::MAX.shr(1)
}
pub fn num_varying_bits(&self) -> usize {
self.vary_mask().count_ones() as usize
}
pub fn try_constant_leading_bit_eval<S: ScalarExt>(&self, chi_eval: S) -> Option<S> {
if U256::from(self.vary_mask) & (U256::ONE.shl(255)) != U256::ZERO {
None
} else if U256::from(self.leading_bit_mask) & U256::ONE.shl(255) == U256::ZERO {
Some(S::ZERO)
} else {
Some(chi_eval)
}
}
pub fn is_valid(&self) -> bool {
(self.vary_mask() & self.leading_bit_mask()) & U256::MAX.shr(1) == U256::ZERO
}
pub fn is_within_acceptable_range(&self) -> bool {
(self.leading_bit_inverse_mask() >> 128) == (U256::MAX.shr(129))
}
#[expect(clippy::missing_panics_doc)]
pub fn vary_mask_iter(&self) -> impl Iterator<Item = u8> + '_ {
(0..4).flat_map(|i| {
BitIter::from(self.vary_mask[i])
.iter()
.map(move |pos| u8::try_from(i * 64 + pos).expect("index greater than 255"))
})
}
}