1use std::alloc::{Allocator, Global};
2use std::marker::PhantomData;
3use std::ops::Deref;
4
5use crate::ring::*;
6use crate::seq::subvector::SubvectorView;
7use crate::seq::VectorView;
8
9use karatsuba::*;
10
11pub mod karatsuba;
15
16pub mod fft;
20
21pub mod ntt;
22
23pub mod rns;
24
25pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
59
60 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
77
78 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
87
88 #[stability::unstable(feature = "enable")]
94 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
95 where F: PreparedConvolutionOperation<Self, R>
96 {
97 Err(function)
98 }
99}
100
101#[stability::unstable(feature = "enable")]
107pub trait PreparedConvolutionOperation<C: ?Sized, R: ?Sized + RingBase> {
108
109 type Output;
110
111 fn execute(self) -> Self::Output
112 where C: PreparedConvolutionAlgorithm<R>;
113}
114
115#[stability::unstable(feature = "enable")]
121pub trait PreparedConvolutionAlgorithm<R: ?Sized + RingBase>: ConvolutionAlgorithm<R> {
122
123 type PreparedConvolutionOperand;
124
125 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
126 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
127
128 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
129 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
130
131 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
132 where S: RingStore<Type = R> + Copy;
133
134 fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
135 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
136 {
137 assert!(ring.is_commutative());
138 self.compute_convolution_lhs_prepared(rhs, lhs, dst, ring);
139 }
140
141 fn compute_convolution_inner_product_lhs_prepared<'a, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S)
142 where S: RingStore<Type = R> + Copy,
143 I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
144 V: VectorView<R::Element>,
145 Self: 'a,
146 R: 'a,
147 Self::PreparedConvolutionOperand: 'a
148 {
149 for (lhs, rhs) in values {
150 self.compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
151 }
152 }
153
154 fn compute_convolution_inner_product_prepared<'a, S, I>(&self, values: I, dst: &mut [R::Element], ring: S)
155 where S: RingStore<Type = R> + Copy,
156 I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
157 Self::PreparedConvolutionOperand: 'a,
158 Self: 'a,
159 R: 'a,
160 {
161 for (lhs, rhs) in values {
162 self.compute_convolution_prepared(lhs, rhs, dst, ring)
163 }
164 }
165}
166
167impl<'a, R, C> ConvolutionAlgorithm<R> for C
168 where R: ?Sized + RingBase,
169 C: Deref,
170 C::Target: ConvolutionAlgorithm<R>
171{
172 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
173 (**self).compute_convolution(lhs, rhs, dst, ring)
174 }
175
176 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
177 (**self).supports_ring(ring)
178 }
179
180 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
181 where F: PreparedConvolutionOperation<Self, R>
182 {
183 struct CallFunction<F, C, R>
184 where R: ?Sized + RingBase,
185 C: Deref,
186 C::Target: ConvolutionAlgorithm<R>,
187 F: PreparedConvolutionOperation<C, R>
188 {
189 convolution: PhantomData<Box<C>>,
190 ring: PhantomData<Box<R>>,
191 function: F
192 }
193 impl<F, C, R> PreparedConvolutionOperation<C::Target, R> for CallFunction<F, C, R>
194 where R: ?Sized + RingBase,
195 C: Deref,
196 C::Target: ConvolutionAlgorithm<R>,
197 F: PreparedConvolutionOperation<C, R>
198 {
199 type Output = F::Output;
200
201 fn execute(self) -> Self::Output
202 where C::Target:PreparedConvolutionAlgorithm<R>
203 {
204 self.function.execute()
205 }
206 }
207 return <C::Target as ConvolutionAlgorithm<R>>::specialize_prepared_convolution::<CallFunction<F, C, R>>(CallFunction {
208 function: function,
209 ring: PhantomData,
210 convolution: PhantomData
211 }).map_err(|f| f.function);
212 }
213}
214
215impl<'a, R, C> PreparedConvolutionAlgorithm<R> for C
216 where R: ?Sized + RingBase,
217 C: Deref,
218 C::Target: PreparedConvolutionAlgorithm<R>
219{
220 type PreparedConvolutionOperand = <C::Target as PreparedConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
221
222 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
223 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
224 {
225 (**self).prepare_convolution_operand(val, ring)
226 }
227
228 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
229 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
230 {
231 (**self).compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
232 }
233
234 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
235 where S: RingStore<Type = R> + Copy
236 {
237 (**self).compute_convolution_prepared(lhs, rhs, dst, ring)
238 }
239
240 fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
241 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
242 {
243 (**self).compute_convolution_rhs_prepared(lhs, rhs, dst, ring)
244 }
245
246 fn compute_convolution_inner_product_lhs_prepared<'b, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S)
247 where S: RingStore<Type = R> + Copy,
248 I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, V)>,
249 V: VectorView<R::Element>,
250 Self: 'b,
251 R: 'b,
252 Self::PreparedConvolutionOperand: 'b
253 {
254 (**self).compute_convolution_inner_product_lhs_prepared(values, dst, ring)
255 }
256
257 fn compute_convolution_inner_product_prepared<'b, S, I>(&self, values: I, dst: &mut [R::Element], ring: S)
258 where S: RingStore<Type = R> + Copy,
259 I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, &'b Self::PreparedConvolutionOperand)>,
260 Self: 'b,
261 R: 'b,
262 Self::PreparedConvolutionOperand: 'b
263 {
264 (**self).compute_convolution_inner_product_prepared(values, dst, ring)
265 }
266}
267
268#[derive(Clone, Copy, Debug)]
273pub struct KaratsubaAlgorithm<A: Allocator = Global> {
274 allocator: A
275}
276
277pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
282
283impl<A: Allocator> KaratsubaAlgorithm<A> {
284
285 #[stability::unstable(feature = "enable")]
286 pub const fn new(allocator: A) -> Self {
287 Self { allocator }
288 }
289}
290
291impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
292
293 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
294 karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
295 }
296
297 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
298 true
299 }
300}
301
302#[stability::unstable(feature = "enable")]
307pub trait KaratsubaHint: RingBase {
308
309 fn karatsuba_threshold(&self) -> usize;
318}
319
320impl<R: RingBase + ?Sized> KaratsubaHint for R {
321
322 default fn karatsuba_threshold(&self) -> usize {
323 0
324 }
325}
326
327#[cfg(test)]
328use test;
329#[cfg(test)]
330use crate::primitive_int::*;
331
332#[bench]
333fn bench_naive_mul(bencher: &mut test::Bencher) {
334 let a: Vec<i32> = (0..32).collect();
335 let b: Vec<i32> = (0..32).collect();
336 let mut c: Vec<i32> = (0..64).collect();
337 bencher.iter(|| {
338 c.clear();
339 c.resize(64, 0);
340 karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
341 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
342 assert_eq!(c[62], 31 * 31);
343 });
344}
345
346#[bench]
347fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
348 let a: Vec<i32> = (0..32).collect();
349 let b: Vec<i32> = (0..32).collect();
350 let mut c: Vec<i32> = (0..64).collect();
351 bencher.iter(|| {
352 c.clear();
353 c.resize(64, 0);
354 karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
355 assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
356 assert_eq!(c[62], 31 * 31);
357 });
358}
359
360
361#[allow(missing_docs)]
362#[cfg(any(test, feature = "generic_tests"))]
363pub mod generic_tests {
364 use std::cmp::min;
365 use crate::homomorphism::*;
366 use crate::ring::*;
367 use super::*;
368
369 pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
370 where C: ConvolutionAlgorithm<R::Type>,
371 R: RingStore
372 {
373 for lhs_len in [2, 3, 4, 15] {
374 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
375 let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
376 let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
377 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
378 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
379 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
380 } else {
381 0
382 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
383
384 let mut actual = Vec::new();
385 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
386 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
387 for i in 0..(lhs_len + rhs_len) {
388 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
389 }
390
391 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
392 i * i +
393 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
394 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
395 } else {
396 0
397 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
398
399 let mut actual = Vec::new();
400 actual.extend((0..(lhs_len + rhs_len)).map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))));
401 convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
402 for i in 0..(lhs_len + rhs_len) {
403 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
404 }
405 }
406 }
407 }
408
409 #[stability::unstable(feature = "enable")]
410 pub fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
411 where C: PreparedConvolutionAlgorithm<R::Type>,
412 R: RingStore
413 {
414 for lhs_len in [2, 3, 4, 14, 15] {
415 for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
416 let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
417 let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
418 let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
419 min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 -
420 (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
421 } else {
422 0
423 }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
424
425 let mut actual = Vec::new();
426 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
427 convolution.compute_convolution_prepared(
428 &convolution.prepare_convolution_operand(&lhs, &ring),
429 &convolution.prepare_convolution_operand(&rhs, &ring),
430 &mut actual,
431 &ring
432 );
433 for i in 0..(lhs_len + rhs_len) {
434 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
435 }
436
437 let mut actual = Vec::new();
438 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
439 convolution.compute_convolution_lhs_prepared(
440 &convolution.prepare_convolution_operand(&lhs, &ring),
441 &rhs,
442 &mut actual,
443 &ring
444 );
445 for i in 0..(lhs_len + rhs_len) {
446 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
447 }
448
449 let mut actual = Vec::new();
450 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
451 let data = [
452 (convolution.prepare_convolution_operand(&lhs, &ring), convolution.prepare_convolution_operand(&rhs, &ring)),
453 (convolution.prepare_convolution_operand(&[ring.one()], &ring), convolution.prepare_convolution_operand(&[ring.one()], &ring))
454 ];
455 convolution.compute_convolution_inner_product_prepared(
456 data.iter().map(|(l, r)| (l, r)),
457 &mut actual,
458 &ring
459 );
460 assert_el_eq!(&ring, ring.add_ref_fst(&expected[0], ring.one()), &actual[0]);
461 for i in 1..(lhs_len + rhs_len) {
462 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
463 }
464
465 let mut actual = Vec::new();
466 actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
467 let data = [
468 (convolution.prepare_convolution_operand(&lhs, &ring), rhs),
469 (convolution.prepare_convolution_operand(&[ring.one()], &ring), vec![ring.one()])
470 ];
471 convolution.compute_convolution_inner_product_lhs_prepared(
472 data.iter().map(|(l, r)| (l, r)),
473 &mut actual,
474 &ring
475 );
476 assert_el_eq!(&ring, ring.add_ref_fst(&expected[0], ring.one()), &actual[0]);
477 for i in 1..(lhs_len + rhs_len) {
478 assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
479 }
480 }
481 }
482 }
483}