use poulpy_hal::{
api::ModuleLogN,
layouts::{Backend, CyclotomicOrder, GaloisElement, ScratchArena, galois_element},
};
use crate::{
GLWEAutomorphism, GLWECopy, GLWENormalize, GLWEShift, ScratchArenaTakeCore,
layouts::{
GGLWEInfos, GGLWELayout, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWELayout, GLWEToBackendMut, GLWEToBackendRef,
GetGaloisElement, LWEInfos, prepared::GGLWEPreparedToBackendRef,
},
};
#[inline(always)]
pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec<i64> {
(0..log_n)
.map(|i| {
if i == 0 {
-1
} else {
galois_element(1 << (i - 1), cyclotomic_order)
}
})
.collect()
}
fn trace_assign_internal<M, K, H, R, BE: Backend>(
module: &M,
res: &mut R,
skip: usize,
keys: &H,
key_size: usize,
scratch: &mut ScratchArena<'_, BE>,
) where
M: ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy<BE>
+ CyclotomicOrder
+ GLWENormalize<BE>
+ GLWETraceDefault<BE>
+ ?Sized,
K: GGLWEPreparedToBackendRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
R: GLWEToBackendMut<BE> + GLWEInfos,
{
let ksk_infos: &GGLWELayout = &keys.automorphism_key_infos();
let log_n: usize = module.log_n();
assert_eq!(res.n(), module.n() as u32);
assert_eq!(ksk_infos.n(), module.n() as u32);
assert!(skip <= log_n);
assert_eq!(ksk_infos.rank_in(), res.rank());
assert_eq!(ksk_infos.rank_out(), res.rank());
assert!(
scratch.available() >= module.glwe_trace_assign_tmp_bytes_default(res, ksk_infos),
"scratch.available(): {} < GLWETrace::glwe_trace_assign_tmp_bytes: {}",
scratch.available(),
module.glwe_trace_assign_tmp_bytes_default(res, ksk_infos)
);
if res.base2k() != ksk_infos.base2k() {
let res_conv_layout = GLWELayout {
n: module.n().into(),
base2k: ksk_infos.base2k(),
k: res.max_k(),
rank: res.rank(),
};
let scratch_local = scratch.borrow();
let (mut res_conv, scratch_1) = scratch_local.take_glwe_scratch(&res_conv_layout);
let mut scratch_1 = scratch_1;
scratch_1 = scratch_1.apply_mut(|scratch| {
module.glwe_normalize(&mut res_conv, res, scratch);
});
scratch_1 = scratch_1.apply_mut(|scratch| {
trace_assign_internal::<M, K, H, _, BE>(module, &mut res_conv, skip, keys, key_size, scratch);
});
scratch_1.apply_mut(|scratch| {
module.glwe_normalize(res, &res_conv, scratch);
});
return;
}
for i in skip..log_n {
let p: i64 = if i == 0 { -1 } else { module.galois_element(1 << (i - 1)) };
module.glwe_rsh(1, res, scratch);
if let Some(key) = keys.get_automorphism_key(p) {
module.glwe_automorphism_add_assign(res, key, key_size, scratch);
} else {
panic!("keys[{p}] is empty")
}
}
}
#[doc(hidden)]
pub trait GLWETraceDefault<BE: Backend> {
fn glwe_trace_assign_tmp_bytes_default<A, K>(&self, a_infos: &A, key_infos: &K) -> usize
where
A: GLWEInfos,
K: GGLWEInfos;
fn glwe_trace_galois_elements_default(&self) -> Vec<i64>;
fn glwe_trace_tmp_bytes_default<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos;
fn glwe_trace_default<R, A, K, H>(
&self,
res: &mut R,
skip: usize,
a: &A,
keys: &H,
key_size: usize,
scratch: &mut ScratchArena<'_, BE>,
) where
R: GLWEToBackendMut<BE> + GLWEInfos,
A: GLWEToBackendRef<BE> + GLWEInfos,
K: GGLWEPreparedToBackendRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>;
fn glwe_trace_assign_default<R, K, H>(
&self,
res: &mut R,
skip: usize,
keys: &H,
key_size: usize,
scratch: &mut ScratchArena<'_, BE>,
) where
R: GLWEToBackendMut<BE> + GLWEInfos,
K: GGLWEPreparedToBackendRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>;
}
pub mod glwe_trace_defaults_impl {
use super::*;
pub fn glwe_trace_assign_tmp_bytes_default<BE, M, A, K>(module: &M, a_infos: &A, key_infos: &K) -> usize
where
BE: Backend,
M: GLWETraceDefault<BE>
+ ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy<BE>
+ CyclotomicOrder
+ GLWENormalize<BE>,
A: GLWEInfos,
K: GGLWEInfos,
{
assert_eq!(module.n() as u32, a_infos.n());
assert_eq!(module.n() as u32, key_infos.n());
if a_infos.base2k() != key_infos.base2k() {
let a_conv_infos: GLWELayout = GLWELayout {
n: a_infos.n(),
base2k: key_infos.base2k(),
k: a_infos.max_k(),
rank: a_infos.rank(),
};
let lvl_0: usize = GLWE::<Vec<u8>>::bytes_of_from_infos(&a_conv_infos);
let lvl_1: usize = module
.glwe_normalize_tmp_bytes()
.max(module.glwe_trace_assign_tmp_bytes_default(&a_conv_infos, key_infos));
return lvl_0 + lvl_1;
}
module
.glwe_shift_tmp_bytes()
.max(module.glwe_automorphism_tmp_bytes(a_infos, a_infos, key_infos))
}
pub fn glwe_trace_galois_elements_default<BE, M>(module: &M) -> Vec<i64>
where
BE: Backend,
M: ModuleLogN + CyclotomicOrder,
{
trace_galois_elements(module.log_n(), module.cyclotomic_order())
}
pub fn glwe_trace_tmp_bytes_default<BE, M, R, A, K>(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
BE: Backend,
M: GLWETraceDefault<BE>
+ ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy<BE>
+ CyclotomicOrder
+ GLWENormalize<BE>,
R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos,
{
assert_eq!(module.n() as u32, res_infos.n());
assert_eq!(module.n() as u32, a_infos.n());
assert_eq!(module.n() as u32, key_infos.n());
let tmp_infos: GLWELayout = GLWELayout {
n: res_infos.n(),
base2k: key_infos.base2k(),
k: a_infos.max_k().max(res_infos.max_k()),
rank: res_infos.rank(),
};
let lvl_0: usize = GLWE::<Vec<u8>>::bytes_of_from_infos(&tmp_infos);
let lvl_1: usize = if a_infos.base2k() == key_infos.base2k() {
0
} else {
module.glwe_normalize_tmp_bytes()
};
let lvl_2: usize = module.glwe_trace_assign_tmp_bytes_default(&tmp_infos, key_infos);
let lvl_3: usize = if res_infos.base2k() == key_infos.base2k() {
0
} else {
GLWE::<Vec<u8>>::bytes_of_from_infos(res_infos) + module.glwe_normalize_tmp_bytes()
};
lvl_0 + lvl_1.max(lvl_2).max(lvl_3)
}
pub fn glwe_trace_default<BE, M, R, A, K, H>(
module: &M,
res: &mut R,
skip: usize,
a: &A,
keys: &H,
key_size: usize,
scratch: &mut ScratchArena<'_, BE>,
) where
BE: Backend,
M: GLWETraceDefault<BE>
+ ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy<BE>
+ CyclotomicOrder
+ GLWENormalize<BE>,
R: GLWEToBackendMut<BE> + GLWEInfos,
A: GLWEToBackendRef<BE> + GLWEInfos,
K: GGLWEPreparedToBackendRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
{
let atk_layout: &GGLWELayout = &keys.automorphism_key_infos();
assert!(
scratch.available() >= module.glwe_trace_tmp_bytes_default(res, a, atk_layout),
"scratch.available(): {} < GLWETrace::glwe_trace_tmp_bytes: {}",
scratch.available(),
module.glwe_trace_tmp_bytes_default(res, a, atk_layout)
);
let scratch_local = scratch.borrow();
let (mut tmp, scratch_1) = scratch_local.take_glwe_scratch(&GLWELayout {
n: res.n(),
base2k: atk_layout.base2k(),
k: a.max_k().max(res.max_k()),
rank: res.rank(),
});
let mut scratch_1 = scratch_1;
if a.base2k() == atk_layout.base2k() {
module.glwe_copy(&mut tmp, a);
} else {
scratch_1 = scratch_1.apply_mut(|scratch| {
module.glwe_normalize(&mut tmp, a, scratch);
});
}
{
scratch_1 = scratch_1.apply_mut(|scratch| {
trace_assign_internal::<M, K, H, _, BE>(module, &mut tmp, skip, keys, key_size, scratch);
});
}
if res.base2k() == atk_layout.base2k() {
module.glwe_copy(res, &tmp);
} else {
let (mut res_out, scratch_2) = scratch_1.take_glwe_scratch(&res.glwe_layout());
{
scratch_2.apply_mut(|scratch| {
module.glwe_normalize(&mut res_out, &tmp, scratch);
});
}
module.glwe_copy(res, &res_out);
}
}
pub fn glwe_trace_assign_default<BE, M, R, K, H>(
module: &M,
res: &mut R,
skip: usize,
keys: &H,
key_size: usize,
scratch: &mut ScratchArena<'_, BE>,
) where
BE: Backend,
M: GLWETraceDefault<BE>
+ ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy<BE>
+ CyclotomicOrder
+ GLWENormalize<BE>,
R: GLWEToBackendMut<BE> + GLWEInfos,
K: GGLWEPreparedToBackendRef<BE> + GetGaloisElement + GGLWEInfos,
H: GLWEAutomorphismKeyHelper<K, BE>,
{
trace_assign_internal::<M, K, H, _, BE>(module, res, skip, keys, key_size, scratch);
}
}