1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use crate::ring::*;
5use crate::seq::subvector::SubvectorView;
6use crate::seq::*;
7
8use karatsuba::*;
9
10pub mod karatsuba;
14
15pub mod fft;
19
20pub mod ntt;
25
26pub mod rns;
31
32pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
66
67 #[stability::unstable(feature = "enable")]
76 type PreparedConvolutionOperand = ();
77
78 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
95
96 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
105
106 #[stability::unstable(feature = "enable")]
157 fn prepare_convolution_operand<S, V>(&self, _val: V, _length_hint: Option<usize>, _ring: S) -> Self::PreparedConvolutionOperand
158 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
159 {
160 struct ProduceUnitType;
161 trait ProduceValue<T> {
162 fn produce() -> T;
163 }
164 impl<T> ProduceValue<T> for ProduceUnitType {
165 default fn produce() -> T {
166 panic!("if you specialize ConvolutionAlgorithm::PreparedConvolutionOperand, you must also specialize ConvolutionAlgorithm::prepare_convolution_operand()")
167 }
168 }
169 impl ProduceValue<()> for ProduceUnitType {
170 fn produce() -> () {}
171 }
172 return <ProduceUnitType as ProduceValue<Self::PreparedConvolutionOperand>>::produce();
173 }
174
175 #[stability::unstable(feature = "enable")]
183 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, _lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, _rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
184 where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
185 {
186 self.compute_convolution(lhs, rhs, dst, ring)
187 }
188
189 #[stability::unstable(feature = "enable")]
197 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
198 where S: RingStore<Type = R> + Copy,
199 I: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
200 V1: VectorView<R::Element>,
201 V2: VectorView<R::Element>,
202 Self: 'a,
203 R: 'a
204 {
205 for (lhs, lhs_prep, rhs, rhs_prep) in values {
206 self.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring)
207 }
208 }
209}
210
211impl<'a, R, C> ConvolutionAlgorithm<R> for C
212 where R: ?Sized + RingBase,
213 C: Deref,
214 C::Target: ConvolutionAlgorithm<R>
215{
216 type PreparedConvolutionOperand = <C::Target as ConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
217
218 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
219 (**self).compute_convolution(lhs, rhs, dst, ring)
220 }
221
222 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
223 (**self).supports_ring(ring)
224 }
225
226 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
227 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
228 {
229 (**self).prepare_convolution_operand(val, len_hint, ring)
230 }
231
232 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
233 where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
234 {
235 (**self).compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
236 }
237
238 fn compute_convolution_sum<'b, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
239 where S: RingStore<Type = R> + Copy,
240 I: ExactSizeIterator<Item = (V1, Option<&'b Self::PreparedConvolutionOperand>, V2, Option<&'b Self::PreparedConvolutionOperand>)>,
241 V1: VectorView<R::Element>,
242 V2: VectorView<R::Element>,
243 Self: 'b,
244 R: 'b
245 {
246 (**self).compute_convolution_sum(values, dst, ring);
247 }
248}
249
250#[derive(Clone, Copy, Debug)]
255pub struct KaratsubaAlgorithm<A: Allocator = Global> {
256 allocator: A
257}
258
259pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
264
265impl<A: Allocator> KaratsubaAlgorithm<A> {
266
267 #[stability::unstable(feature = "enable")]
268 pub const fn new(allocator: A) -> Self {
269 Self { allocator }
270 }
271}
272
273impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
274
275 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
276 karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
277 }
278
279 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
280 true
281 }
282}
283
284#[stability::unstable(feature = "enable")]
289pub trait KaratsubaHint: RingBase {
290
291 fn karatsuba_threshold(&self) -> usize;
300}
301
302impl<R: RingBase + ?Sized> KaratsubaHint for R {
303
304 default fn karatsuba_threshold(&self) -> usize {
305 0
306 }
307}
308
309#[cfg(test)]
310use test;
311#[cfg(test)]
312use crate::primitive_int::*;
313
314#[bench]
315fn bench_naive_mul(bencher: &mut test::Bencher) {
316 let a: Vec<i32> = (0..32).collect();
317 let b: Vec<i32> = (0..32).collect();
318 let mut c: Vec<i32> = (0..64).collect();
319 bencher.iter(|| {
320 c.clear();
321 c.resize(64, 0);
322 karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
323 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
324 assert_eq!(c[62], 31 * 31);
325 });
326}
327
328#[bench]
329fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
330 let a: Vec<i32> = (0..32).collect();
331 let b: Vec<i32> = (0..32).collect();
332 let mut c: Vec<i32> = (0..64).collect();
333 bencher.iter(|| {
334 c.clear();
335 c.resize(64, 0);
336 karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
337 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
338 assert_eq!(c[62], 31 * 31);
339 });
340}
341
342
343#[allow(missing_docs)]
344#[cfg(any(test, feature = "generic_tests"))]
345pub mod generic_tests {
346 use std::cmp::min;
347 use crate::homomorphism::*;
348 use crate::ring::*;
349 use super::*;
350
351 pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
352 where C: ConvolutionAlgorithm<R::Type>,
353 R: RingStore
354 {
355 for lhs_len in [2, 3, 4, 15] {
356 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
357 let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
358 let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
359 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
360 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
361 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
362 } else {
363 0
364 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
365
366 let mut actual = Vec::new();
367 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
368 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
369 for i in 0..(lhs_len + rhs_len) {
370 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
371 }
372
373 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
374 i * i +
375 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
376 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
377 } else {
378 0
379 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
380
381 let mut actual = Vec::new();
382 actual.extend((0..(lhs_len + rhs_len)).map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))));
383 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
384 for i in 0..(lhs_len + rhs_len) {
385 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
386 }
387 }
388 }
389 test_prepared_convolution(convolution, ring, scale);
390 }
391
392 fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
393 where C: ConvolutionAlgorithm<R::Type>,
394 R: RingStore
395 {
396 for lhs_len in [2, 3, 4, 14, 15] {
397 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
398 let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
399 let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
400 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
401 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
402 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
403 } else {
404 0
405 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
406
407 let mut actual = Vec::new();
408 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
409 convolution.compute_convolution_prepared(
410 &lhs,
411 Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
412 &rhs,
413 Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
414 &mut actual,
415 &ring
416 );
417 for i in 0..(lhs_len + rhs_len) {
418 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
419 }
420
421 let mut actual = Vec::new();
422 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
423 convolution.compute_convolution_prepared(
424 &lhs,
425 Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
426 &rhs,
427 None,
428 &mut actual,
429 &ring
430 );
431 for i in 0..(lhs_len + rhs_len) {
432 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
433 }
434
435 let mut actual = Vec::new();
436 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
437 convolution.compute_convolution_prepared(
438 &lhs,
439 None,
440 &rhs,
441 Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
442 &mut actual,
443 &ring
444 );
445 for i in 0..(lhs_len + rhs_len) {
446 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
447 }
448
449 let mut actual = Vec::new();
450 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
451 let data = [
452 (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], Some(convolution.prepare_convolution_operand(&rhs, None, &ring))),
453 (&rhs[..], None, &lhs[..], None)
454 ];
455 convolution.compute_convolution_sum(
456 data.as_fn().map_fn(|(l, l_prep, r, r_prep): &(_, _, _, _)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
457 &mut actual,
458 &ring
459 );
460 for i in 0..(lhs_len + rhs_len) {
461 assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
462 }
463
464 let mut actual = Vec::new();
465 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
466 let data = [
467 (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], None),
468 (&rhs[..], None, &lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)))
469 ];
470 convolution.compute_convolution_sum(
471 data.as_fn().map_fn(|(l, l_prep, r, r_prep)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
472 &mut actual,
473 &ring
474 );
475 for i in 0..(lhs_len + rhs_len) {
476 assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
477 }
478 }
479 }
480 }
481}