mpc_stark/algebra/
scalar.rs

1//! Defines the scalar types that form the basis of the Starknet algebra
2
3// ----------------------------
4// | Scalar Field Definitions |
5// ----------------------------
6
7use std::{
8    fmt::{Display, Formatter, Result as FmtResult},
9    iter::{Product, Sum},
10    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12
13use ark_ff::{batch_inversion, Field, Fp256, MontBackend, MontConfig, PrimeField};
14use itertools::Itertools;
15use num_bigint::BigUint;
16use rand::{CryptoRng, Rng, RngCore};
17use serde::{Deserialize, Serialize};
18
19use crate::fabric::{ResultHandle, ResultValue};
20
21use super::macros::{impl_borrow_variants, impl_commutative};
22
23/// The number of bytes needed to represent an element of the base field
24pub const BASE_FIELD_BYTES: usize = 32;
25/// The number of bytes in a `Scalar`
26pub const SCALAR_BYTES: usize = 32;
27
28/// The config for finite field that the Starknet curve is defined over
29#[derive(MontConfig)]
30#[modulus = "3618502788666131213697322783095070105623107215331596699973092056135872020481"]
31#[generator = "3"]
32pub struct StarknetFqConfig;
33/// The finite field that the Starknet curve is defined over
34pub type StarknetBaseFelt = Fp256<MontBackend<StarknetFqConfig, 4>>;
35
36/// The config for the scalar field of the Starknet curve
37#[derive(MontConfig)]
38#[modulus = "3618502788666131213697322783095070105526743751716087489154079457884512865583"]
39#[generator = "3"]
40pub struct StarknetFrConfig;
41/// The finite field representing the curve group of the Starknet curve
42///
43/// Note that this is not the field that the curve is defined over, but field of integers modulo
44/// the order of the curve's group, see [here](https://crypto.stackexchange.com/questions/98124/is-the-stark-curve-a-safecurve)
45/// for more information
46pub(crate) type ScalarInner = Fp256<MontBackend<StarknetFrConfig, 4>>;
47#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
48/// A wrapper around the inner scalar that allows us to implement foreign traits for the `Scalar`
49pub struct Scalar(pub(crate) ScalarInner);
50
51// -------------------
52// | Implementations |
53// -------------------
54
55impl Scalar {
56    /// The underlying field that the scalar wraps
57    pub type Field = ScalarInner;
58
59    /// The scalar field's additive identity
60    pub fn zero() -> Scalar {
61        Scalar(ScalarInner::from(0))
62    }
63
64    /// The scalar field's multiplicative identity
65    pub fn one() -> Scalar {
66        Scalar(ScalarInner::from(1))
67    }
68
69    /// Get the inner value of the scalar
70    pub fn inner(&self) -> ScalarInner {
71        self.0
72    }
73
74    /// Generate a random scalar
75    ///
76    /// n.b. The `rand::random` method uses `ThreadRng` type which implements
77    /// the `CryptoRng` traits
78    pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Scalar {
79        let inner: ScalarInner = rng.sample(rand::distributions::Standard);
80        Scalar(inner)
81    }
82
83    /// Compute the multiplicative inverse of the scalar in its field
84    pub fn inverse(&self) -> Scalar {
85        Scalar(self.0.inverse().unwrap())
86    }
87
88    /// Compute the batch inversion of a list of Scalars
89    pub fn batch_inverse(vals: &mut [Scalar]) {
90        let mut values = vals.iter().map(|x| x.0).collect_vec();
91        batch_inversion(&mut values);
92
93        for (i, val) in vals.iter_mut().enumerate() {
94            *val = Scalar(values[i]);
95        }
96    }
97
98    /// Construct a scalar from the given bytes and reduce modulo the field's modulus
99    pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Scalar {
100        let inner = ScalarInner::from_be_bytes_mod_order(bytes);
101        Scalar(inner)
102    }
103
104    /// Convert to big endian bytes
105    ///
106    /// Pad to the maximum amount of bytes needed so that the resulting bytes are
107    /// of predictable length
108    pub fn to_bytes_be(&self) -> Vec<u8> {
109        let val_biguint = self.to_biguint();
110        let mut bytes = val_biguint.to_bytes_be();
111
112        let mut padding = vec![0u8; SCALAR_BYTES - bytes.len()];
113        padding.append(&mut bytes);
114
115        padding
116    }
117
118    /// Convert the underlying value to a BigUint
119    pub fn to_biguint(&self) -> BigUint {
120        self.0.into()
121    }
122
123    /// Convert from a `BigUint`
124    pub fn from_biguint(val: &BigUint) -> Scalar {
125        let le_bytes = val.to_bytes_le();
126        let inner = ScalarInner::from_le_bytes_mod_order(&le_bytes);
127        Scalar(inner)
128    }
129}
130
131impl Display for Scalar {
132    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
133        write!(f, "{}", self.to_biguint())
134    }
135}
136
137impl Serialize for Scalar {
138    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
139        let bytes = self.to_bytes_be();
140        bytes.serialize(serializer)
141    }
142}
143
144impl<'de> Deserialize<'de> for Scalar {
145    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
146        let bytes = <Vec<u8>>::deserialize(deserializer)?;
147        let scalar = Scalar::from_be_bytes_mod_order(&bytes);
148        Ok(scalar)
149    }
150}
151
152// --------------
153// | Arithmetic |
154// --------------
155
156// === Addition === //
157
158/// A type alias for a result that resolves to a `Scalar`
159pub type ScalarResult = ResultHandle<Scalar>;
160/// A type alias for a result that resolves to a batch of `Scalar`s
161pub type BatchScalarResult = ResultHandle<Vec<Scalar>>;
162impl ScalarResult {
163    /// Compute the multiplicative inverse of the scalar in its field
164    pub fn inverse(&self) -> ScalarResult {
165        self.fabric.new_gate_op(vec![self.id], |mut args| {
166            let val: Scalar = args.remove(0).into();
167            ResultValue::Scalar(Scalar(val.0.inverse().unwrap()))
168        })
169    }
170}
171
172impl Add<&Scalar> for &Scalar {
173    type Output = Scalar;
174
175    fn add(self, rhs: &Scalar) -> Self::Output {
176        let rhs = *rhs;
177        Scalar(self.0 + rhs.0)
178    }
179}
180impl_borrow_variants!(Scalar, Add, add, +, Scalar);
181
182impl Add<&Scalar> for &ScalarResult {
183    type Output = ScalarResult;
184
185    fn add(self, rhs: &Scalar) -> Self::Output {
186        let rhs = *rhs;
187        self.fabric.new_gate_op(vec![self.id], move |args| {
188            let lhs: Scalar = args[0].to_owned().into();
189            ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
190        })
191    }
192}
193impl_borrow_variants!(ScalarResult, Add, add, +, Scalar);
194impl_commutative!(ScalarResult, Add, add, +, Scalar);
195
196impl Add<&ScalarResult> for &ScalarResult {
197    type Output = ScalarResult;
198
199    fn add(self, rhs: &ScalarResult) -> Self::Output {
200        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
201            let lhs: Scalar = args[0].to_owned().into();
202            let rhs: Scalar = args[1].to_owned().into();
203            ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
204        })
205    }
206}
207impl_borrow_variants!(ScalarResult, Add, add, +, ScalarResult);
208
209impl ScalarResult {
210    /// Add two batches of `ScalarResult`s
211    pub fn batch_add(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
212        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
213
214        let n = a.len();
215        let fabric = &a[0].fabric;
216        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
217        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
218            let mut res = Vec::with_capacity(n);
219            for i in 0..n {
220                let lhs: Scalar = args[i].to_owned().into();
221                let rhs: Scalar = args[i + n].to_owned().into();
222                res.push(ResultValue::Scalar(Scalar(lhs.0 + rhs.0)));
223            }
224
225            res
226        })
227    }
228}
229
230// === AddAssign === //
231
232impl AddAssign for Scalar {
233    fn add_assign(&mut self, rhs: Scalar) {
234        *self = *self + rhs;
235    }
236}
237
238// === Subtraction === //
239
240impl Sub<&Scalar> for &Scalar {
241    type Output = Scalar;
242
243    fn sub(self, rhs: &Scalar) -> Self::Output {
244        let rhs = *rhs;
245        Scalar(self.0 - rhs.0)
246    }
247}
248impl_borrow_variants!(Scalar, Sub, sub, -, Scalar);
249
250impl Sub<&Scalar> for &ScalarResult {
251    type Output = ScalarResult;
252
253    fn sub(self, rhs: &Scalar) -> Self::Output {
254        let rhs = *rhs;
255        self.fabric.new_gate_op(vec![self.id], move |args| {
256            let lhs: Scalar = args[0].to_owned().into();
257            ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
258        })
259    }
260}
261impl_borrow_variants!(ScalarResult, Sub, sub, -, Scalar);
262
263impl Sub<&ScalarResult> for &Scalar {
264    type Output = ScalarResult;
265
266    fn sub(self, rhs: &ScalarResult) -> Self::Output {
267        let lhs = *self;
268        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
269            let rhs: Scalar = args[0].to_owned().into();
270            ResultValue::Scalar(lhs - rhs)
271        })
272    }
273}
274impl_borrow_variants!(Scalar, Sub, sub, -, ScalarResult, Output=ScalarResult);
275
276impl Sub<&ScalarResult> for &ScalarResult {
277    type Output = ScalarResult;
278
279    fn sub(self, rhs: &ScalarResult) -> Self::Output {
280        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
281            let lhs: Scalar = args[0].to_owned().into();
282            let rhs: Scalar = args[1].to_owned().into();
283            ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
284        })
285    }
286}
287impl_borrow_variants!(ScalarResult, Sub, sub, -, ScalarResult);
288
289impl ScalarResult {
290    /// Subtract two batches of `ScalarResult`s
291    pub fn batch_sub(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
292        assert_eq!(a.len(), b.len(), "Batch sub requires equal length inputs");
293
294        let n = a.len();
295        let fabric = &a[0].fabric;
296        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
297        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
298            let mut res = Vec::with_capacity(n);
299            for i in 0..n {
300                let lhs: Scalar = args[i].to_owned().into();
301                let rhs: Scalar = args[i + n].to_owned().into();
302                res.push(ResultValue::Scalar(Scalar(lhs.0 - rhs.0)));
303            }
304
305            res
306        })
307    }
308}
309
310// === SubAssign === //
311
312impl SubAssign for Scalar {
313    fn sub_assign(&mut self, rhs: Scalar) {
314        *self = *self - rhs;
315    }
316}
317
318// === Multiplication === //
319
320impl Mul<&Scalar> for &Scalar {
321    type Output = Scalar;
322
323    fn mul(self, rhs: &Scalar) -> Self::Output {
324        let rhs = *rhs;
325        Scalar(self.0 * rhs.0)
326    }
327}
328impl_borrow_variants!(Scalar, Mul, mul, *, Scalar);
329
330impl Mul<&Scalar> for &ScalarResult {
331    type Output = ScalarResult;
332
333    fn mul(self, rhs: &Scalar) -> Self::Output {
334        let rhs = *rhs;
335        self.fabric.new_gate_op(vec![self.id], move |args| {
336            let lhs: Scalar = args[0].to_owned().into();
337            ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
338        })
339    }
340}
341impl_borrow_variants!(ScalarResult, Mul, mul, *, Scalar);
342impl_commutative!(ScalarResult, Mul, mul, *, Scalar);
343
344impl Mul<&ScalarResult> for &ScalarResult {
345    type Output = ScalarResult;
346
347    fn mul(self, rhs: &ScalarResult) -> Self::Output {
348        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
349            let lhs: Scalar = args[0].to_owned().into();
350            let rhs: Scalar = args[1].to_owned().into();
351            ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
352        })
353    }
354}
355impl_borrow_variants!(ScalarResult, Mul, mul, *, ScalarResult);
356
357impl ScalarResult {
358    /// Multiply two batches of `ScalarResult`s
359    pub fn batch_mul(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
360        assert_eq!(a.len(), b.len(), "Batch mul requires equal length inputs");
361
362        let n = a.len();
363        let fabric = &a[0].fabric;
364        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
365        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
366            let mut res = Vec::with_capacity(n);
367            for i in 0..n {
368                let lhs: Scalar = args[i].to_owned().into();
369                let rhs: Scalar = args[i + n].to_owned().into();
370                res.push(ResultValue::Scalar(Scalar(lhs.0 * rhs.0)));
371            }
372
373            res
374        })
375    }
376}
377
378impl Neg for &Scalar {
379    type Output = Scalar;
380
381    fn neg(self) -> Self::Output {
382        Scalar(-self.0)
383    }
384}
385impl_borrow_variants!(Scalar, Neg, neg, -);
386
387impl Neg for &ScalarResult {
388    type Output = ScalarResult;
389
390    fn neg(self) -> Self::Output {
391        self.fabric.new_gate_op(vec![self.id], |args| {
392            let lhs: Scalar = args[0].to_owned().into();
393            ResultValue::Scalar(Scalar(-lhs.0))
394        })
395    }
396}
397impl_borrow_variants!(ScalarResult, Neg, neg, -);
398
399impl ScalarResult {
400    /// Negate a batch of `ScalarResult`s
401    pub fn batch_neg(a: &[ScalarResult]) -> Vec<ScalarResult> {
402        let n = a.len();
403        let fabric = &a[0].fabric;
404        let ids = a.iter().map(|v| v.id).collect_vec();
405        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
406            args.into_iter()
407                .map(Scalar::from)
408                .map(|x| -x)
409                .map(ResultValue::Scalar)
410                .collect_vec()
411        })
412    }
413}
414
415// === MulAssign === //
416
417impl MulAssign for Scalar {
418    fn mul_assign(&mut self, rhs: Scalar) {
419        *self = *self * rhs;
420    }
421}
422
423// ---------------
424// | Conversions |
425// ---------------
426
427impl<T: Into<ScalarInner>> From<T> for Scalar {
428    fn from(val: T) -> Self {
429        Scalar(val.into())
430    }
431}
432
433// -------------------
434// | Iterator Traits |
435// -------------------
436
437impl Sum for Scalar {
438    fn sum<I: Iterator<Item = Scalar>>(iter: I) -> Self {
439        iter.fold(Scalar::zero(), |acc, x| acc + x)
440    }
441}
442
443impl Product for Scalar {
444    fn product<I: Iterator<Item = Scalar>>(iter: I) -> Self {
445        iter.fold(Scalar::one(), |acc, x| acc * x)
446    }
447}
448
449#[cfg(test)]
450mod test {
451    use crate::{
452        algebra::scalar::{Scalar, SCALAR_BYTES},
453        test_helpers::mock_fabric,
454    };
455    use rand::thread_rng;
456
457    /// Tests serializing and deserializing a scalar
458    #[test]
459    fn test_scalar_serialize() {
460        // Sample a random scalar and convert it to bytes
461        let mut rng = thread_rng();
462        let scalar = Scalar::random(&mut rng);
463        let bytes = scalar.to_bytes_be();
464
465        assert_eq!(bytes.len(), SCALAR_BYTES);
466
467        // Deserialize and validate the scalar
468        let scalar_deserialized = Scalar::from_be_bytes_mod_order(&bytes);
469        assert_eq!(scalar, scalar_deserialized);
470    }
471
472    /// Tests addition of raw scalars in a circuit
473    #[tokio::test]
474    async fn test_scalar_add() {
475        let mut rng = thread_rng();
476        let a = Scalar::random(&mut rng);
477        let b = Scalar::random(&mut rng);
478
479        let expected_res = a + b;
480
481        // Allocate the scalars in a fabric and add them together
482        let fabric = mock_fabric();
483        let a_alloc = fabric.allocate_scalar(a);
484        let b_alloc = fabric.allocate_scalar(b);
485
486        let res = &a_alloc + &b_alloc;
487        let res_final = res.await;
488
489        assert_eq!(res_final, expected_res);
490        fabric.shutdown();
491    }
492
493    /// Tests subtraction of raw scalars in the circuit
494    #[tokio::test]
495    async fn test_scalar_sub() {
496        let mut rng = thread_rng();
497        let a = Scalar::random(&mut rng);
498        let b = Scalar::random(&mut rng);
499
500        let expected_res = a - b;
501
502        // Allocate the scalars in a fabric and subtract them
503        let fabric = mock_fabric();
504        let a_alloc = fabric.allocate_scalar(a);
505        let b_alloc = fabric.allocate_scalar(b);
506
507        let res = a_alloc - b_alloc;
508        let res_final = res.await;
509
510        assert_eq!(res_final, expected_res);
511        fabric.shutdown();
512    }
513
514    /// Tests negation of raw scalars in a circuit
515    #[tokio::test]
516    async fn test_scalar_neg() {
517        let mut rng = thread_rng();
518        let a = Scalar::random(&mut rng);
519
520        let expected_res = -a;
521
522        // Allocate the scalars in a fabric and subtract them
523        let fabric = mock_fabric();
524        let a_alloc = fabric.allocate_scalar(a);
525
526        let res = -a_alloc;
527        let res_final = res.await;
528
529        assert_eq!(res_final, expected_res);
530        fabric.shutdown();
531    }
532
533    /// Tests multiplication of raw scalars in a circuit
534    #[tokio::test]
535    async fn test_scalar_mul() {
536        let mut rng = thread_rng();
537        let a = Scalar::random(&mut rng);
538        let b = Scalar::random(&mut rng);
539
540        let expected_res = a * b;
541
542        // Allocate the scalars in a fabric and multiply them together
543        let fabric = mock_fabric();
544        let a_alloc = fabric.allocate_scalar(a);
545        let b_alloc = fabric.allocate_scalar(b);
546
547        let res = a_alloc * b_alloc;
548        let res_final = res.await;
549
550        assert_eq!(res_final, expected_res);
551        fabric.shutdown();
552    }
553}