poly_commit/
poly.rs

1use core::iter::{self, Sum};
2use core::ops::{Add, Deref, DerefMut, Index, Mul, Sub};
3use rand_core::RngCore;
4use zkstd::common::{FftField, PrimeField, Vec};
5
6/// polynomial coefficients form expression
7/// a_n-1 , a_n-2, ... , a_0
8#[derive(Debug, Clone, PartialEq, Eq, Default)]
9pub struct Coefficients<F: PrimeField>(pub Vec<F>);
10
11/// polynomial points-value form expression
12#[derive(Debug, Clone, PartialEq, Eq, Default)]
13pub struct PointsValue<F: PrimeField>(pub Vec<F>);
14
15impl<F: PrimeField> Deref for Coefficients<F> {
16    type Target = [F];
17
18    fn deref(&self) -> &[F] {
19        &self.0
20    }
21}
22
23impl<F: PrimeField> DerefMut for Coefficients<F> {
24    fn deref_mut(&mut self) -> &mut [F] {
25        &mut self.0
26    }
27}
28
29impl<F: PrimeField> Index<usize> for PointsValue<F> {
30    type Output = F;
31
32    fn index(&self, index: usize) -> &F {
33        &self.0[index]
34    }
35}
36
37impl<F: FftField> Sum for Coefficients<F> {
38    fn sum<I>(iter: I) -> Self
39    where
40        I: Iterator<Item = Self>,
41    {
42        iter.fold(Coefficients::default(), |res, val| &res + &val)
43    }
44}
45
46impl<F: FftField> PointsValue<F> {
47    pub fn new(coeffs: Vec<F>) -> Self {
48        Self(coeffs)
49    }
50
51    pub fn format_degree(mut self) -> Self {
52        while self.0.last().map_or(false, |c| c == &F::zero()) {
53            self.0.pop();
54        }
55        self
56    }
57}
58
59impl<F: FftField> Coefficients<F> {
60    pub fn new(coeffs: Vec<F>) -> Self {
61        Self(coeffs).format_degree()
62    }
63
64    // polynomial evaluation domain
65    // r^0, r^1, r^2, ..., r^n
66    pub fn setup(k: usize, rng: impl RngCore) -> (F, Vec<F>) {
67        let randomness = F::random(rng);
68        (
69            randomness,
70            (0..(1 << k))
71                .scan(F::one(), |w, _| {
72                    let tw = *w;
73                    *w *= randomness;
74                    Some(tw)
75                })
76                .collect::<Vec<_>>(),
77        )
78    }
79
80    // commit polynomial to domain
81    pub fn commit(&self, domain: &Vec<F>) -> F {
82        assert!(self.0.len() <= domain.len());
83        let diff = domain.len() - self.0.len();
84
85        self.0
86            .iter()
87            .zip(domain.iter().skip(diff))
88            .fold(F::zero(), |acc, (a, b)| acc + *a * *b)
89    }
90
91    // evaluate polynomial at
92    pub fn evaluate(&self, at: &F) -> F {
93        self.0
94            .iter()
95            .rev()
96            .fold(F::zero(), |acc, coeff| acc * at + coeff)
97    }
98
99    // no remainder polynomial division with at
100    // f(x) - f(at) / x - at
101    pub fn divide(&self, at: &F) -> Self {
102        let mut coeffs = self
103            .0
104            .iter()
105            .rev()
106            .scan(F::zero(), |w, coeff| {
107                let tmp = *w + coeff;
108                *w = tmp * at;
109                Some(tmp)
110            })
111            .collect::<Vec<_>>();
112        coeffs.pop();
113        coeffs.reverse();
114        Self(coeffs)
115    }
116
117    /// σ^n - 1
118    pub fn t(n: u64, tau: F) -> F {
119        tau.pow(n) - F::one()
120    }
121
122    /// if hiding degree = 1: (b2*X^(n+1) + b1*X^n - b2*X - b1) + witnesses
123    /// if hiding degree = 2: (b3*X^(n+2) + b2*X^(n+1) + b1*X^n - b3*X^2 - b2*X
124    pub fn blind<R: RngCore>(&mut self, hiding_degree: usize, rng: &mut R) {
125        for i in 0..hiding_degree + 1 {
126            let blinding_scalar = F::random(&mut *rng);
127            self.0[i] -= blinding_scalar;
128            self.0.push(blinding_scalar);
129        }
130    }
131
132    pub fn format_degree(mut self) -> Self {
133        while self.0.last().map_or(false, |c| c == &F::zero()) {
134            self.0.pop();
135        }
136        self
137    }
138
139    /// Returns the degree of the [`Coefficients`].
140    pub fn degree(&self) -> usize {
141        if self.is_zero() {
142            return 0;
143        }
144        assert!(self.0.last().map_or(false, |coeff| coeff != &F::zero()));
145        self.0.len() - 1
146    }
147
148    pub(crate) fn is_zero(&self) -> bool {
149        self.0.is_empty() || self.0.iter().all(|coeff| coeff == &F::zero())
150    }
151}
152
153impl<F: FftField> Add for Coefficients<F> {
154    type Output = Coefficients<F>;
155
156    fn add(self, rhs: Self) -> Self::Output {
157        let zero = F::zero();
158        let (left, right) = if self.0.len() > rhs.0.len() {
159            (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
160        } else {
161            (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
162        };
163        Self::new(left.zip(right).map(|(a, b)| *a + *b).collect())
164    }
165}
166
167impl<'a, 'b, F: FftField> Sub<&'a PointsValue<F>> for &'b PointsValue<F> {
168    type Output = PointsValue<F>;
169
170    fn sub(self, rhs: &'a PointsValue<F>) -> Self::Output {
171        let zero = F::zero();
172        PointsValue::new(if self.0.len() > rhs.0.len() {
173            let (left, right) = (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)));
174            left.zip(right).map(|(a, b)| *a - *b).collect()
175        } else {
176            let (left, right) = (self.0.iter().chain(iter::repeat(&zero)), rhs.0.iter());
177            left.zip(right).map(|(a, b)| *a - *b).collect()
178        })
179    }
180}
181
182impl<'a, 'b, F: FftField> Mul<&'a PointsValue<F>> for &'b PointsValue<F> {
183    type Output = PointsValue<F>;
184
185    fn mul(self, rhs: &'a PointsValue<F>) -> Self::Output {
186        let zero = F::zero();
187        let (left, right) = if self.0.len() > rhs.0.len() {
188            (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
189        } else {
190            (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
191        };
192        PointsValue::new(left.zip(right).map(|(a, b)| *a * *b).collect())
193    }
194}
195
196impl<'a, 'b, F: FftField> Add<&'a Coefficients<F>> for &'b Coefficients<F> {
197    type Output = Coefficients<F>;
198
199    fn add(self, rhs: &'a Coefficients<F>) -> Self::Output {
200        let zero = F::zero();
201        let (left, right) = if self.0.len() > rhs.0.len() {
202            (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
203        } else {
204            (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
205        };
206        Coefficients::new(left.zip(right).map(|(a, b)| *a + *b).collect())
207    }
208}
209
210impl<F: FftField> Sub for Coefficients<F> {
211    type Output = Coefficients<F>;
212
213    fn sub(self, rhs: Self) -> Self::Output {
214        let zero = F::zero();
215        let (left, right) = if self.0.len() > rhs.0.len() {
216            (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
217        } else {
218            (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
219        };
220        Self::new(left.zip(right).map(|(a, b)| *a - *b).collect())
221    }
222}
223
224impl<'a, 'b, F: FftField> Mul<&'a F> for &'b Coefficients<F> {
225    type Output = Coefficients<F>;
226
227    fn mul(self, scalar: &'a F) -> Coefficients<F> {
228        Coefficients::new(self.0.iter().map(|coeff| *coeff * scalar).collect())
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::Coefficients;
235    use crate::PointsValue;
236    use bls_12_381::Fr;
237    use core::iter;
238    use rand_core::OsRng;
239    use zkstd::common::{Group, PrimeField};
240
241    fn arb_fr() -> Fr {
242        Fr::random(OsRng)
243    }
244
245    fn arb_poly(k: u32) -> Coefficients<Fr> {
246        Coefficients(
247            (0..(1 << k))
248                .map(|_| Fr::random(OsRng))
249                .collect::<Vec<Fr>>(),
250        )
251    }
252
253    fn arb_points(k: u32) -> PointsValue<Fr> {
254        PointsValue(
255            (0..(1 << k))
256                .map(|_| Fr::random(OsRng))
257                .collect::<Vec<Fr>>(),
258        )
259    }
260
261    fn naive_multiply<F: PrimeField>(a: Vec<F>, b: Vec<F>) -> Vec<F> {
262        let mut c = vec![F::zero(); a.len() + b.len() - 1];
263        a.iter().enumerate().for_each(|(i_a, coeff_a)| {
264            b.iter().enumerate().for_each(|(i_b, coeff_b)| {
265                c[i_a + i_b] += *coeff_a * *coeff_b;
266            })
267        });
268        c
269    }
270
271    #[test]
272    fn polynomial_scalar() {
273        let poly = arb_poly(10);
274        let at = arb_fr();
275        let scalared = &poly * &at;
276        let test = Coefficients(poly.0.into_iter().map(|coeff| coeff * at).collect());
277        assert_eq!(scalared, test);
278    }
279
280    #[test]
281    fn polynomial_division_test() {
282        let at = arb_fr();
283        let divisor = arb_poly(10);
284        // dividend = divisor * quotient
285        let factor_poly = vec![-at, Fr::one()];
286
287        // divisor * (x - at) = dividend
288        let poly_a = Coefficients(naive_multiply(divisor.0, factor_poly.clone()));
289
290        // dividend / (x - at) = quotient
291        let quotient = poly_a.divide(&at);
292
293        // quotient * (x - at) = divident
294        let original = Coefficients(naive_multiply(quotient.0, factor_poly));
295
296        assert_eq!(poly_a.0, original.0);
297    }
298
299    #[test]
300    fn polynomial_subtraction_test() {
301        let a = arb_points(9);
302        let b = arb_points(10);
303
304        let sub = &a - &b;
305
306        let ans: Vec<Fr> = if a.0.len() > b.0.len() {
307            a.0.iter()
308                .zip(b.0.iter().chain(iter::repeat(&Fr::zero())))
309                .map(|(a, b)| *a - *b)
310                .collect()
311        } else {
312            a.0.iter()
313                .chain(iter::repeat(&Fr::zero()))
314                .zip(b.0.iter())
315                .map(|(a, b)| *a - *b)
316                .collect()
317        };
318        assert_eq!(sub.0, ans);
319    }
320}