Skip to main content

ark_poly/polynomial/multivariate/
mod.rs

1//! Work with sparse multivariate polynomials.
2use ark_ff::Field;
3use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
4use ark_std::{
5    cfg_into_iter,
6    cmp::Ordering,
7    fmt::{Debug, Error, Formatter},
8    hash::Hash,
9    ops::Deref,
10    vec::*,
11};
12
13#[cfg(feature = "parallel")]
14use rayon::prelude::*;
15
16mod sparse;
17pub use sparse::SparsePolynomial;
18
19/// Describes the interface for a term (monomial) of a multivariate polynomial.
20pub trait Term:
21    Clone
22    + PartialOrd
23    + Ord
24    + PartialEq
25    + Eq
26    + Hash
27    + Default
28    + Debug
29    + Deref<Target = [(usize, usize)]>
30    + Send
31    + Sync
32    + CanonicalSerialize
33    + CanonicalDeserialize
34{
35    /// Create a new `Term` from a list of tuples of the form `(variable, power)`
36    fn new(term: Vec<(usize, usize)>) -> Self;
37
38    /// Returns the total degree of `self`. This is the sum of all variable
39    /// powers in `self`
40    fn degree(&self) -> usize;
41
42    /// Returns a list of variables in `self`
43    fn vars(&self) -> Vec<usize>;
44
45    /// Returns a list of the powers of each variable in `self`
46    fn powers(&self) -> Vec<usize>;
47
48    /// Returns whether `self` is a constant
49    fn is_constant(&self) -> bool;
50
51    /// Evaluates `self` at the point `p`.
52    fn evaluate<F: Field>(&self, p: &[F]) -> F;
53}
54
55/// Stores a term (monomial) in a multivariate polynomial.
56/// Each element is of the form `(variable, power)`.
57#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
58pub struct SparseTerm(Vec<(usize, usize)>);
59
60impl SparseTerm {
61    /// Sums the powers of any duplicated variables. Assumes `term` is sorted.
62    fn combine(term: &[(usize, usize)]) -> Vec<(usize, usize)> {
63        let mut term_dedup: Vec<(usize, usize)> = Vec::new();
64        for (var, pow) in term {
65            if let Some(prev) = term_dedup.last_mut() {
66                if prev.0 == *var {
67                    prev.1 += pow;
68                    continue;
69                }
70            }
71            term_dedup.push((*var, *pow));
72        }
73        term_dedup
74    }
75}
76
77impl Term for SparseTerm {
78    /// Create a new `Term` from a list of tuples of the form `(variable, power)`
79    fn new(mut term: Vec<(usize, usize)>) -> Self {
80        // Remove any terms with power 0
81        term.retain(|(_, pow)| *pow != 0);
82        // If there are more than one variables, make sure they are
83        // in order and combine any duplicates
84        if term.len() > 1 {
85            term.sort_by(|(v1, _), (v2, _)| v1.cmp(v2));
86            term = Self::combine(&term);
87        }
88        Self(term)
89    }
90
91    /// Returns the sum of all variable powers in `self`
92    fn degree(&self) -> usize {
93        self.iter().fold(0, |sum, acc| sum + acc.1)
94    }
95
96    /// Returns a list of variables in `self`
97    fn vars(&self) -> Vec<usize> {
98        self.iter().map(|(v, _)| *v).collect()
99    }
100
101    /// Returns a list of variable powers in `self`
102    fn powers(&self) -> Vec<usize> {
103        self.iter().map(|(_, p)| *p).collect()
104    }
105
106    /// Returns whether `self` is a constant
107    fn is_constant(&self) -> bool {
108        self.is_empty() || self.degree() == 0
109    }
110
111    /// Evaluates `self` at the given `point` in the field.
112    fn evaluate<F: Field>(&self, point: &[F]) -> F {
113        cfg_into_iter!(self)
114            .map(|(var, power)| point[*var].pow([*power as u64]))
115            .product()
116    }
117}
118
119impl Debug for SparseTerm {
120    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
121        for variable in self.iter() {
122            if variable.1 == 1 {
123                write!(f, " * x_{}", variable.0)?;
124            } else {
125                write!(f, " * x_{}^{}", variable.0, variable.1)?;
126            }
127        }
128        Ok(())
129    }
130}
131
132impl Deref for SparseTerm {
133    type Target = [(usize, usize)];
134
135    fn deref(&self) -> &[(usize, usize)] {
136        &self.0
137    }
138}
139
140impl PartialOrd for SparseTerm {
141    /// Sort by total degree. If total degree is equal then ordering
142    /// is given by exponent weight in lower-numbered variables
143    /// ie. `x_1 > x_2`, `x_1^2 > x_1 * x_2`, etc.
144    #[allow(clippy::non_canonical_partial_ord_impl)]
145    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
146        if self.degree() == other.degree() {
147            // Iterate through all variables and return the corresponding ordering
148            // if they differ in variable numbering or power
149            for ((cur_variable, cur_power), (other_variable, other_power)) in
150                self.iter().zip(other.iter())
151            {
152                if other_variable == cur_variable {
153                    if cur_power != other_power {
154                        return Some(cur_power.cmp(other_power));
155                    }
156                } else {
157                    return Some(other_variable.cmp(cur_variable));
158                }
159            }
160            Some(Ordering::Equal)
161        } else {
162            Some(self.degree().cmp(&other.degree()))
163        }
164    }
165}
166
167impl Ord for SparseTerm {
168    fn cmp(&self, other: &Self) -> Ordering {
169        self.partial_cmp(other).unwrap()
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use ark_ff::{Fp64, MontBackend, MontConfig};
177    use ark_std::vec;
178
179    #[derive(MontConfig)]
180    #[modulus = "5"]
181    #[generator = "2"]
182    pub(crate) struct F5Config;
183
184    pub(crate) type F5 = Fp64<MontBackend<F5Config, 1>>;
185
186    #[test]
187    fn test_sparse_term_combine() {
188        let term = vec![(1, 2), (1, 3), (2, 1)];
189        let combined = SparseTerm::combine(&term);
190        assert_eq!(combined, vec![(1, 5), (2, 1)]);
191    }
192
193    #[test]
194    fn test_sparse_term_new() {
195        let term = vec![(2, 1), (1, 2), (1, 3), (3, 0)];
196        let sparse_term = SparseTerm::new(term);
197        // We expect the terms:
198        // - To be sorted by variable
199        // - To have combined the powers of the same variable
200        // - To have removed any terms with power 0
201        assert_eq!(sparse_term, SparseTerm(vec![(1, 5), (2, 1)]));
202    }
203
204    #[test]
205    fn test_sparse_term_degree() {
206        let term = SparseTerm::new(vec![(1, 2), (2, 3)]);
207        assert_eq!(term.degree(), 5); // 2 + 3 = 5
208    }
209
210    #[test]
211    fn test_sparse_term_vars() {
212        let term = SparseTerm::new(vec![(1, 1), (1, 2), (2, 3)]);
213        assert_eq!(term.vars(), vec![1, 2]);
214    }
215
216    #[test]
217    fn test_sparse_term_powers() {
218        let term = SparseTerm::new(vec![(1, 2), (1, 3), (2, 3)]);
219        assert_eq!(term.powers(), vec![5, 3]);
220    }
221
222    #[test]
223    fn test_sparse_term_is_constant() {
224        let constant_term = SparseTerm::new(vec![]);
225        assert!(constant_term.is_constant());
226
227        let non_constant_term = SparseTerm::new(vec![(1, 2)]);
228        assert!(!non_constant_term.is_constant());
229    }
230
231    #[test]
232    fn test_evaluate() {
233        let term = SparseTerm::new(vec![(0, 2), (1, 3)]);
234        let point = vec![F5::from(1u64), F5::from(2u64)];
235        let result = term.evaluate::<F5>(&point);
236        assert_eq!(result, F5::from(8u64)); // (1^2 * 2^3) = 8 in F5
237    }
238
239    #[test]
240    fn test_partial_cmp() {
241        let term1 = SparseTerm::new(vec![(1, 2), (2, 3)]);
242        let term2 = SparseTerm::new(vec![(1, 2), (2, 2)]);
243        let term3 = SparseTerm::new(vec![(2, 3), (1, 2)]);
244        let term4 = SparseTerm::new(vec![(1, 2)]);
245        let term5 = SparseTerm::new(vec![(2, 2)]);
246        // Constant term (all exponents are zero)
247        let term6 = SparseTerm::new(vec![(1, 0), (2, 0)]);
248        // Empty term, should also be constant
249        let term7 = SparseTerm::new(vec![]);
250
251        // Comparing terms with different degrees
252        assert_eq!(term1.partial_cmp(&term2), Some(Ordering::Greater)); // term1 > term2
253        assert_eq!(term1.partial_cmp(&term3), Some(Ordering::Equal)); // term1 == term3
254        assert_eq!(term2.partial_cmp(&term3), Some(Ordering::Less)); // term2 < term3
255
256        // Comparing terms with equal degree but different exponents
257        assert_eq!(term1.partial_cmp(&term4), Some(Ordering::Greater)); // term1 > term4
258        assert_eq!(term4.partial_cmp(&term5), Some(Ordering::Greater)); // term4 > term5 (x_1 > x_2)
259        assert_eq!(term4.partial_cmp(&term6), Some(Ordering::Greater)); // term4 > term6 (degree 2 vs. degree 0)
260
261        // Comparing constant terms
262        assert_eq!(term6.partial_cmp(&term7), Some(Ordering::Equal)); // term6 == term7 (both constants)
263        assert_eq!(term7.partial_cmp(&term1), Some(Ordering::Less)); // term7 < term1 (constant < non-constant)
264    }
265
266    #[test]
267    fn test_cmp() {
268        let term1 = SparseTerm::new(vec![(1, 2), (2, 3)]);
269        let term2 = SparseTerm::new(vec![(1, 2), (2, 2)]);
270        let term3 = SparseTerm::new(vec![(2, 3), (1, 2)]);
271        let term4 = SparseTerm::new(vec![(1, 2)]);
272        let term5 = SparseTerm::new(vec![(2, 2)]);
273        // Constant term (all exponents are zero)
274        let term6 = SparseTerm::new(vec![(1, 0), (2, 0)]);
275        // Empty term, should also be constant
276        let term7 = SparseTerm::new(vec![]);
277
278        // Comparing terms with different degrees
279        assert_eq!(term1.cmp(&term2), Ordering::Greater); // term1 > term2
280        assert_eq!(term1.cmp(&term3), Ordering::Equal); // term1 == term3
281        assert_eq!(term2.cmp(&term3), Ordering::Less); // term2 < term3
282
283        // Comparing terms with equal degree but different exponents
284        assert_eq!(term1.cmp(&term4), Ordering::Greater); // term1 > term4
285        assert_eq!(term4.cmp(&term5), Ordering::Greater); // term4 > term5 (x_1 > x_2)
286        assert_eq!(term4.cmp(&term6), Ordering::Greater); // term4 > term6 (degree 2 vs. degree 0)
287
288        // Comparing constant terms
289        assert_eq!(term6.cmp(&term7), Ordering::Equal); // term6 == term7 (both constants)
290        assert_eq!(term7.cmp(&term1), Ordering::Less); // term7 < term1 (constant < non-constant)
291    }
292}