poulpy_hal/reference/vec_znx/
merge_rings.rs

1use crate::{
2    layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
3    reference::{
4        vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
5        znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
6    },
7};
8
9pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
10    n * size_of::<i64>()
11}
12
13pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
14where
15    R: VecZnxToMut,
16    A: VecZnxToRef,
17    ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
18{
19    let mut res: VecZnx<&mut [u8]> = res.to_mut();
20
21    let (_n_out, _n_in) = (res.n(), a[0].to_ref().n());
22
23    #[cfg(debug_assertions)]
24    {
25        assert_eq!(tmp.len(), res.n());
26
27        debug_assert!(
28            _n_out > _n_in,
29            "invalid a: output ring degree should be greater"
30        );
31        a[1..].iter().for_each(|ai| {
32            debug_assert_eq!(
33                ai.to_ref().n(),
34                _n_in,
35                "invalid input a: all VecZnx must have the same degree"
36            )
37        });
38
39        assert!(_n_out.is_multiple_of(_n_in));
40        assert_eq!(a.len(), _n_out / _n_in);
41    }
42
43    a.iter().for_each(|ai| {
44        vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
45        vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
46    });
47
48    vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
49}