mpc_stark/algebra/
mpc_scalar.rs

1//! Defines an unauthenticated shared scalar type which forms the basis of the
2//! authenticated scalar type
3
4use std::ops::{Add, Mul, Neg, Sub};
5
6use itertools::Itertools;
7
8use crate::{
9    algebra::scalar::BatchScalarResult,
10    fabric::{MpcFabric, ResultHandle, ResultValue},
11    network::NetworkPayload,
12    PARTY0,
13};
14
15use super::{
16    macros::{impl_borrow_variants, impl_commutative},
17    mpc_stark_point::MpcStarkPointResult,
18    scalar::{Scalar, ScalarResult},
19    stark_curve::{StarkPoint, StarkPointResult},
20};
21
22/// Defines a secret shared type over the `Scalar` field
23#[derive(Clone, Debug)]
24pub struct MpcScalarResult {
25    /// The underlying value held by the local party
26    pub(crate) share: ScalarResult,
27}
28
29impl From<ScalarResult> for MpcScalarResult {
30    fn from(share: ScalarResult) -> Self {
31        Self { share }
32    }
33}
34
35/// Defines the result handle type that represents a future result of an `MpcScalarResult`
36impl MpcScalarResult {
37    /// Creates an MPC scalar from a given underlying scalar assumed to be a secret share
38    pub fn new_shared(value: ScalarResult) -> MpcScalarResult {
39        value.into()
40    }
41
42    /// Get the op-id of the underlying share
43    pub fn id(&self) -> usize {
44        self.share.id
45    }
46
47    /// Borrow the fabric that the result is allocated in
48    pub fn fabric(&self) -> &MpcFabric {
49        self.share.fabric()
50    }
51
52    /// Open the value; both parties send their shares to the counterparty
53    pub fn open(&self) -> ResultHandle<Scalar> {
54        // Party zero sends first then receives
55        let (val0, val1) = if self.fabric().party_id() == PARTY0 {
56            let party0_value: ResultHandle<Scalar> =
57                self.fabric().new_network_op(vec![self.id()], |args| {
58                    let share: Scalar = args[0].to_owned().into();
59                    NetworkPayload::Scalar(share)
60                });
61            let party1_value: ResultHandle<Scalar> = self.fabric().receive_value();
62
63            (party0_value, party1_value)
64        } else {
65            let party0_value: ResultHandle<Scalar> = self.fabric().receive_value();
66            let party1_value: ResultHandle<Scalar> =
67                self.fabric().new_network_op(vec![self.id()], |args| {
68                    let share = args[0].to_owned().into();
69                    NetworkPayload::Scalar(share)
70                });
71
72            (party0_value, party1_value)
73        };
74
75        // Create the new value by combining the additive shares
76        &val0 + &val1
77    }
78
79    /// Open a batch of values
80    pub fn open_batch(values: &[MpcScalarResult]) -> Vec<ScalarResult> {
81        if values.is_empty() {
82            return vec![];
83        }
84
85        let n = values.len();
86        let fabric = &values[0].fabric();
87        let my_results = values.iter().map(|v| v.id()).collect_vec();
88        let send_shares_fn = |args: Vec<ResultValue>| {
89            let shares: Vec<Scalar> = args.into_iter().map(Scalar::from).collect();
90            NetworkPayload::ScalarBatch(shares)
91        };
92
93        // Party zero sends first then receives
94        let (party0_vals, party1_vals) = if values[0].fabric().party_id() == PARTY0 {
95            // Send the local shares
96            let party0_vals: BatchScalarResult = fabric.new_network_op(my_results, send_shares_fn);
97            let party1_vals: BatchScalarResult = fabric.receive_value();
98
99            (party0_vals, party1_vals)
100        } else {
101            let party0_vals: BatchScalarResult = fabric.receive_value();
102            let party1_vals: BatchScalarResult = fabric.new_network_op(my_results, send_shares_fn);
103
104            (party0_vals, party1_vals)
105        };
106
107        // Create the new values by combining the additive shares
108        fabric.new_batch_gate_op(vec![party0_vals.id, party1_vals.id], n, move |args| {
109            let party0_vals: Vec<Scalar> = args[0].to_owned().into();
110            let party1_vals: Vec<Scalar> = args[1].to_owned().into();
111
112            let mut results = Vec::with_capacity(n);
113            for i in 0..n {
114                results.push(ResultValue::Scalar(party0_vals[i] + party1_vals[i]));
115            }
116
117            results
118        })
119    }
120
121    /// Convert the underlying value to a `Scalar`
122    pub fn to_scalar(&self) -> ScalarResult {
123        self.share.clone()
124    }
125}
126
127// --------------
128// | Arithmetic |
129// --------------
130
131// === Addition === //
132
133impl Add<&Scalar> for &MpcScalarResult {
134    type Output = MpcScalarResult;
135
136    // Only party 0 adds the plaintext value as we do not secret share it
137    fn add(self, rhs: &Scalar) -> Self::Output {
138        let rhs = *rhs;
139        let party_id = self.fabric().party_id();
140
141        self.fabric()
142            .new_gate_op(vec![self.id()], move |args| {
143                // Cast the args
144                let lhs_share: Scalar = args[0].to_owned().into();
145                if party_id == PARTY0 {
146                    ResultValue::Scalar(lhs_share + rhs)
147                } else {
148                    ResultValue::Scalar(lhs_share)
149                }
150            })
151            .into()
152    }
153}
154impl_borrow_variants!(MpcScalarResult, Add, add, +, Scalar);
155impl_commutative!(MpcScalarResult, Add, add, +, Scalar);
156
157impl Add<&ScalarResult> for &MpcScalarResult {
158    type Output = MpcScalarResult;
159
160    // Only party 0 adds the plaintext value as we do not secret share it
161    fn add(self, rhs: &ScalarResult) -> Self::Output {
162        let party_id = self.fabric().party_id();
163        self.fabric()
164            .new_gate_op(vec![self.id(), rhs.id], move |mut args| {
165                // Cast the args
166                let lhs: Scalar = args.remove(0).into();
167                let rhs: Scalar = args.remove(0).into();
168
169                if party_id == PARTY0 {
170                    ResultValue::Scalar(lhs + rhs)
171                } else {
172                    ResultValue::Scalar(lhs)
173                }
174            })
175            .into()
176    }
177}
178impl_borrow_variants!(MpcScalarResult, Add, add, +, ScalarResult);
179impl_commutative!(MpcScalarResult, Add, add, +, ScalarResult);
180
181impl Add<&MpcScalarResult> for &MpcScalarResult {
182    type Output = MpcScalarResult;
183
184    fn add(self, rhs: &MpcScalarResult) -> Self::Output {
185        self.fabric()
186            .new_gate_op(vec![self.id(), rhs.id()], |args| {
187                // Cast the args
188                let lhs: Scalar = args[0].to_owned().into();
189                let rhs: Scalar = args[1].to_owned().into();
190
191                ResultValue::Scalar(lhs + rhs)
192            })
193            .into()
194    }
195}
196impl_borrow_variants!(MpcScalarResult, Add, add, +, MpcScalarResult);
197
198impl MpcScalarResult {
199    /// Add two batches of `MpcScalarResult`s using a single batched gate
200    pub fn batch_add(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
201        assert_eq!(
202            a.len(),
203            b.len(),
204            "batch_add: a and b must be the same length"
205        );
206
207        let n = a.len();
208        let fabric = a[0].fabric();
209        let ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
210
211        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
212            // Split the args
213            let scalars = args.into_iter().map(Scalar::from).collect_vec();
214            let (a_res, b_res) = scalars.split_at(n);
215
216            // Add the values
217            a_res
218                .iter()
219                .zip(b_res.iter())
220                .map(|(a, b)| ResultValue::Scalar(a + b))
221                .collect_vec()
222        });
223
224        scalars.into_iter().map(|s| s.into()).collect_vec()
225    }
226
227    /// Add a batch of `MpcScalarResult`s to a batch of public `ScalarResult`s
228    pub fn batch_add_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
229        assert_eq!(
230            a.len(),
231            b.len(),
232            "batch_add_public: a and b must be the same length"
233        );
234
235        let n = a.len();
236        let fabric = a[0].fabric();
237        let ids = a
238            .iter()
239            .map(|v| v.id())
240            .chain(b.iter().map(|v| v.id()))
241            .collect_vec();
242
243        let party_id = fabric.party_id();
244        let scalars: Vec<ScalarResult> =
245            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
246                if party_id == PARTY0 {
247                    let mut res: Vec<ResultValue> = Vec::with_capacity(n);
248
249                    for i in 0..n {
250                        let lhs: Scalar = args[i].to_owned().into();
251                        let rhs: Scalar = args[i + n].to_owned().into();
252
253                        res.push(ResultValue::Scalar(lhs + rhs));
254                    }
255
256                    res
257                } else {
258                    args[..n].to_vec()
259                }
260            });
261
262        scalars.into_iter().map(|s| s.into()).collect_vec()
263    }
264}
265
266// === Subtraction === //
267
268impl Sub<&Scalar> for &MpcScalarResult {
269    type Output = MpcScalarResult;
270
271    // Only party 0 subtracts the plaintext value as we do not secret share it
272    fn sub(self, rhs: &Scalar) -> Self::Output {
273        let rhs = *rhs;
274        let party_id = self.fabric().party_id();
275
276        if party_id == PARTY0 {
277            &self.share - rhs
278        } else {
279            // Party 1 must perform an operation to keep the result queues in sync
280            &self.share - Scalar::zero()
281        }
282        .into()
283    }
284}
285impl_borrow_variants!(MpcScalarResult, Sub, sub, -, Scalar);
286
287impl Sub<&MpcScalarResult> for &Scalar {
288    type Output = MpcScalarResult;
289
290    // Only party 0 subtracts the plaintext value as we do not secret share it
291    fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
292        let party_id = rhs.fabric().party_id();
293
294        if party_id == PARTY0 {
295            self - &rhs.share
296        } else {
297            // Party 1 must perform an operation to keep the result queues in sync
298            Scalar::zero() - &rhs.share
299        }
300        .into()
301    }
302}
303
304impl Sub<&ScalarResult> for &MpcScalarResult {
305    type Output = MpcScalarResult;
306
307    // Only party 0 subtracts the plaintext value as we do not secret share it
308    fn sub(self, rhs: &ScalarResult) -> Self::Output {
309        let party_id = self.fabric().party_id();
310
311        if party_id == PARTY0 {
312            &self.share - rhs
313        } else {
314            // Party 1 must perform an operation to keep the result queues in sync
315            self.share.clone() + Scalar::zero()
316        }
317        .into()
318    }
319}
320impl_borrow_variants!(MpcScalarResult, Sub, sub, -, ScalarResult);
321
322impl Sub<&MpcScalarResult> for &ScalarResult {
323    type Output = MpcScalarResult;
324
325    // Only party 0 subtracts the plaintext value as we do not secret share it
326    fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
327        let party_id = rhs.fabric().party_id();
328
329        if party_id == PARTY0 {
330            self - &rhs.share
331        } else {
332            // Party 1 must perform an operation to keep the result queues in sync
333            Scalar::zero() - rhs.share.clone()
334        }
335        .into()
336    }
337}
338impl_borrow_variants!(ScalarResult, Sub, sub, -, MpcScalarResult, Output=MpcScalarResult);
339
340impl Sub<&MpcScalarResult> for &MpcScalarResult {
341    type Output = MpcScalarResult;
342
343    fn sub(self, rhs: &MpcScalarResult) -> Self::Output {
344        self.fabric()
345            .new_gate_op(vec![self.id(), rhs.id()], |args| {
346                // Cast the args
347                let lhs: Scalar = args[0].to_owned().into();
348                let rhs: Scalar = args[1].to_owned().into();
349
350                ResultValue::Scalar(lhs - rhs)
351            })
352            .into()
353    }
354}
355impl_borrow_variants!(MpcScalarResult, Sub, sub, -, MpcScalarResult);
356
357impl MpcScalarResult {
358    /// Subtract two batches of `MpcScalarResult`s using a single batched gate
359    pub fn batch_sub(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
360        assert_eq!(
361            a.len(),
362            b.len(),
363            "batch_sub: a and b must be the same length"
364        );
365
366        let n = a.len();
367        let fabric = a[0].fabric();
368        let ids = a
369            .iter()
370            .map(|v| v.id())
371            .chain(b.iter().map(|v| v.id()))
372            .collect_vec();
373
374        let scalars: Vec<ScalarResult> =
375            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
376                // Split the args
377                let scalars = args.into_iter().map(Scalar::from).collect_vec();
378                let (a_res, b_res) = scalars.split_at(n);
379
380                // Add the values
381                a_res
382                    .iter()
383                    .zip(b_res.iter())
384                    .map(|(a, b)| ResultValue::Scalar(a - b))
385                    .collect_vec()
386            });
387
388        scalars.into_iter().map(|s| s.into()).collect_vec()
389    }
390
391    /// Subtract a batch of `MpcScalarResult`s from a batch of public `ScalarResult`s
392    pub fn batch_sub_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
393        assert_eq!(
394            a.len(),
395            b.len(),
396            "batch_sub_public: a and b must be the same length"
397        );
398
399        let n = a.len();
400        let fabric = a[0].fabric();
401        let ids = a
402            .iter()
403            .map(|v| v.id())
404            .chain(b.iter().map(|v| v.id()))
405            .collect_vec();
406
407        let party_id = fabric.party_id();
408        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
409            if party_id == PARTY0 {
410                let mut res: Vec<ResultValue> = Vec::with_capacity(n);
411
412                for i in 0..n {
413                    let lhs: Scalar = args[i].to_owned().into();
414                    let rhs: Scalar = args[i + n].to_owned().into();
415
416                    res.push(ResultValue::Scalar(lhs - rhs));
417                }
418
419                res
420            } else {
421                args[..n].to_vec()
422            }
423        });
424
425        scalars.into_iter().map(|s| s.into()).collect_vec()
426    }
427}
428
429// === Negation === //
430
431impl Neg for &MpcScalarResult {
432    type Output = MpcScalarResult;
433
434    fn neg(self) -> Self::Output {
435        self.fabric()
436            .new_gate_op(vec![self.id()], |args| {
437                // Cast the args
438                let lhs: Scalar = args[0].to_owned().into();
439                ResultValue::Scalar(-lhs)
440            })
441            .into()
442    }
443}
444impl_borrow_variants!(MpcScalarResult, Neg, neg, -);
445
446impl MpcScalarResult {
447    /// Negate a batch of `MpcScalarResult`s using a single batched gate
448    pub fn batch_neg(values: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
449        if values.is_empty() {
450            return vec![];
451        }
452
453        let n = values.len();
454        let fabric = values[0].fabric();
455        let ids = values.iter().map(|v| v.id()).collect_vec();
456
457        let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
458            // Split the args
459            let scalars = args.into_iter().map(Scalar::from).collect_vec();
460
461            // Add the values
462            scalars
463                .iter()
464                .map(|a| ResultValue::Scalar(-a))
465                .collect_vec()
466        });
467
468        scalars.into_iter().map(|s| s.into()).collect_vec()
469    }
470}
471
472// === Multiplication === //
473
474impl Mul<&Scalar> for &MpcScalarResult {
475    type Output = MpcScalarResult;
476
477    fn mul(self, rhs: &Scalar) -> Self::Output {
478        let rhs = *rhs;
479        self.fabric()
480            .new_gate_op(vec![self.id()], move |args| {
481                // Cast the args
482                let lhs: Scalar = args[0].to_owned().into();
483                ResultValue::Scalar(lhs * rhs)
484            })
485            .into()
486    }
487}
488impl_borrow_variants!(MpcScalarResult, Mul, mul, *, Scalar);
489impl_commutative!(MpcScalarResult, Mul, mul, *, Scalar);
490
491impl Mul<&ScalarResult> for &MpcScalarResult {
492    type Output = MpcScalarResult;
493
494    fn mul(self, rhs: &ScalarResult) -> Self::Output {
495        self.fabric()
496            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
497                // Cast the args
498                let lhs: Scalar = args.remove(0).into();
499                let rhs: Scalar = args.remove(0).into();
500
501                ResultValue::Scalar(lhs * rhs)
502            })
503            .into()
504    }
505}
506impl_borrow_variants!(MpcScalarResult, Mul, mul, *, ScalarResult);
507impl_commutative!(MpcScalarResult, Mul, mul, *, ScalarResult);
508
509/// Use the beaver trick if both values are shared
510impl Mul<&MpcScalarResult> for &MpcScalarResult {
511    type Output = MpcScalarResult;
512
513    fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
514        // Sample a beaver triplet
515        let (a, b, c) = self.fabric().next_beaver_triple();
516
517        // Open the values d = [lhs - a] and e = [rhs - b]
518        let masked_lhs = self - &a;
519        let masked_rhs = rhs - &b;
520
521        let d_open = masked_lhs.open();
522        let e_open = masked_rhs.open();
523
524        // Identity: [x * y] = de + d[b] + e[a] + [c]
525        &d_open * &b + &e_open * &a + c + &d_open * &e_open
526    }
527}
528impl_borrow_variants!(MpcScalarResult, Mul, mul, *, MpcScalarResult);
529
530impl MpcScalarResult {
531    /// Multiply a batch of `MpcScalarResults` over a single network op
532    pub fn batch_mul(a: &[MpcScalarResult], b: &[MpcScalarResult]) -> Vec<MpcScalarResult> {
533        let n = a.len();
534        assert_eq!(
535            a.len(),
536            b.len(),
537            "batch_mul: a and b must be the same length"
538        );
539
540        // Sample a beaver triplet for each multiplication
541        let fabric = &a[0].fabric();
542        let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
543
544        // Open the values d = [lhs - a] and e = [rhs - b]
545        let masked_lhs = MpcScalarResult::batch_sub(a, &beaver_a);
546        let masked_rhs = MpcScalarResult::batch_sub(b, &beaver_b);
547
548        let all_masks = [masked_lhs, masked_rhs].concat();
549        let opened_values = MpcScalarResult::open_batch(&all_masks);
550        let (d_open, e_open) = opened_values.split_at(n);
551
552        // Identity: [x * y] = de + d[b] + e[a] + [c]
553        let de = ScalarResult::batch_mul(d_open, e_open);
554        let db = MpcScalarResult::batch_mul_public(&beaver_b, d_open);
555        let ea = MpcScalarResult::batch_mul_public(&beaver_a, e_open);
556
557        // Add the terms
558        let de_plus_db = MpcScalarResult::batch_add_public(&db, &de);
559        let ea_plus_c = MpcScalarResult::batch_add(&ea, &beaver_c);
560        MpcScalarResult::batch_add(&de_plus_db, &ea_plus_c)
561    }
562
563    /// Multiply a batch of `MpcScalarResult`s by a batch of public `ScalarResult`s
564    pub fn batch_mul_public(a: &[MpcScalarResult], b: &[ScalarResult]) -> Vec<MpcScalarResult> {
565        assert_eq!(
566            a.len(),
567            b.len(),
568            "batch_mul_public: a and b must be the same length"
569        );
570
571        let n = a.len();
572        let fabric = a[0].fabric();
573        let ids = a
574            .iter()
575            .map(|v| v.id())
576            .chain(b.iter().map(|v| v.id))
577            .collect_vec();
578
579        let scalars: Vec<ScalarResult> =
580            fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
581                let mut res: Vec<ResultValue> = Vec::with_capacity(n);
582                for i in 0..n {
583                    let lhs: Scalar = args[i].to_owned().into();
584                    let rhs: Scalar = args[i + n].to_owned().into();
585
586                    res.push(ResultValue::Scalar(lhs * rhs));
587                }
588
589                res
590            });
591
592        scalars.into_iter().map(|s| s.into()).collect_vec()
593    }
594}
595
596// === Curve Scalar Multiplication === //
597
598impl Mul<&MpcScalarResult> for &StarkPoint {
599    type Output = MpcStarkPointResult;
600
601    fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
602        let self_owned = *self;
603        rhs.fabric()
604            .new_gate_op(vec![rhs.id()], move |mut args| {
605                let rhs: Scalar = args.remove(0).into();
606
607                ResultValue::Point(self_owned * rhs)
608            })
609            .into()
610    }
611}
612impl_commutative!(StarkPoint, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
613
614impl Mul<&MpcScalarResult> for &StarkPointResult {
615    type Output = MpcStarkPointResult;
616
617    fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
618        self.fabric
619            .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
620                let lhs: StarkPoint = args.remove(0).into();
621                let rhs: Scalar = args.remove(0).into();
622
623                ResultValue::Point(lhs * rhs)
624            })
625            .into()
626    }
627}
628impl_borrow_variants!(StarkPointResult, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
629impl_commutative!(StarkPointResult, Mul, mul, *, MpcScalarResult, Output=MpcStarkPointResult);
630
631#[cfg(test)]
632mod test {
633    use rand::thread_rng;
634
635    use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
636
637    /// Test subtraction with a non-commutative pair of types
638    #[tokio::test]
639    async fn test_sub() {
640        let mut rng = thread_rng();
641        let value1 = Scalar::random(&mut rng);
642        let value2 = Scalar::random(&mut rng);
643
644        let (res, _) = execute_mock_mpc(|fabric| async move {
645            // Allocate the first value as a shared scalar and the second as a public scalar
646            let party0_value = fabric.share_scalar(value1, PARTY0).mpc_share();
647            let public_value = fabric.allocate_scalar(value2);
648
649            // Subtract the public value from the shared value
650            let res1 = &party0_value - &public_value;
651            let res_open1 = res1.open().await;
652            let expected1 = value1 - value2;
653
654            // Subtract the shared value from the public value
655            let res2 = &public_value - &party0_value;
656            let res_open2 = res2.open().await;
657            let expected2 = value2 - value1;
658
659            (res_open1 == expected1, res_open2 == expected2)
660        })
661        .await;
662
663        assert!(res.0);
664        assert!(res.1)
665    }
666}