poulpy_hal/reference/vec_znx/
split_ring.rs1use crate::{
2 layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
3 reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero},
4};
5
6pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize {
7 n * size_of::<i64>()
8}
9
10pub fn vec_znx_split_ring<R, A, ZNXARI>(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64])
11where
12 R: VecZnxToMut,
13 A: VecZnxToRef,
14 ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero,
15{
16 let a: VecZnx<&[u8]> = a.to_ref();
17 let a_size = a.size();
18
19 let (_n_in, _n_out) = (a.n(), res[0].to_mut().n());
20
21 #[cfg(debug_assertions)]
22 {
23 assert_eq!(tmp.len(), a.n());
24
25 assert!(
26 _n_out < _n_in,
27 "invalid a: output ring degree should be smaller"
28 );
29
30 res[1..].iter_mut().for_each(|bi| {
31 assert_eq!(
32 bi.to_mut().n(),
33 _n_out,
34 "invalid input a: all VecZnx must have the same degree"
35 )
36 });
37
38 assert!(_n_in.is_multiple_of(_n_out));
39 assert_eq!(res.len(), _n_in / _n_out);
40 }
41
42 res.iter_mut().enumerate().for_each(|(i, bi)| {
43 let mut bi: VecZnx<&mut [u8]> = bi.to_mut();
44
45 let min_size = bi.size().min(a_size);
46
47 if i == 0 {
48 for j in 0..min_size {
49 ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j));
50 }
51 } else {
52 for j in 0..min_size {
53 ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j));
54 ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp);
55 }
56 }
57
58 for j in min_size..bi.size() {
59 ZNXARI::znx_zero(bi.at_mut(res_col, j));
60 }
61 })
62}