ark_mpc/algebra/scalar/
scalar.rs

1//! Defines the scalar types that form the basis of the MPC algebra
2
3// ----------------------------
4// | Scalar Field Definitions |
5// ----------------------------
6
7use std::{
8    fmt::{Display, Formatter, Result as FmtResult},
9    iter::{Product, Sum},
10    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12
13use ark_ec::CurveGroup;
14use ark_ff::{batch_inversion, FftField, Field, PrimeField};
15use ark_poly::EvaluationDomain;
16use ark_std::UniformRand;
17use itertools::Itertools;
18use num_bigint::BigUint;
19use rand::{CryptoRng, RngCore};
20use serde::{Deserialize, Serialize};
21
22use crate::algebra::macros::*;
23use crate::fabric::{ResultHandle, ResultValue};
24
25// -----------
26// | Helpers |
27// -----------
28
29/// Computes the number of bytes needed to represent  field element
30#[inline]
31pub const fn n_bytes_field<F: PrimeField>() -> usize {
32    // We add 7 and divide by 8 to emulate a ceiling operation considering that u32
33    // division is a floor
34    let n_bits = F::MODULUS_BIT_SIZE as usize;
35    (n_bits + 7) / 8
36}
37
38// ---------------------
39// | Scalar Definition |
40// ---------------------
41
42#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
43/// A wrapper around the inner scalar that allows us to implement foreign traits
44/// for the `Scalar`
45pub struct Scalar<C: CurveGroup>(pub(crate) C::ScalarField);
46
47impl<C: CurveGroup> Scalar<C> {
48    /// The underlying field that the scalar wraps
49    pub type Field = C::ScalarField;
50
51    /// Construct a scalar from an inner field element
52    pub fn new(inner: C::ScalarField) -> Self {
53        Scalar(inner)
54    }
55
56    /// The scalar field's additive identity
57    pub fn zero() -> Self {
58        Scalar(C::ScalarField::from(0u8))
59    }
60
61    /// The scalar field's multiplicative identity
62    pub fn one() -> Self {
63        Scalar(C::ScalarField::from(1u8))
64    }
65
66    /// Get the inner value of the scalar
67    pub fn inner(&self) -> C::ScalarField {
68        self.0
69    }
70
71    /// Sample a random field element
72    pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
73        Self(C::ScalarField::rand(rng))
74    }
75
76    /// Compute the multiplicative inverse of the scalar in its field
77    pub fn inverse(&self) -> Self {
78        Scalar(self.0.inverse().unwrap())
79    }
80
81    /// Compute the batch inversion of a list of Scalars
82    pub fn batch_inverse(vals: &mut [Self]) {
83        let mut values = vals.iter().map(|x| x.0).collect_vec();
84        batch_inversion(&mut values);
85
86        for (i, val) in vals.iter_mut().enumerate() {
87            *val = Scalar(values[i]);
88        }
89    }
90
91    /// Compute the exponentiation of the given scalar
92    pub fn pow(&self, exp: u64) -> Self {
93        Scalar::new(self.0.pow([exp]))
94    }
95
96    /// Construct a scalar from the given bytes and reduce modulo the field's
97    /// modulus
98    pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Self {
99        let inner = C::ScalarField::from_be_bytes_mod_order(bytes);
100        Scalar(inner)
101    }
102
103    /// Convert to big endian bytes
104    ///
105    /// Pad to the maximum amount of bytes needed so that the resulting bytes
106    /// are of predictable length
107    pub fn to_bytes_be(&self) -> Vec<u8> {
108        let val_biguint = self.to_biguint();
109        let mut bytes = val_biguint.to_bytes_be();
110
111        let n_bytes = n_bytes_field::<C::ScalarField>();
112        let mut padding = vec![0u8; n_bytes - bytes.len()];
113        padding.append(&mut bytes);
114
115        padding
116    }
117
118    /// Convert the underlying value to a BigUint
119    pub fn to_biguint(&self) -> BigUint {
120        self.0.into()
121    }
122
123    /// Convert from a `BigUint`
124    pub fn from_biguint(val: &BigUint) -> Self {
125        let le_bytes = val.to_bytes_le();
126        let inner = C::ScalarField::from_le_bytes_mod_order(&le_bytes);
127        Scalar(inner)
128    }
129}
130
131impl<C: CurveGroup> Display for Scalar<C> {
132    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
133        write!(f, "{}", self.to_biguint())
134    }
135}
136
137impl<C: CurveGroup> Serialize for Scalar<C> {
138    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
139        let bytes = self.to_bytes_be();
140        bytes.serialize(serializer)
141    }
142}
143
144impl<'de, C: CurveGroup> Deserialize<'de> for Scalar<C> {
145    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
146        let bytes = <Vec<u8>>::deserialize(deserializer)?;
147        let scalar = Scalar::from_be_bytes_mod_order(&bytes);
148        Ok(scalar)
149    }
150}
151
152/// A type alias for a result that resolves to a `Scalar`
153pub type ScalarResult<C> = ResultHandle<C, Scalar<C>>;
154/// A type alias for a result that resolves to a batch of `Scalar`s
155pub type BatchScalarResult<C> = ResultHandle<C, Vec<Scalar<C>>>;
156
157impl<C: CurveGroup> ScalarResult<C> {
158    /// Exponentiation
159    pub fn pow(&self, exp: u64) -> Self {
160        self.fabric().new_gate_op(vec![self.id()], move |mut args| {
161            let base: Scalar<C> = args.pop().unwrap().into();
162            let res = base.inner().pow([exp]);
163
164            ResultValue::Scalar(Scalar::new(res))
165        })
166    }
167}
168
169// --------------
170// | Arithmetic |
171// --------------
172
173// === Addition === //
174
175impl<C: CurveGroup> ScalarResult<C> {
176    /// Compute the multiplicative inverse of the scalar in its field
177    pub fn inverse(&self) -> ScalarResult<C> {
178        self.fabric.new_gate_op(vec![self.id], |mut args| {
179            let val: Scalar<C> = args.remove(0).into();
180            ResultValue::Scalar(Scalar(val.0.inverse().unwrap()))
181        })
182    }
183}
184
185impl<C: CurveGroup> Add<&Scalar<C>> for &Scalar<C> {
186    type Output = Scalar<C>;
187
188    fn add(self, rhs: &Scalar<C>) -> Self::Output {
189        let rhs = *rhs;
190        Scalar(self.0 + rhs.0)
191    }
192}
193impl_borrow_variants!(Scalar<C>, Add, add, +, Scalar<C>, C: CurveGroup);
194
195impl<C: CurveGroup> Add<&Scalar<C>> for &ScalarResult<C> {
196    type Output = ScalarResult<C>;
197
198    fn add(self, rhs: &Scalar<C>) -> Self::Output {
199        let rhs = *rhs;
200        self.fabric.new_gate_op(vec![self.id], move |args| {
201            let lhs: Scalar<C> = args[0].to_owned().into();
202            ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
203        })
204    }
205}
206impl_borrow_variants!(ScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
207impl_commutative!(ScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
208
209impl<C: CurveGroup> Add<&ScalarResult<C>> for &ScalarResult<C> {
210    type Output = ScalarResult<C>;
211
212    fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
213        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
214            let lhs: Scalar<C> = args[0].to_owned().into();
215            let rhs: Scalar<C> = args[1].to_owned().into();
216            ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
217        })
218    }
219}
220impl_borrow_variants!(ScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
221
222impl<C: CurveGroup> ScalarResult<C> {
223    /// Add two batches of `ScalarResult<C>`s
224    pub fn batch_add(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
225        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
226
227        let n = a.len();
228        let fabric = &a[0].fabric;
229        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
230        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
231            let mut res = Vec::with_capacity(n);
232            for i in 0..n {
233                let lhs: Scalar<C> = args[i].to_owned().into();
234                let rhs: Scalar<C> = args[i + n].to_owned().into();
235                res.push(ResultValue::Scalar(Scalar(lhs.0 + rhs.0)));
236            }
237
238            res
239        })
240    }
241}
242
243// === AddAssign === //
244
245impl<C: CurveGroup> AddAssign for Scalar<C> {
246    fn add_assign(&mut self, rhs: Scalar<C>) {
247        *self = *self + rhs;
248    }
249}
250
251// === Subtraction === //
252
253impl<C: CurveGroup> Sub<&Scalar<C>> for &Scalar<C> {
254    type Output = Scalar<C>;
255
256    fn sub(self, rhs: &Scalar<C>) -> Self::Output {
257        let rhs = *rhs;
258        Scalar(self.0 - rhs.0)
259    }
260}
261impl_borrow_variants!(Scalar<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
262
263impl<C: CurveGroup> Sub<&Scalar<C>> for &ScalarResult<C> {
264    type Output = ScalarResult<C>;
265
266    fn sub(self, rhs: &Scalar<C>) -> Self::Output {
267        let rhs = *rhs;
268        self.fabric.new_gate_op(vec![self.id], move |args| {
269            let lhs: Scalar<C> = args[0].to_owned().into();
270            ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
271        })
272    }
273}
274impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
275
276impl<C: CurveGroup> Sub<&ScalarResult<C>> for &Scalar<C> {
277    type Output = ScalarResult<C>;
278
279    fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
280        let lhs = *self;
281        rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
282            let rhs: Scalar<C> = args[0].to_owned().into();
283            ResultValue::Scalar(lhs - rhs)
284        })
285    }
286}
287impl_borrow_variants!(Scalar<C>, Sub, sub, -, ScalarResult<C>, Output=ScalarResult<C>, C: CurveGroup);
288
289impl<C: CurveGroup> Sub<&ScalarResult<C>> for &ScalarResult<C> {
290    type Output = ScalarResult<C>;
291
292    fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
293        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
294            let lhs: Scalar<C> = args[0].to_owned().into();
295            let rhs: Scalar<C> = args[1].to_owned().into();
296            ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
297        })
298    }
299}
300impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, ScalarResult<C>, C: CurveGroup);
301
302impl<C: CurveGroup> ScalarResult<C> {
303    /// Subtract two batches of `ScalarResult`s
304    pub fn batch_sub(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
305        assert_eq!(a.len(), b.len(), "Batch sub requires equal length inputs");
306
307        let n = a.len();
308        let fabric = &a[0].fabric;
309        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
310        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
311            let mut res = Vec::with_capacity(n);
312            for i in 0..n {
313                let lhs: Scalar<C> = args[i].to_owned().into();
314                let rhs: Scalar<C> = args[i + n].to_owned().into();
315                res.push(ResultValue::Scalar(Scalar(lhs.0 - rhs.0)));
316            }
317
318            res
319        })
320    }
321}
322
323// === SubAssign === //
324
325impl<C: CurveGroup> SubAssign for Scalar<C> {
326    fn sub_assign(&mut self, rhs: Scalar<C>) {
327        *self = *self - rhs;
328    }
329}
330
331// === Multiplication === //
332
333impl<C: CurveGroup> Mul<&Scalar<C>> for &Scalar<C> {
334    type Output = Scalar<C>;
335
336    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
337        let rhs = *rhs;
338        Scalar(self.0 * rhs.0)
339    }
340}
341impl_borrow_variants!(Scalar<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
342
343impl<C: CurveGroup> Mul<&Scalar<C>> for &ScalarResult<C> {
344    type Output = ScalarResult<C>;
345
346    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
347        let rhs = *rhs;
348        self.fabric.new_gate_op(vec![self.id], move |args| {
349            let lhs: Scalar<C> = args[0].to_owned().into();
350            ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
351        })
352    }
353}
354impl_borrow_variants!(ScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
355impl_commutative!(ScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
356
357impl<C: CurveGroup> Mul<&ScalarResult<C>> for &ScalarResult<C> {
358    type Output = ScalarResult<C>;
359
360    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
361        self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
362            let lhs: Scalar<C> = args[0].to_owned().into();
363            let rhs: Scalar<C> = args[1].to_owned().into();
364            ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
365        })
366    }
367}
368impl_borrow_variants!(ScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
369
370impl<C: CurveGroup> ScalarResult<C> {
371    /// Multiply two batches of `ScalarResult`s
372    pub fn batch_mul(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
373        assert_eq!(a.len(), b.len(), "Batch mul requires equal length inputs");
374
375        let n = a.len();
376        let fabric = &a[0].fabric;
377        let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
378        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
379            let mut res = Vec::with_capacity(n);
380            for i in 0..n {
381                let lhs: Scalar<C> = args[i].to_owned().into();
382                let rhs: Scalar<C> = args[i + n].to_owned().into();
383                res.push(ResultValue::Scalar(Scalar(lhs.0 * rhs.0)));
384            }
385
386            res
387        })
388    }
389}
390
391impl<C: CurveGroup> Neg for &Scalar<C> {
392    type Output = Scalar<C>;
393
394    fn neg(self) -> Self::Output {
395        Scalar(-self.0)
396    }
397}
398impl_borrow_variants!(Scalar<C>, Neg, neg, -, C: CurveGroup);
399
400impl<C: CurveGroup> Neg for &ScalarResult<C> {
401    type Output = ScalarResult<C>;
402
403    fn neg(self) -> Self::Output {
404        self.fabric.new_gate_op(vec![self.id], |args| {
405            let lhs: Scalar<C> = args[0].to_owned().into();
406            ResultValue::Scalar(Scalar(-lhs.0))
407        })
408    }
409}
410impl_borrow_variants!(ScalarResult<C>, Neg, neg, -, C: CurveGroup);
411
412impl<C: CurveGroup> ScalarResult<C> {
413    /// Negate a batch of `ScalarResult`s
414    pub fn batch_neg(a: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
415        let n = a.len();
416        let fabric = &a[0].fabric;
417        let ids = a.iter().map(|v| v.id).collect_vec();
418        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
419            args.into_iter()
420                .map(Scalar::from)
421                .map(|x| -x)
422                .map(ResultValue::Scalar)
423                .collect_vec()
424        })
425    }
426}
427
428// === MulAssign === //
429
430impl<C: CurveGroup> MulAssign for Scalar<C> {
431    fn mul_assign(&mut self, rhs: Scalar<C>) {
432        *self = *self * rhs;
433    }
434}
435
436// === Division === //
437impl<C: CurveGroup> Div<&Scalar<C>> for &Scalar<C> {
438    type Output = Scalar<C>;
439
440    fn div(self, rhs: &Scalar<C>) -> Self::Output {
441        let rhs = *rhs;
442        Scalar(self.0 / rhs.0)
443    }
444}
445impl_borrow_variants!(Scalar<C>, Div, div, /, Scalar<C>, C: CurveGroup);
446
447// === FFT and IFFT === //
448impl<C: CurveGroup> ScalarResult<C>
449where
450    C::ScalarField: FftField,
451{
452    /// Compute the fft of a sequence of `ScalarResult`s
453    pub fn fft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
454        x: &[ScalarResult<C>],
455    ) -> Vec<ScalarResult<C>> {
456        Self::fft_with_domain(x, D::new(x.len()).unwrap())
457    }
458
459    /// Compute the fft of a sequence of `ScalarResult`s with the given domain
460    pub fn fft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
461        x: &[ScalarResult<C>],
462        domain: D,
463    ) -> Vec<ScalarResult<C>> {
464        assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
465        let n = domain.size();
466
467        let fabric = x[0].fabric();
468        let ids = x.iter().map(|v| v.id).collect_vec();
469
470        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
471            let scalars = args
472                .into_iter()
473                .map(Scalar::from)
474                .map(|x| x.0)
475                .collect_vec();
476
477            domain
478                .fft(&scalars)
479                .into_iter()
480                .map(|x| ResultValue::Scalar(Scalar::new(x)))
481                .collect_vec()
482        })
483    }
484
485    /// Compute the ifft of a sequence of `ScalarResult`s
486    pub fn ifft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
487        x: &[ScalarResult<C>],
488    ) -> Vec<ScalarResult<C>> {
489        Self::ifft_with_domain(x, D::new(x.len()).unwrap())
490    }
491
492    /// Compute the ifft of a sequence of `ScalarResult`s with the given domain
493    pub fn ifft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
494        x: &[ScalarResult<C>],
495        domain: D,
496    ) -> Vec<ScalarResult<C>> {
497        assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
498        let n = domain.size();
499
500        let fabric = x[0].fabric();
501        let ids = x.iter().map(|v| v.id).collect_vec();
502
503        fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
504            let scalars = args
505                .into_iter()
506                .map(Scalar::from)
507                .map(|x| x.0)
508                .collect_vec();
509
510            domain
511                .ifft(&scalars)
512                .into_iter()
513                .map(|x| ResultValue::Scalar(Scalar::new(x)))
514                .collect_vec()
515        })
516    }
517}
518
519// ---------------
520// | Conversions |
521// ---------------
522
523impl<C: CurveGroup> From<bool> for Scalar<C> {
524    fn from(value: bool) -> Self {
525        Scalar(C::ScalarField::from(value))
526    }
527}
528
529impl<C: CurveGroup> From<u8> for Scalar<C> {
530    fn from(value: u8) -> Self {
531        Scalar(C::ScalarField::from(value))
532    }
533}
534
535impl<C: CurveGroup> From<u16> for Scalar<C> {
536    fn from(value: u16) -> Self {
537        Scalar(C::ScalarField::from(value))
538    }
539}
540
541impl<C: CurveGroup> From<u32> for Scalar<C> {
542    fn from(value: u32) -> Self {
543        Scalar(C::ScalarField::from(value))
544    }
545}
546
547impl<C: CurveGroup> From<u64> for Scalar<C> {
548    fn from(value: u64) -> Self {
549        Scalar(C::ScalarField::from(value))
550    }
551}
552
553impl<C: CurveGroup> From<u128> for Scalar<C> {
554    fn from(value: u128) -> Self {
555        Scalar(C::ScalarField::from(value))
556    }
557}
558
559impl<C: CurveGroup> From<usize> for Scalar<C> {
560    fn from(value: usize) -> Self {
561        Scalar(C::ScalarField::from(value as u64))
562    }
563}
564
565// -------------------
566// | Iterator Traits |
567// -------------------
568
569impl<C: CurveGroup> Sum for Scalar<C> {
570    fn sum<I: Iterator<Item = Scalar<C>>>(iter: I) -> Self {
571        iter.fold(Scalar::zero(), |acc, x| acc + x)
572    }
573}
574
575impl<C: CurveGroup> Product for Scalar<C> {
576    fn product<I: Iterator<Item = Scalar<C>>>(iter: I) -> Self {
577        iter.fold(Scalar::one(), |acc, x| acc * x)
578    }
579}
580
581#[cfg(test)]
582mod test {
583    use crate::{
584        algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, ScalarResult},
585        test_helpers::{execute_mock_mpc, mock_fabric, TestCurve},
586    };
587    use ark_ff::Field;
588    use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
589    use futures::future;
590    use itertools::Itertools;
591    use rand::{thread_rng, Rng, RngCore};
592
593    /// Tests addition of raw scalars in a circuit
594    #[tokio::test]
595    async fn test_scalar_add() {
596        let mut rng = thread_rng();
597        let a = Scalar::random(&mut rng);
598        let b = Scalar::random(&mut rng);
599
600        let expected_res = a + b;
601
602        // Allocate the scalars in a fabric and add them together
603        let fabric = mock_fabric();
604        let a_alloc = fabric.allocate_scalar(a);
605        let b_alloc = fabric.allocate_scalar(b);
606
607        let res = &a_alloc + &b_alloc;
608        let res_final = res.await;
609
610        assert_eq!(res_final, expected_res);
611        fabric.shutdown();
612    }
613
614    /// Tests subtraction of raw scalars in the circuit
615    #[tokio::test]
616    async fn test_scalar_sub() {
617        let mut rng = thread_rng();
618        let a = Scalar::random(&mut rng);
619        let b = Scalar::random(&mut rng);
620
621        let expected_res = a - b;
622
623        // Allocate the scalars in a fabric and subtract them
624        let fabric = mock_fabric();
625        let a_alloc = fabric.allocate_scalar(a);
626        let b_alloc = fabric.allocate_scalar(b);
627
628        let res = a_alloc - b_alloc;
629        let res_final = res.await;
630
631        assert_eq!(res_final, expected_res);
632        fabric.shutdown();
633    }
634
635    /// Tests negation of raw scalars in a circuit
636    #[tokio::test]
637    async fn test_scalar_neg() {
638        let mut rng = thread_rng();
639        let a = Scalar::random(&mut rng);
640
641        let expected_res = -a;
642
643        // Allocate the scalars in a fabric and subtract them
644        let fabric = mock_fabric();
645        let a_alloc = fabric.allocate_scalar(a);
646
647        let res = -a_alloc;
648        let res_final = res.await;
649
650        assert_eq!(res_final, expected_res);
651        fabric.shutdown();
652    }
653
654    /// Tests multiplication of raw scalars in a circuit
655    #[tokio::test]
656    async fn test_scalar_mul() {
657        let mut rng = thread_rng();
658        let a = Scalar::random(&mut rng);
659        let b = Scalar::random(&mut rng);
660
661        let expected_res = a * b;
662
663        // Allocate the scalars in a fabric and multiply them together
664        let fabric = mock_fabric();
665        let a_alloc = fabric.allocate_scalar(a);
666        let b_alloc = fabric.allocate_scalar(b);
667
668        let res = a_alloc * b_alloc;
669        let res_final = res.await;
670
671        assert_eq!(res_final, expected_res);
672        fabric.shutdown();
673    }
674
675    /// Tests exponentiation or raw scalars in a circuit
676    #[tokio::test]
677    async fn test_exp() {
678        let mut rng = thread_rng();
679        let base = Scalar::<TestCurve>::random(&mut rng);
680        let exp = rng.next_u64();
681
682        let expected_res = base.inner().pow([exp as u64]);
683
684        let (res, _) = execute_mock_mpc(|fabric| async move {
685            let base_allocated = fabric.allocate_scalar(base);
686            let res = base_allocated.pow(exp);
687            res.await
688        })
689        .await;
690
691        assert_eq!(res, Scalar::new(expected_res));
692    }
693
694    /// Tests fft of scalars allocated in a circuit
695    #[tokio::test]
696    async fn test_circuit_fft() {
697        let mut rng = thread_rng();
698        let n: usize = rng.gen_range(1..=100);
699        let domain_size = rng.gen_range(n..10 * n);
700
701        let seq = (0..n)
702            .map(|_| Scalar::<TestCurve>::random(&mut rng))
703            .collect_vec();
704
705        let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
706        let fft_res = domain.fft(&seq.iter().map(|s| s.inner()).collect_vec());
707        let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec();
708
709        let (res, _) = execute_mock_mpc(|fabric| {
710            let seq = seq.clone();
711            async move {
712                let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();
713
714                let res = ScalarResult::fft_with_domain::<Radix2EvaluationDomain<TestPolyField>>(
715                    &seq_alloc, domain,
716                );
717                future::join_all(res.into_iter()).await
718            }
719        })
720        .await;
721
722        assert_eq!(res.len(), expected_res.len());
723        assert_eq!(res, expected_res);
724    }
725
726    /// Tests the ifft of scalars allocated in a circuit
727    #[tokio::test]
728    async fn test_circuit_ifft() {
729        let mut rng = thread_rng();
730        let n: usize = rng.gen_range(1..=100);
731        let domain_size = rng.gen_range(n..10 * n);
732
733        let seq = (0..n)
734            .map(|_| Scalar::<TestCurve>::random(&mut rng))
735            .collect_vec();
736
737        let domain = Radix2EvaluationDomain::<TestPolyField>::new(domain_size).unwrap();
738        let ifft_res = domain.ifft(&seq.iter().map(|s| s.inner()).collect_vec());
739        let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec();
740
741        let (res, _) = execute_mock_mpc(|fabric| {
742            let seq = seq.clone();
743            async move {
744                let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();
745
746                let res = ScalarResult::ifft_with_domain::<Radix2EvaluationDomain<TestPolyField>>(
747                    &seq_alloc, domain,
748                );
749                future::join_all(res.into_iter()).await
750            }
751        })
752        .await;
753
754        assert_eq!(res.len(), expected_res.len());
755        assert_eq!(res, expected_res);
756    }
757}