poulpy_hal/api/
convolution.rs

1use crate::{
2    api::{
3        ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
4        VecZnxDftBytesOf, VecZnxDftZero,
5    },
6    layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
7};
8
9impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
10where
11    Self: BivariateConvolution<BE>,
12    Scratch<BE>: ScratchTakeBasic,
13{
14}
15
16pub trait BivariateTensoring<BE: Backend>
17where
18    Self: BivariateConvolution<BE>,
19    Scratch<BE>: ScratchTakeBasic,
20{
21    fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
22    where
23        R: VecZnxDftToMut<BE>,
24        A: VecZnxToRef,
25        B: VecZnxDftToRef<BE>,
26    {
27        let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
28        let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
29        let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
30
31        let res_cols: usize = res.cols();
32        let a_cols: usize = a.cols();
33        let b_cols: usize = b.cols();
34
35        assert!(res_cols >= a_cols + b_cols - 1);
36
37        for res_col in 0..res_cols {
38            self.vec_znx_dft_zero(res, res_col);
39        }
40
41        for a_col in 0..a_cols {
42            for b_col in 0..b_cols {
43                self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
44            }
45        }
46    }
47}
48
49impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
50where
51    Self: Sized
52        + ModuleN
53        + SvpPPolAlloc<BE>
54        + SvpApplyDftToDft<BE>
55        + SvpPrepare<BE>
56        + SvpPPolBytesOf
57        + VecZnxDftBytesOf
58        + VecZnxDftAddScaledInplace<BE>
59        + VecZnxDftZero<BE>,
60    Scratch<BE>: ScratchTakeBasic,
61{
62}
63
64pub trait BivariateConvolution<BE: Backend>
65where
66    Self: Sized
67        + ModuleN
68        + SvpPPolAlloc<BE>
69        + SvpApplyDftToDft<BE>
70        + SvpPrepare<BE>
71        + SvpPPolBytesOf
72        + VecZnxDftBytesOf
73        + VecZnxDftAddScaledInplace<BE>
74        + VecZnxDftZero<BE>,
75    Scratch<BE>: ScratchTakeBasic,
76{
77    fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
78        self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
79    }
80
81    #[allow(clippy::too_many_arguments)]
82    /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
83    /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
84    ///
85    /// # Example
86    /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
87    ///     [a01, a11, a21, a31]
88    ///
89    /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
90    ///     [b01, b11, b21, b31]
91    ///
92    /// If k = 0:
93    /// res = [  0,   0,   0,   0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
94    ///       [r01, r11, r21, r31]
95    ///       [r02, r12, r22, r32]
96    ///       [r03, r13, r23, r33]
97    ///       [r04, r14, r24, r34]
98    ///
99    /// If k = 1:
100    /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
101    ///       [r02, r12, r22, r32]
102    ///       [r03, r13, r23, r33]
103    ///       [r04, r14, r24, r34]
104    ///       [r05, r15, r25, r35]
105    ///
106    /// If k = -1:
107    /// res = [  0,   0,   0,   0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
108    ///       [  0,   0,   0,   0]
109    ///       [r01, r11, r21, r31]
110    ///       [r02, r12, r22, r32]
111    ///       [r03, r13, r23, r33]
112    ///
113    /// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
114    fn bivariate_convolution_add<R, A, B>(
115        &self,
116        k: i64,
117        res: &mut R,
118        res_col: usize,
119        a: &A,
120        a_col: usize,
121        b: &B,
122        b_col: usize,
123        scratch: &mut Scratch<BE>,
124    ) where
125        R: VecZnxDftToMut<BE>,
126        A: VecZnxToRef,
127        B: VecZnxDftToRef<BE>,
128    {
129        let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
130        let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
131        let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
132
133        let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
134        let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size());
135
136        for a_limb in 0..a.size() {
137            self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
138            self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
139            self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
140        }
141    }
142
143    #[allow(clippy::too_many_arguments)]
144    fn bivariate_convolution<R, A, B>(
145        &self,
146        k: i64,
147        res: &mut R,
148        res_col: usize,
149        a: &A,
150        a_col: usize,
151        b: &B,
152        b_col: usize,
153        scratch: &mut Scratch<BE>,
154    ) where
155        R: VecZnxDftToMut<BE>,
156        A: VecZnxToRef,
157        B: VecZnxDftToRef<BE>,
158    {
159        self.vec_znx_dft_zero(res, res_col);
160        self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
161    }
162}