1use core::marker::PhantomData;
48use core::ops::{Add, AddAssign, Neg, Sub, SubAssign};
49
50use p3_field::{Algebra, Field};
51
52pub trait ConvolutionElt:
54 Add<Output = Self> + AddAssign + Copy + Neg<Output = Self> + Sub<Output = Self> + SubAssign
55{
56}
57
58impl<T> ConvolutionElt for T where
59 T: Add<Output = T> + AddAssign + Copy + Neg<Output = T> + Sub<Output = T> + SubAssign
60{
61}
62
63pub trait ConvolutionRhs:
65 Add<Output = Self> + Copy + Neg<Output = Self> + Sub<Output = Self>
66{
67}
68
69impl<T> ConvolutionRhs for T where T: Add<Output = T> + Copy + Neg<Output = T> + Sub<Output = T> {}
70
71pub trait Convolve<F, T: ConvolutionElt, U: ConvolutionRhs> {
98 const T_ZERO: T;
103
104 const U_ZERO: U;
109
110 fn halve(val: T) -> T;
115
116 fn read(input: F) -> T;
119
120 fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> T;
127
128 fn reduce(z: T) -> F;
131
132 #[inline(always)]
137 fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [T])>(
138 lhs: [F; N],
139 rhs: [U; N],
140 conv: C,
141 ) -> [F; N] {
142 let lhs = lhs.map(Self::read);
143 let mut output = [Self::T_ZERO; N];
144 conv(lhs, rhs, &mut output);
145 output.map(Self::reduce)
146 }
147
148 #[inline(always)]
149 fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
150 output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
151 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
152 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
153 }
154
155 #[inline(always)]
156 fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
157 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
158 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
159 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
160 }
161
162 #[inline(always)]
163 fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
164 let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
167 let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
168 let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
169 let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
170
171 output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
172 output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
173 output[2] = Self::parity_dot(u_p, v_p);
174 output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
175
176 output[0] += output[2];
177 output[1] += output[3];
178
179 output[0] = Self::halve(output[0]);
180 output[1] = Self::halve(output[1]);
181
182 output[2] -= output[0];
183 output[3] -= output[1];
184 }
185
186 #[inline(always)]
187 fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
188 output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
189 output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
190 output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
191 output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
192 }
193
194 #[inline(always)]
197 fn conv_n_recursive<const N: usize, const HALF_N: usize, C, NC>(
198 lhs: [T; N],
199 rhs: [U; N],
200 output: &mut [T],
201 inner_conv: C,
202 inner_negacyclic_conv: NC,
203 ) where
204 C: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
205 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
206 {
207 debug_assert_eq!(2 * HALF_N, N);
208 let mut lhs_pos = [Self::T_ZERO; HALF_N]; let mut lhs_neg = [Self::T_ZERO; HALF_N]; let mut rhs_pos = [Self::U_ZERO; HALF_N]; let mut rhs_neg = [Self::U_ZERO; HALF_N]; for i in 0..HALF_N {
214 let s = lhs[i];
215 let t = lhs[i + HALF_N];
216
217 lhs_pos[i] = s + t;
218 lhs_neg[i] = s - t;
219
220 let s = rhs[i];
221 let t = rhs[i + HALF_N];
222
223 rhs_pos[i] = s + t;
224 rhs_neg[i] = s - t;
225 }
226
227 let (left, right) = output.split_at_mut(HALF_N);
228
229 inner_negacyclic_conv(lhs_neg, rhs_neg, left);
231
232 inner_conv(lhs_pos, rhs_pos, right);
234
235 for i in 0..HALF_N {
236 left[i] += right[i]; left[i] = Self::halve(left[i]); right[i] -= left[i]; }
240 }
241
242 #[inline(always)]
245 fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, NC>(
246 lhs: [T; N],
247 rhs: [U; N],
248 output: &mut [T],
249 inner_negacyclic_conv: NC,
250 ) where
251 NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
252 {
253 debug_assert_eq!(2 * HALF_N, N);
254 let mut lhs_even = [Self::T_ZERO; HALF_N];
255 let mut lhs_odd = [Self::T_ZERO; HALF_N];
256 let mut lhs_sum = [Self::T_ZERO; HALF_N];
257 let mut rhs_even = [Self::U_ZERO; HALF_N];
258 let mut rhs_odd = [Self::U_ZERO; HALF_N];
259 let mut rhs_sum = [Self::U_ZERO; HALF_N];
260
261 for i in 0..HALF_N {
262 let s = lhs[2 * i];
263 let t = lhs[2 * i + 1];
264 lhs_even[i] = s;
265 lhs_odd[i] = t;
266 lhs_sum[i] = s + t;
267
268 let s = rhs[2 * i];
269 let t = rhs[2 * i + 1];
270 rhs_even[i] = s;
271 rhs_odd[i] = t;
272 rhs_sum[i] = s + t;
273 }
274
275 let mut even_s_conv = [Self::T_ZERO; HALF_N];
276 let (left, right) = output.split_at_mut(HALF_N);
277
278 inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
281 inner_negacyclic_conv(lhs_odd, rhs_odd, left);
282 inner_negacyclic_conv(lhs_sum, rhs_sum, right);
283
284 right[0] -= even_s_conv[0] + left[0];
287 even_s_conv[0] -= left[HALF_N - 1];
288
289 for i in 1..HALF_N {
290 right[i] -= even_s_conv[i] + left[i];
291 even_s_conv[i] += left[i - 1];
292 }
293
294 for i in 0..HALF_N {
296 output[2 * i] = even_s_conv[i];
297 output[2 * i + 1] = output[i + HALF_N];
298 }
299 }
300
301 #[inline(always)]
302 fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
303 Self::conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
304 }
305
306 #[inline(always)]
307 fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
308 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
309 }
310
311 #[inline(always)]
312 fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
313 Self::conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
314 }
315
316 #[inline(always)]
317 fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
318 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
319 }
320
321 #[inline(always)]
322 fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
323 Self::conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
324 }
325
326 #[inline(always)]
327 fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
328 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
329 }
330
331 #[inline(always)]
332 fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
333 Self::conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
334 }
335
336 #[inline(always)]
337 fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
338 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
339 }
340
341 #[inline(always)]
342 fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [T]) {
343 Self::conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
344 }
345
346 #[inline(always)]
347 fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
348 Self::conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
349 }
350
351 #[inline(always)]
352 fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
353 Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
354 }
355
356 #[inline(always)]
357 fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [T]) {
358 Self::conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
359 }
360}
361
362struct FieldConvolve<F, A>(PhantomData<(F, A)>);
366
367impl<F: Field, A: Algebra<F> + Copy> Convolve<A, A, F> for FieldConvolve<F, A> {
368 const T_ZERO: A = A::ZERO;
369 const U_ZERO: F = F::ZERO;
370
371 #[inline(always)]
372 fn halve(val: A) -> A {
373 val.halve()
374 }
375
376 #[inline(always)]
377 fn read(input: A) -> A {
378 input
379 }
380
381 #[inline(always)]
382 fn parity_dot<const N: usize>(lhs: [A; N], rhs: [F; N]) -> A {
383 A::mixed_dot_product(&lhs, &rhs)
384 }
385
386 #[inline(always)]
387 fn reduce(z: A) -> A {
388 z
389 }
390}
391
392#[inline]
394pub fn mds_circulant_karatsuba_16<F: Field, A: Algebra<F> + Copy>(
395 state: &mut [A; 16],
396 col: &[F; 16],
397) {
398 let input = *state;
399 FieldConvolve::<F, A>::conv16(input, *col, state.as_mut_slice());
400}
401
402#[inline]
404pub fn mds_circulant_karatsuba_24<F: Field, A: Algebra<F> + Copy>(
405 state: &mut [A; 24],
406 col: &[F; 24],
407) {
408 let input = *state;
409 FieldConvolve::<F, A>::conv24(input, *col, state.as_mut_slice());
410}
411
412#[cfg(test)]
413mod tests {
414 use p3_baby_bear::BabyBear;
415 use p3_field::PrimeCharacteristicRing;
416 use proptest::prelude::*;
417
418 use super::*;
419
420 type F = BabyBear;
421
422 fn arb_f() -> impl Strategy<Value = F> {
424 prop::num::u32::ANY.prop_map(F::from_u32)
425 }
426
427 fn naive_circulant<const N: usize>(col: [F; N], state: [F; N]) -> [F; N] {
433 core::array::from_fn(|i| {
434 let mut acc = F::ZERO;
435 for j in 0..N {
436 acc += col[(N + i - j) % N] * state[j];
437 }
438 acc
439 })
440 }
441
442 fn col_16() -> [F; 16] {
445 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17].map(F::from_i64)
446 }
447
448 fn col_24() -> [F; 24] {
450 [
451 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
452 ]
453 .map(F::from_i64)
454 }
455
456 proptest! {
457 #[test]
460 fn karatsuba_16_matches_naive(state in prop::array::uniform16(arb_f())) {
461 let col = col_16();
462
463 let expected = naive_circulant(col, state);
465
466 let mut actual = state;
468 mds_circulant_karatsuba_16(&mut actual, &col);
469
470 prop_assert_eq!(actual, expected);
471 }
472
473 #[test]
476 fn karatsuba_24_matches_naive(state in prop::array::uniform24(arb_f())) {
477 let col = col_24();
478
479 let expected = naive_circulant(col, state);
481
482 let mut actual = state;
484 mds_circulant_karatsuba_24(&mut actual, &col);
485
486 prop_assert_eq!(actual, expected);
487 }
488
489 #[test]
492 fn karatsuba_16_random_col(
493 col in prop::array::uniform16(arb_f()),
494 state in prop::array::uniform16(arb_f()),
495 ) {
496 let expected = naive_circulant(col, state);
497
498 let mut actual = state;
499 mds_circulant_karatsuba_16(&mut actual, &col);
500
501 prop_assert_eq!(actual, expected);
502 }
503
504 #[test]
507 fn karatsuba_24_random_col(
508 col in prop::array::uniform24(arb_f()),
509 state in prop::array::uniform24(arb_f()),
510 ) {
511 let expected = naive_circulant(col, state);
512
513 let mut actual = state;
514 mds_circulant_karatsuba_24(&mut actual, &col);
515
516 prop_assert_eq!(actual, expected);
517 }
518 }
519}