1use core::iter::{self, Sum};
2use core::ops::{Add, Deref, DerefMut, Index, Mul, Sub};
3use rand_core::RngCore;
4use zkstd::common::{FftField, PrimeField, Vec};
5
6#[derive(Debug, Clone, PartialEq, Eq, Default)]
9pub struct Coefficients<F: PrimeField>(pub Vec<F>);
10
11#[derive(Debug, Clone, PartialEq, Eq, Default)]
13pub struct PointsValue<F: PrimeField>(pub Vec<F>);
14
15impl<F: PrimeField> Deref for Coefficients<F> {
16 type Target = [F];
17
18 fn deref(&self) -> &[F] {
19 &self.0
20 }
21}
22
23impl<F: PrimeField> DerefMut for Coefficients<F> {
24 fn deref_mut(&mut self) -> &mut [F] {
25 &mut self.0
26 }
27}
28
29impl<F: PrimeField> Index<usize> for PointsValue<F> {
30 type Output = F;
31
32 fn index(&self, index: usize) -> &F {
33 &self.0[index]
34 }
35}
36
37impl<F: FftField> Sum for Coefficients<F> {
38 fn sum<I>(iter: I) -> Self
39 where
40 I: Iterator<Item = Self>,
41 {
42 iter.fold(Coefficients::default(), |res, val| &res + &val)
43 }
44}
45
46impl<F: FftField> PointsValue<F> {
47 pub fn new(coeffs: Vec<F>) -> Self {
48 Self(coeffs)
49 }
50
51 pub fn format_degree(mut self) -> Self {
52 while self.0.last().map_or(false, |c| c == &F::zero()) {
53 self.0.pop();
54 }
55 self
56 }
57}
58
59impl<F: FftField> Coefficients<F> {
60 pub fn new(coeffs: Vec<F>) -> Self {
61 Self(coeffs).format_degree()
62 }
63
64 pub fn setup(k: usize, rng: impl RngCore) -> (F, Vec<F>) {
67 let randomness = F::random(rng);
68 (
69 randomness,
70 (0..(1 << k))
71 .scan(F::one(), |w, _| {
72 let tw = *w;
73 *w *= randomness;
74 Some(tw)
75 })
76 .collect::<Vec<_>>(),
77 )
78 }
79
80 pub fn commit(&self, domain: &Vec<F>) -> F {
82 assert!(self.0.len() <= domain.len());
83 let diff = domain.len() - self.0.len();
84
85 self.0
86 .iter()
87 .zip(domain.iter().skip(diff))
88 .fold(F::zero(), |acc, (a, b)| acc + *a * *b)
89 }
90
91 pub fn evaluate(&self, at: &F) -> F {
93 self.0
94 .iter()
95 .rev()
96 .fold(F::zero(), |acc, coeff| acc * at + coeff)
97 }
98
99 pub fn divide(&self, at: &F) -> Self {
102 let mut coeffs = self
103 .0
104 .iter()
105 .rev()
106 .scan(F::zero(), |w, coeff| {
107 let tmp = *w + coeff;
108 *w = tmp * at;
109 Some(tmp)
110 })
111 .collect::<Vec<_>>();
112 coeffs.pop();
113 coeffs.reverse();
114 Self(coeffs)
115 }
116
117 pub fn t(n: u64, tau: F) -> F {
119 tau.pow(n) - F::one()
120 }
121
122 pub fn blind<R: RngCore>(&mut self, hiding_degree: usize, rng: &mut R) {
125 for i in 0..hiding_degree + 1 {
126 let blinding_scalar = F::random(&mut *rng);
127 self.0[i] -= blinding_scalar;
128 self.0.push(blinding_scalar);
129 }
130 }
131
132 pub fn format_degree(mut self) -> Self {
133 while self.0.last().map_or(false, |c| c == &F::zero()) {
134 self.0.pop();
135 }
136 self
137 }
138
139 pub fn degree(&self) -> usize {
141 if self.is_zero() {
142 return 0;
143 }
144 assert!(self.0.last().map_or(false, |coeff| coeff != &F::zero()));
145 self.0.len() - 1
146 }
147
148 pub(crate) fn is_zero(&self) -> bool {
149 self.0.is_empty() || self.0.iter().all(|coeff| coeff == &F::zero())
150 }
151}
152
153impl<F: FftField> Add for Coefficients<F> {
154 type Output = Coefficients<F>;
155
156 fn add(self, rhs: Self) -> Self::Output {
157 let zero = F::zero();
158 let (left, right) = if self.0.len() > rhs.0.len() {
159 (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
160 } else {
161 (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
162 };
163 Self::new(left.zip(right).map(|(a, b)| *a + *b).collect())
164 }
165}
166
167impl<'a, 'b, F: FftField> Sub<&'a PointsValue<F>> for &'b PointsValue<F> {
168 type Output = PointsValue<F>;
169
170 fn sub(self, rhs: &'a PointsValue<F>) -> Self::Output {
171 let zero = F::zero();
172 PointsValue::new(if self.0.len() > rhs.0.len() {
173 let (left, right) = (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)));
174 left.zip(right).map(|(a, b)| *a - *b).collect()
175 } else {
176 let (left, right) = (self.0.iter().chain(iter::repeat(&zero)), rhs.0.iter());
177 left.zip(right).map(|(a, b)| *a - *b).collect()
178 })
179 }
180}
181
182impl<'a, 'b, F: FftField> Mul<&'a PointsValue<F>> for &'b PointsValue<F> {
183 type Output = PointsValue<F>;
184
185 fn mul(self, rhs: &'a PointsValue<F>) -> Self::Output {
186 let zero = F::zero();
187 let (left, right) = if self.0.len() > rhs.0.len() {
188 (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
189 } else {
190 (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
191 };
192 PointsValue::new(left.zip(right).map(|(a, b)| *a * *b).collect())
193 }
194}
195
196impl<'a, 'b, F: FftField> Add<&'a Coefficients<F>> for &'b Coefficients<F> {
197 type Output = Coefficients<F>;
198
199 fn add(self, rhs: &'a Coefficients<F>) -> Self::Output {
200 let zero = F::zero();
201 let (left, right) = if self.0.len() > rhs.0.len() {
202 (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
203 } else {
204 (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
205 };
206 Coefficients::new(left.zip(right).map(|(a, b)| *a + *b).collect())
207 }
208}
209
210impl<F: FftField> Sub for Coefficients<F> {
211 type Output = Coefficients<F>;
212
213 fn sub(self, rhs: Self) -> Self::Output {
214 let zero = F::zero();
215 let (left, right) = if self.0.len() > rhs.0.len() {
216 (self.0.iter(), rhs.0.iter().chain(iter::repeat(&zero)))
217 } else {
218 (rhs.0.iter(), self.0.iter().chain(iter::repeat(&zero)))
219 };
220 Self::new(left.zip(right).map(|(a, b)| *a - *b).collect())
221 }
222}
223
224impl<'a, 'b, F: FftField> Mul<&'a F> for &'b Coefficients<F> {
225 type Output = Coefficients<F>;
226
227 fn mul(self, scalar: &'a F) -> Coefficients<F> {
228 Coefficients::new(self.0.iter().map(|coeff| *coeff * scalar).collect())
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::Coefficients;
235 use crate::PointsValue;
236 use bls_12_381::Fr;
237 use core::iter;
238 use rand_core::OsRng;
239 use zkstd::common::{Group, PrimeField};
240
241 fn arb_fr() -> Fr {
242 Fr::random(OsRng)
243 }
244
245 fn arb_poly(k: u32) -> Coefficients<Fr> {
246 Coefficients(
247 (0..(1 << k))
248 .map(|_| Fr::random(OsRng))
249 .collect::<Vec<Fr>>(),
250 )
251 }
252
253 fn arb_points(k: u32) -> PointsValue<Fr> {
254 PointsValue(
255 (0..(1 << k))
256 .map(|_| Fr::random(OsRng))
257 .collect::<Vec<Fr>>(),
258 )
259 }
260
261 fn naive_multiply<F: PrimeField>(a: Vec<F>, b: Vec<F>) -> Vec<F> {
262 let mut c = vec![F::zero(); a.len() + b.len() - 1];
263 a.iter().enumerate().for_each(|(i_a, coeff_a)| {
264 b.iter().enumerate().for_each(|(i_b, coeff_b)| {
265 c[i_a + i_b] += *coeff_a * *coeff_b;
266 })
267 });
268 c
269 }
270
271 #[test]
272 fn polynomial_scalar() {
273 let poly = arb_poly(10);
274 let at = arb_fr();
275 let scalared = &poly * &at;
276 let test = Coefficients(poly.0.into_iter().map(|coeff| coeff * at).collect());
277 assert_eq!(scalared, test);
278 }
279
280 #[test]
281 fn polynomial_division_test() {
282 let at = arb_fr();
283 let divisor = arb_poly(10);
284 let factor_poly = vec![-at, Fr::one()];
286
287 let poly_a = Coefficients(naive_multiply(divisor.0, factor_poly.clone()));
289
290 let quotient = poly_a.divide(&at);
292
293 let original = Coefficients(naive_multiply(quotient.0, factor_poly));
295
296 assert_eq!(poly_a.0, original.0);
297 }
298
299 #[test]
300 fn polynomial_subtraction_test() {
301 let a = arb_points(9);
302 let b = arb_points(10);
303
304 let sub = &a - &b;
305
306 let ans: Vec<Fr> = if a.0.len() > b.0.len() {
307 a.0.iter()
308 .zip(b.0.iter().chain(iter::repeat(&Fr::zero())))
309 .map(|(a, b)| *a - *b)
310 .collect()
311 } else {
312 a.0.iter()
313 .chain(iter::repeat(&Fr::zero()))
314 .zip(b.0.iter())
315 .map(|(a, b)| *a - *b)
316 .collect()
317 };
318 assert_eq!(sub.0, ans);
319 }
320}