feanor_math/algorithms/convolution/mod.rs
1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use karatsuba::*;
5
6use crate::ring::*;
7use crate::seq::subvector::SubvectorView;
8use crate::seq::*;
9
10/// Contains an optimized implementation of Karatsuba's for computing convolutions
11pub mod karatsuba;
12
13/// Contains an implementation of computing convolutions using complex floating-point FFTs.
14pub mod fft;
15
16/// Contains an implementation of computing convolutions using NTTs, i.e. FFTs over
17/// a finite field that has suitable roots of unity.
18pub mod ntt;
19
20/// Contains an implementation of computing convolutions by considering them modulo
21/// various primes that are either smaller or allow for suitable roots of unity.
22pub mod rns;
23
24/// Trait for objects that can compute a convolution over some ring.
25///
26/// # Example
27/// ```rust
28/// # use std::cmp::{min, max};
29/// # use feanor_math::ring::*;
30/// # use feanor_math::primitive_int::*;
31/// # use feanor_math::seq::*;
32/// # use feanor_math::algorithms::convolution::*;
33/// struct NaiveConvolution;
34/// // we support all rings!
35/// impl<R: ?Sized + RingBase> ConvolutionAlgorithm<R> for NaiveConvolution {
36/// fn compute_convolution<
37/// S: RingStore<Type = R>,
38/// V1: VectorView<R::Element>,
39/// V2: VectorView<R::Element>,
40/// >(
41/// &self,
42/// lhs: V1,
43/// rhs: V2,
44/// dst: &mut [R::Element],
45/// ring: S,
46/// ) {
47/// for i in 0..(lhs.len() + rhs.len() - 1) {
48/// for j in max(0, i as isize - rhs.len() as isize + 1)
49/// ..min(lhs.len() as isize, i as isize + 1)
50/// {
51/// ring.add_assign(
52/// &mut dst[i],
53/// ring.mul_ref(lhs.at(j as usize), rhs.at(i - j as usize)),
54/// );
55/// }
56/// }
57/// }
58/// fn supports_ring<S: RingStore<Type = R>>(&self, _: S) -> bool
59/// where
60/// S: Copy,
61/// {
62/// true
63/// }
64/// }
65/// let lhs = [1, 2, 3, 4, 5];
66/// let rhs = [2, 3, 4, 5, 6];
67/// let mut expected = [0; 10];
68/// let mut actual = [0; 10];
69/// STANDARD_CONVOLUTION.compute_convolution(lhs, rhs, &mut expected, StaticRing::<i64>::RING);
70/// NaiveConvolution.compute_convolution(lhs, rhs, &mut actual, StaticRing::<i64>::RING);
71/// assert_eq!(expected, actual);
72/// ```
73pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
74 /// Additional data associated to a list of ring elements, which can be used to
75 /// compute a convolution where this list is one of the operands faster.
76 ///
77 /// For more details, see [`ConvolutionAlgorithm::prepare_convolution_operand()`].
78 /// Note that a `PreparedConvolutionOperand` can only be used for convolutions
79 /// with the same list of values it was created for.
80 #[stability::unstable(feature = "enable")]
81 type PreparedConvolutionOperand = ();
82
83 /// Elementwise adds the convolution of `lhs` and `rhs` to `dst`.
84 ///
85 /// In other words, computes `dst[i] += sum_j lhs[j] * rhs[i - j]` for all `i`, where
86 /// `j` runs through all positive integers for which `lhs[j]` and `rhs[i - j]` are defined,
87 /// i.e. not out-of-bounds.
88 ///
89 /// In particular, it is necessary that `dst.len() >= lhs.len() + rhs.len() - 1`. However,
90 /// to allow for more efficient implementations, it is instead required that
91 /// `dst.len() >= lhs.len() + rhs.len()`.
92 ///
93 /// # Panic
94 ///
95 /// Panics if `dst` is shorter than `lhs.len() + rhs.len() - 1`. May panic if `dst` is shorter
96 /// than `lhs.len() + rhs.len()`.
97 ///
98 /// TODO: On next breaking release, just take slice instead of [`VectorView`]s.
99 /// This might require the user to copy the data once, but so far most algorithms copy
100 /// it anyway, because this will make subsequent memory accesses more predictable and
101 /// better optimized.
102 ///
103 /// Maybe also consider taking the ring by `&RingBase`, since this would then allow
104 /// for dynamic dispatch.
105 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
106 &self,
107 lhs: V1,
108 rhs: V2,
109 dst: &mut [R::Element],
110 ring: S,
111 );
112
113 /// Returns whether this convolution algorithm supports computations of
114 /// the given ring.
115 ///
116 /// Note that most algorithms will support all rings of type `R`. However in some cases,
117 /// e.g. for finite fields, required data might only be precomputed for some moduli,
118 /// and thus only these will be supported.
119 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
120
121 /// Takes an input list of values and computes an opaque
122 /// [`ConvolutionAlgorithm::PreparedConvolutionOperand`], which can be used to compute
123 /// future convolutions with this list of values faster.
124 ///
125 /// Although the [`ConvolutionAlgorithm::PreparedConvolutionOperand`] does not have any explicit
126 /// reference to the list of values it was created for, passing it to
127 /// [`ConvolutionAlgorithm::compute_convolution_prepared()`] with another list of values
128 /// will give erroneous results.
129 ///
130 /// # Length-dependence when preparing a convolution
131 ///
132 /// For some algorithms, different data is required to speed up the convolution with an operand,
133 /// depending on the length of the other operand. For example, for FFT-based convolutions,
134 /// the prepared data would consist of the Fourier transform of the list of values,
135 /// zero-padded to a length that can store the complete result of the (future) convolution.
136 ///
137 /// To handle this, implementations can make use of the `length_hint`, which - if given - should
138 /// be an upper bound to the length of the output of any future convolution that uses the
139 /// given operand. Alternatively, implementations are encouraged to not compute any data
140 /// during [`ConvolutionAlgorithm::prepare_convolution_operand()`], but initialize an object
141 /// with interior mutability, and use it to cache data computed during
142 /// [`ConvolutionAlgorithm::compute_convolution_prepared()`].
143 ///
144 /// TODO: At next breaking release, remove the default implementation
145 ///
146 /// TODO: On next breaking release, just take slice instead of [`VectorView`]s.
147 /// This might require the user to copy the data once, but so far most algorithms copy
148 /// it anyway, because this will make subsequent memory accesses more predictable and
149 /// better optimized.
150 ///
151 /// # Example
152 ///
153 /// ```rust
154 /// # use feanor_math::ring::*;
155 /// # use feanor_math::algorithms::convolution::*;
156 /// # use feanor_math::algorithms::convolution::ntt::*;
157 /// # use feanor_math::rings::zn::*;
158 /// # use feanor_math::rings::zn::zn_64::*;
159 /// # use feanor_math::rings::finite::*;
160 /// let ring = Zn::new(65537);
161 /// let convolution = NTTConvolution::new(ring);
162 /// let lhs = ring.elements().take(10).collect::<Vec<_>>();
163 /// let rhs = ring.elements().take(10).collect::<Vec<_>>();
164 /// // "standard" use
165 /// let mut expected = (0..19).map(|_| ring.zero()).collect::<Vec<_>>();
166 /// convolution.compute_convolution(&lhs, &rhs, &mut expected, ring);
167 ///
168 /// // "prepared" variant
169 /// let lhs_prep = convolution.prepare_convolution_operand(&lhs, None, ring);
170 /// let rhs_prep = convolution.prepare_convolution_operand(&rhs, None, ring);
171 /// let mut actual = (0..19).map(|_| ring.zero()).collect::<Vec<_>>();
172 /// // this will now be faster than `convolution.compute_convolution()`
173 /// convolution.compute_convolution_prepared(
174 /// &lhs,
175 /// Some(&lhs_prep),
176 /// &rhs,
177 /// Some(&rhs_prep),
178 /// &mut actual,
179 /// ring,
180 /// );
181 /// println!(
182 /// "{:?}, {:?}",
183 /// actual.iter().map(|x| ring.format(x)).collect::<Vec<_>>(),
184 /// expected.iter().map(|x| ring.format(x)).collect::<Vec<_>>()
185 /// );
186 /// assert!(
187 /// expected
188 /// .iter()
189 /// .zip(actual.iter())
190 /// .all(|(l, r)| ring.eq_el(l, r))
191 /// );
192 /// ```
193 ///
194 /// TODO: On next breaking release, just take slice instead of [`VectorView`]s.
195 /// This might require the user to copy the data once, but so far most algorithms copy
196 /// it anyway, because this will make subsequent memory accesses more predictable and
197 /// better optimized.
198 ///
199 /// Maybe also consider taking the ring by `&RingBase`, since this would then allow
200 /// for dynamic dispatch.
201 #[stability::unstable(feature = "enable")]
202 fn prepare_convolution_operand<S, V>(
203 &self,
204 _val: V,
205 _length_hint: Option<usize>,
206 _ring: S,
207 ) -> Self::PreparedConvolutionOperand
208 where
209 S: RingStore<Type = R> + Copy,
210 V: VectorView<R::Element>,
211 {
212 struct ProduceUnitType;
213 trait ProduceValue<T> {
214 fn produce() -> T;
215 }
216 impl<T> ProduceValue<T> for ProduceUnitType {
217 default fn produce() -> T {
218 panic!(
219 "if you specialize ConvolutionAlgorithm::PreparedConvolutionOperand, you must also specialize ConvolutionAlgorithm::prepare_convolution_operand()"
220 )
221 }
222 }
223 impl ProduceValue<()> for ProduceUnitType {
224 fn produce() {}
225 }
226 return <ProduceUnitType as ProduceValue<Self::PreparedConvolutionOperand>>::produce();
227 }
228
229 /// Elementwise adds the convolution of `lhs` and `rhs` to `dst`. If provided, the given
230 /// prepared convolution operands are used for a faster computation.
231 ///
232 /// When called with `None` as both the prepared convolution operands, this is exactly
233 /// equivalent to [`ConvolutionAlgorithm::compute_convolution()`].
234 ///
235 /// TODO: On next breaking release, just take slice instead of [`VectorView`]s.
236 /// This might require the user to copy the data once, but so far most algorithms copy
237 /// it anyway, because this will make subsequent memory accesses more predictable and
238 /// better optimized.
239 #[stability::unstable(feature = "enable")]
240 fn compute_convolution_prepared<S, V1, V2>(
241 &self,
242 lhs: V1,
243 _lhs_prep: Option<&Self::PreparedConvolutionOperand>,
244 rhs: V2,
245 _rhs_prep: Option<&Self::PreparedConvolutionOperand>,
246 dst: &mut [R::Element],
247 ring: S,
248 ) where
249 S: RingStore<Type = R> + Copy,
250 V1: VectorView<R::Element>,
251 V2: VectorView<R::Element>,
252 {
253 self.compute_convolution(lhs, rhs, dst, ring)
254 }
255
256 /// Computes a convolution for each tuple in the given sequence, and sums the result of each
257 /// convolution to `dst`.
258 ///
259 /// In other words, this computes `dst[k] += sum_l sum_(i + j = k) values[l][i] * values[l][k]`.
260 /// It can be faster than calling [`ConvolutionAlgorithm::prepare_convolution_operand()`].
261 ///
262 /// TODO: On next breaking release, just take slice instead of [`VectorView`]s.
263 /// This might require the user to copy the data once, but so far most algorithms copy
264 /// it anyway, because this will make subsequent memory accesses more predictable and
265 /// better optimized.
266 #[stability::unstable(feature = "enable")]
267 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
268 where
269 S: RingStore<Type = R> + Copy,
270 I: ExactSizeIterator<
271 Item = (
272 V1,
273 Option<&'a Self::PreparedConvolutionOperand>,
274 V2,
275 Option<&'a Self::PreparedConvolutionOperand>,
276 ),
277 >,
278 V1: VectorView<R::Element>,
279 V2: VectorView<R::Element>,
280 Self: 'a,
281 R: 'a,
282 {
283 for (lhs, lhs_prep, rhs, rhs_prep) in values {
284 self.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring)
285 }
286 }
287}
288
289impl<'a, R, C> ConvolutionAlgorithm<R> for C
290where
291 R: ?Sized + RingBase,
292 C: Deref,
293 C::Target: ConvolutionAlgorithm<R>,
294{
295 type PreparedConvolutionOperand = <C::Target as ConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
296
297 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
298 &self,
299 lhs: V1,
300 rhs: V2,
301 dst: &mut [R::Element],
302 ring: S,
303 ) {
304 (**self).compute_convolution(lhs, rhs, dst, ring)
305 }
306
307 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool { (**self).supports_ring(ring) }
308
309 fn prepare_convolution_operand<S, V>(
310 &self,
311 val: V,
312 len_hint: Option<usize>,
313 ring: S,
314 ) -> Self::PreparedConvolutionOperand
315 where
316 S: RingStore<Type = R> + Copy,
317 V: VectorView<R::Element>,
318 {
319 (**self).prepare_convolution_operand(val, len_hint, ring)
320 }
321
322 fn compute_convolution_prepared<S, V1, V2>(
323 &self,
324 lhs: V1,
325 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
326 rhs: V2,
327 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
328 dst: &mut [R::Element],
329 ring: S,
330 ) where
331 S: RingStore<Type = R> + Copy,
332 V1: VectorView<R::Element>,
333 V2: VectorView<R::Element>,
334 {
335 (**self).compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
336 }
337
338 fn compute_convolution_sum<'b, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
339 where
340 S: RingStore<Type = R> + Copy,
341 I: ExactSizeIterator<
342 Item = (
343 V1,
344 Option<&'b Self::PreparedConvolutionOperand>,
345 V2,
346 Option<&'b Self::PreparedConvolutionOperand>,
347 ),
348 >,
349 V1: VectorView<R::Element>,
350 V2: VectorView<R::Element>,
351 Self: 'b,
352 R: 'b,
353 {
354 (**self).compute_convolution_sum(values, dst, ring);
355 }
356}
357
358/// Implementation of convolutions that uses Karatsuba's algorithm
359/// with a threshold defined by [`KaratsubaHint`].
360#[derive(Clone, Copy, Debug)]
361pub struct KaratsubaAlgorithm<A: Allocator = Global> {
362 allocator: A,
363}
364
365/// Good default algorithm for computing convolutions, using Karatsuba's algorithm
366/// with a threshold defined by [`KaratsubaHint`].
367pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
368
369impl<A: Allocator> KaratsubaAlgorithm<A> {
370 #[stability::unstable(feature = "enable")]
371 pub const fn new(allocator: A) -> Self { Self { allocator } }
372}
373
374impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
375 fn compute_convolution<
376 S: RingStore<Type = R>,
377 V1: VectorView<<R as RingBase>::Element>,
378 V2: VectorView<<R as RingBase>::Element>,
379 >(
380 &self,
381 lhs: V1,
382 rhs: V2,
383 dst: &mut [<R as RingBase>::Element],
384 ring: S,
385 ) {
386 karatsuba(
387 ring.get_ring().karatsuba_threshold(),
388 dst,
389 SubvectorView::new(&lhs),
390 SubvectorView::new(&rhs),
391 &ring,
392 &self.allocator,
393 )
394 }
395
396 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
397}
398
399/// Very simple schoolbook convolution algorithm.
400pub struct SchoolbookConvolution;
401
402impl<R: ?Sized + RingBase> ConvolutionAlgorithm<R> for SchoolbookConvolution {
403 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
404
405 fn compute_convolution<S, V1, V2>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S)
406 where
407 S: RingStore<Type = R> + Copy,
408 V1: VectorView<<R as RingBase>::Element>,
409 V2: VectorView<<R as RingBase>::Element>,
410 {
411 naive_assign_mul::<_, _, _, _, true>(dst, lhs, rhs, ring)
412 }
413}
414
415/// Trait to allow rings to customize the parameters with which [`KaratsubaAlgorithm`] will
416/// compute convolutions over the ring.
417#[stability::unstable(feature = "enable")]
418pub trait KaratsubaHint: RingBase {
419 /// Define a threshold from which on [`KaratsubaAlgorithm`] will use the Karatsuba algorithm.
420 ///
421 /// Concretely, when this returns `k`, [`KaratsubaAlgorithm`] will reduce the
422 /// convolution down to ones on slices of size `2^k`, and compute their convolution naively. The
423 /// default value is `0`, but if the considered rings have fast multiplication (compared to
424 /// addition), then setting it higher may result in a performance gain.
425 fn karatsuba_threshold(&self) -> usize;
426}
427
428impl<R: RingBase + ?Sized> KaratsubaHint for R {
429 default fn karatsuba_threshold(&self) -> usize { 0 }
430}
431
432#[cfg(test)]
433use test;
434
435#[cfg(test)]
436use crate::primitive_int::*;
437
438#[bench]
439fn bench_naive_mul(bencher: &mut test::Bencher) {
440 let a: Vec<i32> = (0..32).collect();
441 let b: Vec<i32> = (0..32).collect();
442 let mut c: Vec<i32> = (0..64).collect();
443 bencher.iter(|| {
444 c.clear();
445 c.resize(64, 0);
446 karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
447 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
448 assert_eq!(c[62], 31 * 31);
449 });
450}
451
452#[bench]
453fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
454 let a: Vec<i32> = (0..32).collect();
455 let b: Vec<i32> = (0..32).collect();
456 let mut c: Vec<i32> = (0..64).collect();
457 bencher.iter(|| {
458 c.clear();
459 c.resize(64, 0);
460 karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
461 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
462 assert_eq!(c[62], 31 * 31);
463 });
464}
465
466#[test]
467fn test_schoolbook_convolution() { generic_tests::test_convolution(SchoolbookConvolution, StaticRing::<i64>::RING, 1); }
468
469#[allow(missing_docs)]
470#[cfg(any(test, feature = "generic_tests"))]
471pub mod generic_tests {
472 use std::cmp::min;
473
474 use super::*;
475 use crate::homomorphism::*;
476
477 pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
478 where
479 C: ConvolutionAlgorithm<R::Type>,
480 R: RingStore,
481 {
482 for lhs_len in [2, 3, 4, 15] {
483 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
484 let lhs = (0..lhs_len)
485 .map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
486 .collect::<Vec<_>>();
487 let rhs = (0..rhs_len)
488 .map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
489 .collect::<Vec<_>>();
490 let expected = (0..(lhs_len + rhs_len))
491 .map(|i| {
492 if i < lhs_len + rhs_len {
493 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6
494 - (i - 1 - min(i, rhs_len - 1))
495 * (i - min(i, rhs_len - 1))
496 * (i + 2 * min(i, rhs_len - 1) + 1)
497 / 6
498 } else {
499 0
500 }
501 })
502 .map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
503 .collect::<Vec<_>>();
504
505 let mut actual = Vec::new();
506 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
507 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
508 for i in 0..(lhs_len + rhs_len) {
509 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
510 }
511
512 let expected = (0..(lhs_len + rhs_len))
513 .map(|i| {
514 if i < lhs_len + rhs_len {
515 i * i
516 + min(i, lhs_len - 1)
517 * (min(i, lhs_len - 1) + 1)
518 * (3 * i - 2 * min(i, lhs_len - 1) - 1)
519 / 6
520 - (i - 1 - min(i, rhs_len - 1))
521 * (i - min(i, rhs_len - 1))
522 * (i + 2 * min(i, rhs_len - 1) + 1)
523 / 6
524 } else {
525 0
526 }
527 })
528 .map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
529 .collect::<Vec<_>>();
530
531 let mut actual = Vec::new();
532 actual.extend(
533 (0..(lhs_len + rhs_len))
534 .map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))),
535 );
536 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
537 for i in 0..(lhs_len + rhs_len) {
538 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
539 }
540 }
541 }
542 test_prepared_convolution(convolution, ring, scale);
543 }
544
545 fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
546 where
547 C: ConvolutionAlgorithm<R::Type>,
548 R: RingStore,
549 {
550 for lhs_len in [2, 3, 4, 14, 15] {
551 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
552 let lhs = (0..lhs_len)
553 .map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
554 .collect::<Vec<_>>();
555 let rhs = (0..rhs_len)
556 .map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
557 .collect::<Vec<_>>();
558 let expected = (0..(lhs_len + rhs_len))
559 .map(|i| {
560 if i < lhs_len + rhs_len {
561 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6
562 - (i - 1 - min(i, rhs_len - 1))
563 * (i - min(i, rhs_len - 1))
564 * (i + 2 * min(i, rhs_len - 1) + 1)
565 / 6
566 } else {
567 0
568 }
569 })
570 .map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
571 .collect::<Vec<_>>();
572
573 let mut actual = Vec::new();
574 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
575 convolution.compute_convolution_prepared(
576 &lhs,
577 Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
578 &rhs,
579 Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
580 &mut actual,
581 &ring,
582 );
583 for i in 0..(lhs_len + rhs_len) {
584 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
585 }
586
587 let mut actual = Vec::new();
588 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
589 convolution.compute_convolution_prepared(
590 &lhs,
591 Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
592 &rhs,
593 None,
594 &mut actual,
595 &ring,
596 );
597 for i in 0..(lhs_len + rhs_len) {
598 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
599 }
600
601 let mut actual = Vec::new();
602 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
603 convolution.compute_convolution_prepared(
604 &lhs,
605 None,
606 &rhs,
607 Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
608 &mut actual,
609 &ring,
610 );
611 for i in 0..(lhs_len + rhs_len) {
612 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
613 }
614
615 let mut actual = Vec::new();
616 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
617 let data = [
618 (
619 &lhs[..],
620 Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
621 &rhs[..],
622 Some(convolution.prepare_convolution_operand(&rhs, None, &ring)),
623 ),
624 (&rhs[..], None, &lhs[..], None),
625 ];
626 convolution.compute_convolution_sum(
627 data.as_fn()
628 .map_fn(|(l, l_prep, r, r_prep): &(_, _, _, _)| (l, l_prep.as_ref(), r, r_prep.as_ref()))
629 .iter(),
630 &mut actual,
631 &ring,
632 );
633 for i in 0..(lhs_len + rhs_len) {
634 assert_el_eq!(
635 &ring,
636 &ring.add_ref(&expected[i as usize], &expected[i as usize]),
637 &actual[i as usize]
638 );
639 }
640
641 let mut actual = Vec::new();
642 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
643 let data = [
644 (
645 &lhs[..],
646 Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
647 &rhs[..],
648 None,
649 ),
650 (
651 &rhs[..],
652 None,
653 &lhs[..],
654 Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
655 ),
656 ];
657 convolution.compute_convolution_sum(
658 data.as_fn()
659 .map_fn(|(l, l_prep, r, r_prep)| (l, l_prep.as_ref(), r, r_prep.as_ref()))
660 .iter(),
661 &mut actual,
662 &ring,
663 );
664 for i in 0..(lhs_len + rhs_len) {
665 assert_el_eq!(
666 &ring,
667 &ring.add_ref(&expected[i as usize], &expected[i as usize]),
668 &actual[i as usize]
669 );
670 }
671 }
672 }
673 }
674}