Skip to main content

poulpy_hal/api/
convolution.rs

1use crate::layouts::{
2    Backend, CnvPVecL, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecR, CnvPVecRBackendMut, CnvPVecRBackendRef, ScratchArena,
3    VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDftBackendMut,
4};
5
6/// Allocates prepared convolution operands ([`CnvPVecL`], [`CnvPVecR`]).
7pub trait CnvPVecAlloc<BE: Backend> {
8    fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<BE::OwnedBuf, BE>;
9    fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<BE::OwnedBuf, BE>;
10}
11
12/// Returns the byte sizes for prepared convolution operands.
13pub trait CnvPVecBytesOf {
14    fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize;
15    fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize;
16}
17
18/// Bivariate convolution over `Z[X, Y] mod (X^N + 1)` where `Y = 2^{-K}`.
19///
20/// Provides methods to prepare left/right operands and apply the convolution.
21/// See method-level documentation for the mathematical formulation.
22pub trait Convolution<BE: Backend> {
23    /// Returns scratch bytes required for [`cnv_prepare_left`](Convolution::cnv_prepare_left).
24    fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
25    /// Prepares a coefficient-domain [`VecZnx`](crate::layouts::VecZnx) as the left
26    /// operand of a bivariate convolution.
27    fn cnv_prepare_left(
28        &self,
29        res: &mut CnvPVecLBackendMut<'_, BE>,
30        a: &VecZnxBackendRef<'_, BE>,
31        mask: i64,
32        scratch: &mut ScratchArena<'_, BE>,
33    );
34
35    /// Returns scratch bytes required for [`cnv_prepare_right`](Convolution::cnv_prepare_right).
36    fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
37    /// Prepares a coefficient-domain [`VecZnx`](crate::layouts::VecZnx) as the right
38    /// operand of a bivariate convolution.
39    fn cnv_prepare_right(
40        &self,
41        res: &mut CnvPVecRBackendMut<'_, BE>,
42        a: &VecZnxBackendRef<'_, BE>,
43        mask: i64,
44        scratch: &mut ScratchArena<'_, BE>,
45    );
46
47    /// Returns scratch bytes required for [`cnv_apply_dft`](Convolution::cnv_apply_dft).
48    fn cnv_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
49
50    /// Returns scratch bytes required for [`cnv_by_const_apply`](Convolution::cnv_by_const_apply).
51    fn cnv_by_const_apply_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
52
53    /// Evaluates a bivariate convolution over Z\[X, Y\] (x) Z\[Y\] mod (X^N + 1) where Y = 2^-K over the
54    /// selected columns and stores the result on the selected column, scaled by 2^{cnv_offset * Base2K}
55    ///
56    /// Behavior is identical to [Convolution::cnv_apply_dft] with `b` treated as a constant polynomial
57    /// in the X variable, for example:
58    ///```text
59    ///       1    X   X^2  X^3
60    /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
61    ///     Y [a01, a11, a21, a31]
62    ///
63    /// b = 1 [b0] = (b00 + b01 * 2^-K)
64    ///     Y [b0]
65    /// ```
66    /// This method is intended to be used for multiplications by constants that are greater than the base2k.
67    #[allow(clippy::too_many_arguments)]
68    fn cnv_by_const_apply(
69        &self,
70        cnv_offset: usize,
71        res: &mut VecZnxBigBackendMut<'_, BE>,
72        res_col: usize,
73        a: &VecZnxBackendRef<'_, BE>,
74        a_col: usize,
75        b: &VecZnxBackendRef<'_, BE>,
76        b_col: usize,
77        b_coeff: usize,
78        scratch: &mut ScratchArena<'_, BE>,
79    );
80
81    #[allow(clippy::too_many_arguments)]
82    /// Evaluates a bivariate convolution over Z\[X, Y\] (x) Z\[X, Y\] mod (X^N + 1) where Y = 2^-K over the
83    /// selected columns and stores the result on the selected column, scaled by 2^{cnv_offset * Base2K}
84    ///
85    /// # Example
86    ///```text
87    ///       1    X   X^2  X^3
88    /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
89    ///     Y [a01, a11, a21, a31]
90    ///
91    /// b = 1 [b00, b10, b20, b30] = (b00 + b01 * 2^-K) + (b10 + b11 * 2^-K) * X ...
92    ///     Y [b01, b11, b21, b31]
93    ///
94    /// If cnv_offset = 0:
95    ///
96    ///            1    X   X^2  X^3
97    /// res = 1  [r00, r10, r20, r30] = (r00 + r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K) + ... * X + ...
98    ///       Y  [r01, r11, r21, r31]
99    ///       Y^2[r02, r12, r22, r32]
100    ///       Y^3[r03, r13, r23, r33]
101    ///
102    /// If cnv_offset = 1:
103    ///
104    ///            1    X   X^2  X^3
105    /// res = 1  [r01, r11, r21, r31]  = (r01 + r02 * 2^-K + r03 * 2^-2K) + ... * X + ...
106    ///       Y  [r02, r12, r22, r32]
107    ///       Y^2[r03, r13, r23, r33]
108    ///       Y^3[  0,   0,   0 ,  0]
109    /// ```
110    /// If res.size() < a.size() + b.size() + k, result is truncated accordingly in the Y dimension.
111    fn cnv_apply_dft(
112        &self,
113        cnv_offset: usize,
114        res: &mut VecZnxDftBackendMut<'_, BE>,
115        res_col: usize,
116        a: &CnvPVecLBackendRef<'_, BE>,
117        a_col: usize,
118        b: &CnvPVecRBackendRef<'_, BE>,
119        b_col: usize,
120        scratch: &mut ScratchArena<'_, BE>,
121    );
122
123    /// Returns scratch bytes required for [`cnv_pairwise_apply_dft`](Convolution::cnv_pairwise_apply_dft).
124    fn cnv_pairwise_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
125
126    #[allow(clippy::too_many_arguments)]
127    /// Evaluates the bivariate pair-wise convolution res = (a\[i\] + a\[j\]) * (b\[i\] + b\[j\]).
128    /// If i == j then calls [Convolution::cnv_apply_dft], i.e. res = a\[i\] * b\[i\].
129    /// See [Convolution::cnv_apply_dft] for information about the bivariate convolution.
130    fn cnv_pairwise_apply_dft(
131        &self,
132        cnv_offset: usize,
133        res: &mut VecZnxDftBackendMut<'_, BE>,
134        res_col: usize,
135        a: &CnvPVecLBackendRef<'_, BE>,
136        b: &CnvPVecRBackendRef<'_, BE>,
137        i: usize,
138        j: usize,
139        scratch: &mut ScratchArena<'_, BE>,
140    );
141
142    /// Returns scratch bytes required for [`cnv_prepare_self`](Convolution::cnv_prepare_self).
143    fn cnv_prepare_self_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
144
145    /// Prepares both left and right convolution operands from the same input polynomial,
146    /// sharing the FFT/NTT computation. This is an optimization for self-convolution
147    /// (squaring) where both operands are the same polynomial.
148    fn cnv_prepare_self(
149        &self,
150        left: &mut CnvPVecLBackendMut<'_, BE>,
151        right: &mut CnvPVecRBackendMut<'_, BE>,
152        a: &VecZnxBackendRef<'_, BE>,
153        mask: i64,
154        scratch: &mut ScratchArena<'_, BE>,
155    );
156}