ark_mpc/algebra/scalar/
mpc_scalar.rs

1//! Defines an unauthenticated shared scalar type which forms the basis of the
2//! authenticated scalar type
3
4use std::ops::{Add, Mul, Neg, Sub};
5
6use ark_ec::CurveGroup;
7use itertools::Itertools;
8
9use crate::{
10    algebra::macros::*,
11    algebra::BatchScalarResult,
12    algebra::{CurvePoint, CurvePointResult, MpcPointResult},
13    fabric::{MpcFabric, ResultValue},
14    network::NetworkPayload,
15    PARTY0,
16};
17
18use super::scalar::{Scalar, ScalarResult};
19
20/// Defines a secret shared type over the `Scalar` field
21#[derive(Clone, Debug)]
22pub struct MpcScalarResult<C: CurveGroup> {
23    /// The underlying value held by the local party
24    pub(crate) share: ScalarResult<C>,
25}
26
27impl<C: CurveGroup> From<ScalarResult<C>> for MpcScalarResult<C> {
28    fn from(share: ScalarResult<C>) -> Self {
29        Self { share }
30    }
31}
32
33/// Defines the result handle type that represents a future result of an
34/// `MpcScalar`
35impl<C: CurveGroup> MpcScalarResult<C> {
36    /// Creates an MPC scalar from a given underlying scalar assumed to be a
37    /// secret share
38    pub fn new_shared(value: ScalarResult<C>) -> MpcScalarResult<C> {
39        value.into()
40    }
41
42    /// Get the op-id of the underlying share
43    pub fn id(&self) -> usize {
44        self.share.id
45    }
46
47    /// Borrow the fabric that the result is allocated in
48    pub fn fabric(&self) -> &MpcFabric<C> {
49        self.share.fabric()
50    }
51
52    /// Open the value; both parties send their shares to the counterparty
53    pub fn open(&self) -> ScalarResult<C> {
54        // Party zero sends first then receives
55        let (val0, val1) = if self.fabric().party_id() == PARTY0 {
56            let party0_value: ScalarResult<C> =
57                self.fabric().new_network_op(vec![self.id()], |args| {
58                    let share: Scalar<C> = args[0].to_owned().into();
59                    NetworkPayload::Scalar(share)
60                });
61            let party1_value: ScalarResult<C> = self.fabric().receive_value();
62
63            (party0_value, party1_value)
64        } else {
65            let party0_value: ScalarResult<C> = self.fabric().receive_value();
66            let party1_value: ScalarResult<C> =
67                self.fabric().new_network_op(vec![self.id()], |args| {
68                    let share = args[0].to_owned().into();
69                    NetworkPayload::Scalar(share)
70                });
71
72            (party0_value, party1_value)
73        };
74
75        // Create the new value by combining the additive shares
76        &val0 + &val1
77    }
78
79    /// Open a batch of values
80    pub fn open_batch(values: &[MpcScalarResult<C>]) -> Vec<ScalarResult<C>> {
81        if values.is_empty() {
82            return vec![];
83        }
84
85        let n = values.len();
86        let fabric = &values[0].fabric();
87        let my_results = values.iter().map(|v| v.id()).collect_vec();
88        let send_shares_fn = |args: Vec<ResultValue<C>>| {
89            let shares: Vec<Scalar<C>> = args.into_iter().map(Scalar::from).collect();
90            NetworkPayload::ScalarBatch(shares)
91        };
92
93        // Party zero sends first then receives
94        let (party0_vals, party1_vals) = if values[0].fabric().party_id() == PARTY0 {
95            // Send the local shares
96            let party0_vals: BatchScalarResult<C> =
97                fabric.new_network_op(my_results, send_shares_fn);
98            let party1_vals: BatchScalarResult<C> = fabric.receive_value();
99
100            (party0_vals, party1_vals)
101        } else {
102            let party0_vals: BatchScalarResult<C> = fabric.receive_value();
103            let party1_vals: BatchScalarResult<C> =
104                fabric.new_network_op(my_results, send_shares_fn);
105
106            (party0_vals, party1_vals)
107        };
108
109        // Create the new values by combining the additive shares
110        fabric.new_batch_gate_op(vec![party0_vals.id, party1_vals.id], n, move |args| {
111            let party0_vals: Vec<Scalar<C>> = args[0].to_owned().into();
112            let party1_vals: Vec<Scalar<C>> = args[1].to_owned().into();
113
114            let mut results = Vec::with_capacity(n);
115            for i in 0..n {
116                results.push(ResultValue::Scalar(party0_vals[i] + party1_vals[i]));
117            }
118
119            results
120        })
121    }
122
123    /// Convert the underlying value to a `Scalar`
124    pub fn to_scalar(&self) -> ScalarResult<C> {
125        self.share.clone()
126    }
127}
128
129// --------------
130// | Arithmetic |
131// --------------
132
133// === Addition === //
134
135impl<C: CurveGroup> Add<&Scalar<C>> for &MpcScalarResult<C> {
136    type Output = MpcScalarResult<C>;
137
138    // Only party 0 adds the plaintext value as we do not secret share it
139    fn add(self, rhs: &Scalar<C>) -> Self::Output {
140        let rhs = *rhs;
141        let party_id = self.fabric().party_id();
142
143        self.fabric()
144            .new_gate_op(vec![self.id()], move |args| {
145                // Cast the args
146                let lhs_share: Scalar<C> = args[0].to_owned().into();
147                if party_id == PARTY0 {
148                    ResultValue::Scalar(lhs_share + rhs)
149                } else {
150                    ResultValue::Scalar(lhs_share)
151                }
152            })
153            .into()
154    }
155}
156impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
157impl_commutative!(MpcScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
158
159impl<C: CurveGroup> Add<&ScalarResult<C>> for &MpcScalarResult<C> {
160    type Output = MpcScalarResult<C>;
161
162    // Only party 0 adds the plaintext value as we do not secret share it
163    fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
164        let party_id = self.fabric().party_id();
165        self.fabric()
166            .new_gate_op(vec![self.id(), rhs.id], move |mut args| {
167                // Cast the args
168                let lhs: Scalar<C> = args.remove(0).into();
169                let rhs: Scalar<C> = args.remove(0).into();
170
171                if party_id == PARTY0 {
172                    ResultValue::Scalar(lhs + rhs)
173                } else {
174                    ResultValue::Scalar(lhs)
175                }
176            })
177            .into()
178    }
179}
180impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
181impl_commutative!(MpcScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
182
183impl<C: CurveGroup> Add<&MpcScalarResult<C>> for &MpcScalarResult<C> {
184    type Output = MpcScalarResult<C>;
185
186    fn add(self, rhs: &MpcScalarResult<C>) -> Self::Output {
187        self.fabric()
188            .new_gate_op(vec![self.id(), rhs.id()], |args| {
189                // Cast the args
190                let lhs: Scalar<C> = args[0].to_owned().into();
191                let rhs: Scalar<C> = args[1].to_owned().into();
192
193                ResultValue::Scalar(lhs + rhs)
194            })
195            .into()
196    }
197}
198impl_borrow_variants!(MpcScalarResult<C>, Add, add, +, MpcScalarResult<C>, C: CurveGroup);
199
200impl<C: CurveGroup> MpcScalarResult<C> {
201    /// Add two batches of `MpcScalarResult`s using a single batched gate
202    pub fn batch_add(
203        a: &[MpcScalarResult<C>],
204        b: &[MpcScalarResult<C>],
205    ) -> Vec<MpcScalarResult<C>> {
206        assert_eq!(
207            a.len(),
208            b.len(),
209            "batch_add: a and b must be the same length"
210        );
211
212        let n = a.len();
213        let fabric = a[0].fabric();
214        let ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
215
216        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
217            // Split the args
218            let scalars = args.into_iter().map(Scalar::from).collect_vec();
219            let (a_res, b_res) = scalars.split_at(n);
220
221            // Add the values
222            a_res
223                .iter()
224                .zip(b_res.iter())
225                .map(|(a, b)| ResultValue::Scalar(a + b))
226                .collect_vec()
227        });
228
229        scalars.into_iter().map(|s| s.into()).collect_vec()
230    }
231
232    /// Add a batch of `MpcScalarResult`s to a batch of public `ScalarResult`s
233    pub fn batch_add_public(
234        a: &[MpcScalarResult<C>],
235        b: &[ScalarResult<C>],
236    ) -> Vec<MpcScalarResult<C>> {
237        assert_eq!(
238            a.len(),
239            b.len(),
240            "batch_add_public: a and b must be the same length"
241        );
242
243        let n = a.len();
244        let fabric = a[0].fabric();
245        let ids = a
246            .iter()
247            .map(|v| v.id())
248            .chain(b.iter().map(|v| v.id()))
249            .collect_vec();
250
251        let party_id = fabric.party_id();
252        let scalars: Vec<ScalarResult<C>> =
253            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
254                if party_id == PARTY0 {
255                    let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
256
257                    for i in 0..n {
258                        let lhs: Scalar<C> = args[i].to_owned().into();
259                        let rhs: Scalar<C> = args[i + n].to_owned().into();
260
261                        res.push(ResultValue::Scalar(lhs + rhs));
262                    }
263
264                    res
265                } else {
266                    args[..n].to_vec()
267                }
268            });
269
270        scalars.into_iter().map(|s| s.into()).collect_vec()
271    }
272}
273
274// === Subtraction === //
275
276impl<C: CurveGroup> Sub<&Scalar<C>> for &MpcScalarResult<C> {
277    type Output = MpcScalarResult<C>;
278
279    // Only party 0 subtracts the plaintext value as we do not secret share it
280    fn sub(self, rhs: &Scalar<C>) -> Self::Output {
281        let rhs = *rhs;
282        let party_id = self.fabric().party_id();
283
284        if party_id == PARTY0 {
285            &self.share - rhs
286        } else {
287            // Party 1 must perform an operation to keep the result queues in sync
288            &self.share - Scalar::zero()
289        }
290        .into()
291    }
292}
293impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
294
295impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &Scalar<C> {
296    type Output = MpcScalarResult<C>;
297
298    // Only party 0 subtracts the plaintext value as we do not secret share it
299    fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
300        let party_id = rhs.fabric().party_id();
301
302        if party_id == PARTY0 {
303            self - &rhs.share
304        } else {
305            // Party 1 must perform an operation to keep the result queues in sync
306            Scalar::zero() - &rhs.share
307        }
308        .into()
309    }
310}
311
312impl<C: CurveGroup> Sub<&ScalarResult<C>> for &MpcScalarResult<C> {
313    type Output = MpcScalarResult<C>;
314
315    // Only party 0 subtracts the plaintext value as we do not secret share it
316    fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
317        let party_id = self.fabric().party_id();
318
319        if party_id == PARTY0 {
320            &self.share - rhs
321        } else {
322            // Party 1 must perform an operation to keep the result queues in sync
323            self.share.clone() + Scalar::zero()
324        }
325        .into()
326    }
327}
328impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, ScalarResult<C>, C: CurveGroup);
329
330impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &ScalarResult<C> {
331    type Output = MpcScalarResult<C>;
332
333    // Only party 0 subtracts the plaintext value as we do not secret share it
334    fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
335        let party_id = rhs.fabric().party_id();
336
337        if party_id == PARTY0 {
338            self - &rhs.share
339        } else {
340            // Party 1 must perform an operation to keep the result queues in sync
341            Scalar::zero() - rhs.share.clone()
342        }
343        .into()
344    }
345}
346impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, MpcScalarResult<C>, Output=MpcScalarResult<C>, C: CurveGroup);
347
348impl<C: CurveGroup> Sub<&MpcScalarResult<C>> for &MpcScalarResult<C> {
349    type Output = MpcScalarResult<C>;
350
351    fn sub(self, rhs: &MpcScalarResult<C>) -> Self::Output {
352        self.fabric()
353            .new_gate_op(vec![self.id(), rhs.id()], |args| {
354                // Cast the args
355                let lhs: Scalar<C> = args[0].to_owned().into();
356                let rhs: Scalar<C> = args[1].to_owned().into();
357
358                ResultValue::Scalar(lhs - rhs)
359            })
360            .into()
361    }
362}
363impl_borrow_variants!(MpcScalarResult<C>, Sub, sub, -, MpcScalarResult<C>, C: CurveGroup);
364
365impl<C: CurveGroup> MpcScalarResult<C> {
366    /// Subtract two batches of `MpcScalarResult`s using a single batched gate
367    pub fn batch_sub(
368        a: &[MpcScalarResult<C>],
369        b: &[MpcScalarResult<C>],
370    ) -> Vec<MpcScalarResult<C>> {
371        assert_eq!(
372            a.len(),
373            b.len(),
374            "batch_sub: a and b must be the same length"
375        );
376
377        let n = a.len();
378        let fabric = a[0].fabric();
379        let ids = a
380            .iter()
381            .map(|v| v.id())
382            .chain(b.iter().map(|v| v.id()))
383            .collect_vec();
384
385        let scalars: Vec<ScalarResult<C>> =
386            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
387                // Split the args
388                let scalars = args.into_iter().map(Scalar::from).collect_vec();
389                let (a_res, b_res) = scalars.split_at(n);
390
391                // Add the values
392                a_res
393                    .iter()
394                    .zip(b_res.iter())
395                    .map(|(a, b)| ResultValue::Scalar(a - b))
396                    .collect_vec()
397            });
398
399        scalars.into_iter().map(|s| s.into()).collect_vec()
400    }
401
402    /// Subtract a batch of `MpcScalarResult`s from a batch of public
403    /// `ScalarResult`s
404    pub fn batch_sub_public(
405        a: &[MpcScalarResult<C>],
406        b: &[ScalarResult<C>],
407    ) -> Vec<MpcScalarResult<C>> {
408        assert_eq!(
409            a.len(),
410            b.len(),
411            "batch_sub_public: a and b must be the same length"
412        );
413
414        let n = a.len();
415        let fabric = a[0].fabric();
416        let ids = a
417            .iter()
418            .map(|v| v.id())
419            .chain(b.iter().map(|v| v.id()))
420            .collect_vec();
421
422        let party_id = fabric.party_id();
423        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
424            if party_id == PARTY0 {
425                let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
426
427                for i in 0..n {
428                    let lhs: Scalar<C> = args[i].to_owned().into();
429                    let rhs: Scalar<C> = args[i + n].to_owned().into();
430
431                    res.push(ResultValue::Scalar(lhs - rhs));
432                }
433
434                res
435            } else {
436                args[..n].to_vec()
437            }
438        });
439
440        scalars.into_iter().map(|s| s.into()).collect_vec()
441    }
442}
443
444// === Negation === //
445
446impl<C: CurveGroup> Neg for &MpcScalarResult<C> {
447    type Output = MpcScalarResult<C>;
448
449    fn neg(self) -> Self::Output {
450        self.fabric()
451            .new_gate_op(vec![self.id()], |args| {
452                // Cast the args
453                let lhs: Scalar<C> = args[0].to_owned().into();
454                ResultValue::Scalar(-lhs)
455            })
456            .into()
457    }
458}
459impl_borrow_variants!(MpcScalarResult<C>, Neg, neg, -, C: CurveGroup);
460
461impl<C: CurveGroup> MpcScalarResult<C> {
462    /// Negate a batch of `MpcScalarResult<C>`s using a single batched gate
463    pub fn batch_neg(values: &[MpcScalarResult<C>]) -> Vec<MpcScalarResult<C>> {
464        if values.is_empty() {
465            return vec![];
466        }
467
468        let n = values.len();
469        let fabric = values[0].fabric();
470        let ids = values.iter().map(|v| v.id()).collect_vec();
471
472        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
473            // Split the args
474            let scalars = args.into_iter().map(Scalar::from).collect_vec();
475
476            // Add the values
477            scalars
478                .iter()
479                .map(|a| ResultValue::Scalar(-a))
480                .collect_vec()
481        });
482
483        scalars.into_iter().map(|s| s.into()).collect_vec()
484    }
485}
486
487// === Multiplication === //
488
489impl<C: CurveGroup> Mul<&Scalar<C>> for &MpcScalarResult<C> {
490    type Output = MpcScalarResult<C>;
491
492    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
493        let rhs = *rhs;
494        self.fabric()
495            .new_gate_op(vec![self.id()], move |args| {
496                // Cast the args
497                let lhs: Scalar<C> = args[0].to_owned().into();
498                ResultValue::Scalar(lhs * rhs)
499            })
500            .into()
501    }
502}
503impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
504impl_commutative!(MpcScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
505
506impl<C: CurveGroup> Mul<&ScalarResult<C>> for &MpcScalarResult<C> {
507    type Output = MpcScalarResult<C>;
508
509    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
510        self.fabric()
511            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
512                // Cast the args
513                let lhs: Scalar<C> = args.remove(0).into();
514                let rhs: Scalar<C> = args.remove(0).into();
515
516                ResultValue::Scalar(lhs * rhs)
517            })
518            .into()
519    }
520}
521impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
522impl_commutative!(MpcScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
523
524/// Use the beaver trick if both values are shared
525impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &MpcScalarResult<C> {
526    type Output = MpcScalarResult<C>;
527
528    fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
529        // Sample a beaver triplet
530        let (a, b, c) = self.fabric().next_beaver_triple();
531
532        // Open the values d = [lhs - a] and e = [rhs - b]
533        let masked_lhs = self - &a;
534        let masked_rhs = rhs - &b;
535
536        let d_open = masked_lhs.open();
537        let e_open = masked_rhs.open();
538
539        // Identity: [x * y] = de + d[b] + e[a] + [c]
540        &d_open * &b + &e_open * &a + c + &d_open * &e_open
541    }
542}
543impl_borrow_variants!(MpcScalarResult<C>, Mul, mul, *, MpcScalarResult<C>, C: CurveGroup);
544
545impl<C: CurveGroup> MpcScalarResult<C> {
546    /// Multiply a batch of `MpcScalarResult`s over a single network op
547    pub fn batch_mul(
548        a: &[MpcScalarResult<C>],
549        b: &[MpcScalarResult<C>],
550    ) -> Vec<MpcScalarResult<C>> {
551        let n = a.len();
552        assert_eq!(
553            a.len(),
554            b.len(),
555            "batch_mul: a and b must be the same length"
556        );
557
558        // Sample a beaver triplet for each multiplication
559        let fabric = &a[0].fabric();
560        let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
561
562        // Open the values d = [lhs - a] and e = [rhs - b]
563        let masked_lhs = MpcScalarResult::batch_sub(a, &beaver_a);
564        let masked_rhs = MpcScalarResult::batch_sub(b, &beaver_b);
565
566        let all_masks = [masked_lhs, masked_rhs].concat();
567        let opened_values = MpcScalarResult::open_batch(&all_masks);
568        let (d_open, e_open) = opened_values.split_at(n);
569
570        // Identity: [x * y] = de + d[b] + e[a] + [c]
571        let de = ScalarResult::batch_mul(d_open, e_open);
572        let db = MpcScalarResult::batch_mul_public(&beaver_b, d_open);
573        let ea = MpcScalarResult::batch_mul_public(&beaver_a, e_open);
574
575        // Add the terms
576        let de_plus_db = MpcScalarResult::batch_add_public(&db, &de);
577        let ea_plus_c = MpcScalarResult::batch_add(&ea, &beaver_c);
578        MpcScalarResult::batch_add(&de_plus_db, &ea_plus_c)
579    }
580
581    /// Multiply a batch of `MpcScalarResult`s by a batch of public
582    /// `ScalarResult`s
583    pub fn batch_mul_public(
584        a: &[MpcScalarResult<C>],
585        b: &[ScalarResult<C>],
586    ) -> Vec<MpcScalarResult<C>> {
587        assert_eq!(
588            a.len(),
589            b.len(),
590            "batch_mul_public: a and b must be the same length"
591        );
592
593        let n = a.len();
594        let fabric = a[0].fabric();
595        let ids = a
596            .iter()
597            .map(|v| v.id())
598            .chain(b.iter().map(|v| v.id))
599            .collect_vec();
600
601        let scalars: Vec<ScalarResult<C>> =
602            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
603                let mut res: Vec<ResultValue<C>> = Vec::with_capacity(n);
604                for i in 0..n {
605                    let lhs: Scalar<C> = args[i].to_owned().into();
606                    let rhs: Scalar<C> = args[i + n].to_owned().into();
607
608                    res.push(ResultValue::Scalar(lhs * rhs));
609                }
610
611                res
612            });
613
614        scalars.into_iter().map(|s| s.into()).collect_vec()
615    }
616}
617
618// === Curve Scalar Multiplication === //
619
620impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &CurvePoint<C> {
621    type Output = MpcPointResult<C>;
622
623    fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
624        let self_owned = *self;
625        rhs.fabric()
626            .new_gate_op(vec![rhs.id()], move |mut args| {
627                let rhs: Scalar<C> = args.remove(0).into();
628
629                ResultValue::Point(self_owned * rhs)
630            })
631            .into()
632    }
633}
634impl_commutative!(CurvePoint<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
635
636impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &CurvePointResult<C> {
637    type Output = MpcPointResult<C>;
638
639    fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
640        self.fabric
641            .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
642                let lhs: CurvePoint<C> = args.remove(0).into();
643                let rhs: Scalar<C> = args.remove(0).into();
644
645                ResultValue::Point(lhs * rhs)
646            })
647            .into()
648    }
649}
650impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
651impl_commutative!(CurvePointResult<C>, Mul, mul, *, MpcScalarResult<C>, Output=MpcPointResult<C>, C: CurveGroup);
652
653#[cfg(test)]
654mod test {
655    use rand::thread_rng;
656
657    use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
658
659    /// Test subtraction with a non-commutative pair of types
660    #[tokio::test]
661    async fn test_sub() {
662        let mut rng = thread_rng();
663        let value1 = Scalar::random(&mut rng);
664        let value2 = Scalar::random(&mut rng);
665
666        let (res, _) = execute_mock_mpc(|fabric| async move {
667            // Allocate the first value as a shared scalar and the second as a public scalar
668            let party0_value = fabric.share_scalar(value1, PARTY0).mpc_share();
669            let public_value = fabric.allocate_scalar(value2);
670
671            // Subtract the public value from the shared value
672            let res1 = &party0_value - &public_value;
673            let res_open1 = res1.open().await;
674            let expected1 = value1 - value2;
675
676            // Subtract the shared value from the public value
677            let res2 = &public_value - &party0_value;
678            let res_open2 = res2.open().await;
679            let expected2 = value2 - value1;
680
681            (res_open1 == expected1, res_open2 == expected2)
682        })
683        .await;
684
685        assert!(res.0);
686        assert!(res.1)
687    }
688}