use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
3 * n * size_of::<i64>()
}
#[allow(clippy::too_many_arguments)]
pub fn vec_znx_normalize<R, A, ZNXARI>(
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a: &A,
a_base2k: usize,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeDigit,
{
match res_base2k == a_base2k {
true => vec_znx_normalize_inter_base2k::<R, A, ZNXARI>(res_base2k, res, res_offset, res_col, a, a_col, carry),
false => vec_znx_normalize_cross_base2k::<R, A, ZNXARI>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry),
}
}
fn vec_znx_normalize_inter_base2k<R, A, ZNXARI>(
base2k: usize,
res: &mut R,
res_offset: i64,
res_col: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeMiddleStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size: usize = a.size();
let (carry, _) = carry.split_at_mut(n);
let mut lsh: i64 = res_offset % base2k as i64;
let mut limbs_offset: i64 = res_offset / base2k as i64;
if res_offset < 0 && lsh != 0 {
lsh = (lsh + base2k as i64) % (base2k as i64);
limbs_offset -= 1;
}
let lsh_pos: usize = lsh as usize;
let res_end: usize = (-limbs_offset).clamp(0, res_size as i64) as usize;
let res_start: usize = (a_size as i64 - limbs_offset).clamp(0, res_size as i64) as usize;
let a_end: usize = limbs_offset.clamp(0, a_size as i64) as usize;
let a_start: usize = (res_size as i64 + limbs_offset).clamp(0, a_size as i64) as usize;
let a_out_range: usize = a_size.saturating_sub(a_start);
for j in 0..a_out_range {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
}
}
if a_out_range == 0 {
ZNXARI::znx_zero(carry);
}
for j in res_start..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
let mid_range: usize = a_start.saturating_sub(a_end);
for j in 0..mid_range {
ZNXARI::znx_normalize_middle_step(
base2k,
lsh_pos,
res.at_mut(res_col, res_start - j - 1),
a.at(a_col, a_start - j - 1),
carry,
);
}
for j in 0..res_end {
ZNXARI::znx_zero(res.at_mut(res_col, res_end - j - 1));
if j == res_end - 1 {
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
}
}
}
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize_cross_base2k<R, A, ZNXARI>(
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a: &A,
a_base2k: usize,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeDigit,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size: usize = a.size();
let (a_norm, carry) = carry.split_at_mut(n);
let (res_carry, a_carry) = carry[..2 * n].split_at_mut(n);
ZNXARI::znx_zero(res_carry);
let a_tot_bits: usize = a_size * a_base2k;
let res_tot_bits: usize = res_size * res_base2k;
let mut lsh: i64 = res_offset % a_base2k as i64;
let mut limbs_offset: i64 = res_offset / a_base2k as i64;
if res_offset < 0 && lsh != 0 {
lsh = (lsh + a_base2k as i64) % (a_base2k as i64);
limbs_offset -= 1;
}
let lsh_pos: usize = lsh as usize;
let res_end_bit: usize = (-limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; let res_start_bit: usize = (a_tot_bits as i64 - limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; let a_end_bit: usize = (limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; let a_start_bit: usize = (res_tot_bits as i64 + limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize;
let res_end: usize = res_end_bit / res_base2k;
let res_start: usize = res_start_bit.div_ceil(res_base2k);
let a_end: usize = a_end_bit / a_base2k;
let a_start: usize = a_start_bit.div_ceil(a_base2k);
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
if res_start == 0 {
return;
}
let a_out_range: usize = a_size.saturating_sub(a_start);
for j in 0..a_out_range {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
}
}
if a_out_range == 0 {
ZNXARI::znx_zero(a_carry);
}
let mut res_acc_left: usize = res_base2k;
let mut res_limb: usize = res_start - 1;
let mid_range: usize = a_start.saturating_sub(a_end);
'outer: for j in 0..mid_range {
let a_limb: usize = a_start - j - 1;
let a_slice: &[i64] = a.at(a_col, a_limb);
let mut a_take_left: usize = a_base2k;
ZNXARI::znx_normalize_middle_step(a_base2k, lsh_pos, a_norm, a_slice, a_carry);
if j == 0 {
if !(a_tot_bits - a_start_bit).is_multiple_of(a_base2k) {
let take: usize = (a_tot_bits - a_start_bit) % a_base2k;
ZNXARI::znx_mul_power_of_two_inplace(-(take as i64), a_norm);
a_take_left -= take;
} else if !(res_tot_bits - res_start_bit).is_multiple_of(res_base2k) {
res_acc_left -= (res_tot_bits - res_start_bit) % res_base2k;
}
}
'inner: loop {
let res_slice: &mut [i64] = res.at_mut(res_col, res_limb);
let a_take: usize = a_base2k.min(a_take_left).min(res_acc_left);
if a_take != 0 {
let scale: usize = res_base2k - res_acc_left;
ZNXARI::znx_extract_digit_addmul(a_take, scale, res_slice, a_norm);
a_take_left -= a_take;
res_acc_left -= a_take;
}
if res_acc_left == 0 || a_limb == 0 {
if a_limb == 0 && a_take_left == 0 {
ZNXARI::znx_add_inplace(a_carry, a_norm);
if res_acc_left != 0 {
let scale: usize = res_base2k - res_acc_left;
ZNXARI::znx_extract_digit_addmul(res_acc_left, scale, res_slice, a_carry);
}
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res_slice, res_carry);
ZNXARI::znx_add_inplace(res_carry, a_carry);
break 'outer;
}
if res_limb == 0 {
break 'outer;
}
res_acc_left += res_base2k;
res_limb -= 1;
}
if a_take_left == 0 {
ZNXARI::znx_add_inplace(a_carry, a_norm);
break 'inner;
}
}
}
if res_end != 0 {
let carry_to_use = if a_start == a_end { a_carry } else { res_carry };
for j in 0..res_end {
if j == res_end - 1 {
ZNXARI::znx_normalize_final_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
} else {
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
}
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
}
}
}
#[test]
fn test_vec_znx_normalize_cross_base2k() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
use crate::reference::znx::ZnxRef;
use dashu_float::{FBig, ops::Abs, round::mode::HalfEven};
let prec: usize = 128;
let pow2 = |exp: u32| -> FBig<HalfEven> {
let mut result = FBig::<HalfEven>::ONE;
let chunk = FBig::<HalfEven>::from(1u64 << 63);
let rem = exp % 63;
let full = exp / 63;
for _ in 0..full {
result *= chunk.clone();
}
result * FBig::from(1u64 << rem)
};
let reduce = |x: FBig<HalfEven>| -> FBig<HalfEven> {
let fl = x.floor();
let mut r = x - fl; if r >= FBig::<HalfEven>::from(1u64) / FBig::from(2u64) {
r -= FBig::<HalfEven>::from(1u64);
}
r
};
for in_base2k in 1..=51 {
for out_base2k in 1..=51 {
for offset in [
-(prec as i64),
-(prec as i64 - 1),
-(prec as i64 - in_base2k as i64),
-(in_base2k as i64 + 1),
in_base2k as i64,
-(in_base2k as i64 - 1),
0,
(in_base2k as i64 - 1),
in_base2k as i64,
(in_base2k as i64 + 1),
(prec as i64 - in_base2k as i64),
(prec - 1) as i64,
prec as i64,
] {
let mut source: Source = Source::new([1u8; 32]);
let in_size: usize = prec.div_ceil(in_base2k);
let in_prec: u32 = (in_size * in_base2k) as u32;
let out_size: usize = (in_prec as usize).div_ceil(out_base2k);
let min_prec: u32 = (in_size * in_base2k).min(out_size * out_base2k) as u32;
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, in_size);
want.fill_uniform(60, &mut source);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, out_size);
have.fill_uniform(60, &mut source);
vec_znx_normalize_cross_base2k::<_, _, ZnxRef>(&mut have, out_base2k, offset, 0, &want, in_base2k, 0, &mut carry);
let mut data_have: Vec<FBig<HalfEven>> = (0..n).map(|_| FBig::ZERO).collect();
let mut data_want: Vec<FBig<HalfEven>> = (0..n).map(|_| FBig::ZERO).collect();
have.decode_vec_float(out_base2k, 0, &mut data_have);
want.decode_vec_float(in_base2k, 0, &mut data_want);
let scale: FBig<HalfEven> = pow2(offset.unsigned_abs() as u32);
if offset > 0 {
for x in &mut data_want {
*x = reduce(x.clone() * scale.clone());
}
} else if offset < 0 {
for x in &mut data_want {
*x = reduce(x.clone() / scale.clone());
}
} else {
for x in &mut data_want {
*x = reduce(x.clone());
}
}
let half: FBig<HalfEven> = FBig::from(1u64) / FBig::from(2u64);
let neg_half: FBig<HalfEven> = -FBig::from(1u64) / FBig::from(2u64);
for x in &mut data_have {
if *x >= half {
*x = x.clone() - FBig::from(1u64);
} else if *x < neg_half {
*x = x.clone() + FBig::from(1u64);
}
}
for i in 0..n {
let err = (data_have[i].clone() - data_want[i].clone()).abs();
let err_log2: f64 = f64::try_from(err).unwrap_or(0.0).max(1e-60_f64).log2();
assert!(err_log2 <= -(min_prec as f64) + 1.0, "{} {}", err_log2, -(min_prec as f64))
}
}
}
}
}
#[test]
fn test_vec_znx_normalize_inter_base2k() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
use crate::reference::znx::ZnxRef;
use dashu_float::{FBig, ops::Abs, round::mode::HalfEven};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let offset_range: i64 = prec as i64;
let pow2 = |exp: u32| -> FBig<HalfEven> {
let mut result = FBig::<HalfEven>::ONE;
let chunk = FBig::<HalfEven>::from(1u64 << 63);
let rem = exp % 63;
let full = exp / 63;
for _ in 0..full {
result *= chunk.clone();
}
result * FBig::from(1u64 << rem)
};
let reduce = |x: FBig<HalfEven>| -> FBig<HalfEven> {
let fl = x.floor();
let mut r = x - fl; if r >= FBig::<HalfEven>::from(1u64) / FBig::from(2u64) {
r -= FBig::<HalfEven>::from(1u64);
}
r
};
for base2k in 1..=51 {
for offset in (-offset_range..=offset_range).step_by(base2k + 1) {
let size: usize = prec.div_ceil(base2k);
let out_prec: u32 = (size * base2k) as u32;
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
want.fill_uniform(60, &mut source);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
have.fill_uniform(60, &mut source);
vec_znx_normalize_inter_base2k::<_, _, ZnxRef>(base2k, &mut have, offset, 0, &want, 0, &mut carry);
let mut data_have: Vec<FBig<HalfEven>> = (0..n).map(|_| FBig::ZERO).collect();
let mut data_want: Vec<FBig<HalfEven>> = (0..n).map(|_| FBig::ZERO).collect();
have.decode_vec_float(base2k, 0, &mut data_have);
want.decode_vec_float(base2k, 0, &mut data_want);
let scale: FBig<HalfEven> = pow2(offset.unsigned_abs() as u32);
if offset > 0 {
for x in &mut data_want {
*x = reduce(x.clone() * scale.clone());
}
} else if offset < 0 {
for x in &mut data_want {
*x = reduce(x.clone() / scale.clone());
}
} else {
for x in &mut data_want {
*x = reduce(x.clone());
}
}
let half: FBig<HalfEven> = FBig::from(1u64) / FBig::from(2u64);
let neg_half: FBig<HalfEven> = -FBig::from(1u64) / FBig::from(2u64);
for x in &mut data_have {
if *x >= half {
*x = x.clone() - FBig::from(1u64);
} else if *x < neg_half {
*x = x.clone() + FBig::from(1u64);
}
}
for i in 0..n {
let err = (data_have[i].clone() - data_want[i].clone()).abs();
let err_log2: f64 = f64::try_from(err).unwrap_or(0.0).max(1e-60_f64).log2();
assert!(err_log2 <= -(out_prec as f64), "{} {}", err_log2, -(out_prec as f64))
}
}
}
}
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
let res_offset: i64 = 0;
move || {
for i in 0..cols {
module.vec_znx_normalize(&mut res, base2k, res_offset, i, &a, base2k, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
a.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}