use poulpy_hal::{
api::{
CnvPVecBytesOf, Convolution, ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace,
VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply,
VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig},
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
};
use crate::{
GGLWEProduct, ScratchTakeCore,
layouts::{
Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWETensor, GLWETensorKeyPrepared, GLWEToMut, GLWEToRef, LWEInfos,
TorusPrecision,
},
};
pub trait GLWEMulConst<BE: Backend> {
fn glwe_mul_const_tmp_bytes<R, A>(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize
where
R: GLWEInfos,
A: GLWEInfos;
fn glwe_mul_const<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWE<A>, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef;
fn glwe_mul_const_inplace<R>(&self, res: &mut GLWE<R>, res_offset: usize, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut;
}
impl<BE: Backend> GLWEMulConst<BE> for Module<BE>
where
Self: Convolution<BE> + VecZnxBigBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_mul_const_tmp_bytes<R, A>(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
{
assert_eq!(self.n() as u32, res.n());
assert_eq!(self.n() as u32, a.n());
let a_base2k: usize = a.base2k().as_usize();
let res_base2k: usize = res.base2k().as_usize();
let res_size: usize = (res.size() * res_base2k).div_ceil(a_base2k);
let lvl_0: usize = self.bytes_of_vec_znx_big(1, res_size);
let lvl_1_cnv: usize = self.cnv_by_const_apply_tmp_bytes(res_size, res_offset, a.size(), b_size);
let lvl_1_norm: usize = self.vec_znx_big_normalize_tmp_bytes();
let lvl_1: usize = lvl_1_cnv.max(lvl_1_norm);
lvl_0 + lvl_1
}
fn glwe_mul_const<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWE<A>, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef,
{
assert_eq!(res.rank(), a.rank());
assert!(
scratch.available() >= self.glwe_mul_const_tmp_bytes(res, res_offset, a, b.len()),
"scratch.available(): {} < GLWEMulConst::glwe_mul_const_tmp_bytes: {}",
scratch.available(),
self.glwe_mul_const_tmp_bytes(res, res_offset, a, b.len())
);
let cols: usize = res.rank().as_usize() + 1;
let a_base2k: usize = a.base2k().as_usize();
let res_base2k: usize = res.base2k().as_usize();
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.len() - res_offset_hi);
let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res_dft_size);
for i in 0..cols {
self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, a.data(), i, b, scratch_1);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_1);
}
}
fn glwe_mul_const_inplace<R>(&self, res: &mut GLWE<R>, res_offset: usize, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut,
{
let res_ref: &GLWE<&[u8]> = &res.to_ref();
assert!(
scratch.available() >= self.glwe_mul_const_tmp_bytes(res_ref, res_offset, res_ref, b.len()),
"scratch.available(): {} < GLWEMulConst::glwe_mul_const_tmp_bytes: {}",
scratch.available(),
self.glwe_mul_const_tmp_bytes(res_ref, res_offset, res_ref, b.len())
);
let cols: usize = res.rank().as_usize() + 1;
let res_base2k: usize = res.base2k().as_usize();
let (res_offset_hi, res_offset_lo) = if res_offset < res_base2k {
(0, -((res_base2k - (res_offset % res_base2k)) as i64))
} else {
((res_offset / res_base2k).saturating_sub(1), (res_offset % res_base2k) as i64)
};
let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size());
for i in 0..cols {
self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, res.data(), i, b, scratch_1);
self.vec_znx_big_normalize(
res.data_mut(),
res_base2k,
res_offset_lo,
i,
&res_big,
res_base2k,
0,
scratch_1,
);
}
}
}
impl<BE: Backend> GLWEMulPlain<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ CnvPVecBytesOf
+ VecZnxDftBytesOf
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ Convolution<BE>
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_mul_plain_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos,
{
assert_eq!(self.n() as u32, res.n());
assert_eq!(self.n() as u32, a.n());
assert_eq!(self.n() as u32, b.n());
let ab_base2k: Base2K = a.base2k();
assert_eq!(b.base2k(), ab_base2k);
let cols: usize = res.rank().as_usize() + 1;
let a_size: usize = a.size();
let b_size: usize = b.size();
let res_size: usize = res.size();
let lvl_0: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(1, b_size);
let lvl_1: usize = self
.cnv_prepare_left_tmp_bytes(a_size, a_size)
.max(self.cnv_prepare_right_tmp_bytes(a_size, a_size));
let lvl_2_cnv_apply: usize = self.cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size);
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a_size + b_size - res_offset / ab_base2k.as_usize());
let lvl_2_res_dft: usize = self.bytes_of_vec_znx_dft(1, res_dft_size);
let lvl_2_norm: usize = self.vec_znx_big_normalize_tmp_bytes();
let lvl_2: usize = lvl_2_res_dft + lvl_2_cnv_apply.max(lvl_2_norm);
lvl_0 + lvl_1.max(lvl_2)
}
fn glwe_mul_plain<R, A, B>(
&self,
res: &mut GLWE<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWEPlaintext<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
assert_eq!(res.rank(), a.rank());
assert!(
scratch.available() >= self.glwe_mul_plain_tmp_bytes(res, res_offset, a, b),
"scratch.available(): {} < GLWEMulPlain::glwe_mul_plain_tmp_bytes: {}",
scratch.available(),
self.glwe_mul_plain_tmp_bytes(res, res_offset, a, b)
);
let a_base2k: usize = a.base2k().as_usize();
assert_eq!(b.base2k().as_usize(), a_base2k);
let res_base2k: usize = res.base2k().as_usize();
let cols: usize = res.rank().as_usize() + 1;
let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size());
let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, b.size());
self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2);
self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2);
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.size() - res_offset_hi);
for i in 0..cols {
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, 0, scratch_3);
let res_big = self.vec_znx_idft_apply_consume(res_dft);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3);
}
}
fn glwe_mul_plain_inplace<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWEPlaintext<A>, scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef,
{
let res_ref: &GLWE<&[u8]> = &res.to_ref();
assert!(
scratch.available() >= self.glwe_mul_plain_tmp_bytes(res_ref, res_offset, res_ref, a),
"scratch.available(): {} < GLWEMulPlain::glwe_mul_plain_tmp_bytes: {}",
scratch.available(),
self.glwe_mul_plain_tmp_bytes(res_ref, res_offset, res_ref, a)
);
let a_base2k: usize = a.base2k().as_usize();
let res_base2k: usize = res.base2k().as_usize();
assert_eq!(res_base2k, a_base2k);
let cols: usize = res.rank().as_usize() + 1;
let (mut res_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, res.size());
let (mut a_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, a.size());
self.cnv_prepare_left(&mut res_prep, res.data(), scratch_2);
self.cnv_prepare_right(&mut a_prep, a.data(), scratch_2);
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + res.size() - res_offset_hi);
for i in 0..cols {
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset, 0, &res_prep, i, &a_prep, 0, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3);
}
}
}
pub trait GLWEMulPlain<BE: Backend> {
fn glwe_mul_plain_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos;
fn glwe_mul_plain<R, A, B>(
&self,
res: &mut GLWE<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWEPlaintext<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_mul_plain_inplace<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWEPlaintext<A>, scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef;
}
pub trait GLWETensoring<BE: Backend> {
fn glwe_tensor_apply_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos;
fn glwe_tensor_apply<R, A, B>(
&self,
res: &mut GLWETensor<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWE<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_tensor_relinearize<R, A, B>(
&self,
res: &mut GLWE<R>,
a: &GLWETensor<A>,
tsk: &GLWETensorKeyPrepared<B, BE>,
tsk_size: usize,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_tensor_relinearize_tmp_bytes<R, A, B>(&self, res: &R, a: &A, tsk: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos;
}
impl<BE: Backend> GLWETensoring<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ CnvPVecBytesOf
+ VecZnxDftBytesOf
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ Convolution<BE>
+ VecZnxSubInplace
+ VecZnxNegate
+ VecZnxAddInplace
+ VecZnxBigNormalizeTmpBytes
+ VecZnxCopy
+ VecZnxNormalize<BE>
+ VecZnxDftApply<BE>
+ GGLWEProduct<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_tensor_apply_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos,
{
assert_eq!(self.n() as u32, res.n());
assert_eq!(self.n() as u32, a.n());
assert_eq!(self.n() as u32, b.n());
let ab_base2k: Base2K = a.base2k();
assert_eq!(b.base2k(), ab_base2k);
let cols: usize = res.rank().as_usize() + 1;
let a_size: usize = a.size();
let b_size: usize = b.size();
let res_size: usize = res.size();
let lvl_0: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(cols, b_size);
let lvl_1: usize = self
.cnv_prepare_left_tmp_bytes(a_size, a_size)
.max(self.cnv_prepare_right_tmp_bytes(a_size, a_size));
let lvl_2_apply: usize = self.cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size);
let lvl_2_pairwise: usize = self.cnv_pairwise_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size);
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a_size + b_size - res_offset / ab_base2k.as_usize());
let lvl_2a: usize = self.bytes_of_vec_znx_dft(1, res_dft_size)
+ lvl_2_apply.max(VecZnx::bytes_of(self.n(), 1, res_dft_size) + self.vec_znx_big_normalize_tmp_bytes());
let lvl_2b: usize = self.bytes_of_vec_znx_dft(1, res.size())
+ lvl_2_pairwise.max(VecZnx::bytes_of(self.n(), 1, res.size()) + self.vec_znx_big_normalize_tmp_bytes());
let lvl_2: usize = lvl_2a.max(lvl_2b);
lvl_0 + lvl_1.max(lvl_2)
}
fn glwe_tensor_relinearize_tmp_bytes<R, A, B>(&self, res: &R, a: &A, tsk: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos,
{
assert_eq!(self.n() as u32, res.n());
assert_eq!(self.n() as u32, a.n());
assert_eq!(self.n() as u32, tsk.n());
let a_base2k: usize = a.base2k().into();
let key_base2k: usize = tsk.base2k().into();
let res_base2k: usize = res.base2k().into();
let cols: usize = tsk.rank_out().as_usize() + 1;
let pairs: usize = tsk.rank_in().as_usize();
let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k);
let lvl_0: usize = self.bytes_of_vec_znx_dft(pairs, a_dft_size);
let lvl_1_pre_conv: usize = if a_base2k != key_base2k {
VecZnx::bytes_of(self.n(), 1, a_dft_size) + self.vec_znx_normalize_tmp_bytes()
} else {
0
};
let lvl_1_res_dft: usize = self.bytes_of_vec_znx_dft(cols, tsk.size());
let lvl_1_gglwe_product: usize = self.gglwe_product_dft_tmp_bytes(res.size(), a_dft_size, tsk);
let lvl_1_post_conv: usize = if res_base2k != key_base2k {
VecZnx::bytes_of(self.n(), 1, a_dft_size) + self.vec_znx_normalize_tmp_bytes()
} else {
0
};
let lvl_1_big_norm: usize = self.vec_znx_big_normalize_tmp_bytes();
let lvl_1_main: usize = lvl_1_res_dft + lvl_1_gglwe_product.max(lvl_1_post_conv).max(lvl_1_big_norm);
let lvl_1: usize = lvl_1_pre_conv.max(lvl_1_main);
lvl_0 + lvl_1
}
fn glwe_tensor_relinearize<R, A, B>(
&self,
res: &mut GLWE<R>,
a: &GLWETensor<A>,
tsk: &GLWETensorKeyPrepared<B, BE>,
tsk_size: usize,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
assert!(
scratch.available() >= self.glwe_tensor_relinearize_tmp_bytes(res, a, tsk),
"scratch.available(): {} < GLWETensoring::glwe_tensor_relinearize_tmp_bytes: {}",
scratch.available(),
self.glwe_tensor_relinearize_tmp_bytes(res, a, tsk)
);
let a_base2k: usize = a.base2k().into();
let key_base2k: usize = tsk.base2k().into();
let res_base2k: usize = res.base2k().into();
assert_eq!(res.rank(), tsk.rank_out());
assert_eq!(a.rank(), tsk.rank_out());
let cols: usize = tsk.rank_out().as_usize() + 1;
let pairs: usize = tsk.rank_in().as_usize();
let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, pairs, a_dft_size);
if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_dft_size);
for i in 0..pairs {
self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, cols + i, scratch_2);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0);
}
} else {
for i in 0..pairs {
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, a.data(), 0);
}
}
let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk_size);
self.gglwe_product_dft(&mut res_dft, &a_dft, &tsk.0, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
if res_base2k == key_base2k {
for i in 0..cols {
self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i);
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), 1, a_dft_size);
for i in 0..cols {
self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, i, scratch_3);
self.vec_znx_big_add_small_inplace(&mut res_big, i, &a_conv, 0);
}
}
for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
}
}
fn glwe_tensor_apply<R, A, B>(
&self,
res: &mut GLWETensor<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWE<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
assert!(
scratch.available() >= self.glwe_tensor_apply_tmp_bytes(res, res_offset, a, b),
"scratch.available(): {} < GLWETensoring::glwe_tensor_apply_tmp_bytes: {}",
scratch.available(),
self.glwe_tensor_apply_tmp_bytes(res, res_offset, a, b)
);
let a_base2k: usize = a.base2k().as_usize();
assert_eq!(b.base2k().as_usize(), a_base2k);
let res_base2k: usize = res.base2k().as_usize();
let cols: usize = res.rank().as_usize() + 1;
let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size());
let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, cols, b.size());
self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2);
self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2);
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.size() - res_offset_hi);
for i in 0..cols {
let col_i: usize = i * cols - (i * (i + 1) / 2);
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, i, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
let (mut tmp, scratch_4) = scratch_3.take_vec_znx(self.n(), 1, res_dft_size);
self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_4);
self.vec_znx_copy(res.data_mut(), col_i + i, &tmp, 0);
for j in 0..cols {
if j != i {
if j < i {
let col_j = j * cols - (j * (j + 1) / 2);
self.vec_znx_sub_inplace(res.data_mut(), col_j + i, &tmp, 0);
} else {
self.vec_znx_negate(res.data_mut(), col_i + j, &tmp, 0);
}
}
}
}
for i in 0..cols {
let col_i: usize = i * cols - (i * (i + 1) / 2);
for j in i..cols {
if j != i {
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res.size());
self.cnv_pairwise_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, &b_prep, i, j, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
let (mut tmp, scratch_3) = scratch_3.take_vec_znx(self.n(), 1, res.size());
self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_3);
self.vec_znx_add_inplace(res.data_mut(), col_i + j, &tmp, 0);
}
}
}
}
}
pub trait GLWEAdd
where
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
{
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
R: GLWEToMut,
A: GLWEToRef,
B: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k());
assert_eq!(res.base2k(), b.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into();
for i in 0..min_col {
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
}
if a.rank() > b.rank() {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
} else {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
}
}
for i in max_col..self_col {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
for i in 0..(a.rank() + 1).into() {
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
}
}
}
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
impl<BE: Backend> GLWESub for Module<BE> where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
{
}
pub trait GLWESub
where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
{
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
R: GLWEToMut,
A: GLWEToRef,
B: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), res.base2k());
assert_eq!(b.base2k(), res.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into();
for i in 0..min_col {
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
}
if a.rank() > b.rank() {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
} else {
for i in min_col..max_col {
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
}
}
for i in max_col..self_col {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() == a.rank() || a.rank() == 0);
for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
}
}
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() == a.rank() || a.rank() == 0);
for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
}
}
}
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
pub trait GLWERotate<BE: Backend>
where
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
{
fn glwe_rotate_tmp_bytes(&self) -> usize {
let lvl_0: usize = vec_znx_rotate_inplace_tmp_bytes(self.n());
lvl_0
}
fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
let res_cols = (res.rank() + 1).into();
let a_cols = (a.rank() + 1).into();
for i in 0..a_cols {
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
}
for i in a_cols..res_cols {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
assert!(
scratch.available() >= self.glwe_rotate_tmp_bytes(),
"scratch.available(): {} < GLWERotate::glwe_rotate_tmp_bytes: {}",
scratch.available(),
self.glwe_rotate_tmp_bytes()
);
for i in 0..(res.rank() + 1).into() {
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
}
}
}
impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
pub trait GLWEMulXpMinusOne<BE: Backend>
where
Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
{
fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
}
}
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
assert_eq!(res.n(), self.n() as u32);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
}
}
}
impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy + VecZnxZero {}
pub trait GLWECopy
where
Self: ModuleN + VecZnxCopy + VecZnxZero,
{
fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
for i in 0..min_rank {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
for i in min_rank..(res.rank() + 1).into() {
self.vec_znx_zero(res.data_mut(), i);
}
}
}
impl<BE: Backend> GLWEShift<BE> for Module<BE> where Self: ModuleN + VecZnxRshInplace<BE> {}
pub trait GLWEShift<BE: Backend>
where
Self: ModuleN + VecZnxRshInplace<BE>,
{
fn glwe_rsh_tmp_byte(&self) -> usize {
let lvl_0: usize = VecZnx::rsh_tmp_bytes(self.n());
lvl_0
}
fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
assert!(
scratch.available() >= self.glwe_rsh_tmp_byte(),
"scratch.available(): {} < GLWEShift::glwe_rsh_tmp_byte: {}",
scratch.available(),
self.glwe_rsh_tmp_byte()
);
let base2k: usize = res.base2k().into();
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
}
}
}
impl GLWE<Vec<u8>> {
pub fn rsh_tmp_bytes<M, BE: Backend>(module: &M) -> usize
where
M: GLWEShift<BE>,
{
module.glwe_rsh_tmp_byte()
}
}
impl<BE: Backend> GLWENormalize<BE> for Module<BE> where
Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes
{
}
pub trait GLWENormalize<BE: Backend>
where
Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes,
{
fn glwe_normalize_tmp_bytes(&self) -> usize {
let lvl_0: usize = self.vec_znx_normalize_tmp_bytes();
lvl_0
}
fn glwe_maybe_cross_normalize_to_ref<'a, A>(
&self,
glwe: &'a A,
target_base2k: usize,
tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, scratch: &'a mut Scratch<BE>,
) -> (GLWE<&'a [u8]>, &'a mut Scratch<BE>)
where
A: GLWEToRef + GLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
if glwe.base2k().as_usize() == target_base2k {
tmp_slot.take();
return (glwe.to_ref(), scratch);
}
let mut layout = glwe.glwe_layout();
layout.base2k = target_base2k.into();
let (tmp, scratch2) = scratch.take_glwe(&layout);
*tmp_slot = Some(tmp);
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None");
self.glwe_normalize(tmp_ref, glwe, scratch2);
(tmp_ref.to_ref(), scratch2)
}
fn glwe_maybe_cross_normalize_to_mut<'a, A>(
&self,
glwe: &'a mut A,
target_base2k: usize,
tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, scratch: &'a mut Scratch<BE>,
) -> (GLWE<&'a mut [u8]>, &'a mut Scratch<BE>)
where
A: GLWEToMut + GLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
if glwe.base2k().as_usize() == target_base2k {
tmp_slot.take();
return (glwe.to_mut(), scratch);
}
let mut layout = glwe.glwe_layout();
layout.base2k = target_base2k.into();
let (tmp, scratch2) = scratch.take_glwe(&layout);
*tmp_slot = Some(tmp);
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None");
self.glwe_normalize(tmp_ref, glwe, scratch2);
(tmp_ref.to_mut(), scratch2)
}
fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
assert!(
scratch.available() >= self.glwe_normalize_tmp_bytes(),
"scratch.available(): {} < GLWENormalize::glwe_normalize_tmp_bytes: {}",
scratch.available(),
self.glwe_normalize_tmp_bytes()
);
let res_base2k = res.base2k().into();
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_normalize(res.data_mut(), res_base2k, 0, i, a.data(), a.base2k().into(), i, scratch);
}
}
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
assert!(
scratch.available() >= self.glwe_normalize_tmp_bytes(),
"scratch.available(): {} < GLWENormalize::glwe_normalize_tmp_bytes: {}",
scratch.available(),
self.glwe_normalize_tmp_bytes()
);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
}
}
}
#[allow(dead_code)]
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 {
let k = if a.rank() == 0 {
b.k()
} else if b.rank() == 0 {
a.k()
} else {
a.k().min(b.k())
};
k.min(c.k())
} else {
c.k()
}
}
#[allow(dead_code)]
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 {
a.k().min(b.k())
} else {
a.k()
}
}