mpc_stark/algebra/
mpc_stark_point.rs

1//! Defines an unauthenticated shared curve point type which forms the basis
2//! of the authenticated curve point type
3
4use std::ops::{Add, Mul, Neg, Sub};
5
6use itertools::Itertools;
7
8use crate::{
9    fabric::{ResultHandle, ResultValue},
10    network::NetworkPayload,
11    MpcFabric, ResultId, PARTY0,
12};
13
14use super::{
15    macros::{impl_borrow_variants, impl_commutative},
16    mpc_scalar::MpcScalarResult,
17    scalar::{Scalar, ScalarResult},
18    stark_curve::{BatchStarkPointResult, StarkPoint, StarkPointResult},
19};
20
21/// Defines a secret shared type of a curve point
22#[derive(Clone, Debug)]
23pub struct MpcStarkPointResult {
24    /// The underlying value held by the local party
25    pub(crate) share: StarkPointResult,
26}
27
28impl From<StarkPointResult> for MpcStarkPointResult {
29    fn from(value: StarkPointResult) -> Self {
30        Self { share: value }
31    }
32}
33
34/// Defines the result handle type that represents a future result of an `MpcStarkPoint`
35impl MpcStarkPointResult {
36    /// Creates an `MpcStarkPoint` from a given underlying point assumed to be a secret share
37    pub fn new_shared(value: StarkPointResult) -> MpcStarkPointResult {
38        MpcStarkPointResult { share: value }
39    }
40
41    /// Get the ID of the underlying share's result
42    pub fn id(&self) -> ResultId {
43        self.share.id
44    }
45
46    /// Borrow the fabric that this result is allocated in
47    pub fn fabric(&self) -> &MpcFabric {
48        self.share.fabric()
49    }
50
51    /// Open the value; both parties send their shares to the counterparty
52    pub fn open(&self) -> ResultHandle<StarkPoint> {
53        let send_my_share =
54            |args: Vec<ResultValue>| NetworkPayload::Point(args[0].to_owned().into());
55
56        // Party zero sends first then receives
57        let (share0, share1): (StarkPointResult, StarkPointResult) =
58            if self.fabric().party_id() == PARTY0 {
59                let party0_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
60                let party1_value = self.fabric().receive_value();
61
62                (party0_value, party1_value)
63            } else {
64                let party0_value = self.fabric().receive_value();
65                let party1_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
66
67                (party0_value, party1_value)
68            };
69
70        share0 + share1
71    }
72
73    /// Open a batch of values
74    pub fn open_batch(values: &[MpcStarkPointResult]) -> Vec<StarkPointResult> {
75        if values.is_empty() {
76            return Vec::new();
77        }
78
79        let n = values.len();
80        let fabric = &values[0].fabric();
81        let all_ids = values.iter().map(|v| v.id()).collect_vec();
82        let send_my_shares = |args: Vec<ResultValue>| {
83            NetworkPayload::PointBatch(args.into_iter().map(|arg| arg.into()).collect_vec())
84        };
85
86        // Party zero sends first then receives
87        let (party0_values, party1_values): (BatchStarkPointResult, BatchStarkPointResult) =
88            if fabric.party_id() == PARTY0 {
89                let party0_values = fabric.new_network_op(all_ids, send_my_shares);
90                let party1_values = fabric.receive_value();
91
92                (party0_values, party1_values)
93            } else {
94                let party0_values = fabric.receive_value();
95                let party1_values = fabric.new_network_op(all_ids, send_my_shares);
96
97                (party0_values, party1_values)
98            };
99
100        // Create a gate to component-wise add the shares
101        fabric.new_batch_gate_op(
102            vec![party0_values.id(), party1_values.id()],
103            n, /* output_arity */
104            |mut args| {
105                let party0_values: Vec<StarkPoint> = args.remove(0).into();
106                let party1_values: Vec<StarkPoint> = args.remove(0).into();
107
108                party0_values
109                    .into_iter()
110                    .zip(party1_values.into_iter())
111                    .map(|(x, y)| x + y)
112                    .map(ResultValue::Point)
113                    .collect_vec()
114            },
115        )
116    }
117}
118
119// --------------
120// | Arithmetic |
121// --------------
122
123// === Addition === //
124
125impl Add<&StarkPoint> for &MpcStarkPointResult {
126    type Output = MpcStarkPointResult;
127
128    // Only party 0 adds the plaintext value to its share
129    fn add(self, rhs: &StarkPoint) -> Self::Output {
130        let rhs = *rhs;
131        let party_id = self.fabric().party_id();
132        self.fabric()
133            .new_gate_op(vec![self.id()], move |args| {
134                let lhs: StarkPoint = args[0].to_owned().into();
135
136                if party_id == PARTY0 {
137                    ResultValue::Point(lhs + rhs)
138                } else {
139                    ResultValue::Point(lhs)
140                }
141            })
142            .into()
143    }
144}
145impl_borrow_variants!(MpcStarkPointResult, Add, add, +, StarkPoint);
146impl_commutative!(MpcStarkPointResult, Add, add, +, StarkPoint);
147
148impl Add<&StarkPointResult> for &MpcStarkPointResult {
149    type Output = MpcStarkPointResult;
150
151    // Only party 0 adds the plaintext value to its share
152    fn add(self, rhs: &StarkPointResult) -> Self::Output {
153        let party_id = self.fabric().party_id();
154        self.fabric()
155            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
156                let lhs: StarkPoint = args.remove(0).into();
157                let rhs: StarkPoint = args.remove(0).into();
158
159                if party_id == PARTY0 {
160                    ResultValue::Point(lhs + rhs)
161                } else {
162                    ResultValue::Point(lhs)
163                }
164            })
165            .into()
166    }
167}
168impl_borrow_variants!(MpcStarkPointResult, Add, add, +, StarkPointResult);
169impl_commutative!(MpcStarkPointResult, Add, add, +, StarkPointResult);
170
171impl Add<&MpcStarkPointResult> for &MpcStarkPointResult {
172    type Output = MpcStarkPointResult;
173
174    fn add(self, rhs: &MpcStarkPointResult) -> Self::Output {
175        self.fabric()
176            .new_gate_op(vec![self.id(), rhs.id()], |args| {
177                let lhs: StarkPoint = args[0].to_owned().into();
178                let rhs: StarkPoint = args[1].to_owned().into();
179
180                ResultValue::Point(lhs + rhs)
181            })
182            .into()
183    }
184}
185impl_borrow_variants!(MpcStarkPointResult, Add, add, +, MpcStarkPointResult);
186
187impl MpcStarkPointResult {
188    /// Add two batches of values
189    pub fn batch_add(
190        a: &[MpcStarkPointResult],
191        b: &[MpcStarkPointResult],
192    ) -> Vec<MpcStarkPointResult> {
193        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
194        if a.is_empty() {
195            return Vec::new();
196        }
197
198        let n = a.len();
199        let fabric = a[0].fabric();
200        let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
201
202        // Create a gate to component-wise add the shares
203        fabric
204            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
205                let points = args.into_iter().map(StarkPoint::from).collect_vec();
206                let (a, b) = points.split_at(n);
207
208                a.iter()
209                    .zip(b.iter())
210                    .map(|(x, y)| x + y)
211                    .map(ResultValue::Point)
212                    .collect_vec()
213            })
214            .into_iter()
215            .map(MpcStarkPointResult::from)
216            .collect_vec()
217    }
218
219    /// Add a batch of `MpcStarkPointResults` to a batch of `StarkPointResult`s
220    pub fn batch_add_public(
221        a: &[MpcStarkPointResult],
222        b: &[StarkPointResult],
223    ) -> Vec<MpcStarkPointResult> {
224        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
225        if a.is_empty() {
226            return Vec::new();
227        }
228
229        let n = a.len();
230        let fabric = a[0].fabric();
231        let all_ids = a
232            .iter()
233            .map(|v| v.id())
234            .chain(b.iter().map(|b| b.id))
235            .collect_vec();
236
237        // Add the shares in a batch gate
238        let party_id = fabric.party_id();
239        fabric
240            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
241                let lhs_points = args.drain(..n).map(StarkPoint::from).collect_vec();
242                let rhs_points = args.into_iter().map(StarkPoint::from).collect_vec();
243
244                lhs_points
245                    .into_iter()
246                    .zip(rhs_points.into_iter())
247                    .map(|(x, y)| if party_id == PARTY0 { x + y } else { x })
248                    .map(ResultValue::Point)
249                    .collect_vec()
250            })
251            .into_iter()
252            .map(MpcStarkPointResult::from)
253            .collect_vec()
254    }
255}
256
257// === Subtraction === //
258
259impl Sub<&StarkPoint> for &MpcStarkPointResult {
260    type Output = MpcStarkPointResult;
261
262    // Only party 0 subtracts the plaintext value
263    fn sub(self, rhs: &StarkPoint) -> Self::Output {
264        let rhs = *rhs;
265        let party_id = self.fabric().party_id();
266        self.fabric()
267            .new_gate_op(vec![self.id()], move |args| {
268                let lhs: StarkPoint = args[0].to_owned().into();
269
270                if party_id == PARTY0 {
271                    ResultValue::Point(lhs - rhs)
272                } else {
273                    ResultValue::Point(lhs)
274                }
275            })
276            .into()
277    }
278}
279impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, StarkPoint);
280
281impl Sub<&StarkPointResult> for &MpcStarkPointResult {
282    type Output = MpcStarkPointResult;
283
284    fn sub(self, rhs: &StarkPointResult) -> Self::Output {
285        let party_id = self.fabric().party_id();
286        self.fabric()
287            .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
288                let lhs: StarkPoint = args.remove(0).into();
289                let rhs: StarkPoint = args.remove(0).into();
290
291                if party_id == PARTY0 {
292                    ResultValue::Point(lhs - rhs)
293                } else {
294                    ResultValue::Point(lhs)
295                }
296            })
297            .into()
298    }
299}
300impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, StarkPointResult);
301
302impl Sub<&MpcStarkPointResult> for &MpcStarkPointResult {
303    type Output = MpcStarkPointResult;
304
305    fn sub(self, rhs: &MpcStarkPointResult) -> Self::Output {
306        self.fabric()
307            .new_gate_op(vec![self.id(), rhs.id()], |args| {
308                let lhs: StarkPoint = args[0].to_owned().into();
309                let rhs: StarkPoint = args[1].to_owned().into();
310
311                ResultValue::Point(lhs - rhs)
312            })
313            .into()
314    }
315}
316impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, MpcStarkPointResult);
317
318impl MpcStarkPointResult {
319    /// Subtract two batches of values
320    pub fn batch_sub(
321        a: &[MpcStarkPointResult],
322        b: &[MpcStarkPointResult],
323    ) -> Vec<MpcStarkPointResult> {
324        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
325        if a.is_empty() {
326            return Vec::new();
327        }
328
329        let n = a.len();
330        let fabric = a[0].fabric();
331        let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
332
333        // Create a gate to component-wise add the shares
334        fabric
335            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
336                let points = args.into_iter().map(StarkPoint::from).collect_vec();
337                let (a, b) = points.split_at(n);
338
339                a.iter()
340                    .zip(b.iter())
341                    .map(|(x, y)| x - y)
342                    .map(ResultValue::Point)
343                    .collect_vec()
344            })
345            .into_iter()
346            .map(MpcStarkPointResult::from)
347            .collect_vec()
348    }
349
350    /// Subtract a batch of `MpcStarkPointResults` to a batch of `StarkPointResult`s
351    pub fn batch_sub_public(
352        a: &[MpcStarkPointResult],
353        b: &[StarkPointResult],
354    ) -> Vec<MpcStarkPointResult> {
355        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
356        if a.is_empty() {
357            return Vec::new();
358        }
359
360        let n = a.len();
361        let fabric = a[0].fabric();
362        let all_ids = a
363            .iter()
364            .map(|v| v.id())
365            .chain(b.iter().map(|b| b.id))
366            .collect_vec();
367
368        // Add the shares in a batch gate
369        let party_id = fabric.party_id();
370        fabric
371            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
372                let lhs_points = args.drain(..n).map(StarkPoint::from).collect_vec();
373                let rhs_points = args.into_iter().map(StarkPoint::from).collect_vec();
374
375                lhs_points
376                    .into_iter()
377                    .zip(rhs_points.into_iter())
378                    .map(|(x, y)| if party_id == PARTY0 { x - y } else { x })
379                    .map(ResultValue::Point)
380                    .collect_vec()
381            })
382            .into_iter()
383            .map(MpcStarkPointResult::from)
384            .collect_vec()
385    }
386}
387
388// === Negation === //
389
390impl Neg for &MpcStarkPointResult {
391    type Output = MpcStarkPointResult;
392
393    fn neg(self) -> Self::Output {
394        self.fabric()
395            .new_gate_op(vec![self.id()], |mut args| {
396                let mpc_val: StarkPoint = args.remove(0).into();
397                ResultValue::Point(-mpc_val)
398            })
399            .into()
400    }
401}
402impl_borrow_variants!(MpcStarkPointResult, Neg, neg, -);
403
404impl MpcStarkPointResult {
405    /// Negate a batch of values
406    pub fn batch_neg(values: &[MpcStarkPointResult]) -> Vec<MpcStarkPointResult> {
407        if values.is_empty() {
408            return Vec::new();
409        }
410
411        let n = values.len();
412        let fabric = values[0].fabric();
413        let all_ids = values.iter().map(|v| v.id()).collect_vec();
414
415        // Create a gate to component-wise add the shares
416        fabric
417            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
418                let points = args.into_iter().map(StarkPoint::from).collect_vec();
419
420                points
421                    .into_iter()
422                    .map(|x| -x)
423                    .map(ResultValue::Point)
424                    .collect_vec()
425            })
426            .into_iter()
427            .map(MpcStarkPointResult::from)
428            .collect_vec()
429    }
430}
431
432// === Scalar Multiplication === //
433
434impl Mul<&Scalar> for &MpcStarkPointResult {
435    type Output = MpcStarkPointResult;
436
437    fn mul(self, rhs: &Scalar) -> Self::Output {
438        let rhs = *rhs;
439        self.fabric()
440            .new_gate_op(vec![self.id()], move |args| {
441                let lhs: StarkPoint = args[0].to_owned().into();
442                ResultValue::Point(lhs * rhs)
443            })
444            .into()
445    }
446}
447impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, Scalar);
448impl_commutative!(MpcStarkPointResult, Mul, mul, *, Scalar);
449
450impl Mul<&ScalarResult> for &MpcStarkPointResult {
451    type Output = MpcStarkPointResult;
452
453    fn mul(self, rhs: &ScalarResult) -> Self::Output {
454        self.fabric()
455            .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
456                let lhs: StarkPoint = args.remove(0).into();
457                let rhs: Scalar = args.remove(0).into();
458
459                ResultValue::Point(lhs * rhs)
460            })
461            .into()
462    }
463}
464impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, ScalarResult);
465impl_commutative!(MpcStarkPointResult, Mul, mul, *, ScalarResult);
466
467impl Mul<&MpcScalarResult> for &MpcStarkPointResult {
468    type Output = MpcStarkPointResult;
469
470    // Use the beaver trick as in the scalar case
471    fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
472        let generator = StarkPoint::generator();
473        let (a, b, c) = self.fabric().next_beaver_triple();
474
475        // Open the values d = [rhs - a] and e = [lhs - bG] for curve group generator G
476        let masked_rhs = rhs - &a;
477        let masked_lhs = self - (&generator * &b);
478
479        #[allow(non_snake_case)]
480        let eG_open = masked_lhs.open();
481        let d_open = masked_rhs.open();
482
483        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
484        &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
485    }
486}
487impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, MpcScalarResult);
488impl_commutative!(MpcStarkPointResult, Mul, mul, *, MpcScalarResult);
489
490impl MpcStarkPointResult {
491    /// Multiply a batch of `MpcStarkPointResult`s with a batch of `MpcScalarResult`s
492    #[allow(non_snake_case)]
493    pub fn batch_mul(a: &[MpcScalarResult], b: &[MpcStarkPointResult]) -> Vec<MpcStarkPointResult> {
494        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
495        if a.is_empty() {
496            return Vec::new();
497        }
498
499        let n = a.len();
500        let fabric = a[0].fabric();
501
502        // Sample a set of beaver triples for the multiplications
503        let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
504        let beaver_b_gen = MpcStarkPointResult::batch_mul_generator(&beaver_b);
505
506        let masked_rhs = MpcScalarResult::batch_sub(a, &beaver_a);
507        let masked_lhs = MpcStarkPointResult::batch_sub(b, &beaver_b_gen);
508
509        let eG_open = MpcStarkPointResult::open_batch(&masked_lhs);
510        let d_open = MpcScalarResult::open_batch(&masked_rhs);
511
512        // Identity [x * yG] = deG + d[bG] + [a]eG + [c]G
513        let deG = StarkPointResult::batch_mul(&d_open, &eG_open);
514        let dbG = MpcStarkPointResult::batch_mul_public(&d_open, &beaver_b_gen);
515        let aeG = StarkPointResult::batch_mul_shared(&beaver_a, &eG_open);
516        let cG = MpcStarkPointResult::batch_mul_generator(&beaver_c);
517
518        let de_db_G = MpcStarkPointResult::batch_add_public(&dbG, &deG);
519        let ae_c_G = MpcStarkPointResult::batch_add(&aeG, &cG);
520
521        MpcStarkPointResult::batch_add(&de_db_G, &ae_c_G)
522    }
523
524    /// Multiply a batch of `MpcStarkPointResult`s with a batch of `ScalarResult`s
525    pub fn batch_mul_public(
526        a: &[ScalarResult],
527        b: &[MpcStarkPointResult],
528    ) -> Vec<MpcStarkPointResult> {
529        assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
530        if a.is_empty() {
531            return Vec::new();
532        }
533
534        let n = a.len();
535        let fabric = a[0].fabric();
536        let all_ids = a
537            .iter()
538            .map(|v| v.id())
539            .chain(b.iter().map(|b| b.id()))
540            .collect_vec();
541
542        // Multiply the shares in a batch gate
543        fabric
544            .new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
545                let scalars = args.drain(..n).map(Scalar::from).collect_vec();
546                let points = args.into_iter().map(StarkPoint::from).collect_vec();
547
548                scalars
549                    .into_iter()
550                    .zip(points.into_iter())
551                    .map(|(x, y)| x * y)
552                    .map(ResultValue::Point)
553                    .collect_vec()
554            })
555            .into_iter()
556            .map(MpcStarkPointResult::from)
557            .collect_vec()
558    }
559
560    /// Multiply a batch of `MpcScalarResult`s by the generator
561    pub fn batch_mul_generator(a: &[MpcScalarResult]) -> Vec<MpcStarkPointResult> {
562        if a.is_empty() {
563            return Vec::new();
564        }
565
566        let n = a.len();
567        let fabric = a[0].fabric();
568        let all_ids = a.iter().map(|v| v.id()).collect_vec();
569
570        // Multiply the shares in a batch gate
571        fabric
572            .new_batch_gate_op(all_ids, n /* output_arity */, move |args| {
573                let scalars = args.into_iter().map(Scalar::from).collect_vec();
574                let generator = StarkPoint::generator();
575
576                scalars
577                    .into_iter()
578                    .map(|x| x * generator)
579                    .map(ResultValue::Point)
580                    .collect_vec()
581            })
582            .into_iter()
583            .map(MpcStarkPointResult::from)
584            .collect_vec()
585    }
586}