ark_mpc/algebra/curve/
mpc_curve.rs

1//! Defines an unauthenticated shared curve point type which forms the basis
2//! of the authenticated curve point type
3
4use std::ops::{Add, Mul, Neg, Sub};
5
6use ark_ec::CurveGroup;
7use itertools::Itertools;
8
9use crate::{
10    algebra::macros::*, algebra::scalar::*, fabric::ResultValue, network::NetworkPayload,
11    MpcFabric, ResultId, PARTY0,
12};
13
14use super::curve::{BatchCurvePointResult, CurvePoint, CurvePointResult};
15
16/// Defines a secret shared type of a curve point
17#[derive(Clone, Debug)]
18pub struct MpcPointResult<C: CurveGroup> {
19    /// The underlying value held by the local party
20    pub(crate) share: CurvePointResult<C>,
21}
22
23impl<C: CurveGroup> From<CurvePointResult<C>> for MpcPointResult<C> {
24    fn from(value: CurvePointResult<C>) -> Self {
25        Self { share: value }
26    }
27}
28
29/// Defines the result handle type that represents a future result of an
30/// `MpcPoint`
31impl<C: CurveGroup> MpcPointResult<C> {
32    /// Creates an `MpcPoint` from a given underlying point assumed to be a
33    /// secret share
34    pub fn new_shared(value: CurvePointResult<C>) -> MpcPointResult<C> {
35        MpcPointResult { share: value }
36    }
37
38    /// Get the ID of the underlying share's result
39    pub fn id(&self) -> ResultId {
40        self.share.id
41    }
42
43    /// Borrow the fabric that this result is allocated in
44    pub fn fabric(&self) -> &MpcFabric<C> {
45        self.share.fabric()
46    }
47
48    /// Open the value; both parties send their shares to the counterparty
49    pub fn open(&self) -> CurvePointResult<C> {
50        let send_my_share =
51            |args: Vec<ResultValue<C>>| NetworkPayload::Point(args[0].to_owned().into());
52
53        // Party zero sends first then receives
54        let (share0, share1): (CurvePointResult<C>, CurvePointResult<C>) =
55            if self.fabric().party_id() == PARTY0 {
56                let party0_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
57                let party1_value = self.fabric().receive_value();
58
59                (party0_value, party1_value)
60            } else {
61                let party0_value = self.fabric().receive_value();
62                let party1_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
63
64                (party0_value, party1_value)
65            };
66
67        share0 + share1
68    }
69
70    /// Open a batch of values
71    pub fn open_batch(values: &[MpcPointResult<C>]) -> Vec<CurvePointResult<C>> {
72        if values.is_empty() {
73            return Vec::new();
74        }
75
76        let n = values.len();
77        let fabric = &values[0].fabric();
78        let all_ids = values.iter().map(|v| v.id()).collect_vec();
79        let send_my_shares = |args: Vec<ResultValue<C>>| {
80            NetworkPayload::PointBatch(args.into_iter().map(|arg| arg.into()).collect_vec())
81        };
82
83        // Party zero sends first then receives
84        let (party0_values, party1_values): (BatchCurvePointResult<C>, BatchCurvePointResult<C>) =
85            if fabric.party_id() == PARTY0 {
86                let party0_values = fabric.new_network_op(all_ids, send_my_shares);
87                let party1_values = fabric.receive_value();
88
89                (party0_values, party1_values)
90            } else {
91                let party0_values = fabric.receive_value();
92                let party1_values = fabric.new_network_op(all_ids, send_my_shares);
93
94                (party0_values, party1_values)
95            };
96
97        // Create a gate to component-wise add the shares
98        fabric.new_batch_gate_op(
99            vec![party0_values.id(), party1_values.id()],
100            n, // output_arity
101            |mut args| {
102                let party0_values: Vec<CurvePoint<C>> = args.remove(0).into();
103                let party1_values: Vec<CurvePoint<C>> = args.remove(0).into();
104
105                party0_values
106                    .into_iter()
107                    .zip(party1_values)
108                    .map(|(x, y)| x + y)
109                    .map(ResultValue::Point)
110                    .collect_vec()
111            },
112        )
113    }
114}
115
116// --------------
117// | Arithmetic |
118// --------------
119
120// === Addition === //
121
122impl<C: CurveGroup> Add<&CurvePoint<C>> for &MpcPointResult<C> {
123    type Output = MpcPointResult<C>;
124
125    // Only party 0 adds the plaintext value to its share
126    fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
127        let rhs = *rhs;
128        let party_id = self.fabric().party_id();
129        self.fabric()
130            .new_gate_op(vec![self.id()], move |args| {
131                let lhs: CurvePoint<C> = args[0].to_owned().into();
132
133                if party_id == PARTY0 {
134                    ResultValue::Point(lhs + rhs)
135                } else {
136                    ResultValue::Point(lhs)
137                }
138            })
139            .into()
140    }
141}
142impl_borrow_variants!(MpcPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
143impl_commutative!(MpcPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
144
145impl<C: CurveGroup> Add<&CurvePointResult<C>> for &MpcPointResult<C> {
146    type Output = MpcPointResult<C>;
147
148    // Only party 0 adds the plaintext value to its share
149    fn add(self, rhs: &CurvePointResult<C>) -> Self::Output {
150        let party_id = self.fabric().party_id();
151        self.fabric()
152            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
153                let lhs: CurvePoint<C> = args.remove(0).into();
154                let rhs: CurvePoint<C> = args.remove(0).into();
155
156                if party_id == PARTY0 {
157                    ResultValue::Point(lhs + rhs)
158                } else {
159                    ResultValue::Point(lhs)
160                }
161            })
162            .into()
163    }
164}
165impl_borrow_variants!(MpcPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
166impl_commutative!(MpcPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
167
168impl<C: CurveGroup> Add<&MpcPointResult<C>> for &MpcPointResult<C> {
169    type Output = MpcPointResult<C>;
170
171    fn add(self, rhs: &MpcPointResult<C>) -> Self::Output {
172        self.fabric()
173            .new_gate_op(vec![self.id(), rhs.id()], |args| {
174                let lhs: CurvePoint<C> = args[0].to_owned().into();
175                let rhs: CurvePoint<C> = args[1].to_owned().into();
176
177                ResultValue::Point(lhs + rhs)
178            })
179            .into()
180    }
181}
182impl_borrow_variants!(MpcPointResult<C>, Add, add, +, MpcPointResult<C>, C: CurveGroup);
183
184impl<C: CurveGroup> MpcPointResult<C> {
185    /// Add two batches of values
186    pub fn batch_add(a: &[MpcPointResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
187        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
188        if a.is_empty() {
189            return Vec::new();
190        }
191
192        let n = a.len();
193        let fabric = a[0].fabric();
194        let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
195
196        // Create a gate to component-wise add the shares
197        fabric
198            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
199                let points = args.into_iter().map(CurvePoint::from).collect_vec();
200                let (a, b) = points.split_at(n);
201
202                a.iter()
203                    .zip(b.iter())
204                    .map(|(x, y)| x + y)
205                    .map(ResultValue::Point)
206                    .collect_vec()
207            })
208            .into_iter()
209            .map(MpcPointResult::from)
210            .collect_vec()
211    }
212
213    /// Add a batch of `MpcPointResults` to a batch of `CurvePointResult`s
214    pub fn batch_add_public(
215        a: &[MpcPointResult<C>],
216        b: &[CurvePointResult<C>],
217    ) -> Vec<MpcPointResult<C>> {
218        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
219        if a.is_empty() {
220            return Vec::new();
221        }
222
223        let n = a.len();
224        let fabric = a[0].fabric();
225        let all_ids = a
226            .iter()
227            .map(|v| v.id())
228            .chain(b.iter().map(|b| b.id))
229            .collect_vec();
230
231        // Add the shares in a batch gate
232        let party_id = fabric.party_id();
233        fabric
234            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
235                let lhs_points = args.drain(..n).map(CurvePoint::from).collect_vec();
236                let rhs_points = args.into_iter().map(CurvePoint::from).collect_vec();
237
238                lhs_points
239                    .into_iter()
240                    .zip(rhs_points)
241                    .map(|(x, y)| if party_id == PARTY0 { x + y } else { x })
242                    .map(ResultValue::Point)
243                    .collect_vec()
244            })
245            .into_iter()
246            .map(MpcPointResult::from)
247            .collect_vec()
248    }
249}
250
251// === Subtraction === //
252
253impl<C: CurveGroup> Sub<&CurvePoint<C>> for &MpcPointResult<C> {
254    type Output = MpcPointResult<C>;
255
256    // Only party 0 subtracts the plaintext value
257    fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
258        let rhs = *rhs;
259        let party_id = self.fabric().party_id();
260        self.fabric()
261            .new_gate_op(vec![self.id()], move |args| {
262                let lhs: CurvePoint<C> = args[0].to_owned().into();
263
264                if party_id == PARTY0 {
265                    ResultValue::Point(lhs - rhs)
266                } else {
267                    ResultValue::Point(lhs)
268                }
269            })
270            .into()
271    }
272}
273impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
274
275impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &MpcPointResult<C> {
276    type Output = MpcPointResult<C>;
277
278    fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
279        let party_id = self.fabric().party_id();
280        self.fabric()
281            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
282                let lhs: CurvePoint<C> = args.remove(0).into();
283                let rhs: CurvePoint<C> = args.remove(0).into();
284
285                if party_id == PARTY0 {
286                    ResultValue::Point(lhs - rhs)
287                } else {
288                    ResultValue::Point(lhs)
289                }
290            })
291            .into()
292    }
293}
294impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
295
296impl<C: CurveGroup> Sub<&MpcPointResult<C>> for &MpcPointResult<C> {
297    type Output = MpcPointResult<C>;
298
299    fn sub(self, rhs: &MpcPointResult<C>) -> Self::Output {
300        self.fabric()
301            .new_gate_op(vec![self.id(), rhs.id()], |args| {
302                let lhs: CurvePoint<C> = args[0].to_owned().into();
303                let rhs: CurvePoint<C> = args[1].to_owned().into();
304
305                ResultValue::Point(lhs - rhs)
306            })
307            .into()
308    }
309}
310impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, MpcPointResult<C>, C: CurveGroup);
311
312impl<C: CurveGroup> MpcPointResult<C> {
313    /// Subtract two batches of values
314    pub fn batch_sub(a: &[MpcPointResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
315        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
316        if a.is_empty() {
317            return Vec::new();
318        }
319
320        let n = a.len();
321        let fabric = a[0].fabric();
322        let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
323
324        // Create a gate to component-wise add the shares
325        fabric
326            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
327                let points = args.into_iter().map(CurvePoint::from).collect_vec();
328                let (a, b) = points.split_at(n);
329
330                a.iter()
331                    .zip(b.iter())
332                    .map(|(x, y)| x - y)
333                    .map(ResultValue::Point)
334                    .collect_vec()
335            })
336            .into_iter()
337            .map(MpcPointResult::from)
338            .collect_vec()
339    }
340
341    /// Subtract a batch of `MpcPointResults` to a batch of `CurvePointResult`s
342    pub fn batch_sub_public(
343        a: &[MpcPointResult<C>],
344        b: &[CurvePointResult<C>],
345    ) -> Vec<MpcPointResult<C>> {
346        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
347        if a.is_empty() {
348            return Vec::new();
349        }
350
351        let n = a.len();
352        let fabric = a[0].fabric();
353        let all_ids = a
354            .iter()
355            .map(|v| v.id())
356            .chain(b.iter().map(|b| b.id))
357            .collect_vec();
358
359        // Add the shares in a batch gate
360        let party_id = fabric.party_id();
361        fabric
362            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
363                let lhs_points = args.drain(..n).map(CurvePoint::from).collect_vec();
364                let rhs_points = args.into_iter().map(CurvePoint::from).collect_vec();
365
366                lhs_points
367                    .into_iter()
368                    .zip(rhs_points)
369                    .map(|(x, y)| if party_id == PARTY0 { x - y } else { x })
370                    .map(ResultValue::Point)
371                    .collect_vec()
372            })
373            .into_iter()
374            .map(MpcPointResult::from)
375            .collect_vec()
376    }
377}
378
379// === Negation === //
380
381impl<C: CurveGroup> Neg for &MpcPointResult<C> {
382    type Output = MpcPointResult<C>;
383
384    fn neg(self) -> Self::Output {
385        self.fabric()
386            .new_gate_op(vec![self.id()], |mut args| {
387                let mpc_val: CurvePoint<C> = args.remove(0).into();
388                ResultValue::Point(-mpc_val)
389            })
390            .into()
391    }
392}
393impl_borrow_variants!(MpcPointResult<C>, Neg, neg, -, C: CurveGroup);
394
395impl<C: CurveGroup> MpcPointResult<C> {
396    /// Negate a batch of values
397    pub fn batch_neg(values: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
398        if values.is_empty() {
399            return Vec::new();
400        }
401
402        let n = values.len();
403        let fabric = values[0].fabric();
404        let all_ids = values.iter().map(|v| v.id()).collect_vec();
405
406        // Create a gate to component-wise add the shares
407        fabric
408            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
409                let points = args.into_iter().map(CurvePoint::from).collect_vec();
410
411                points
412                    .into_iter()
413                    .map(|x| -x)
414                    .map(ResultValue::Point)
415                    .collect_vec()
416            })
417            .into_iter()
418            .map(MpcPointResult::from)
419            .collect_vec()
420    }
421}
422
423// === Scalar Multiplication === //
424
425impl<C: CurveGroup> Mul<&Scalar<C>> for &MpcPointResult<C> {
426    type Output = MpcPointResult<C>;
427
428    fn mul(self, rhs: &Scalar<C>) -> Self::Output {
429        let rhs = *rhs;
430        self.fabric()
431            .new_gate_op(vec![self.id()], move |args| {
432                let lhs: CurvePoint<C> = args[0].to_owned().into();
433                ResultValue::Point(lhs * rhs)
434            })
435            .into()
436    }
437}
438impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
439impl_commutative!(MpcPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
440
441impl<C: CurveGroup> Mul<&ScalarResult<C>> for &MpcPointResult<C> {
442    type Output = MpcPointResult<C>;
443
444    fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
445        self.fabric()
446            .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
447                let lhs: CurvePoint<C> = args.remove(0).into();
448                let rhs: Scalar<C> = args.remove(0).into();
449
450                ResultValue::Point(lhs * rhs)
451            })
452            .into()
453    }
454}
455impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
456impl_commutative!(MpcPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
457
458impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &MpcPointResult<C> {
459    type Output = MpcPointResult<C>;
460
461    // Use the beaver trick as in the scalar case
462    fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
463        let generator = CurvePoint::generator();
464        let (a, b, c) = self.fabric().next_beaver_triple();
465
466        // Open the values d = [rhs - a] and e = [lhs - bG] for curve group generator G
467        let masked_rhs = rhs - &a;
468        let masked_lhs = self - (&generator * &b);
469
470        #[allow(non_snake_case)]
471        let eG_open = masked_lhs.open();
472        let d_open = masked_rhs.open();
473
474        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
475        &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
476    }
477}
478impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, MpcScalarResult<C>, C: CurveGroup);
479impl_commutative!(MpcPointResult<C>, Mul, mul, *, MpcScalarResult<C>, C:CurveGroup);
480
481impl<C: CurveGroup> MpcPointResult<C> {
482    /// Multiply a batch of `MpcPointResult`s with a batch of `MpcScalarResult`s
483    #[allow(non_snake_case)]
484    pub fn batch_mul(a: &[MpcScalarResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
485        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
486        if a.is_empty() {
487            return Vec::new();
488        }
489
490        let n = a.len();
491        let fabric = a[0].fabric();
492
493        // Sample a set of beaver triples for the multiplications
494        let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
495        let beaver_b_gen = MpcPointResult::batch_mul_generator(&beaver_b);
496
497        let masked_rhs = MpcScalarResult::batch_sub(a, &beaver_a);
498        let masked_lhs = MpcPointResult::batch_sub(b, &beaver_b_gen);
499
500        let eG_open = MpcPointResult::open_batch(&masked_lhs);
501        let d_open = MpcScalarResult::open_batch(&masked_rhs);
502
503        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
504        let deG = CurvePointResult::batch_mul(&d_open, &eG_open);
505        let dbG = MpcPointResult::batch_mul_public(&d_open, &beaver_b_gen);
506        let aeG = CurvePointResult::batch_mul_shared(&beaver_a, &eG_open);
507        let cG = MpcPointResult::batch_mul_generator(&beaver_c);
508
509        let de_db_G = MpcPointResult::batch_add_public(&dbG, &deG);
510        let ae_c_G = MpcPointResult::batch_add(&aeG, &cG);
511
512        MpcPointResult::batch_add(&de_db_G, &ae_c_G)
513    }
514
515    /// Multiply a batch of `MpcPointResult`s with a batch of `ScalarResult`s
516    pub fn batch_mul_public(
517        a: &[ScalarResult<C>],
518        b: &[MpcPointResult<C>],
519    ) -> Vec<MpcPointResult<C>> {
520        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
521        if a.is_empty() {
522            return Vec::new();
523        }
524
525        let n = a.len();
526        let fabric = a[0].fabric();
527        let all_ids = a
528            .iter()
529            .map(|v| v.id())
530            .chain(b.iter().map(|b| b.id()))
531            .collect_vec();
532
533        // Multiply the shares in a batch gate
534        fabric
535            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
536                let scalars = args.drain(..n).map(Scalar::from).collect_vec();
537                let points = args.into_iter().map(CurvePoint::from).collect_vec();
538
539                scalars
540                    .into_iter()
541                    .zip(points)
542                    .map(|(x, y)| x * y)
543                    .map(ResultValue::Point)
544                    .collect_vec()
545            })
546            .into_iter()
547            .map(MpcPointResult::from)
548            .collect_vec()
549    }
550
551    /// Multiply a batch of `MpcScalarResult`s by the generator
552    pub fn batch_mul_generator(a: &[MpcScalarResult<C>]) -> Vec<MpcPointResult<C>> {
553        if a.is_empty() {
554            return Vec::new();
555        }
556
557        let n = a.len();
558        let fabric = a[0].fabric();
559        let all_ids = a.iter().map(|v| v.id()).collect_vec();
560
561        // Multiply the shares in a batch gate
562        fabric
563            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
564                let scalars = args.into_iter().map(Scalar::from).collect_vec();
565                let generator = CurvePoint::generator();
566
567                scalars
568                    .into_iter()
569                    .map(|x| x * generator)
570                    .map(ResultValue::Point)
571                    .collect_vec()
572            })
573            .into_iter()
574            .map(MpcPointResult::from)
575            .collect_vec()
576    }
577}