Skip to main content

p3_mds/
karatsuba_convolution.rs

1//! Calculate the convolution of two vectors using a Karatsuba-style
2//! decomposition and the CRT.
3//!
4//! This is not a new idea, but we did have the pleasure of
5//! reinventing it independently. Some references:
6//! - `<https://cr.yp.to/lineartime/multapps-20080515.pdf>`
7//! - `<https://2π.com/23/convolution/>`
8//!
9//! Given a vector v \in F^N, let v(x) \in F[x] denote the polynomial
10//! v_0 + v_1 x + ... + v_{N - 1} x^{N - 1}.  Then w is equal to the
11//! convolution v * u if and only if w(x) = v(x)u(x) mod x^N - 1.
12//! Additionally, define the negacyclic convolution by w(x) = v(x)u(x)
13//! mod x^N + 1.  Using the Chinese remainder theorem we can compute
14//! w(x) as
15//!     w(x) = 1/2 (w_0(x) + w_1(x)) + x^{N/2}/2 (w_0(x) - w_1(x))
16//! where
17//!     w_0 = v(x)u(x) mod x^{N/2} - 1
18//!     w_1 = v(x)u(x) mod x^{N/2} + 1
19//!
20//! To compute w_0 and w_1 we first compute
21//!                  v_0(x) = v(x) mod x^{N/2} - 1
22//!                  v_1(x) = v(x) mod x^{N/2} + 1
23//!                  u_0(x) = u(x) mod x^{N/2} - 1
24//!                  u_1(x) = u(x) mod x^{N/2} + 1
25//!
26//! Now w_0 is the convolution of v_0 and u_0 which we can compute
27//! recursively.  For w_1 we compute the negacyclic convolution
28//! v_1(x)u_1(x) mod x^{N/2} + 1 using Karatsuba.
29//!
30//! There are 2 possible approaches to applying Karatsuba which mirror
31//! the DIT vs DIF approaches to FFT's, the left/right decomposition
32//! or the even/odd decomposition. The latter seems to have fewer
33//! operations and so it is the one implemented below, though it does
34//! require a bit more data manipulation. It works as follows:
35//!
36//! Define the even v_e and odd v_o parts so that v(x) = (v_e(x^2) + x v_o(x^2)).
37//! Then v(x)u(x)
38//!    = (v_e(x^2)u_e(x^2) + x^2 v_o(x^2)u_o(x^2))
39//!      + x ((v_e(x^2) + v_o(x^2))(u_e(x^2) + u_o(x^2))
40//!            - (v_e(x^2)u_e(x^2) + v_o(x^2)u_o(x^2)))
41//! This reduces the problem to 3 negacyclic convolutions of size N/2 which
42//! can be computed recursively.
43//!
44//! Of course, for small sizes we just explicitly write out the O(n^2)
45//! approach.
46
47use core::marker::PhantomData;
48use core::ops::{Add, AddAssign, Neg, Sub, SubAssign};
49
50use p3_field::{Algebra, Field};
51
52/// Bound alias for the "wide" operand type (lhs + output) in a Karatsuba convolution.
53pub trait ConvolutionElt:
54    Add<Output = Self> + AddAssign + Copy + Neg<Output = Self> + Sub<Output = Self> + SubAssign
55{
56}
57
58impl<T> ConvolutionElt for T where
59    T: Add<Output = T> + AddAssign + Copy + Neg<Output = T> + Sub<Output = T> + SubAssign
60{
61}
62
63/// Bound alias for the "narrow" operand type (rhs) in a Karatsuba convolution.
64pub trait ConvolutionRhs:
65    Add<Output = Self> + Copy + Neg<Output = Self> + Sub<Output = Self>
66{
67}
68
69impl<T> ConvolutionRhs for T where T: Add<Output = T> + Copy + Neg<Output = T> + Sub<Output = T> {}
70
71/// Template function to perform convolution of vectors.
72///
73/// Roughly speaking, for a convolution of size `N`, it should be
74/// possible to add `N` elements of type `T` without overflowing, and
75/// similarly for `U`. Then multiplication via `Self::mul` should
76/// produce an element of type `T` which will not overflow after about
77/// `N` additions (this is an over-estimate).
78///
79/// For example usage, see `{mersenne-31,baby-bear,goldilocks}/src/mds.rs`.
80///
81/// NB: In practice, one of the parameters to the convolution will be
82/// constant (the MDS matrix). After inspecting Godbolt output, it
83/// seems that the compiler does indeed generate single constants as
84/// inputs to the multiplication, rather than doing all that
85/// arithmetic on the constant values every time. Note however that,
86/// for MDS matrices with large entries (N >= 24), these compile-time
87/// generated constants will be about N times bigger than they need to
88/// be in principle, which could be a potential avenue for some minor
89/// improvements.
90///
91/// NB: If primitive multiplications are still the bottleneck, a
92/// further possibility would be to find an MDS matrix some of whose
93/// entries are powers of 2. Then the multiplication can be replaced
94/// with a shift, which on most architectures has better throughput
95/// and latency, and is issued on different ports (1*p06) to
96/// multiplication (1*p1).
97pub trait Convolve<F, T: ConvolutionElt, U: ConvolutionRhs> {
98    /// Additive identity for the wide operand type `T`.
99    ///
100    /// Used to initialize output and scratch arrays before the convolution
101    /// fills them with computed values.
102    const T_ZERO: T;
103
104    /// Additive identity for the narrow operand type `U`.
105    ///
106    /// Used to initialize temporary arrays for the RHS decomposition
107    /// in the recursive CRT / Karatsuba steps.
108    const U_ZERO: U;
109
110    /// Divide an element of `T` by 2.
111    ///
112    /// - For integers (`i64`, `i128`): arithmetic right shift by 1.
113    /// - For field elements: multiplication by the multiplicative inverse of 2.
114    fn halve(val: T) -> T;
115
116    /// Given an input element, retrieve the corresponding internal
117    /// element that will be used in calculations.
118    fn read(input: F) -> T;
119
120    /// Given input vectors `lhs` and `rhs`, calculate their dot
121    /// product. The result can be reduced with respect to the modulus
122    /// (of `F`), but it must have the same lower 10 bits as the dot
123    /// product if all inputs are considered integers. See
124    /// `monty-31/src/mds.rs::barrett_red_monty31()` for an example
125    /// of how this can be implemented in practice.
126    fn parity_dot<const N: usize>(lhs: [T; N], rhs: [U; N]) -> T;
127
128    /// Convert an internal element of type `T` back into an external
129    /// element.
130    fn reduce(z: T) -> F;
131
132    /// Convolve `lhs` and `rhs`.
133    ///
134    /// The parameter `conv` should be the function in this trait that
135    /// corresponds to length `N`.
136    #[inline(always)]
137    fn apply<const N: usize, C: Fn([T; N], [U; N], &mut [T])>(
138        lhs: [F; N],
139        rhs: [U; N],
140        conv: C,
141    ) -> [F; N] {
142        let lhs = lhs.map(Self::read);
143        let mut output = [Self::T_ZERO; N];
144        conv(lhs, rhs, &mut output);
145        output.map(Self::reduce)
146    }
147
148    #[inline(always)]
149    fn conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
150        output[0] = Self::parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]);
151        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]);
152        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
153    }
154
155    #[inline(always)]
156    fn negacyclic_conv3(lhs: [T; 3], rhs: [U; 3], output: &mut [T]) {
157        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]);
158        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]);
159        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]);
160    }
161
162    #[inline(always)]
163    fn conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
164        // NB: This is just explicitly implementing
165        // conv_n_recursive::<4, 2, _, _>(lhs, rhs, output, Self::conv2, Self::negacyclic_conv2)
166        let u_p = [lhs[0] + lhs[2], lhs[1] + lhs[3]];
167        let u_m = [lhs[0] - lhs[2], lhs[1] - lhs[3]];
168        let v_p = [rhs[0] + rhs[2], rhs[1] + rhs[3]];
169        let v_m = [rhs[0] - rhs[2], rhs[1] - rhs[3]];
170
171        output[0] = Self::parity_dot(u_m, [v_m[0], -v_m[1]]);
172        output[1] = Self::parity_dot(u_m, [v_m[1], v_m[0]]);
173        output[2] = Self::parity_dot(u_p, v_p);
174        output[3] = Self::parity_dot(u_p, [v_p[1], v_p[0]]);
175
176        output[0] += output[2];
177        output[1] += output[3];
178
179        output[0] = Self::halve(output[0]);
180        output[1] = Self::halve(output[1]);
181
182        output[2] -= output[0];
183        output[3] -= output[1];
184    }
185
186    #[inline(always)]
187    fn negacyclic_conv4(lhs: [T; 4], rhs: [U; 4], output: &mut [T]) {
188        output[0] = Self::parity_dot(lhs, [rhs[0], -rhs[3], -rhs[2], -rhs[1]]);
189        output[1] = Self::parity_dot(lhs, [rhs[1], rhs[0], -rhs[3], -rhs[2]]);
190        output[2] = Self::parity_dot(lhs, [rhs[2], rhs[1], rhs[0], -rhs[3]]);
191        output[3] = Self::parity_dot(lhs, [rhs[3], rhs[2], rhs[1], rhs[0]]);
192    }
193
194    /// Compute output(x) = lhs(x)rhs(x) mod x^N - 1 recursively using
195    /// a convolution and negacyclic convolution of size HALF_N = N/2.
196    #[inline(always)]
197    fn conv_n_recursive<const N: usize, const HALF_N: usize, C, NC>(
198        lhs: [T; N],
199        rhs: [U; N],
200        output: &mut [T],
201        inner_conv: C,
202        inner_negacyclic_conv: NC,
203    ) where
204        C: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
205        NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
206    {
207        debug_assert_eq!(2 * HALF_N, N);
208        let mut lhs_pos = [Self::T_ZERO; HALF_N]; // lhs_pos = lhs(x) mod x^{N/2} - 1
209        let mut lhs_neg = [Self::T_ZERO; HALF_N]; // lhs_neg = lhs(x) mod x^{N/2} + 1
210        let mut rhs_pos = [Self::U_ZERO; HALF_N]; // rhs_pos = rhs(x) mod x^{N/2} - 1
211        let mut rhs_neg = [Self::U_ZERO; HALF_N]; // rhs_neg = rhs(x) mod x^{N/2} + 1
212
213        for i in 0..HALF_N {
214            let s = lhs[i];
215            let t = lhs[i + HALF_N];
216
217            lhs_pos[i] = s + t;
218            lhs_neg[i] = s - t;
219
220            let s = rhs[i];
221            let t = rhs[i + HALF_N];
222
223            rhs_pos[i] = s + t;
224            rhs_neg[i] = s - t;
225        }
226
227        let (left, right) = output.split_at_mut(HALF_N);
228
229        // left = w1 = lhs(x)rhs(x) mod x^{N/2} + 1
230        inner_negacyclic_conv(lhs_neg, rhs_neg, left);
231
232        // right = w0 = lhs(x)rhs(x) mod x^{N/2} - 1
233        inner_conv(lhs_pos, rhs_pos, right);
234
235        for i in 0..HALF_N {
236            left[i] += right[i]; // w_0 + w_1
237            left[i] = Self::halve(left[i]); // (w_0 + w_1)/2
238            right[i] -= left[i]; // (w_0 - w_1)/2
239        }
240    }
241
242    /// Compute output(x) = lhs(x)rhs(x) mod x^N + 1 recursively using
243    /// three negacyclic convolutions of size HALF_N = N/2.
244    #[inline(always)]
245    fn negacyclic_conv_n_recursive<const N: usize, const HALF_N: usize, NC>(
246        lhs: [T; N],
247        rhs: [U; N],
248        output: &mut [T],
249        inner_negacyclic_conv: NC,
250    ) where
251        NC: Fn([T; HALF_N], [U; HALF_N], &mut [T]),
252    {
253        debug_assert_eq!(2 * HALF_N, N);
254        let mut lhs_even = [Self::T_ZERO; HALF_N];
255        let mut lhs_odd = [Self::T_ZERO; HALF_N];
256        let mut lhs_sum = [Self::T_ZERO; HALF_N];
257        let mut rhs_even = [Self::U_ZERO; HALF_N];
258        let mut rhs_odd = [Self::U_ZERO; HALF_N];
259        let mut rhs_sum = [Self::U_ZERO; HALF_N];
260
261        for i in 0..HALF_N {
262            let s = lhs[2 * i];
263            let t = lhs[2 * i + 1];
264            lhs_even[i] = s;
265            lhs_odd[i] = t;
266            lhs_sum[i] = s + t;
267
268            let s = rhs[2 * i];
269            let t = rhs[2 * i + 1];
270            rhs_even[i] = s;
271            rhs_odd[i] = t;
272            rhs_sum[i] = s + t;
273        }
274
275        let mut even_s_conv = [Self::T_ZERO; HALF_N];
276        let (left, right) = output.split_at_mut(HALF_N);
277
278        // Recursively compute the size N/2 negacyclic convolutions of
279        // the even parts, odd parts, and sums.
280        inner_negacyclic_conv(lhs_even, rhs_even, &mut even_s_conv);
281        inner_negacyclic_conv(lhs_odd, rhs_odd, left);
282        inner_negacyclic_conv(lhs_sum, rhs_sum, right);
283
284        // Adjust so that the correct values are in right and
285        // even_s_conv respectively:
286        right[0] -= even_s_conv[0] + left[0];
287        even_s_conv[0] -= left[HALF_N - 1];
288
289        for i in 1..HALF_N {
290            right[i] -= even_s_conv[i] + left[i];
291            even_s_conv[i] += left[i - 1];
292        }
293
294        // Interleave even_s_conv and right in the output:
295        for i in 0..HALF_N {
296            output[2 * i] = even_s_conv[i];
297            output[2 * i + 1] = output[i + HALF_N];
298        }
299    }
300
301    #[inline(always)]
302    fn conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
303        Self::conv_n_recursive(lhs, rhs, output, Self::conv3, Self::negacyclic_conv3);
304    }
305
306    #[inline(always)]
307    fn negacyclic_conv6(lhs: [T; 6], rhs: [U; 6], output: &mut [T]) {
308        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv3);
309    }
310
311    #[inline(always)]
312    fn conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
313        Self::conv_n_recursive(lhs, rhs, output, Self::conv4, Self::negacyclic_conv4);
314    }
315
316    #[inline(always)]
317    fn negacyclic_conv8(lhs: [T; 8], rhs: [U; 8], output: &mut [T]) {
318        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv4);
319    }
320
321    #[inline(always)]
322    fn conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
323        Self::conv_n_recursive(lhs, rhs, output, Self::conv6, Self::negacyclic_conv6);
324    }
325
326    #[inline(always)]
327    fn negacyclic_conv12(lhs: [T; 12], rhs: [U; 12], output: &mut [T]) {
328        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv6);
329    }
330
331    #[inline(always)]
332    fn conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
333        Self::conv_n_recursive(lhs, rhs, output, Self::conv8, Self::negacyclic_conv8);
334    }
335
336    #[inline(always)]
337    fn negacyclic_conv16(lhs: [T; 16], rhs: [U; 16], output: &mut [T]) {
338        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv8);
339    }
340
341    #[inline(always)]
342    fn conv24(lhs: [T; 24], rhs: [U; 24], output: &mut [T]) {
343        Self::conv_n_recursive(lhs, rhs, output, Self::conv12, Self::negacyclic_conv12);
344    }
345
346    #[inline(always)]
347    fn conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
348        Self::conv_n_recursive(lhs, rhs, output, Self::conv16, Self::negacyclic_conv16);
349    }
350
351    #[inline(always)]
352    fn negacyclic_conv32(lhs: [T; 32], rhs: [U; 32], output: &mut [T]) {
353        Self::negacyclic_conv_n_recursive(lhs, rhs, output, Self::negacyclic_conv16);
354    }
355
356    #[inline(always)]
357    fn conv64(lhs: [T; 64], rhs: [U; 64], output: &mut [T]) {
358        Self::conv_n_recursive(lhs, rhs, output, Self::conv32, Self::negacyclic_conv32);
359    }
360}
361
362/// Convolve implementor for field elements.
363///
364/// All arithmetic stays in the field — no integer lifting.
365struct FieldConvolve<F, A>(PhantomData<(F, A)>);
366
367impl<F: Field, A: Algebra<F> + Copy> Convolve<A, A, F> for FieldConvolve<F, A> {
368    const T_ZERO: A = A::ZERO;
369    const U_ZERO: F = F::ZERO;
370
371    #[inline(always)]
372    fn halve(val: A) -> A {
373        val.halve()
374    }
375
376    #[inline(always)]
377    fn read(input: A) -> A {
378        input
379    }
380
381    #[inline(always)]
382    fn parity_dot<const N: usize>(lhs: [A; N], rhs: [F; N]) -> A {
383        A::mixed_dot_product(&lhs, &rhs)
384    }
385
386    #[inline(always)]
387    fn reduce(z: A) -> A {
388        z
389    }
390}
391
392/// Circulant matrix-vector multiply for width 16 via Karatsuba convolution.
393#[inline]
394pub fn mds_circulant_karatsuba_16<F: Field, A: Algebra<F> + Copy>(
395    state: &mut [A; 16],
396    col: &[F; 16],
397) {
398    let input = *state;
399    FieldConvolve::<F, A>::conv16(input, *col, state.as_mut_slice());
400}
401
402/// Circulant matrix-vector multiply for width 24 via Karatsuba convolution.
403#[inline]
404pub fn mds_circulant_karatsuba_24<F: Field, A: Algebra<F> + Copy>(
405    state: &mut [A; 24],
406    col: &[F; 24],
407) {
408    let input = *state;
409    FieldConvolve::<F, A>::conv24(input, *col, state.as_mut_slice());
410}
411
412#[cfg(test)]
413mod tests {
414    use p3_baby_bear::BabyBear;
415    use p3_field::PrimeCharacteristicRing;
416    use proptest::prelude::*;
417
418    use super::*;
419
420    type F = BabyBear;
421
422    /// Map an arbitrary `u32` into a field element.
423    fn arb_f() -> impl Strategy<Value = F> {
424        prop::num::u32::ANY.prop_map(F::from_u32)
425    }
426
427    /// Naive O(N^2) circulant multiply used as the reference oracle.
428    ///
429    /// For each output index `i`, computes the dot product of the
430    /// cyclically shifted column with the state vector:
431    ///   r[i] = sum_j col[(i - j) mod N] * state[j]
432    fn naive_circulant<const N: usize>(col: [F; N], state: [F; N]) -> [F; N] {
433        core::array::from_fn(|i| {
434            let mut acc = F::ZERO;
435            for j in 0..N {
436                acc += col[(N + i - j) % N] * state[j];
437            }
438            acc
439        })
440    }
441
442    /// Fixed circulant column for width-16 tests.
443    /// Uses small distinct integers for a reproducible test vector.
444    fn col_16() -> [F; 16] {
445        [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17].map(F::from_i64)
446    }
447
448    /// Fixed circulant column for width-24 tests.
449    fn col_24() -> [F; 24] {
450        [
451            2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
452        ]
453        .map(F::from_i64)
454    }
455
456    proptest! {
457        /// Karatsuba width-16 must match the naive circulant multiply
458        /// for every random state vector.
459        #[test]
460        fn karatsuba_16_matches_naive(state in prop::array::uniform16(arb_f())) {
461            let col = col_16();
462
463            // Compute the expected result via naive O(N^2) circulant multiply.
464            let expected = naive_circulant(col, state);
465
466            // Compute the actual result via Karatsuba convolution.
467            let mut actual = state;
468            mds_circulant_karatsuba_16(&mut actual, &col);
469
470            prop_assert_eq!(actual, expected);
471        }
472
473        /// Karatsuba width-24 must match the naive circulant multiply
474        /// for every random state vector.
475        #[test]
476        fn karatsuba_24_matches_naive(state in prop::array::uniform24(arb_f())) {
477            let col = col_24();
478
479            // Compute the expected result via naive O(N^2) circulant multiply.
480            let expected = naive_circulant(col, state);
481
482            // Compute the actual result via Karatsuba convolution.
483            let mut actual = state;
484            mds_circulant_karatsuba_24(&mut actual, &col);
485
486            prop_assert_eq!(actual, expected);
487        }
488
489        /// Karatsuba width-16 with a random circulant column.
490        /// Tests that the algorithm is correct beyond a single fixed matrix.
491        #[test]
492        fn karatsuba_16_random_col(
493            col in prop::array::uniform16(arb_f()),
494            state in prop::array::uniform16(arb_f()),
495        ) {
496            let expected = naive_circulant(col, state);
497
498            let mut actual = state;
499            mds_circulant_karatsuba_16(&mut actual, &col);
500
501            prop_assert_eq!(actual, expected);
502        }
503
504        /// Karatsuba width-24 with a random circulant column.
505        /// Tests that the algorithm is correct beyond a single fixed matrix.
506        #[test]
507        fn karatsuba_24_random_col(
508            col in prop::array::uniform24(arb_f()),
509            state in prop::array::uniform24(arb_f()),
510        ) {
511            let expected = naive_circulant(col, state);
512
513            let mut actual = state;
514            mds_circulant_karatsuba_24(&mut actual, &col);
515
516            prop_assert_eq!(actual, expected);
517        }
518    }
519}