Skip to main content

primitives/algebra/
multivariate_ring.rs

1use std::{
2    iter::Sum,
3    marker::PhantomData,
4    ops::{Add, AddAssign, Mul},
5};
6
7use ff::Field;
8use rayon::prelude::*;
9use typenum::{PartialDiv, PowerOfTwo};
10
11// This module defines a multivariate polynomial ring $$ F_q[X_1, ..., X_n] / (X_1^2 - 1, ...,
12// X_n^2 - 1) $$ together with basic operations of this ring.
13use crate::{
14    algebra::field::{FieldExtension, SubfieldElement},
15    random::{CryptoRngCore, Random},
16    types::{heap_array::SubfieldElements, Positive},
17};
18
19/// Minimum number of ring elements one rayon task processes; below this the
20/// element-wise loops run effectively sequentially.
21const PAR_MIN_LEN: usize = 1 << 13;
22
23/// The form of multivariate polynomial
24#[derive(Default, Copy, Clone, Debug, PartialEq)]
25pub struct Coefficient;
26#[derive(Default, Copy, Clone, Debug, PartialEq)]
27pub struct Evaluation;
28
29/// An element `P(X_0, ..., X_{N-1})` of the multivariate polynomial ring with `N` variables.
30///
31/// The polynomial can be either in coefficient or in evaluation form, given by `Form`
32/// parameter. The number of coefficients/evaluations is `M = 2 ^ N`.
33///
34/// In _coefficient_ representation `data` are the coefficients of the polynomial:
35///     `\sum_{i=0..2^N-1} c_i * \Prod_{j=0..N-1} X_j ^ bin(i)_j`
36/// `bin(i)` is the  binary-decomposition of `i` (LSB first).
37/// For example for `N=2` the 2-variate polynomial is `c_0 + c_1 . X_0 + c_2 . X_1 + c_3 . X_0 .
38/// X_1` and `data = [c_0, c_1, c_2, c_3]`.
39///
40/// In _evaluation_ representation `data` are the evaluations:
41///     `e_i = P(bin_signed(i))`
42/// `bin_signed(i)` is the signed binary-decomposition of `i` (LSB first).
43/// For example for `N=2` the evaluations are:
44///     - `e_0 = P(-1, -1)`
45///     - `e_1 = P(-1,  1)`
46///     - `e_2 = P( 1, -1)`
47///     - `e_3 = P( 1,  1)`
48/// and `data = [e_0, e_1, e_2, e_3]`
49#[derive(Clone, Default, Debug, PartialEq)]
50pub struct MultivariateRing<F: FieldExtension, M: Positive, Form> {
51    data: SubfieldElements<F, M>,
52    _form: PhantomData<Form>,
53}
54
55pub type MultivariateRingCoefForm<F, M> = MultivariateRing<F, M, Coefficient>;
56pub type MultivariateRingEvalForm<F, M> = MultivariateRing<F, M, Evaluation>;
57
58impl<F: FieldExtension, M: Positive + PowerOfTwo, Form> MultivariateRing<F, M, Form> {
59    pub fn new(data: SubfieldElements<F, M>) -> Self {
60        Self {
61            data,
62            _form: PhantomData,
63        }
64    }
65
66    pub fn random(rng: impl CryptoRngCore) -> Self {
67        Self::new(SubfieldElements::<F, M>::random(rng))
68    }
69
70    pub fn data(&self) -> &SubfieldElements<F, M> {
71        &self.data
72    }
73
74    pub fn into_data(self) -> SubfieldElements<F, M> {
75        self.data
76    }
77}
78
79impl<F: FieldExtension, M: Positive + PowerOfTwo> MultivariateRing<F, M, Coefficient> {
80    /// Transform polynomial from coefficient representation to evaluation representation using the
81    /// Walsh-Hadamard transform.
82    pub fn to_eval_repr(self) -> MultivariateRing<F, M, Evaluation> {
83        let mut data: SubfieldElements<F, M> = self.data;
84        walsh_hadamard_inplace(&mut data);
85        MultivariateRing::<_, _, Evaluation>::new(data)
86    }
87
88    pub fn nb_coefs(&self) -> usize {
89        self.data.len()
90    }
91}
92
93impl<F: FieldExtension, M: Positive + PowerOfTwo> MultivariateRing<F, M, Evaluation> {
94    /// Realize a `T`-sparse coefficient vector directly in evaluation form.
95    ///
96    /// The Walsh transform of `Σ_k c_k·X^{i_k}` is evaluated point-wise — the
97    /// monomial `X^{i}` at the signed point `bin_signed(x)` is
98    /// `(−1)^{popcount(i & !x)}` — costing `O(M·T)`.
99    ///
100    /// # Crossover with the dense path
101    ///
102    /// Densifying the coefficients and running the full Walsh–Hadamard transform
103    /// ([`Self::from_coeffs`], i.e. `to_eval_repr` on the coefficient form) costs `O(M·log M) =
104    /// O(M·N)`, where `N = log2(M)` is the number of variables. This sparse path is therefore
105    /// only cheaper while `T ≲ N`; for larger `T` it is strictly worse (per element it does `T`
106    /// work versus the dense path's `N`), so callers with `T > N` should densify and use the
107    /// dense path instead.
108    pub fn new_from_sparse_coef(
109        coefs: impl IntoIterator<Item = (usize, SubfieldElement<F>)>,
110    ) -> Self {
111        let m = M::to_usize();
112        let coefs: Vec<(usize, SubfieldElement<F>)> = coefs
113            .into_iter()
114            .inspect(|(i, _)| debug_assert!(*i < m))
115            .map(|(i, c_i)| (i & (m - 1), c_i)) // Reduce `i` modulo `M`
116            .collect();
117
118        let mut data = SubfieldElements::<F, M>::default();
119        data.par_iter_mut()
120            .enumerate()
121            .with_min_len(PAR_MIN_LEN)
122            .for_each(|(x, e)| {
123                let not_x = !x & (m - 1);
124                for (i, c_i) in &coefs {
125                    if (i & not_x).count_ones() & 1 == 1 {
126                        *e -= *c_i;
127                    } else {
128                        *e += *c_i;
129                    }
130                }
131            });
132        Self::new(data)
133    }
134
135    pub fn from_coeffs(coefs: SubfieldElements<F, M>) -> Self {
136        MultivariateRing::new(coefs).to_eval_repr()
137    }
138
139    /// Complete a transform whose `BlockSize`-aligned blocks are already in the Walsh domain: only
140    /// the cross-block butterfly levels are applied (the transform factors per bit, so levels
141    /// commute). `BlockSize == M` is a pure relabel; `BlockSize == 1` is the full transform.
142    pub fn from_blockwise_walsh<BlockSize>(mut data: SubfieldElements<F, M>) -> Self
143    where
144        BlockSize: Positive + PowerOfTwo,
145        M: PartialDiv<BlockSize>,
146    {
147        walsh_hadamard::walsh_hadamard_from_step(&mut data, BlockSize::USIZE);
148        Self::new(data)
149    }
150
151    pub fn square(&self) -> Self {
152        Self {
153            data: self.data.iter().map(SubfieldElement::<F>::square).collect(),
154            _form: PhantomData,
155        }
156    }
157
158    pub fn nb_evals(&self) -> usize {
159        self.data.len()
160    }
161}
162
163#[macros::op_variants(owned, borrowed, flipped)]
164impl<F: FieldExtension, M: Positive + PowerOfTwo> Mul<&MultivariateRing<F, M, Evaluation>>
165    for MultivariateRing<F, M, Evaluation>
166{
167    type Output = MultivariateRing<F, M, Evaluation>;
168
169    fn mul(mut self, rhs: &MultivariateRing<F, M, Evaluation>) -> Self::Output {
170        self.data
171            .par_iter_mut()
172            .zip(rhs.data.par_iter())
173            .with_min_len(PAR_MIN_LEN)
174            .for_each(|(a, b)| *a *= b);
175        self
176    }
177}
178
179impl<F: FieldExtension, M: Positive> AddAssign for MultivariateRing<F, M, Evaluation> {
180    fn add_assign(&mut self, rhs: Self) {
181        self.data
182            .par_iter_mut()
183            .zip(rhs.data.par_iter())
184            .with_min_len(PAR_MIN_LEN)
185            .for_each(|(a, b)| *a += b);
186    }
187}
188
189impl<F: FieldExtension, M: Positive + PowerOfTwo> Add for MultivariateRing<F, M, Evaluation> {
190    type Output = Self;
191
192    fn add(mut self, rhs: Self) -> Self::Output {
193        self += rhs;
194        self
195    }
196}
197
198impl<F: FieldExtension, M: Positive + PowerOfTwo> Sum for MultivariateRing<F, M, Evaluation> {
199    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
200        iter.fold(Self::default(), |acc, x| acc + x)
201    }
202}
203
204/// In-place Walsh-Hadamard transform: the coefficient → evaluation map of
205/// [`MultivariateRing`]. Linear, an involution up to scaling by `M`.
206pub fn walsh_hadamard_inplace<F: FieldExtension, M: Positive + PowerOfTwo>(
207    data: &mut SubfieldElements<F, M>,
208) {
209    walsh_hadamard::walsh_hadamard_from_step(data, 1);
210}
211
212mod walsh_hadamard {
213    use rayon::prelude::*;
214    use typenum::PowerOfTwo;
215
216    use super::PAR_MIN_LEN;
217    use crate::{
218        algebra::field::FieldExtension,
219        types::{heap_array::SubfieldElements, Positive},
220    };
221
222    /// In-place Walsh-Hadamard butterfly levels with step ≥ `first_step`.
223    ///
224    /// The transform factors into one butterfly level per index bit and the
225    /// levels commute, so starting at `first_step = B` completes a transform
226    /// whose `B`-aligned blocks are already transformed; `first_step = 1` is
227    /// the full transform. `first_step` must be a power of two.
228    pub(super) fn walsh_hadamard_from_step<F: FieldExtension, M: Positive + PowerOfTwo>(
229        data: &mut SubfieldElements<F, M>,
230        first_step: usize,
231    ) {
232        let m = M::to_usize();
233        debug_assert!(first_step.is_power_of_two() && first_step <= m);
234        let mut step = first_step;
235        while step < m {
236            let step2 = step << 1;
237            if step < PAR_MIN_LEN {
238                // Many small blocks: parallelize across blocks.
239                data.par_chunks_mut(step2)
240                    .with_min_len((PAR_MIN_LEN / step2).max(1))
241                    .for_each(|chunk| {
242                        let (lo, hi) = chunk.split_at_mut(step);
243                        for (a, b) in lo.iter_mut().zip(hi) {
244                            let t = *a;
245                            *a = t - *b;
246                            *b = t + *b;
247                        }
248                    });
249            } else {
250                // Few large blocks: parallelize the butterflies inside each.
251                for chunk in data.chunks_mut(step2) {
252                    let (lo, hi) = chunk.split_at_mut(step);
253                    lo.par_iter_mut()
254                        .zip(hi.par_iter_mut())
255                        .with_min_len(PAR_MIN_LEN)
256                        .for_each(|(a, b)| {
257                            let t = *a;
258                            *a = t - *b;
259                            *b = t + *b;
260                        });
261                }
262            }
263            step = step2;
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use itertools::izip;
271    use itybity::IntoBits;
272    use typenum::{PartialDiv, PowerOfTwo, Shleft, Unsigned};
273
274    use super::{walsh_hadamard_inplace, Coefficient, Evaluation, MultivariateRing};
275    use crate::{
276        algebra::{
277            elliptic_curve::{Curve25519Ristretto as C, ScalarField},
278            field::{mersenne::Mersenne107, FieldExtension, SubfieldElement},
279        },
280        izip_eq,
281        random::{test_rng, Random},
282        types::{heap_array::SubfieldElements, Positive},
283    };
284
285    type Fq = ScalarField<C>;
286
287    impl<F: FieldExtension, M: Positive + PowerOfTwo> MultivariateRing<F, M, Coefficient> {
288        fn eval(&self, x_vals: Vec<SubfieldElement<F>>) -> SubfieldElement<F> {
289            use num_traits::identities::One;
290            debug_assert_eq!(1usize << x_vals.len(), self.data.len());
291            self.data
292                .iter()
293                .enumerate()
294                .map(|(i, c_i)| {
295                    let monomial_i = izip!(x_vals.iter(), i.into_iter_lsb0())
296                        .map(|(x_i, b_i)| {
297                            if b_i {
298                                *x_i
299                            } else {
300                                SubfieldElement::<F>::one()
301                            }
302                        })
303                        .product::<SubfieldElement<F>>();
304                    *c_i * monomial_i
305                })
306                .sum()
307        }
308
309        /// Multiply 2 polynomials in coefficient representation.
310        ///
311        /// __Remark:__ Use only in tests as it has O(M^2) complexity
312        fn mul_slow(&self, other: &Self) -> Self {
313            let mut out = Self::default();
314
315            let n = self.nb_coefs();
316            for i in 0..n {
317                for j in 0..n {
318                    let k = i ^ j;
319                    out.data[k] += self.data[i] * other.data[j];
320                }
321            }
322
323            out
324        }
325    }
326
327    fn signed_binary_decomposition<F: FieldExtension>(
328        x: usize,
329        n: usize,
330    ) -> Vec<SubfieldElement<F>> {
331        use num_traits::identities::One;
332        x.into_iter_lsb0()
333            .take(n)
334            .map(|b_i| {
335                if b_i {
336                    SubfieldElement::<F>::one()
337                } else {
338                    -SubfieldElement::<F>::one()
339                }
340            })
341            .collect()
342    }
343
344    fn test_multivariate_ring_walsh_hadamard_impl<F: FieldExtension, M: Positive + PowerOfTwo>(
345        walsh_hadamard: fn(&mut SubfieldElements<F, M>),
346    ) {
347        let rng = test_rng();
348        let poly_coef = MultivariateRing::<F, M, Coefficient>::random(rng);
349        let mut data = poly_coef.data.clone();
350        walsh_hadamard(&mut data);
351        let poly_eval = MultivariateRing::<_, _, Evaluation>::new(data);
352
353        // Check each evaluation values by manually evaluating the input polynomial at each
354        // point
355        let log2m = M::to_usize().ilog2() as usize;
356        for i in 0..poly_eval.nb_evals() {
357            // Signed binary-decomposition of `i`
358            let x_vals: Vec<_> = signed_binary_decomposition::<F>(i, log2m);
359            let e_i = poly_coef.eval(x_vals);
360            assert_eq!(e_i, poly_eval.data[i])
361        }
362    }
363
364    #[test]
365    fn test_multivariate_ring_walsh_hadamard() {
366        type M = Shleft<typenum::U1, typenum::U8>;
367
368        test_multivariate_ring_walsh_hadamard_impl::<Mersenne107, M>(walsh_hadamard_inplace);
369        test_multivariate_ring_walsh_hadamard_impl::<Fq, M>(walsh_hadamard_inplace);
370    }
371
372    fn test_multivariate_ring_new_from_sparse_coef_impl<
373        F: FieldExtension,
374        M: Positive + PowerOfTwo,
375        T: Positive,
376    >() {
377        use num_traits::identities::One;
378
379        let mut rng = test_rng();
380        let coefs = SubfieldElements::<F, T>::random(&mut rng);
381        let coefs_sparse: Vec<_> = izip!(
382            (0..T::to_usize()).map(|_| usize::random(&mut rng) % M::to_usize()),
383            coefs
384        )
385        .collect();
386
387        let poly_eval_exp =
388            MultivariateRing::<_, M, Evaluation>::new_from_sparse_coef(coefs_sparse.clone());
389
390        let log2m = M::to_usize().ilog2() as usize;
391        let coef_sparse_bin: Vec<(Vec<_>, SubfieldElement<F>)> = coefs_sparse
392            .into_iter()
393            .map(|(i, c_i)| (i.into_iter_lsb0().take(log2m).collect(), c_i))
394            .collect();
395
396        for i in 0..poly_eval_exp.nb_evals() {
397            let x_vals: Vec<_> = signed_binary_decomposition::<F>(i, log2m);
398            let e_i = coef_sparse_bin
399                .iter()
400                .map(|(i_bin, c_i)| {
401                    izip_eq!(&x_vals, i_bin)
402                        .map(|(x_j, i_bin_j)| {
403                            if *i_bin_j {
404                                *x_j
405                            } else {
406                                SubfieldElement::<F>::one()
407                            }
408                        })
409                        .product::<SubfieldElement<F>>()
410                        * *c_i
411                })
412                .sum::<SubfieldElement<F>>();
413
414            assert_eq!(e_i, poly_eval_exp.data[i]);
415        }
416    }
417
418    #[test]
419    fn test_multivariate_ring_new_from_sparse_coef() {
420        type M = Shleft<typenum::U1, typenum::U8>;
421        test_multivariate_ring_new_from_sparse_coef_impl::<Mersenne107, M, typenum::U1>();
422        test_multivariate_ring_new_from_sparse_coef_impl::<Mersenne107, M, typenum::U13>();
423        test_multivariate_ring_new_from_sparse_coef_impl::<Fq, M, typenum::U1>();
424        test_multivariate_ring_new_from_sparse_coef_impl::<Fq, M, typenum::U13>();
425    }
426
427    fn test_multivariate_ring_mul_impl<F: FieldExtension, M: Positive + PowerOfTwo>() {
428        let mut rng = test_rng();
429        let a = MultivariateRing::<F, M, Coefficient>::random(&mut rng);
430        let b = MultivariateRing::<F, M, Coefficient>::random(&mut rng);
431
432        // Transform to evaluation representation
433        let a_eval = a.clone().to_eval_repr();
434        let b_eval = b.clone().to_eval_repr();
435
436        // Multiplication in the evaluation domain
437        let c_eval_exp = a_eval * b_eval;
438
439        // Schoolbook multiplication in the coefficient representation
440        let c = a.mul_slow(&b);
441        let c_eval_act = c.to_eval_repr();
442
443        assert_eq!(c_eval_act, c_eval_exp);
444    }
445
446    #[test]
447    fn test_multivariate_ring_mul() {
448        type M = Shleft<typenum::U1, typenum::U8>;
449
450        test_multivariate_ring_mul_impl::<Mersenne107, M>();
451        test_multivariate_ring_mul_impl::<Fq, M>();
452    }
453
454    /// Transforming `BlockSize`-aligned blocks first and completing with
455    /// `from_blockwise_walsh` reproduces the full transform, for every block
456    /// size (1 = nothing pre-transformed, M = pure relabel).
457    #[test]
458    fn test_multivariate_ring_from_blockwise_walsh() {
459        type F = Mersenne107;
460        type M = Shleft<typenum::U1, typenum::U8>;
461
462        // Pre-transform `BlockSize`-aligned blocks of `coeffs`, complete with
463        // `from_blockwise_walsh::<BlockSize>`, and check it matches the full transform.
464        fn check<BlockSize>(
465            coeffs: &SubfieldElements<F, M>,
466            expected: &MultivariateRing<F, M, Evaluation>,
467        ) where
468            BlockSize: Positive + PowerOfTwo,
469            M: PartialDiv<BlockSize>,
470        {
471            let block_size = BlockSize::USIZE;
472            let mut data = coeffs.clone();
473            for block in data.chunks_mut(block_size) {
474                // Blockwise transform via the same butterflies, on a copy.
475                let mut step = 1;
476                while step < block_size {
477                    let step2 = step << 1;
478                    for pair in block.chunks_mut(step2) {
479                        let (lo, hi) = pair.split_at_mut(step);
480                        for (a, b) in lo.iter_mut().zip(hi) {
481                            let t = *a;
482                            *a = t - *b;
483                            *b = t + *b;
484                        }
485                    }
486                    step = step2;
487                }
488            }
489            let act = MultivariateRing::<F, M, Evaluation>::from_blockwise_walsh::<BlockSize>(data);
490            assert_eq!(&act, expected, "BlockSize={block_size}");
491        }
492
493        let mut rng = test_rng();
494        let coeffs = SubfieldElements::<F, M>::random(&mut rng);
495        let expected = MultivariateRing::<F, M, Evaluation>::from_coeffs(coeffs.clone());
496
497        // M = 2^8: every power-of-two block size from 1 to M.
498        check::<typenum::U1>(&coeffs, &expected);
499        check::<typenum::U2>(&coeffs, &expected);
500        check::<typenum::U4>(&coeffs, &expected);
501        check::<typenum::U8>(&coeffs, &expected);
502        check::<typenum::U16>(&coeffs, &expected);
503        check::<typenum::U32>(&coeffs, &expected);
504        check::<typenum::U64>(&coeffs, &expected);
505        check::<typenum::U128>(&coeffs, &expected);
506        check::<typenum::U256>(&coeffs, &expected);
507    }
508
509    #[test]
510    #[ignore]
511    fn test_multivariate_ring_to_eval_repr_large() {
512        type F = Mersenne107;
513        type N = typenum::U25;
514        type M = Shleft<typenum::U1, N>;
515        const NB_ITER: usize = 1;
516
517        let data_orig = SubfieldElements::<F, M>::random(test_rng());
518
519        let mut data = data_orig.clone();
520        let start = std::time::Instant::now();
521        for _ in 0..NB_ITER {
522            walsh_hadamard_inplace(&mut data);
523            std::hint::black_box(&mut data);
524        }
525        let duration = start.elapsed();
526        println!(
527            "Walsh-Hadamard transformation of size 2^{} execution time: {:?} sec",
528            N::to_usize(),
529            duration.as_secs_f32() / NB_ITER as f32
530        );
531    }
532}