1use crate::{
2 api::{
3 ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
4 VecZnxDftBytesOf, VecZnxDftZero,
5 },
6 layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
7};
8
9impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
10where
11 Self: BivariateConvolution<BE>,
12 Scratch<BE>: ScratchTakeBasic,
13{
14}
15
16pub trait BivariateTensoring<BE: Backend>
17where
18 Self: BivariateConvolution<BE>,
19 Scratch<BE>: ScratchTakeBasic,
20{
21 fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
22 where
23 R: VecZnxDftToMut<BE>,
24 A: VecZnxToRef,
25 B: VecZnxDftToRef<BE>,
26 {
27 let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
28 let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
29 let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
30
31 let res_cols: usize = res.cols();
32 let a_cols: usize = a.cols();
33 let b_cols: usize = b.cols();
34
35 assert!(res_cols >= a_cols + b_cols - 1);
36
37 for res_col in 0..res_cols {
38 self.vec_znx_dft_zero(res, res_col);
39 }
40
41 for a_col in 0..a_cols {
42 for b_col in 0..b_cols {
43 self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
44 }
45 }
46 }
47}
48
49impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
50where
51 Self: Sized
52 + ModuleN
53 + SvpPPolAlloc<BE>
54 + SvpApplyDftToDft<BE>
55 + SvpPrepare<BE>
56 + SvpPPolBytesOf
57 + VecZnxDftBytesOf
58 + VecZnxDftAddScaledInplace<BE>
59 + VecZnxDftZero<BE>,
60 Scratch<BE>: ScratchTakeBasic,
61{
62}
63
64pub trait BivariateConvolution<BE: Backend>
65where
66 Self: Sized
67 + ModuleN
68 + SvpPPolAlloc<BE>
69 + SvpApplyDftToDft<BE>
70 + SvpPrepare<BE>
71 + SvpPPolBytesOf
72 + VecZnxDftBytesOf
73 + VecZnxDftAddScaledInplace<BE>
74 + VecZnxDftZero<BE>,
75 Scratch<BE>: ScratchTakeBasic,
76{
77 fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
78 self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
79 }
80
81 #[allow(clippy::too_many_arguments)]
82 fn bivariate_convolution_add<R, A, B>(
115 &self,
116 k: i64,
117 res: &mut R,
118 res_col: usize,
119 a: &A,
120 a_col: usize,
121 b: &B,
122 b_col: usize,
123 scratch: &mut Scratch<BE>,
124 ) where
125 R: VecZnxDftToMut<BE>,
126 A: VecZnxToRef,
127 B: VecZnxDftToRef<BE>,
128 {
129 let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
130 let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
131 let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
132
133 let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
134 let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size());
135
136 for a_limb in 0..a.size() {
137 self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
138 self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
139 self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
140 }
141 }
142
143 #[allow(clippy::too_many_arguments)]
144 fn bivariate_convolution<R, A, B>(
145 &self,
146 k: i64,
147 res: &mut R,
148 res_col: usize,
149 a: &A,
150 a_col: usize,
151 b: &B,
152 b_col: usize,
153 scratch: &mut Scratch<BE>,
154 ) where
155 R: VecZnxDftToMut<BE>,
156 A: VecZnxToRef,
157 B: VecZnxDftToRef<BE>,
158 {
159 self.vec_znx_dft_zero(res, res_col);
160 self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
161 }
162}