1use std::ops::{Add, Mul, Neg, Sub};
5
6use ark_ec::CurveGroup;
7use itertools::Itertools;
8
9use crate::{
10 algebra::macros::*, algebra::scalar::*, fabric::ResultValue, network::NetworkPayload,
11 MpcFabric, ResultId, PARTY0,
12};
13
14use super::curve::{BatchCurvePointResult, CurvePoint, CurvePointResult};
15
16#[derive(Clone, Debug)]
18pub struct MpcPointResult<C: CurveGroup> {
19 pub(crate) share: CurvePointResult<C>,
21}
22
23impl<C: CurveGroup> From<CurvePointResult<C>> for MpcPointResult<C> {
24 fn from(value: CurvePointResult<C>) -> Self {
25 Self { share: value }
26 }
27}
28
29impl<C: CurveGroup> MpcPointResult<C> {
32 pub fn new_shared(value: CurvePointResult<C>) -> MpcPointResult<C> {
35 MpcPointResult { share: value }
36 }
37
38 pub fn id(&self) -> ResultId {
40 self.share.id
41 }
42
43 pub fn fabric(&self) -> &MpcFabric<C> {
45 self.share.fabric()
46 }
47
48 pub fn open(&self) -> CurvePointResult<C> {
50 let send_my_share =
51 |args: Vec<ResultValue<C>>| NetworkPayload::Point(args[0].to_owned().into());
52
53 let (share0, share1): (CurvePointResult<C>, CurvePointResult<C>) =
55 if self.fabric().party_id() == PARTY0 {
56 let party0_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
57 let party1_value = self.fabric().receive_value();
58
59 (party0_value, party1_value)
60 } else {
61 let party0_value = self.fabric().receive_value();
62 let party1_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
63
64 (party0_value, party1_value)
65 };
66
67 share0 + share1
68 }
69
70 pub fn open_batch(values: &[MpcPointResult<C>]) -> Vec<CurvePointResult<C>> {
72 if values.is_empty() {
73 return Vec::new();
74 }
75
76 let n = values.len();
77 let fabric = &values[0].fabric();
78 let all_ids = values.iter().map(|v| v.id()).collect_vec();
79 let send_my_shares = |args: Vec<ResultValue<C>>| {
80 NetworkPayload::PointBatch(args.into_iter().map(|arg| arg.into()).collect_vec())
81 };
82
83 let (party0_values, party1_values): (BatchCurvePointResult<C>, BatchCurvePointResult<C>) =
85 if fabric.party_id() == PARTY0 {
86 let party0_values = fabric.new_network_op(all_ids, send_my_shares);
87 let party1_values = fabric.receive_value();
88
89 (party0_values, party1_values)
90 } else {
91 let party0_values = fabric.receive_value();
92 let party1_values = fabric.new_network_op(all_ids, send_my_shares);
93
94 (party0_values, party1_values)
95 };
96
97 fabric.new_batch_gate_op(
99 vec![party0_values.id(), party1_values.id()],
100 n, |mut args| {
102 let party0_values: Vec<CurvePoint<C>> = args.remove(0).into();
103 let party1_values: Vec<CurvePoint<C>> = args.remove(0).into();
104
105 party0_values
106 .into_iter()
107 .zip(party1_values)
108 .map(|(x, y)| x + y)
109 .map(ResultValue::Point)
110 .collect_vec()
111 },
112 )
113 }
114}
115
116impl<C: CurveGroup> Add<&CurvePoint<C>> for &MpcPointResult<C> {
123 type Output = MpcPointResult<C>;
124
125 fn add(self, rhs: &CurvePoint<C>) -> Self::Output {
127 let rhs = *rhs;
128 let party_id = self.fabric().party_id();
129 self.fabric()
130 .new_gate_op(vec![self.id()], move |args| {
131 let lhs: CurvePoint<C> = args[0].to_owned().into();
132
133 if party_id == PARTY0 {
134 ResultValue::Point(lhs + rhs)
135 } else {
136 ResultValue::Point(lhs)
137 }
138 })
139 .into()
140 }
141}
142impl_borrow_variants!(MpcPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
143impl_commutative!(MpcPointResult<C>, Add, add, +, CurvePoint<C>, C: CurveGroup);
144
145impl<C: CurveGroup> Add<&CurvePointResult<C>> for &MpcPointResult<C> {
146 type Output = MpcPointResult<C>;
147
148 fn add(self, rhs: &CurvePointResult<C>) -> Self::Output {
150 let party_id = self.fabric().party_id();
151 self.fabric()
152 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
153 let lhs: CurvePoint<C> = args.remove(0).into();
154 let rhs: CurvePoint<C> = args.remove(0).into();
155
156 if party_id == PARTY0 {
157 ResultValue::Point(lhs + rhs)
158 } else {
159 ResultValue::Point(lhs)
160 }
161 })
162 .into()
163 }
164}
165impl_borrow_variants!(MpcPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
166impl_commutative!(MpcPointResult<C>, Add, add, +, CurvePointResult<C>, C: CurveGroup);
167
168impl<C: CurveGroup> Add<&MpcPointResult<C>> for &MpcPointResult<C> {
169 type Output = MpcPointResult<C>;
170
171 fn add(self, rhs: &MpcPointResult<C>) -> Self::Output {
172 self.fabric()
173 .new_gate_op(vec![self.id(), rhs.id()], |args| {
174 let lhs: CurvePoint<C> = args[0].to_owned().into();
175 let rhs: CurvePoint<C> = args[1].to_owned().into();
176
177 ResultValue::Point(lhs + rhs)
178 })
179 .into()
180 }
181}
182impl_borrow_variants!(MpcPointResult<C>, Add, add, +, MpcPointResult<C>, C: CurveGroup);
183
184impl<C: CurveGroup> MpcPointResult<C> {
185 pub fn batch_add(a: &[MpcPointResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
187 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
188 if a.is_empty() {
189 return Vec::new();
190 }
191
192 let n = a.len();
193 let fabric = a[0].fabric();
194 let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
195
196 fabric
198 .new_batch_gate_op(all_ids, n , move |args| {
199 let points = args.into_iter().map(CurvePoint::from).collect_vec();
200 let (a, b) = points.split_at(n);
201
202 a.iter()
203 .zip(b.iter())
204 .map(|(x, y)| x + y)
205 .map(ResultValue::Point)
206 .collect_vec()
207 })
208 .into_iter()
209 .map(MpcPointResult::from)
210 .collect_vec()
211 }
212
213 pub fn batch_add_public(
215 a: &[MpcPointResult<C>],
216 b: &[CurvePointResult<C>],
217 ) -> Vec<MpcPointResult<C>> {
218 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
219 if a.is_empty() {
220 return Vec::new();
221 }
222
223 let n = a.len();
224 let fabric = a[0].fabric();
225 let all_ids = a
226 .iter()
227 .map(|v| v.id())
228 .chain(b.iter().map(|b| b.id))
229 .collect_vec();
230
231 let party_id = fabric.party_id();
233 fabric
234 .new_batch_gate_op(all_ids, n , move |mut args| {
235 let lhs_points = args.drain(..n).map(CurvePoint::from).collect_vec();
236 let rhs_points = args.into_iter().map(CurvePoint::from).collect_vec();
237
238 lhs_points
239 .into_iter()
240 .zip(rhs_points)
241 .map(|(x, y)| if party_id == PARTY0 { x + y } else { x })
242 .map(ResultValue::Point)
243 .collect_vec()
244 })
245 .into_iter()
246 .map(MpcPointResult::from)
247 .collect_vec()
248 }
249}
250
251impl<C: CurveGroup> Sub<&CurvePoint<C>> for &MpcPointResult<C> {
254 type Output = MpcPointResult<C>;
255
256 fn sub(self, rhs: &CurvePoint<C>) -> Self::Output {
258 let rhs = *rhs;
259 let party_id = self.fabric().party_id();
260 self.fabric()
261 .new_gate_op(vec![self.id()], move |args| {
262 let lhs: CurvePoint<C> = args[0].to_owned().into();
263
264 if party_id == PARTY0 {
265 ResultValue::Point(lhs - rhs)
266 } else {
267 ResultValue::Point(lhs)
268 }
269 })
270 .into()
271 }
272}
273impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, CurvePoint<C>, C: CurveGroup);
274
275impl<C: CurveGroup> Sub<&CurvePointResult<C>> for &MpcPointResult<C> {
276 type Output = MpcPointResult<C>;
277
278 fn sub(self, rhs: &CurvePointResult<C>) -> Self::Output {
279 let party_id = self.fabric().party_id();
280 self.fabric()
281 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
282 let lhs: CurvePoint<C> = args.remove(0).into();
283 let rhs: CurvePoint<C> = args.remove(0).into();
284
285 if party_id == PARTY0 {
286 ResultValue::Point(lhs - rhs)
287 } else {
288 ResultValue::Point(lhs)
289 }
290 })
291 .into()
292 }
293}
294impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, CurvePointResult<C>, C: CurveGroup);
295
296impl<C: CurveGroup> Sub<&MpcPointResult<C>> for &MpcPointResult<C> {
297 type Output = MpcPointResult<C>;
298
299 fn sub(self, rhs: &MpcPointResult<C>) -> Self::Output {
300 self.fabric()
301 .new_gate_op(vec![self.id(), rhs.id()], |args| {
302 let lhs: CurvePoint<C> = args[0].to_owned().into();
303 let rhs: CurvePoint<C> = args[1].to_owned().into();
304
305 ResultValue::Point(lhs - rhs)
306 })
307 .into()
308 }
309}
310impl_borrow_variants!(MpcPointResult<C>, Sub, sub, -, MpcPointResult<C>, C: CurveGroup);
311
312impl<C: CurveGroup> MpcPointResult<C> {
313 pub fn batch_sub(a: &[MpcPointResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
315 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
316 if a.is_empty() {
317 return Vec::new();
318 }
319
320 let n = a.len();
321 let fabric = a[0].fabric();
322 let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
323
324 fabric
326 .new_batch_gate_op(all_ids, n , move |args| {
327 let points = args.into_iter().map(CurvePoint::from).collect_vec();
328 let (a, b) = points.split_at(n);
329
330 a.iter()
331 .zip(b.iter())
332 .map(|(x, y)| x - y)
333 .map(ResultValue::Point)
334 .collect_vec()
335 })
336 .into_iter()
337 .map(MpcPointResult::from)
338 .collect_vec()
339 }
340
341 pub fn batch_sub_public(
343 a: &[MpcPointResult<C>],
344 b: &[CurvePointResult<C>],
345 ) -> Vec<MpcPointResult<C>> {
346 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
347 if a.is_empty() {
348 return Vec::new();
349 }
350
351 let n = a.len();
352 let fabric = a[0].fabric();
353 let all_ids = a
354 .iter()
355 .map(|v| v.id())
356 .chain(b.iter().map(|b| b.id))
357 .collect_vec();
358
359 let party_id = fabric.party_id();
361 fabric
362 .new_batch_gate_op(all_ids, n , move |mut args| {
363 let lhs_points = args.drain(..n).map(CurvePoint::from).collect_vec();
364 let rhs_points = args.into_iter().map(CurvePoint::from).collect_vec();
365
366 lhs_points
367 .into_iter()
368 .zip(rhs_points)
369 .map(|(x, y)| if party_id == PARTY0 { x - y } else { x })
370 .map(ResultValue::Point)
371 .collect_vec()
372 })
373 .into_iter()
374 .map(MpcPointResult::from)
375 .collect_vec()
376 }
377}
378
379impl<C: CurveGroup> Neg for &MpcPointResult<C> {
382 type Output = MpcPointResult<C>;
383
384 fn neg(self) -> Self::Output {
385 self.fabric()
386 .new_gate_op(vec![self.id()], |mut args| {
387 let mpc_val: CurvePoint<C> = args.remove(0).into();
388 ResultValue::Point(-mpc_val)
389 })
390 .into()
391 }
392}
393impl_borrow_variants!(MpcPointResult<C>, Neg, neg, -, C: CurveGroup);
394
395impl<C: CurveGroup> MpcPointResult<C> {
396 pub fn batch_neg(values: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
398 if values.is_empty() {
399 return Vec::new();
400 }
401
402 let n = values.len();
403 let fabric = values[0].fabric();
404 let all_ids = values.iter().map(|v| v.id()).collect_vec();
405
406 fabric
408 .new_batch_gate_op(all_ids, n , move |args| {
409 let points = args.into_iter().map(CurvePoint::from).collect_vec();
410
411 points
412 .into_iter()
413 .map(|x| -x)
414 .map(ResultValue::Point)
415 .collect_vec()
416 })
417 .into_iter()
418 .map(MpcPointResult::from)
419 .collect_vec()
420 }
421}
422
423impl<C: CurveGroup> Mul<&Scalar<C>> for &MpcPointResult<C> {
426 type Output = MpcPointResult<C>;
427
428 fn mul(self, rhs: &Scalar<C>) -> Self::Output {
429 let rhs = *rhs;
430 self.fabric()
431 .new_gate_op(vec![self.id()], move |args| {
432 let lhs: CurvePoint<C> = args[0].to_owned().into();
433 ResultValue::Point(lhs * rhs)
434 })
435 .into()
436 }
437}
438impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
439impl_commutative!(MpcPointResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
440
441impl<C: CurveGroup> Mul<&ScalarResult<C>> for &MpcPointResult<C> {
442 type Output = MpcPointResult<C>;
443
444 fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
445 self.fabric()
446 .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
447 let lhs: CurvePoint<C> = args.remove(0).into();
448 let rhs: Scalar<C> = args.remove(0).into();
449
450 ResultValue::Point(lhs * rhs)
451 })
452 .into()
453 }
454}
455impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
456impl_commutative!(MpcPointResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
457
458impl<C: CurveGroup> Mul<&MpcScalarResult<C>> for &MpcPointResult<C> {
459 type Output = MpcPointResult<C>;
460
461 fn mul(self, rhs: &MpcScalarResult<C>) -> Self::Output {
463 let generator = CurvePoint::generator();
464 let (a, b, c) = self.fabric().next_beaver_triple();
465
466 let masked_rhs = rhs - &a;
468 let masked_lhs = self - (&generator * &b);
469
470 #[allow(non_snake_case)]
471 let eG_open = masked_lhs.open();
472 let d_open = masked_rhs.open();
473
474 &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
476 }
477}
478impl_borrow_variants!(MpcPointResult<C>, Mul, mul, *, MpcScalarResult<C>, C: CurveGroup);
479impl_commutative!(MpcPointResult<C>, Mul, mul, *, MpcScalarResult<C>, C:CurveGroup);
480
481impl<C: CurveGroup> MpcPointResult<C> {
482 #[allow(non_snake_case)]
484 pub fn batch_mul(a: &[MpcScalarResult<C>], b: &[MpcPointResult<C>]) -> Vec<MpcPointResult<C>> {
485 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
486 if a.is_empty() {
487 return Vec::new();
488 }
489
490 let n = a.len();
491 let fabric = a[0].fabric();
492
493 let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
495 let beaver_b_gen = MpcPointResult::batch_mul_generator(&beaver_b);
496
497 let masked_rhs = MpcScalarResult::batch_sub(a, &beaver_a);
498 let masked_lhs = MpcPointResult::batch_sub(b, &beaver_b_gen);
499
500 let eG_open = MpcPointResult::open_batch(&masked_lhs);
501 let d_open = MpcScalarResult::open_batch(&masked_rhs);
502
503 let deG = CurvePointResult::batch_mul(&d_open, &eG_open);
505 let dbG = MpcPointResult::batch_mul_public(&d_open, &beaver_b_gen);
506 let aeG = CurvePointResult::batch_mul_shared(&beaver_a, &eG_open);
507 let cG = MpcPointResult::batch_mul_generator(&beaver_c);
508
509 let de_db_G = MpcPointResult::batch_add_public(&dbG, &deG);
510 let ae_c_G = MpcPointResult::batch_add(&aeG, &cG);
511
512 MpcPointResult::batch_add(&de_db_G, &ae_c_G)
513 }
514
515 pub fn batch_mul_public(
517 a: &[ScalarResult<C>],
518 b: &[MpcPointResult<C>],
519 ) -> Vec<MpcPointResult<C>> {
520 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
521 if a.is_empty() {
522 return Vec::new();
523 }
524
525 let n = a.len();
526 let fabric = a[0].fabric();
527 let all_ids = a
528 .iter()
529 .map(|v| v.id())
530 .chain(b.iter().map(|b| b.id()))
531 .collect_vec();
532
533 fabric
535 .new_batch_gate_op(all_ids, n , move |mut args| {
536 let scalars = args.drain(..n).map(Scalar::from).collect_vec();
537 let points = args.into_iter().map(CurvePoint::from).collect_vec();
538
539 scalars
540 .into_iter()
541 .zip(points)
542 .map(|(x, y)| x * y)
543 .map(ResultValue::Point)
544 .collect_vec()
545 })
546 .into_iter()
547 .map(MpcPointResult::from)
548 .collect_vec()
549 }
550
551 pub fn batch_mul_generator(a: &[MpcScalarResult<C>]) -> Vec<MpcPointResult<C>> {
553 if a.is_empty() {
554 return Vec::new();
555 }
556
557 let n = a.len();
558 let fabric = a[0].fabric();
559 let all_ids = a.iter().map(|v| v.id()).collect_vec();
560
561 fabric
563 .new_batch_gate_op(all_ids, n , move |args| {
564 let scalars = args.into_iter().map(Scalar::from).collect_vec();
565 let generator = CurvePoint::generator();
566
567 scalars
568 .into_iter()
569 .map(|x| x * generator)
570 .map(ResultValue::Point)
571 .collect_vec()
572 })
573 .into_iter()
574 .map(MpcPointResult::from)
575 .collect_vec()
576 }
577}