poulpy_hal/api/convolution.rs
1use crate::layouts::{
2 Backend, CnvPVecL, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecR, CnvPVecRBackendMut, CnvPVecRBackendRef, ScratchArena,
3 VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDftBackendMut,
4};
5
6/// Allocates prepared convolution operands ([`CnvPVecL`], [`CnvPVecR`]).
7pub trait CnvPVecAlloc<BE: Backend> {
8 fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<BE::OwnedBuf, BE>;
9 fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<BE::OwnedBuf, BE>;
10}
11
12/// Returns the byte sizes for prepared convolution operands.
13pub trait CnvPVecBytesOf {
14 fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize;
15 fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize;
16}
17
18/// Bivariate convolution over `Z[X, Y] mod (X^N + 1)` where `Y = 2^{-K}`.
19///
20/// Provides methods to prepare left/right operands and apply the convolution.
21/// See method-level documentation for the mathematical formulation.
22pub trait Convolution<BE: Backend> {
23 /// Returns scratch bytes required for [`cnv_prepare_left`](Convolution::cnv_prepare_left).
24 fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
25 /// Prepares a coefficient-domain [`VecZnx`](crate::layouts::VecZnx) as the left
26 /// operand of a bivariate convolution.
27 fn cnv_prepare_left(
28 &self,
29 res: &mut CnvPVecLBackendMut<'_, BE>,
30 a: &VecZnxBackendRef<'_, BE>,
31 mask: i64,
32 scratch: &mut ScratchArena<'_, BE>,
33 );
34
35 /// Returns scratch bytes required for [`cnv_prepare_right`](Convolution::cnv_prepare_right).
36 fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
37 /// Prepares a coefficient-domain [`VecZnx`](crate::layouts::VecZnx) as the right
38 /// operand of a bivariate convolution.
39 fn cnv_prepare_right(
40 &self,
41 res: &mut CnvPVecRBackendMut<'_, BE>,
42 a: &VecZnxBackendRef<'_, BE>,
43 mask: i64,
44 scratch: &mut ScratchArena<'_, BE>,
45 );
46
47 /// Returns scratch bytes required for [`cnv_apply_dft`](Convolution::cnv_apply_dft).
48 fn cnv_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
49
50 /// Returns scratch bytes required for [`cnv_by_const_apply`](Convolution::cnv_by_const_apply).
51 fn cnv_by_const_apply_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
52
53 /// Evaluates a bivariate convolution over Z\[X, Y\] (x) Z\[Y\] mod (X^N + 1) where Y = 2^-K over the
54 /// selected columns and stores the result on the selected column, scaled by 2^{cnv_offset * Base2K}
55 ///
56 /// Behavior is identical to [Convolution::cnv_apply_dft] with `b` treated as a constant polynomial
57 /// in the X variable, for example:
58 ///```text
59 /// 1 X X^2 X^3
60 /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
61 /// Y [a01, a11, a21, a31]
62 ///
63 /// b = 1 [b0] = (b00 + b01 * 2^-K)
64 /// Y [b0]
65 /// ```
66 /// This method is intended to be used for multiplications by constants that are greater than the base2k.
67 #[allow(clippy::too_many_arguments)]
68 fn cnv_by_const_apply(
69 &self,
70 cnv_offset: usize,
71 res: &mut VecZnxBigBackendMut<'_, BE>,
72 res_col: usize,
73 a: &VecZnxBackendRef<'_, BE>,
74 a_col: usize,
75 b: &VecZnxBackendRef<'_, BE>,
76 b_col: usize,
77 b_coeff: usize,
78 scratch: &mut ScratchArena<'_, BE>,
79 );
80
81 #[allow(clippy::too_many_arguments)]
82 /// Evaluates a bivariate convolution over Z\[X, Y\] (x) Z\[X, Y\] mod (X^N + 1) where Y = 2^-K over the
83 /// selected columns and stores the result on the selected column, scaled by 2^{cnv_offset * Base2K}
84 ///
85 /// # Example
86 ///```text
87 /// 1 X X^2 X^3
88 /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
89 /// Y [a01, a11, a21, a31]
90 ///
91 /// b = 1 [b00, b10, b20, b30] = (b00 + b01 * 2^-K) + (b10 + b11 * 2^-K) * X ...
92 /// Y [b01, b11, b21, b31]
93 ///
94 /// If cnv_offset = 0:
95 ///
96 /// 1 X X^2 X^3
97 /// res = 1 [r00, r10, r20, r30] = (r00 + r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K) + ... * X + ...
98 /// Y [r01, r11, r21, r31]
99 /// Y^2[r02, r12, r22, r32]
100 /// Y^3[r03, r13, r23, r33]
101 ///
102 /// If cnv_offset = 1:
103 ///
104 /// 1 X X^2 X^3
105 /// res = 1 [r01, r11, r21, r31] = (r01 + r02 * 2^-K + r03 * 2^-2K) + ... * X + ...
106 /// Y [r02, r12, r22, r32]
107 /// Y^2[r03, r13, r23, r33]
108 /// Y^3[ 0, 0, 0 , 0]
109 /// ```
110 /// If res.size() < a.size() + b.size() + k, result is truncated accordingly in the Y dimension.
111 fn cnv_apply_dft(
112 &self,
113 cnv_offset: usize,
114 res: &mut VecZnxDftBackendMut<'_, BE>,
115 res_col: usize,
116 a: &CnvPVecLBackendRef<'_, BE>,
117 a_col: usize,
118 b: &CnvPVecRBackendRef<'_, BE>,
119 b_col: usize,
120 scratch: &mut ScratchArena<'_, BE>,
121 );
122
123 /// Returns scratch bytes required for [`cnv_pairwise_apply_dft`](Convolution::cnv_pairwise_apply_dft).
124 fn cnv_pairwise_apply_dft_tmp_bytes(&self, cnv_offset: usize, res_size: usize, a_size: usize, b_size: usize) -> usize;
125
126 #[allow(clippy::too_many_arguments)]
127 /// Evaluates the bivariate pair-wise convolution res = (a\[i\] + a\[j\]) * (b\[i\] + b\[j\]).
128 /// If i == j then calls [Convolution::cnv_apply_dft], i.e. res = a\[i\] * b\[i\].
129 /// See [Convolution::cnv_apply_dft] for information about the bivariate convolution.
130 fn cnv_pairwise_apply_dft(
131 &self,
132 cnv_offset: usize,
133 res: &mut VecZnxDftBackendMut<'_, BE>,
134 res_col: usize,
135 a: &CnvPVecLBackendRef<'_, BE>,
136 b: &CnvPVecRBackendRef<'_, BE>,
137 i: usize,
138 j: usize,
139 scratch: &mut ScratchArena<'_, BE>,
140 );
141
142 /// Returns scratch bytes required for [`cnv_prepare_self`](Convolution::cnv_prepare_self).
143 fn cnv_prepare_self_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
144
145 /// Prepares both left and right convolution operands from the same input polynomial,
146 /// sharing the FFT/NTT computation. This is an optimization for self-convolution
147 /// (squaring) where both operands are the same polynomial.
148 fn cnv_prepare_self(
149 &self,
150 left: &mut CnvPVecLBackendMut<'_, BE>,
151 right: &mut CnvPVecRBackendMut<'_, BE>,
152 a: &VecZnxBackendRef<'_, BE>,
153 mask: i64,
154 scratch: &mut ScratchArena<'_, BE>,
155 );
156}