use crate::base::{bit::bit_mask_utils::make_bit_mask, scalar::Scalar};
use ark_std::iterable::Iterable;
use bit_iter::BitIter;
use bnum::types::U256;
use core::{convert::Into, ops::Shl};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ByteDistribution {
vary_mask: u32,
constant_mask: [u64; 4],
}
impl ByteDistribution {
pub fn new<S: Scalar, T: Into<S> + Copy>(data: &[T]) -> Self {
let bit_masks = data.iter().copied().map(Into::<S>::into).map(make_bit_mask);
let (vary_mask, constant_mask) = (0u8..32)
.map(|u| {
let shifted_max_byte = U256::from(255u8).shl(u * 8);
let mut shifted_byte_column = bit_masks
.clone()
.map(|bit_mask| bit_mask & shifted_max_byte);
let (is_const, shifted_byte) = match shifted_byte_column.next() {
None => (true, U256::ZERO),
Some(first) => (shifted_byte_column.all(|byte| first == byte), first),
};
if is_const {
(0u32, shifted_byte)
} else {
(1u32 << u, U256::ZERO)
}
})
.fold(
(0u32, U256::ZERO),
|(vary_mask, constant_mask), (vary_bit, shifted_byte)| {
(vary_mask | vary_bit, constant_mask | shifted_byte)
},
);
Self {
vary_mask,
constant_mask: constant_mask.into(),
}
}
#[expect(clippy::missing_panics_doc)]
pub fn varying_byte_indices(&self) -> impl Iterator<Item = u8> + '_ {
BitIter::from(self.vary_mask)
.iter()
.map(|u| u8::try_from(u * 8).unwrap())
}
#[expect(clippy::missing_panics_doc)]
#[cfg_attr(not(test), expect(dead_code))]
pub fn varying_byte_count(&self) -> u8 {
self.vary_mask.count_ones().try_into().unwrap()
}
#[cfg_attr(not(test), expect(dead_code))]
pub fn constant_mask(&self) -> U256 {
U256::from(self.constant_mask)
}
}
#[cfg(test)]
mod tests {
use super::ByteDistribution;
use crate::base::scalar::{test_scalar::TestScalar, ScalarExt};
use bnum::types::U256;
use core::ops::{Neg, Shl, Shr};
use itertools::Itertools;
#[test]
fn we_can_get_byte_distribution_from_empty_column() {
let byte_distribution = ByteDistribution::new::<TestScalar, TestScalar>(&[]);
assert_eq!(byte_distribution.vary_mask, 0);
assert_eq!(byte_distribution.constant_mask(), U256::ZERO);
assert_eq!(byte_distribution.varying_byte_count(), 0);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
Vec::<u8>::new()
);
}
#[test]
fn we_can_get_byte_distribution_from_single_positive_value_column() {
let value = U256::from(18_446_744_073_709_551_615u64);
let column = [value].map(TestScalar::from_wrapping);
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 0);
assert_eq!(
byte_distribution.constant_mask(),
value | U256::ONE.shl(255)
);
assert_eq!(byte_distribution.varying_byte_count(), 0);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
Vec::<u8>::new()
);
}
#[test]
fn we_can_get_byte_distribution_from_single_negative_value_column() {
let value = U256::from(18_446_744_073_709_551_615u64);
let const_scalar = -TestScalar::from_wrapping(value);
let column = [const_scalar];
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 0);
assert_eq!(
byte_distribution.constant_mask(),
U256::ONE.shl(255) - value
);
assert_eq!(byte_distribution.varying_byte_count(), 0);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
Vec::<u8>::new()
);
}
#[test]
fn we_can_get_byte_distribution_from_constant_positive_column() {
let value = U256::from(18_446_744_073_709_551_615u64);
let column = [value; 3].map(TestScalar::from_wrapping);
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 0);
assert_eq!(
byte_distribution.constant_mask(),
value | U256::ONE.shl(255)
);
assert_eq!(byte_distribution.varying_byte_count(), 0);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
Vec::<u8>::new()
);
}
#[test]
fn we_can_get_byte_distribution_from_constant_negative_column() {
let value = U256::from(18_446_744_073_709_551_615u64);
let const_scalar = -TestScalar::from_wrapping(value);
let column = [const_scalar; 3];
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 0);
assert_eq!(
byte_distribution.constant_mask(),
U256::ONE.shl(255) - value
);
assert_eq!(byte_distribution.varying_byte_count(), 0);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
Vec::<u8>::new()
);
}
#[test]
fn we_can_get_byte_distribution_from_bitwise_inverse_column() {
let value = U256::from(18_446_744_073_709_551_615u64);
let negative_scalar = -TestScalar::from_wrapping(value + U256::ONE);
let positive_scalar = TestScalar::from_wrapping(value);
let column = [positive_scalar, negative_scalar];
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, u32::MAX);
assert_eq!(byte_distribution.constant_mask(), U256::ZERO);
assert_eq!(byte_distribution.varying_byte_count(), 32);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
(0u8..32).map(|i| i * 8).collect_vec()
);
}
#[test]
fn we_can_get_byte_distribution_from_variable_positive_column() {
let column = [
1_974_179_072u32,
2_518_259_060,
1_394_578_845,
1_000_510_769,
1_675_728_301,
]
.map(TestScalar::from);
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 13);
assert_eq!(
byte_distribution.constant_mask(),
U256::from(149u8).shl(8) | U256::ONE.shl(255)
);
assert_eq!(byte_distribution.varying_byte_count(), 3);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
[0u8, 16, 24].into_iter().collect_vec()
);
}
#[test]
fn we_can_get_byte_distribution_from_variable_negative_column() {
let column = [
1_974_179_073u32,
2_518_259_061,
1_394_578_846,
1_000_510_770,
1_675_728_302,
]
.map(TestScalar::from)
.map(Neg::neg);
let byte_distribution = ByteDistribution::new::<TestScalar, _>(&column);
assert_eq!(byte_distribution.vary_mask, 13);
assert_eq!(
byte_distribution.constant_mask(),
U256::from(106u8).shl(8) | U256::MAX.shr(33u8).shl(32)
);
assert_eq!(byte_distribution.varying_byte_count(), 3);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
[0u8, 16, 24].into_iter().collect_vec()
);
}
#[test]
fn we_can_get_byte_distribution_from_variable_column() {
let negative_column = [
1_974_179_073u32,
2_518_259_061,
1_394_578_846,
1_000_510_770,
1_675_728_302,
]
.map(TestScalar::from)
.map(Neg::neg);
let positive_column = [
1_974_179_072u32,
2_518_259_060,
1_394_578_845,
1_000_510_769,
1_675_728_301,
]
.map(TestScalar::from);
let byte_distribution = ByteDistribution::new::<TestScalar, _>(
&negative_column
.iter()
.chain(positive_column.iter())
.collect_vec(),
);
assert_eq!(byte_distribution.vary_mask, u32::MAX);
assert_eq!(U256::from(byte_distribution.constant_mask), U256::ZERO);
assert_eq!(byte_distribution.varying_byte_count(), 32);
assert_eq!(
byte_distribution.varying_byte_indices().collect_vec(),
(0u8..32).map(|i| i * 8).collect_vec()
);
}
}