1use core::marker::PhantomData;
48use core::ops::{Add, AddAssign, Neg, Sub, SubAssign};
49
50use p3_field::{Algebra, Field};
51
52pub trait ConvolutionElt:
56 Add<Output = Self> + AddAssign + Copy + Neg<Output = Self> + Sub<Output = Self> + SubAssign
57{
58}
59
60impl<T> ConvolutionElt for T where
61 T: Add<Output = T> + AddAssign + Copy + Neg<Output = T> + Sub<Output = T> + SubAssign
62{
63}
64
65pub trait ConvolutionRhs:
69 Add<Output = Self> + Copy + Neg<Output = Self> + Sub<Output = Self>
70{
71}
72
73impl<T> ConvolutionRhs for T where T: Add<Output = T> + Copy + Neg<Output = T> + Sub<Output = T> {}
74
75pub trait Convolve<F, T: ConvolutionElt, U: ConvolutionRhs> {
96 const T_ZERO: T;
101
102 const U_ZERO: U;
107
108 fn halve(val: T) -> T;
113
114 fn read(input: F) -> T;
117
118 fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> T;
125
126 fn reduce(z: T) -> F;
129
130 #[inline(always)]
135 fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [T])>(
136 lhs: [F; N],
137 rhs: [U; N],
138 conv: C,
139 ) -> [F; N] {
140 let lhs = lhs.map(Self::read);
141 let mut output = [Self::T_ZERO; N];
142 conv(lhs, rhs, &mut output);
143 output.map(Self::reduce)
144 }
145
146 #[inline(always)]
147 fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
148 output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
149 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
150 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
151 }
152
153 #[inline(always)]
154 fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
155 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
156 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
157 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
158 }
159
160 #[inline(always)]
161 fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
162 let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
165 let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
166 let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
167 let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
168
169 output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
170 output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
171 output[2] = Self::parity_dot(u_p, v_p);
172 output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
173
174 output[0] += output[2];
175 output[1] += output[3];
176
177 output[0] = Self::halve(output[0]);
178 output[1] = Self::halve(output[1]);
179
180 output[2] -= output[0];
181 output[3] -= output[1];
182 }
183
184 #[inline(always)]
185 fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
186 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
187 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
188 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
189 output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
190 }
191
192 #[inline(always)]
195 fn conv_n_recursive<const N: usize, const HALF_N: usize, C, NC>(
196 lhs: [T; N],
197 rhs: [U; N],
198 output: &mut [T],
199 inner_conv: C,
200 inner_negacyclic_conv: NC,
201 ) where
202 C: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
203 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
204 {
205 debug_assert_eq!(2 * HALF_N, N);
206 let mut lhs_pos = [Self::T_ZERO; HALF_N]; let mut lhs_neg = [Self::T_ZERO; HALF_N]; let mut rhs_pos = [Self::U_ZERO; HALF_N]; let mut rhs_neg = [Self::U_ZERO; HALF_N]; for i in 0..HALF_N {
212 let s = lhs[i];
213 let t = lhs[i + HALF_N];
214
215 lhs_pos[i] = s + t;
216 lhs_neg[i] = s - t;
217
218 let s = rhs[i];
219 let t = rhs[i + HALF_N];
220
221 rhs_pos[i] = s + t;
222 rhs_neg[i] = s - t;
223 }
224
225 let (left, right) = output.split_at_mut(HALF_N);
226
227 inner_negacyclic_conv(lhs_neg, rhs_neg, left);
229
230 inner_conv(lhs_pos, rhs_pos, right);
232
233 for i in 0..HALF_N {
234 left[i] += right[i]; left[i] = Self::halve(left[i]); right[i] -= left[i]; }
238 }
239
240 #[inline(always)]
243 fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, NC>(
244 lhs: [T; N],
245 rhs: [U; N],
246 output: &mut [T],
247 inner_negacyclic_conv: NC,
248 ) where
249 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
250 {
251 debug_assert_eq!(2 * HALF_N, N);
252 let mut lhs_even = [Self::T_ZERO; HALF_N];
253 let mut lhs_odd = [Self::T_ZERO; HALF_N];
254 let mut lhs_sum = [Self::T_ZERO; HALF_N];
255 let mut rhs_even = [Self::U_ZERO; HALF_N];
256 let mut rhs_odd = [Self::U_ZERO; HALF_N];
257 let mut rhs_sum = [Self::U_ZERO; HALF_N];
258
259 for i in 0..HALF_N {
260 let s = lhs[2 * i];
261 let t = lhs[2 * i + 1];
262 lhs_even[i] = s;
263 lhs_odd[i] = t;
264 lhs_sum[i] = s + t;
265
266 let s = rhs[2 * i];
267 let t = rhs[2 * i + 1];
268 rhs_even[i] = s;
269 rhs_odd[i] = t;
270 rhs_sum[i] = s + t;
271 }
272
273 let mut even_s_conv = [Self::T_ZERO; HALF_N];
274 let (left, right) = output.split_at_mut(HALF_N);
275
276 inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
279 inner_negacyclic_conv(lhs_odd, rhs_odd, left);
280 inner_negacyclic_conv(lhs_sum, rhs_sum, right);
281
282 right[0] -= even_s_conv[0] + left[0];
285 even_s_conv[0] -= left[HALF_N - 1];
286
287 for i in 1..HALF_N {
288 right[i] -= even_s_conv[i] + left[i];
289 even_s_conv[i] += left[i - 1];
290 }
291
292 for i in 0..HALF_N {
294 output[2 * i] = even_s_conv[i];
295 output[2 * i + 1] = output[i + HALF_N];
296 }
297 }
298
299 #[inline(always)]
300 fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
301 Self::conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
302 }
303
304 #[inline(always)]
305 fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
306 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
307 }
308
309 #[inline(always)]
310 fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
311 Self::conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
312 }
313
314 #[inline(always)]
315 fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
316 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
317 }
318
319 #[inline(always)]
320 fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
321 Self::conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
322 }
323
324 #[inline(always)]
325 fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
326 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
327 }
328
329 #[inline(always)]
330 fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
331 Self::conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
332 }
333
334 #[inline(always)]
335 fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
336 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
337 }
338
339 #[inline(always)]
340 fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [T]) {
341 Self::conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
342 }
343
344 #[inline(always)]
345 fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
346 Self::conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
347 }
348
349 #[inline(always)]
350 fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
351 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
352 }
353
354 #[inline(always)]
355 fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [T]) {
356 Self::conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
357 }
358}
359
360struct FieldConvolve<F, A>(PhantomData<(F, A)>);
365
366impl<F: Field, A: Algebra<F> + Copy> Convolve<A, A, F> for FieldConvolve<F, A> {
367 const T_ZERO: A = A::ZERO;
368 const U_ZERO: F = F::ZERO;
369
370 #[inline(always)]
371 fn halve(val: A) -> A {
372 val.halve()
373 }
374
375 #[inline(always)]
376 fn read(input: A) -> A {
377 input
378 }
379
380 #[inline(always)]
381 fn parity_dot<const N: usize>(lhs: [A; N], rhs: [F; N]) -> A {
382 A::mixed_dot_product(&lhs, &rhs)
383 }
384
385 #[inline(always)]
386 fn reduce(z: A) -> A {
387 z
388 }
389}
390
391#[inline]
393pub fn mds_circulant_karatsuba_8<F: Field, A: Algebra<F> + Copy>(state: &mut [A; 8], col: &[F; 8]) {
394 let input = *state;
395 FieldConvolve::<F, A>::conv8(input, *col, state.as_mut_slice());
396}
397
398#[inline]
400pub fn mds_circulant_karatsuba_12<F: Field, A: Algebra<F> + Copy>(
401 state: &mut [A; 12],
402 col: &[F; 12],
403) {
404 let input = *state;
405 FieldConvolve::<F, A>::conv12(input, *col, state.as_mut_slice());
406}
407
408#[inline]
410pub fn mds_circulant_karatsuba_16<F: Field, A: Algebra<F> + Copy>(
411 state: &mut [A; 16],
412 col: &[F; 16],
413) {
414 let input = *state;
415 FieldConvolve::<F, A>::conv16(input, *col, state.as_mut_slice());
416}
417
418#[inline]
420pub fn mds_circulant_karatsuba_24<F: Field, A: Algebra<F> + Copy>(
421 state: &mut [A; 24],
422 col: &[F; 24],
423) {
424 let input = *state;
425 FieldConvolve::<F, A>::conv24(input, *col, state.as_mut_slice());
426}
427
428#[cfg(test)]
429mod tests {
430 use p3_baby_bear::BabyBear;
431 use p3_field::PrimeCharacteristicRing;
432 use proptest::prelude::*;
433
434 use super::*;
435
436 type F = BabyBear;
437
438 fn arb_f() -> impl Strategy<Value = F> {
439 prop::num::u32::ANY.prop_map(F::from_u32)
440 }
441
442 fn naive_cyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
443 core::array::from_fn(|i| {
445 let mut acc = F::ZERO;
446 for j in 0..N {
447 acc += lhs[j] * rhs[(N + i - j) % N];
448 }
449 acc
450 })
451 }
452
453 fn naive_negacyclic_conv<const N: usize>(lhs: [F; N], rhs: [F; N]) -> [F; N] {
454 let mut out = [F::ZERO; N];
457 for (i, &l) in lhs.iter().enumerate() {
458 for (j, &r) in rhs.iter().enumerate() {
459 let k = i + j;
460 if k < N {
461 out[k] += l * r;
462 } else {
463 out[k - N] -= l * r;
464 }
465 }
466 }
467 out
468 }
469
470 fn check_conv<const N: usize>(
471 lhs: [F; N],
472 rhs: [F; N],
473 conv_fn: fn([F; N], [F; N], &mut [F]),
474 naive_fn: fn([F; N], [F; N]) -> [F; N],
475 ) {
476 let expected = naive_fn(lhs, rhs);
477 let mut output = [F::ZERO; N];
478 conv_fn(lhs, rhs, &mut output);
479 assert_eq!(output, expected, "convolution mismatch");
480 }
481
482 macro_rules! conv_test {
483 ($name:ident, $n:expr, $conv:expr, $naive:expr, $arr:ident) => {
484 proptest! {
485 #[test]
486 fn $name(
487 lhs in prop::array::$arr(arb_f()),
488 rhs in prop::array::$arr(arb_f()),
489 ) {
490 check_conv::<$n>(lhs, rhs, $conv, $naive);
491 }
492 }
493 };
494 }
495
496 conv_test!(
498 conv3_matches_naive,
499 3,
500 FieldConvolve::<F, F>::conv3,
501 naive_cyclic_conv,
502 uniform3
503 );
504 conv_test!(
505 negacyclic_conv3_matches_naive,
506 3,
507 FieldConvolve::<F, F>::negacyclic_conv3,
508 naive_negacyclic_conv,
509 uniform3
510 );
511
512 conv_test!(
514 conv4_matches_naive,
515 4,
516 FieldConvolve::<F, F>::conv4,
517 naive_cyclic_conv,
518 uniform4
519 );
520 conv_test!(
521 negacyclic_conv4_matches_naive,
522 4,
523 FieldConvolve::<F, F>::negacyclic_conv4,
524 naive_negacyclic_conv,
525 uniform4
526 );
527
528 conv_test!(
530 conv6_matches_naive,
531 6,
532 FieldConvolve::<F, F>::conv6,
533 naive_cyclic_conv,
534 uniform6
535 );
536 conv_test!(
537 negacyclic_conv6_matches_naive,
538 6,
539 FieldConvolve::<F, F>::negacyclic_conv6,
540 naive_negacyclic_conv,
541 uniform6
542 );
543
544 conv_test!(
546 conv8_matches_naive,
547 8,
548 FieldConvolve::<F, F>::conv8,
549 naive_cyclic_conv,
550 uniform8
551 );
552 conv_test!(
553 negacyclic_conv8_matches_naive,
554 8,
555 FieldConvolve::<F, F>::negacyclic_conv8,
556 naive_negacyclic_conv,
557 uniform8
558 );
559
560 conv_test!(
562 conv12_matches_naive,
563 12,
564 FieldConvolve::<F, F>::conv12,
565 naive_cyclic_conv,
566 uniform12
567 );
568 conv_test!(
569 negacyclic_conv12_matches_naive,
570 12,
571 FieldConvolve::<F, F>::negacyclic_conv12,
572 naive_negacyclic_conv,
573 uniform12
574 );
575
576 conv_test!(
578 conv16_matches_naive,
579 16,
580 FieldConvolve::<F, F>::conv16,
581 naive_cyclic_conv,
582 uniform16
583 );
584 conv_test!(
585 negacyclic_conv16_matches_naive,
586 16,
587 FieldConvolve::<F, F>::negacyclic_conv16,
588 naive_negacyclic_conv,
589 uniform16
590 );
591
592 conv_test!(
594 conv24_matches_naive,
595 24,
596 FieldConvolve::<F, F>::conv24,
597 naive_cyclic_conv,
598 uniform24
599 );
600
601 conv_test!(
603 conv32_matches_naive,
604 32,
605 FieldConvolve::<F, F>::conv32,
606 naive_cyclic_conv,
607 uniform32
608 );
609 conv_test!(
610 negacyclic_conv32_matches_naive,
611 32,
612 FieldConvolve::<F, F>::negacyclic_conv32,
613 naive_negacyclic_conv,
614 uniform32
615 );
616
617 #[test]
618 fn conv64_matches_naive_fixed() {
619 let lhs: [F; 64] = core::array::from_fn(|i| F::from_u32(i as u32 + 1));
620 let rhs: [F; 64] = core::array::from_fn(|i| F::from_u32(64 - i as u32));
621 check_conv::<64>(lhs, rhs, FieldConvolve::<F, F>::conv64, naive_cyclic_conv);
622 }
623
624 #[test]
625 fn conv64_all_ones() {
626 let ones = [F::ONE; 64];
627 let expected = naive_cyclic_conv(ones, ones);
628 let mut output = [F::ZERO; 64];
629 FieldConvolve::<F, F>::conv64(ones, ones, &mut output);
630 assert_eq!(output, expected);
631 }
632
633 proptest! {
634 #[test]
635 fn karatsuba_16_matches_naive(
636 col in prop::array::uniform16(arb_f()),
637 state in prop::array::uniform16(arb_f()),
638 ) {
639 let expected = naive_cyclic_conv(state, col);
640 let mut actual = state;
641 mds_circulant_karatsuba_16(&mut actual, &col);
642 prop_assert_eq!(actual, expected);
643 }
644
645 #[test]
646 fn karatsuba_24_matches_naive(
647 col in prop::array::uniform24(arb_f()),
648 state in prop::array::uniform24(arb_f()),
649 ) {
650 let expected = naive_cyclic_conv(state, col);
651 let mut actual = state;
652 mds_circulant_karatsuba_24(&mut actual, &col);
653 prop_assert_eq!(actual, expected);
654 }
655 }
656
657 proptest! {
658 #[test]
659 fn conv8_commutative(
660 a in prop::array::uniform8(arb_f()),
661 b in prop::array::uniform8(arb_f()),
662 ) {
663 let mut ab = [F::ZERO; 8];
665 let mut ba = [F::ZERO; 8];
666 FieldConvolve::<F, F>::conv8(a, b, &mut ab);
667 FieldConvolve::<F, F>::conv8(b, a, &mut ba);
668 prop_assert_eq!(ab, ba);
669 }
670
671 #[test]
672 fn conv8_identity(a in prop::array::uniform8(arb_f())) {
673 let mut id = [F::ZERO; 8];
675 id[0] = F::ONE;
676 let mut out = [F::ZERO; 8];
677 FieldConvolve::<F, F>::conv8(a, id, &mut out);
678 prop_assert_eq!(out, a);
679 }
680
681 #[test]
682 fn conv8_zero(a in prop::array::uniform8(arb_f())) {
683 let zeros = [F::ZERO; 8];
685 let mut out = [F::ZERO; 8];
686 FieldConvolve::<F, F>::conv8(a, zeros, &mut out);
687 prop_assert_eq!(out, zeros);
688 }
689 }
690}