cait_sith/
math.rs

1use std::ops::{Add, AddAssign, Index, Mul, MulAssign};
2
3use elliptic_curve::{Field, Group};
4use rand_core::CryptoRngCore;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    compat::CSCurve,
9    serde::{deserialize_projective_points, serialize_projective_points},
10};
11
12/// Represents a polynomial with coefficients in the scalar field of the curve.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct Polynomial<C: CSCurve> {
15    /// The coefficients of our polynomial, from 0..size-1.
16    coefficients: Vec<C::Scalar>,
17}
18
19impl<C: CSCurve> Polynomial<C> {
20    /// Generate a random polynomial with a certain number of coefficients.
21    pub fn random(rng: &mut impl CryptoRngCore, size: usize) -> Self {
22        let coefficients = (0..size).map(|_| C::Scalar::random(&mut *rng)).collect();
23        Self { coefficients }
24    }
25
26    /// Extend a constant to a random polynomial of a certain size.
27    ///
28    /// This is useful if you want the polynomial to have a certain value, but
29    /// otherwise be random.
30    pub fn extend_random(rng: &mut impl CryptoRngCore, size: usize, constant: &C::Scalar) -> Self {
31        let mut coefficients = Vec::with_capacity(size);
32        coefficients.push(*constant);
33        for _ in 1..size {
34            coefficients.push(C::Scalar::random(&mut *rng));
35        }
36        Self { coefficients }
37    }
38
39    /// Modify this polynomial by adding another polynomial.
40    pub fn add_mut(&mut self, other: &Self) {
41        let new_len = self.coefficients.len().max(other.coefficients.len());
42        self.coefficients.resize(new_len, C::Scalar::ZERO);
43        self.coefficients
44            .iter_mut()
45            .zip(other.coefficients.iter())
46            .for_each(|(a, b)| *a += b);
47    }
48
49    /// Return the addition of this polynomial with another.
50    pub fn add(&self, other: &Self) -> Self {
51        let mut out = self.clone();
52        out.add_mut(other);
53        out
54    }
55
56    /// Scale this polynomial in place by a field element.
57    pub fn scale_mut(&mut self, scale: &C::Scalar) {
58        self.coefficients.iter_mut().for_each(|a| *a *= scale);
59    }
60
61    /// Return the result of scaling this polynomial by a field element.
62    pub fn scale(&self, scale: &C::Scalar) -> Self {
63        let mut out = self.clone();
64        out.scale_mut(scale);
65        out
66    }
67
68    /// Evaluate this polynomial at 0.
69    ///
70    /// This is much more efficient than evaluating at other points.
71    pub fn evaluate_zero(&self) -> C::Scalar {
72        self.coefficients.get(0).cloned().unwrap_or_default()
73    }
74
75    /// Set the zero value of this polynomial to a new scalar
76    pub fn set_zero(&mut self, v: C::Scalar) {
77        if self.coefficients.is_empty() {
78            self.coefficients.push(v)
79        } else {
80            self.coefficients[0] = v
81        }
82    }
83
84    /// Evaluate this polynomial at a specific point.
85    pub fn evaluate(&self, x: &C::Scalar) -> C::Scalar {
86        let mut out = C::Scalar::ZERO;
87        for c in self.coefficients.iter().rev() {
88            out = out * x + c;
89        }
90        out
91    }
92
93    /// Commit to this polynomial by acting on the generator
94    pub fn commit(&self) -> GroupPolynomial<C> {
95        let coefficients = self
96            .coefficients
97            .iter()
98            .map(|x| C::ProjectivePoint::generator() * x)
99            .collect();
100        GroupPolynomial { coefficients }
101    }
102
103    /// Return the length of this polynomial.
104    pub fn len(&self) -> usize {
105        self.coefficients.len()
106    }
107}
108
109impl<C: CSCurve> Index<usize> for Polynomial<C> {
110    type Output = C::Scalar;
111
112    fn index(&self, i: usize) -> &Self::Output {
113        &self.coefficients[i]
114    }
115}
116
117impl<C: CSCurve> Add for &Polynomial<C> {
118    type Output = Polynomial<C>;
119
120    fn add(self, rhs: Self) -> Self::Output {
121        self.add(rhs)
122    }
123}
124
125impl<C: CSCurve> AddAssign<&Self> for Polynomial<C> {
126    fn add_assign(&mut self, rhs: &Self) {
127        self.add_mut(rhs)
128    }
129}
130
131impl<C: CSCurve> Mul<&C::Scalar> for &Polynomial<C> {
132    type Output = Polynomial<C>;
133
134    fn mul(self, rhs: &C::Scalar) -> Self::Output {
135        self.scale(rhs)
136    }
137}
138
139impl<C: CSCurve> MulAssign<&C::Scalar> for Polynomial<C> {
140    fn mul_assign(&mut self, rhs: &C::Scalar) {
141        self.scale_mut(rhs)
142    }
143}
144
145/// A polynomial with group coefficients.
146#[derive(Debug, Clone, Deserialize, Serialize)]
147pub struct GroupPolynomial<C: CSCurve> {
148    #[serde(
149        serialize_with = "serialize_projective_points::<C, _>",
150        deserialize_with = "deserialize_projective_points::<C, _>"
151    )]
152    coefficients: Vec<C::ProjectivePoint>,
153}
154
155impl<C: CSCurve> GroupPolynomial<C> {
156    /// Modify this polynomial by adding another one.
157    pub fn add_mut(&mut self, other: &Self) {
158        self.coefficients
159            .iter_mut()
160            .zip(other.coefficients.iter())
161            .for_each(|(a, b)| *a += b)
162    }
163
164    /// The result of adding this polynomial with another.
165    pub fn add(&self, other: &Self) -> Self {
166        let coefficients = self
167            .coefficients
168            .iter()
169            .zip(other.coefficients.iter())
170            .map(|(a, b)| *a + *b)
171            .collect();
172        Self { coefficients }
173    }
174
175    /// Evaluate this polynomial at 0.
176    ///
177    /// This is more efficient than evaluating at an arbitrary point.
178    pub fn evaluate_zero(&self) -> C::ProjectivePoint {
179        self.coefficients.get(0).cloned().unwrap_or_default()
180    }
181
182    /// Evaluate this polynomial at a specific value.
183    pub fn evaluate(&self, x: &C::Scalar) -> C::ProjectivePoint {
184        let mut out = C::ProjectivePoint::identity();
185        for c in self.coefficients.iter().rev() {
186            out = out * x + c;
187        }
188        out
189    }
190
191    /// Set the zero value of this polynomial to a new group value.
192    pub fn set_zero(&mut self, v: C::ProjectivePoint) {
193        if self.coefficients.is_empty() {
194            self.coefficients.push(v)
195        } else {
196            self.coefficients[0] = v
197        }
198    }
199
200    /// Return the length of this polynomial.
201    pub fn len(&self) -> usize {
202        self.coefficients.len()
203    }
204}
205
206impl<C: CSCurve> Add for &GroupPolynomial<C> {
207    type Output = GroupPolynomial<C>;
208
209    fn add(self, rhs: Self) -> Self::Output {
210        self.add(rhs)
211    }
212}
213
214impl<C: CSCurve> AddAssign<&Self> for GroupPolynomial<C> {
215    fn add_assign(&mut self, rhs: &Self) {
216        self.add_mut(rhs)
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use super::*;
223    use k256::{Scalar, Secp256k1};
224
225    #[test]
226    fn test_addition() {
227        let mut f = Polynomial::<Secp256k1> {
228            coefficients: vec![Scalar::from(1u32), Scalar::from(2u32)],
229        };
230        let g = Polynomial {
231            coefficients: vec![Scalar::from(1u32), Scalar::from(2u32), Scalar::from(3u32)],
232        };
233        let h = Polynomial {
234            coefficients: vec![Scalar::from(2u32), Scalar::from(4u32), Scalar::from(3u32)],
235        };
236        assert_eq!(&f + &g, h);
237        f += &g;
238        assert_eq!(f, h);
239    }
240
241    #[test]
242    fn test_scaling() {
243        let s = Scalar::from(2u32);
244        let mut f = Polynomial::<Secp256k1> {
245            coefficients: vec![Scalar::from(1u32), Scalar::from(2u32)],
246        };
247        let h = Polynomial {
248            coefficients: vec![Scalar::from(2u32), Scalar::from(4u32)],
249        };
250        assert_eq!(&f * &s, h);
251        f *= &s;
252        assert_eq!(f, h);
253    }
254
255    #[test]
256    fn test_evaluation() {
257        let f = Polynomial::<Secp256k1> {
258            coefficients: vec![Scalar::from(1u32), Scalar::from(2u32)],
259        };
260        assert_eq!(f.evaluate(&Scalar::from(1u32)), Scalar::from(3u32));
261        assert_eq!(f.evaluate(&Scalar::from(2u32)), Scalar::from(5u32));
262    }
263}