poulpy_hal/api/
convolution.rs

1use crate::{
2    api::{
3        ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
4        VecZnxDftBytesOf,
5    },
6    layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero},
7};
8
9impl<BE: Backend> Convolution<BE> for Module<BE>
10where
11    Self: Sized
12        + ModuleN
13        + SvpPPolAlloc<BE>
14        + SvpApplyDftToDft<BE>
15        + SvpPrepare<BE>
16        + SvpPPolBytesOf
17        + VecZnxDftBytesOf
18        + VecZnxDftAddScaledInplace<BE>,
19    Scratch<BE>: ScratchTakeBasic,
20{
21}
22
23pub trait Convolution<BE: Backend>
24where
25    Self: Sized
26        + ModuleN
27        + SvpPPolAlloc<BE>
28        + SvpApplyDftToDft<BE>
29        + SvpPrepare<BE>
30        + SvpPPolBytesOf
31        + VecZnxDftBytesOf
32        + VecZnxDftAddScaledInplace<BE>,
33    Scratch<BE>: ScratchTakeBasic,
34{
35    fn convolution_tmp_bytes(&self, res_size: usize) -> usize {
36        self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size)
37    }
38
39    /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K
40    /// and scales the result by 2^{res_scale * K}
41    ///
42    /// # Example
43    /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
44    ///     [a01, a11, a21, a31]
45    ///
46    /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
47    ///     [b01, b11, b21, b31]
48    ///
49    /// If res_scale = 0:
50    /// res = [  0,   0,   0,   0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
51    ///       [r01, r11, r21, r31]
52    ///       [r02, r12, r22, r32]
53    ///       [r03, r13, r23, r33]
54    ///       [r04, r14, r24, r34]
55    ///
56    /// If res_scale = 1:
57    /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
58    ///       [r02, r12, r22, r32]
59    ///       [r03, r13, r23, r33]
60    ///       [r04, r14, r24, r34]
61    ///       [r05, r15, r25, r35]
62    ///
63    /// If res_scale = -1:
64    /// res = [  0,   0,   0,   0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
65    ///       [  0,   0,   0,   0]
66    ///       [r01, r11, r21, r31]
67    ///       [r02, r12, r22, r32]
68    ///       [r03, r13, r23, r33]
69    ///
70    /// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension.
71    fn convolution<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>)
72    where
73        R: VecZnxDftToMut<BE>,
74        A: VecZnxToRef,
75        B: VecZnxDftToRef<BE>,
76    {
77        let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
78        let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
79        let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
80
81        assert!(res.cols() >= a.cols() + b.cols() - 1);
82
83        res.zero();
84
85        let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
86        let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size());
87
88        for a_col in 0..a.cols() {
89            for a_limb in 0..a.size() {
90                // Prepares the j-th limb of the i-th col of A
91                self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
92
93                for b_col in 0..b.cols() {
94                    // Multiplies with the i-th col of B
95                    self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
96
97                    // Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K}
98                    self.vec_znx_dft_add_scaled_inplace(
99                        res,
100                        a_col + b_col,
101                        &res_tmp,
102                        0,
103                        -(1 + a_limb as i64) + res_scale,
104                    );
105                }
106            }
107        }
108    }
109}