1use std::alloc::{Allocator, Global};
2
3use super::ConvolutionAlgorithm;
4use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
5use crate::cow::*;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::*;
12use crate::seq::VectorView;
13
14#[stability::unstable(feature = "enable")]
18pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
19where
20 R_main: ?Sized + ZnRing,
21 R_twiddle: ?Sized + ZnRing,
22 H: Homomorphism<R_twiddle, R_main> + Clone,
23 A: Allocator + Clone,
24{
25 hom: H,
26 fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
27 allocator: A,
28}
29
30#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33where
34 R: ?Sized + ZnRing,
35 A: Allocator + Clone,
36{
37 significant_entries: usize,
38 ntt_data: Vec<R::Element, A>,
39}
40
41impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
42where
43 R: RingStore + Clone,
44 R::Type: ZnRing,
45{
46 #[stability::unstable(feature = "enable")]
52 pub fn new(ring: R) -> Self { Self::new_with_hom(ring.into_identity(), Global) }
53}
54
55impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
56where
57 R_main: ?Sized + ZnRing,
58 R_twiddle: ?Sized + ZnRing,
59 H: Homomorphism<R_twiddle, R_main> + Clone,
60 A: Allocator + Clone,
61{
62 #[stability::unstable(feature = "enable")]
72 pub fn new_with_hom(hom: H, allocator: A) -> Self {
73 Self {
74 fft_algos: LazyVec::new(),
75 hom,
76 allocator,
77 }
78 }
79
80 #[stability::unstable(feature = "enable")]
82 pub fn ring(&self) -> RingRef<'_, R_main> { RingRef::new(self.hom.codomain().get_ring()) }
83
84 fn get_ntt_table(&self, log2_n: usize) -> &CooleyTuckeyFFT<R_main, R_twiddle, H> {
85 self.fft_algos.get_or_init(log2_n, || {
86 CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n)
87 .expect("NTTConvolution was instantiated with parameters that don't support this length")
88 })
89 }
90
91 fn get_ntt_data<'a, V>(
92 &self,
93 data: V,
94 data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
95 significant_entries: usize,
96 ) -> MyCow<'a, Vec<R_main::Element, A>>
97 where
98 V: VectorView<R_main::Element>,
99 {
100 assert!(data.len() <= significant_entries);
101 let log2_len = StaticRing::<i64>::RING
102 .abs_log2_ceil(&significant_entries.try_into().unwrap())
103 .unwrap();
104
105 let compute_result = || {
106 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
107 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
108 result.resize_with(1 << log2_len, || self.ring().zero());
109 self.get_ntt_table(log2_len)
110 .unordered_truncated_fft(&mut result, significant_entries);
111 return result;
112 };
113
114 return if let Some(data_prep) = data_prep {
115 assert!(data_prep.significant_entries >= significant_entries);
116 MyCow::Borrowed(&data_prep.ntt_data)
117 } else {
118 MyCow::Owned(compute_result())
119 };
120 }
121
122 fn compute_convolution_ntt<'a, V1, V2>(
123 &self,
124 lhs: V1,
125 mut lhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
126 rhs: V2,
127 mut rhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
128 len: usize,
129 ) -> MyCow<'a, Vec<R_main::Element, A>>
130 where
131 V1: VectorView<R_main::Element>,
132 V2: VectorView<R_main::Element>,
133 {
134 if lhs.len() == 0 || rhs.len() == 0 {
135 return MyCow::Owned(Vec::new_in(self.allocator.clone()));
136 }
137 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
138
139 if lhs_prep.is_some()
140 && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len)
141 {
142 lhs_prep = None;
143 }
144 if rhs_prep.is_some()
145 && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len)
146 {
147 rhs_prep = None;
148 }
149
150 let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
151 let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
152 if rhs_ntt.is_owned() {
153 std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
154 }
155 let lhs_ntt_data = lhs_ntt.to_mut_with(|data| {
156 let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
157 copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
158 copied_data
159 });
160
161 for i in 0..len {
162 self.ring().mul_assign_ref(&mut lhs_ntt_data[i], &rhs_ntt[i]);
163 }
164 return lhs_ntt;
165 }
166
167 fn prepare_convolution_impl<V>(&self, data: V, len_hint: Option<usize>) -> PreparedConvolutionOperand<R_main, A>
168 where
169 V: VectorView<R_main::Element>,
170 {
171 let significant_entries = if let Some(out_len) = len_hint {
172 assert!(data.len() <= out_len);
173 out_len
174 } else {
175 2 * data.len()
176 };
177 let log2_len = StaticRing::<i64>::RING
178 .abs_log2_ceil(&significant_entries.try_into().unwrap())
179 .unwrap();
180
181 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
182 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
183 result.resize_with(1 << log2_len, || self.ring().zero());
184 self.get_ntt_table(log2_len)
185 .unordered_truncated_fft(&mut result, significant_entries);
186
187 return PreparedConvolutionOperand {
188 ntt_data: result,
189 significant_entries,
190 };
191 }
192
193 fn compute_convolution_impl<V1, V2>(
194 &self,
195 lhs: V1,
196 lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
197 rhs: V2,
198 rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
199 dst: &mut [R_main::Element],
200 ) where
201 V1: VectorView<R_main::Element>,
202 V2: VectorView<R_main::Element>,
203 {
204 assert!(lhs.len() + rhs.len() - 1 <= dst.len());
205 let len = lhs.len() + rhs.len() - 1;
206 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
207
208 let mut lhs_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
209 let lhs_ntt = lhs_ntt.to_mut_with(|_| unreachable!());
210 self.get_ntt_table(log2_len)
211 .unordered_truncated_fft_inv(&mut lhs_ntt[..], len);
212 for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
213 self.ring().add_assign(&mut dst[i], x);
214 }
215 }
216
217 fn compute_convolution_sum_impl<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
218 where
219 S: RingStore<Type = R_main> + Copy,
220 I: ExactSizeIterator<
221 Item = (
222 V1,
223 Option<&'a PreparedConvolutionOperand<R_main, A>>,
224 V2,
225 Option<&'a PreparedConvolutionOperand<R_main, A>>,
226 ),
227 >,
228 V1: VectorView<R_main::Element>,
229 V2: VectorView<R_main::Element>,
230 Self: 'a,
231 R_main: 'a,
232 {
233 let len = dst.len();
234 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
235
236 let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
237 buffer.resize_with(1 << log2_len, || ring.zero());
238
239 for (lhs, lhs_prep, rhs, rhs_prep) in values {
240 assert!(lhs.len() + rhs.len() - 1 <= len);
241
242 let res_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
243 for i in 0..len {
244 self.ring().add_assign_ref(&mut buffer[i], &res_ntt[i]);
245 }
246 }
247 self.get_ntt_table(log2_len)
248 .unordered_truncated_fft_inv(&mut buffer, len);
249 for (i, x) in buffer.drain(..).enumerate().take(len) {
250 self.ring().add_assign(&mut dst[i], x);
251 }
252 }
253}
254
255impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
256where
257 R_main: ?Sized + ZnRing,
258 R_twiddle: ?Sized + ZnRing,
259 H: Homomorphism<R_twiddle, R_main> + Clone,
260 A: Allocator + Clone,
261{
262 type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
263
264 fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
265 ring.get_ring() == self.ring().get_ring()
266 }
267
268 fn compute_convolution<
269 S: RingStore<Type = R_main> + Copy,
270 V1: VectorView<<R_main as RingBase>::Element>,
271 V2: VectorView<<R_main as RingBase>::Element>,
272 >(
273 &self,
274 lhs: V1,
275 rhs: V2,
276 dst: &mut [R_main::Element],
277 ring: S,
278 ) {
279 assert!(self.supports_ring(ring));
280 self.compute_convolution_impl(lhs, None, rhs, None, dst)
281 }
282
283 fn prepare_convolution_operand<S, V>(
284 &self,
285 val: V,
286 length_hint: Option<usize>,
287 ring: S,
288 ) -> Self::PreparedConvolutionOperand
289 where
290 S: RingStore<Type = R_main> + Copy,
291 V: VectorView<R_main::Element>,
292 {
293 assert!(self.supports_ring(ring));
294 self.prepare_convolution_impl(val, length_hint)
295 }
296
297 fn compute_convolution_prepared<S, V1, V2>(
298 &self,
299 lhs: V1,
300 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
301 rhs: V2,
302 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
303 dst: &mut [R_main::Element],
304 ring: S,
305 ) where
306 S: RingStore<Type = R_main> + Copy,
307 V1: VectorView<R_main::Element>,
308 V2: VectorView<R_main::Element>,
309 {
310 assert!(self.supports_ring(ring));
311 self.compute_convolution_impl(lhs, lhs_prep, rhs, rhs_prep, dst)
312 }
313
314 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
315 where
316 S: RingStore<Type = R_main> + Copy,
317 I: ExactSizeIterator<
318 Item = (
319 V1,
320 Option<&'a Self::PreparedConvolutionOperand>,
321 V2,
322 Option<&'a Self::PreparedConvolutionOperand>,
323 ),
324 >,
325 V1: VectorView<R_main::Element>,
326 V2: VectorView<R_main::Element>,
327 Self: 'a,
328 R_main: 'a,
329 {
330 assert!(self.supports_ring(ring));
331 self.compute_convolution_sum_impl(values, dst, ring)
332 }
333}
334
335#[cfg(test)]
336use test::Bencher;
337
338#[cfg(test)]
339use crate::algorithms::convolution::STANDARD_CONVOLUTION;
340#[cfg(test)]
341use crate::rings::zn::zn_64::{Zn, ZnBase, ZnEl};
342
343#[test]
344fn test_convolution() {
345 let ring = zn_64::Zn::new(65537);
346 let convolution = NTTConvolution::new(ring);
347 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
348}
349
350#[cfg(test)]
351fn run_benchmark<F>(ring: Zn, bencher: &mut Bencher, mut f: F)
352where
353 F: for<'a> FnMut(
354 &mut dyn ExactSizeIterator<
355 Item = (
356 Vec<ZnEl>,
357 Option<&'a PreparedConvolutionOperand<ZnBase>>,
358 Vec<ZnEl>,
359 Option<&'a PreparedConvolutionOperand<ZnBase>>,
360 ),
361 >,
362 &mut [ZnEl],
363 Zn,
364 ),
365{
366 let mut expected = (0..512).map(|_| ring.zero()).collect::<Vec<_>>();
367 let value = (0..256).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
368 STANDARD_CONVOLUTION.compute_convolution(&value, &value, &mut expected, ring);
369
370 let mut i = 1;
371 let mut actual = Vec::with_capacity(511);
372 let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
373 bencher.iter(|| {
374 actual.clear();
375 actual.resize_with(511, || ring.zero());
376 f(
377 &mut (0..256).map(|j| {
378 (
379 (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
380 None,
381 (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
382 None,
383 )
384 }),
385 &mut actual,
386 ring,
387 );
388 let factor = hom.map(i * i * 128 * 511 * 85);
389 for (l, r) in expected.iter().zip(actual.iter()) {
390 assert_el_eq!(ring, ring.mul_ref(l, &factor), r);
391 }
392 i += 1;
393 });
394}
395
396#[bench]
397fn bench_convolution_sum(bencher: &mut Bencher) {
398 let ring = zn_64::Zn::new(65537);
399 let convolution = NTTConvolution::new(ring);
400
401 run_benchmark(ring, bencher, |values, dst, ring| {
402 convolution.compute_convolution_sum_impl(values, dst, ring)
403 });
404}
405
406#[bench]
407fn bench_convolution_sum_default(bencher: &mut Bencher) {
408 let ring = zn_64::Zn::new(65537);
409 let convolution = NTTConvolution::new(ring);
410
411 run_benchmark(ring, bencher, |values, dst, ring| {
412 for (lhs, lhs_prep, rhs, rhs_prep) in values {
413 convolution.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
414 }
415 });
416}