poulpy_hal/api/convolution.rs
1use crate::{
2 api::{
3 ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
4 VecZnxDftBytesOf,
5 },
6 layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero},
7};
8
9impl<BE: Backend> Convolution<BE> for Module<BE>
10where
11 Self: Sized
12 + ModuleN
13 + SvpPPolAlloc<BE>
14 + SvpApplyDftToDft<BE>
15 + SvpPrepare<BE>
16 + SvpPPolBytesOf
17 + VecZnxDftBytesOf
18 + VecZnxDftAddScaledInplace<BE>,
19 Scratch<BE>: ScratchTakeBasic,
20{
21}
22
23pub trait Convolution<BE: Backend>
24where
25 Self: Sized
26 + ModuleN
27 + SvpPPolAlloc<BE>
28 + SvpApplyDftToDft<BE>
29 + SvpPrepare<BE>
30 + SvpPPolBytesOf
31 + VecZnxDftBytesOf
32 + VecZnxDftAddScaledInplace<BE>,
33 Scratch<BE>: ScratchTakeBasic,
34{
35 fn convolution_tmp_bytes(&self, res_size: usize) -> usize {
36 self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size)
37 }
38
39 /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K
40 /// and scales the result by 2^{res_scale * K}
41 ///
42 /// # Example
43 /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
44 /// [a01, a11, a21, a31]
45 ///
46 /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
47 /// [b01, b11, b21, b31]
48 ///
49 /// If res_scale = 0:
50 /// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
51 /// [r01, r11, r21, r31]
52 /// [r02, r12, r22, r32]
53 /// [r03, r13, r23, r33]
54 /// [r04, r14, r24, r34]
55 ///
56 /// If res_scale = 1:
57 /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
58 /// [r02, r12, r22, r32]
59 /// [r03, r13, r23, r33]
60 /// [r04, r14, r24, r34]
61 /// [r05, r15, r25, r35]
62 ///
63 /// If res_scale = -1:
64 /// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
65 /// [ 0, 0, 0, 0]
66 /// [r01, r11, r21, r31]
67 /// [r02, r12, r22, r32]
68 /// [r03, r13, r23, r33]
69 ///
70 /// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension.
71 fn convolution<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>)
72 where
73 R: VecZnxDftToMut<BE>,
74 A: VecZnxToRef,
75 B: VecZnxDftToRef<BE>,
76 {
77 let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
78 let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
79 let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
80
81 assert!(res.cols() >= a.cols() + b.cols() - 1);
82
83 res.zero();
84
85 let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
86 let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size());
87
88 for a_col in 0..a.cols() {
89 for a_limb in 0..a.size() {
90 // Prepares the j-th limb of the i-th col of A
91 self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
92
93 for b_col in 0..b.cols() {
94 // Multiplies with the i-th col of B
95 self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
96
97 // Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K}
98 self.vec_znx_dft_add_scaled_inplace(
99 res,
100 a_col + b_col,
101 &res_tmp,
102 0,
103 -(1 + a_limb as i64) + res_scale,
104 );
105 }
106 }
107 }
108 }
109}