1use std::alloc::{Allocator, Global};
2
3use crate::cow::*;
4use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
5use crate::lazy::LazyVec;
6use crate::homomorphism::*;
7use crate::primitive_int::StaticRing;
8use crate::ring::*;
9use crate::rings::zn::*;
10use crate::integer::*;
11use crate::seq::VectorView;
12
13use super::ConvolutionAlgorithm;
14
15#[stability::unstable(feature = "enable")]
21pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
22 where R_main: ?Sized + ZnRing,
23 R_twiddle: ?Sized + ZnRing,
24 H: Homomorphism<R_twiddle, R_main> + Clone,
25 A: Allocator + Clone
26{
27 hom: H,
28 fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
29 allocator: A
30}
31
32#[stability::unstable(feature = "enable")]
36pub struct PreparedConvolutionOperand<R, A = Global>
37 where R: ?Sized + ZnRing,
38 A: Allocator + Clone
39{
40 significant_entries: usize,
41 ntt_data: Vec<R::Element, A>
42}
43
44impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
45 where R: RingStore + Clone,
46 R::Type: ZnRing
47{
48 #[stability::unstable(feature = "enable")]
56 pub fn new(ring: R) -> Self {
57 Self::new_with_hom(ring.into_identity(), Global)
58 }
59}
60
61impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
62 where R_main: ?Sized + ZnRing,
63 R_twiddle: ?Sized + ZnRing,
64 H: Homomorphism<R_twiddle, R_main> + Clone,
65 A: Allocator + Clone
66{
67 #[stability::unstable(feature = "enable")]
79 pub fn new_with_hom(hom: H, allocator: A) -> Self {
80 Self {
81 fft_algos: LazyVec::new(),
82 hom: hom,
83 allocator: allocator
84 }
85 }
86
87 #[stability::unstable(feature = "enable")]
91 pub fn ring(&self) -> RingRef<'_, R_main> {
92 RingRef::new(self.hom.codomain().get_ring())
93 }
94
95 fn get_ntt_table<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R_main, R_twiddle, H> {
96 self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n).expect("NTTConvolution was instantiated with parameters that don't support this length"))
97 }
98
99 fn get_ntt_data<'a, V>(
100 &self,
101 data: V,
102 data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
103 significant_entries: usize,
104 ) -> MyCow<'a, Vec<R_main::Element, A>>
105 where V: VectorView<R_main::Element>
106 {
107 assert!(data.len() <= significant_entries);
108 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
109
110 let compute_result = || {
111 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
112 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
113 result.resize_with(1 << log2_len, || self.ring().zero());
114 self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
115 return result;
116 };
117
118 return if let Some(data_prep) = data_prep {
119 assert!(data_prep.significant_entries >= significant_entries);
120 MyCow::Borrowed(&data_prep.ntt_data)
121 } else {
122 MyCow::Owned(compute_result())
123 }
124 }
125
126 fn compute_convolution_ntt<'a, V1, V2>(&self,
127 lhs: V1,
128 mut lhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
129 rhs: V2,
130 mut rhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
131 len: usize
132 ) -> MyCow<'a, Vec<R_main::Element, A>>
133 where V1: VectorView<R_main::Element>,
134 V2: VectorView<R_main::Element>
135 {
136 if lhs.len() == 0 || rhs.len() == 0 {
137 return MyCow::Owned(Vec::new_in(self.allocator.clone()));
138 }
139 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
140
141 if lhs_prep.is_some() && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
142 lhs_prep = None;
143 }
144 if rhs_prep.is_some() && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
145 rhs_prep = None;
146 }
147
148 let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
149 let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
150 if rhs_ntt.is_owned() {
151 std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
152 }
153 let lhs_ntt_data = lhs_ntt.to_mut_with(|data| {
154 let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
155 copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
156 copied_data
157 });
158
159 for i in 0..len {
160 self.ring().mul_assign_ref(&mut lhs_ntt_data[i], &rhs_ntt[i]);
161 }
162 return lhs_ntt;
163 }
164
165 fn prepare_convolution_impl<V>(
166 &self,
167 data: V,
168 len_hint: Option<usize>
169 ) -> PreparedConvolutionOperand<R_main, A>
170 where V: VectorView<R_main::Element>
171 {
172 let significant_entries = if let Some(out_len) = len_hint {
173 assert!(data.len() <= out_len);
174 out_len
175 } else {
176 2 * data.len()
177 };
178 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
179
180 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
181 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
182 result.resize_with(1 << log2_len, || self.ring().zero());
183 self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
184
185 return PreparedConvolutionOperand {
186 ntt_data: result,
187 significant_entries: significant_entries
188 };
189 }
190
191 fn compute_convolution_impl<V1, V2>(
192 &self,
193 lhs: V1,
194 lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
195 rhs: V2,
196 rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
197 dst: &mut [R_main::Element]
198 )
199 where V1: VectorView<R_main::Element>,
200 V2: VectorView<R_main::Element>
201 {
202 assert!(lhs.len() + rhs.len() - 1 <= dst.len());
203 let len = lhs.len() + rhs.len() - 1;
204 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
205
206 let mut lhs_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
207 let lhs_ntt = lhs_ntt.to_mut_with(|_| unreachable!());
208 self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut lhs_ntt[..], len);
209 for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
210 self.ring().add_assign(&mut dst[i], x);
211 }
212 }
213
214 fn compute_convolution_sum_impl<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
215 where S: RingStore<Type = R_main> + Copy,
216 I: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R_main, A>>, V2, Option<&'a PreparedConvolutionOperand<R_main, A>>)>,
217 V1: VectorView<R_main::Element>,
218 V2: VectorView<R_main::Element>,
219 Self: 'a,
220 R_main: 'a
221 {
222 let len = dst.len();
223 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
224
225 let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
226 buffer.resize_with(1 << log2_len, || ring.zero());
227
228 for (lhs, lhs_prep, rhs, rhs_prep) in values {
229 assert!(lhs.len() + rhs.len() - 1 <= len);
230
231 let res_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
232 for i in 0..len {
233 self.ring().add_assign_ref(&mut buffer[i], &res_ntt[i]);
234 }
235 }
236 self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut buffer, len);
237 for (i, x) in buffer.drain(..).enumerate().take(len) {
238 self.ring().add_assign(&mut dst[i], x);
239 }
240 }
241}
242
243impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
244 where R_main: ?Sized + ZnRing,
245 R_twiddle: ?Sized + ZnRing,
246 H: Homomorphism<R_twiddle, R_main> + Clone,
247 A: Allocator + Clone
248{
249 type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
250
251 fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
252 ring.get_ring() == self.ring().get_ring()
253 }
254
255 fn compute_convolution<S: RingStore<Type = R_main> + Copy, V1: VectorView<<R_main as RingBase>::Element>, V2: VectorView<<R_main as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R_main::Element], ring: S) {
256 assert!(self.supports_ring(ring));
257 self.compute_convolution_impl(
258 lhs,
259 None,
260 rhs,
261 None,
262 dst
263 )
264 }
265
266 fn prepare_convolution_operand<S, V>(&self, val: V, length_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
267 where S: RingStore<Type = R_main> + Copy, V: VectorView<R_main::Element>
268 {
269 assert!(self.supports_ring(ring));
270 self.prepare_convolution_impl(
271 val,
272 length_hint
273 )
274 }
275
276 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R_main::Element], ring: S)
277 where S: RingStore<Type = R_main> + Copy, V1: VectorView<R_main::Element>, V2: VectorView<R_main::Element>
278 {
279 assert!(self.supports_ring(ring));
280 self.compute_convolution_impl(
281 lhs,
282 lhs_prep,
283 rhs,
284 rhs_prep,
285 dst
286 )
287 }
288
289 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
290 where S: RingStore<Type = R_main> + Copy,
291 I: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
292 V1: VectorView<R_main::Element>,
293 V2: VectorView<R_main::Element>,
294 Self: 'a,
295 R_main: 'a
296 {
297 assert!(self.supports_ring(ring));
298 self.compute_convolution_sum_impl(
299 values,
300 dst,
301 ring
302 )
303 }
304}
305
306#[cfg(test)]
307use test::Bencher;
308#[cfg(test)]
309use crate::algorithms::convolution::STANDARD_CONVOLUTION;
310#[cfg(test)]
311use crate::rings::zn::zn_64::{Zn, ZnBase, ZnEl};
312
313#[test]
314fn test_convolution() {
315 let ring = zn_64::Zn::new(65537);
316 let convolution = NTTConvolution::new(ring);
317 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
318}
319
320#[cfg(test)]
321fn run_benchmark<F>(ring: Zn, bencher: &mut Bencher, mut f: F)
322 where F: for<'a> FnMut(&mut dyn ExactSizeIterator<Item = (Vec<ZnEl>, Option<&'a PreparedConvolutionOperand<ZnBase>>, Vec<ZnEl>, Option<&'a PreparedConvolutionOperand<ZnBase>>)>, &mut [ZnEl], Zn)
323{
324 let mut expected = (0..512).map(|_| ring.zero()).collect::<Vec<_>>();
325 let value = (0..256).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
326 STANDARD_CONVOLUTION.compute_convolution(
327 &value,
328 &value,
329 &mut expected,
330 ring
331 );
332
333 let mut i = 1;
334 let mut actual = Vec::with_capacity(511);
335 let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
336 bencher.iter(|| {
337 actual.clear();
338 actual.resize_with(511, || ring.zero());
339 f(
340 &mut (0..256).map(|j| (
341 (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
342 None,
343 (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
344 None
345 )),
346 &mut actual,
347 ring
348 );
349 let factor = hom.map(i * i * 128 * 511 * 85);
350 for (l, r) in expected.iter().zip(actual.iter()) {
351 assert_el_eq!(ring, ring.mul_ref(l, &factor), r);
352 }
353 i += 1;
354 });
355}
356
357#[bench]
358fn bench_convolution_sum(bencher: &mut Bencher) {
359 let ring = zn_64::Zn::new(65537);
360 let convolution = NTTConvolution::new(ring);
361
362 run_benchmark(ring, bencher, |values, dst, ring| convolution.compute_convolution_sum_impl(values, dst, ring));
363}
364
365#[bench]
366fn bench_convolution_sum_default(bencher: &mut Bencher) {
367 let ring = zn_64::Zn::new(65537);
368 let convolution = NTTConvolution::new(ring);
369
370 run_benchmark(ring, bencher, |values, dst, ring| {
371 for (lhs, lhs_prep, rhs, rhs_prep) in values {
372 convolution.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
373 }
374 });
375}