poulpy_hal/reference/vec_znx/
merge_rings.rs1use crate::{
2 layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
3 reference::{
4 vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
5 znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
6 },
7};
8
9pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
10 n * size_of::<i64>()
11}
12
13pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
14where
15 R: VecZnxToMut,
16 A: VecZnxToRef,
17 ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
18{
19 let mut res: VecZnx<&mut [u8]> = res.to_mut();
20
21 let (_n_out, _n_in) = (res.n(), a[0].to_ref().n());
22
23 #[cfg(debug_assertions)]
24 {
25 assert_eq!(tmp.len(), res.n());
26
27 debug_assert!(
28 _n_out > _n_in,
29 "invalid a: output ring degree should be greater"
30 );
31 a[1..].iter().for_each(|ai| {
32 debug_assert_eq!(
33 ai.to_ref().n(),
34 _n_in,
35 "invalid input a: all VecZnx must have the same degree"
36 )
37 });
38
39 assert!(_n_out.is_multiple_of(_n_in));
40 assert_eq!(a.len(), _n_out / _n_in);
41 }
42
43 a.iter().for_each(|ai| {
44 vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
45 vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
46 });
47
48 vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
49}