1use std::ops::{Add, Mul, Neg, Sub};
5
6use itertools::Itertools;
7
8use crate::{
9 fabric::{ResultHandle, ResultValue},
10 network::NetworkPayload,
11 MpcFabric, ResultId, PARTY0,
12};
13
14use super::{
15 macros::{impl_borrow_variants, impl_commutative},
16 mpc_scalar::MpcScalarResult,
17 scalar::{Scalar, ScalarResult},
18 stark_curve::{BatchStarkPointResult, StarkPoint, StarkPointResult},
19};
20
21#[derive(Clone, Debug)]
23pub struct MpcStarkPointResult {
24 pub(crate) share: StarkPointResult,
26}
27
28impl From<StarkPointResult> for MpcStarkPointResult {
29 fn from(value: StarkPointResult) -> Self {
30 Self { share: value }
31 }
32}
33
34impl MpcStarkPointResult {
36 pub fn new_shared(value: StarkPointResult) -> MpcStarkPointResult {
38 MpcStarkPointResult { share: value }
39 }
40
41 pub fn id(&self) -> ResultId {
43 self.share.id
44 }
45
46 pub fn fabric(&self) -> &MpcFabric {
48 self.share.fabric()
49 }
50
51 pub fn open(&self) -> ResultHandle<StarkPoint> {
53 let send_my_share =
54 |args: Vec<ResultValue>| NetworkPayload::Point(args[0].to_owned().into());
55
56 let (share0, share1): (StarkPointResult, StarkPointResult) =
58 if self.fabric().party_id() == PARTY0 {
59 let party0_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
60 let party1_value = self.fabric().receive_value();
61
62 (party0_value, party1_value)
63 } else {
64 let party0_value = self.fabric().receive_value();
65 let party1_value = self.fabric().new_network_op(vec![self.id()], send_my_share);
66
67 (party0_value, party1_value)
68 };
69
70 share0 + share1
71 }
72
73 pub fn open_batch(values: &[MpcStarkPointResult]) -> Vec<StarkPointResult> {
75 if values.is_empty() {
76 return Vec::new();
77 }
78
79 let n = values.len();
80 let fabric = &values[0].fabric();
81 let all_ids = values.iter().map(|v| v.id()).collect_vec();
82 let send_my_shares = |args: Vec<ResultValue>| {
83 NetworkPayload::PointBatch(args.into_iter().map(|arg| arg.into()).collect_vec())
84 };
85
86 let (party0_values, party1_values): (BatchStarkPointResult, BatchStarkPointResult) =
88 if fabric.party_id() == PARTY0 {
89 let party0_values = fabric.new_network_op(all_ids, send_my_shares);
90 let party1_values = fabric.receive_value();
91
92 (party0_values, party1_values)
93 } else {
94 let party0_values = fabric.receive_value();
95 let party1_values = fabric.new_network_op(all_ids, send_my_shares);
96
97 (party0_values, party1_values)
98 };
99
100 fabric.new_batch_gate_op(
102 vec![party0_values.id(), party1_values.id()],
103 n, |mut args| {
105 let party0_values: Vec<StarkPoint> = args.remove(0).into();
106 let party1_values: Vec<StarkPoint> = args.remove(0).into();
107
108 party0_values
109 .into_iter()
110 .zip(party1_values.into_iter())
111 .map(|(x, y)| x + y)
112 .map(ResultValue::Point)
113 .collect_vec()
114 },
115 )
116 }
117}
118
119impl Add<&StarkPoint> for &MpcStarkPointResult {
126 type Output = MpcStarkPointResult;
127
128 fn add(self, rhs: &StarkPoint) -> Self::Output {
130 let rhs = *rhs;
131 let party_id = self.fabric().party_id();
132 self.fabric()
133 .new_gate_op(vec![self.id()], move |args| {
134 let lhs: StarkPoint = args[0].to_owned().into();
135
136 if party_id == PARTY0 {
137 ResultValue::Point(lhs + rhs)
138 } else {
139 ResultValue::Point(lhs)
140 }
141 })
142 .into()
143 }
144}
145impl_borrow_variants!(MpcStarkPointResult, Add, add, +, StarkPoint);
146impl_commutative!(MpcStarkPointResult, Add, add, +, StarkPoint);
147
148impl Add<&StarkPointResult> for &MpcStarkPointResult {
149 type Output = MpcStarkPointResult;
150
151 fn add(self, rhs: &StarkPointResult) -> Self::Output {
153 let party_id = self.fabric().party_id();
154 self.fabric()
155 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
156 let lhs: StarkPoint = args.remove(0).into();
157 let rhs: StarkPoint = args.remove(0).into();
158
159 if party_id == PARTY0 {
160 ResultValue::Point(lhs + rhs)
161 } else {
162 ResultValue::Point(lhs)
163 }
164 })
165 .into()
166 }
167}
168impl_borrow_variants!(MpcStarkPointResult, Add, add, +, StarkPointResult);
169impl_commutative!(MpcStarkPointResult, Add, add, +, StarkPointResult);
170
171impl Add<&MpcStarkPointResult> for &MpcStarkPointResult {
172 type Output = MpcStarkPointResult;
173
174 fn add(self, rhs: &MpcStarkPointResult) -> Self::Output {
175 self.fabric()
176 .new_gate_op(vec![self.id(), rhs.id()], |args| {
177 let lhs: StarkPoint = args[0].to_owned().into();
178 let rhs: StarkPoint = args[1].to_owned().into();
179
180 ResultValue::Point(lhs + rhs)
181 })
182 .into()
183 }
184}
185impl_borrow_variants!(MpcStarkPointResult, Add, add, +, MpcStarkPointResult);
186
187impl MpcStarkPointResult {
188 pub fn batch_add(
190 a: &[MpcStarkPointResult],
191 b: &[MpcStarkPointResult],
192 ) -> Vec<MpcStarkPointResult> {
193 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
194 if a.is_empty() {
195 return Vec::new();
196 }
197
198 let n = a.len();
199 let fabric = a[0].fabric();
200 let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
201
202 fabric
204 .new_batch_gate_op(all_ids, n , move |args| {
205 let points = args.into_iter().map(StarkPoint::from).collect_vec();
206 let (a, b) = points.split_at(n);
207
208 a.iter()
209 .zip(b.iter())
210 .map(|(x, y)| x + y)
211 .map(ResultValue::Point)
212 .collect_vec()
213 })
214 .into_iter()
215 .map(MpcStarkPointResult::from)
216 .collect_vec()
217 }
218
219 pub fn batch_add_public(
221 a: &[MpcStarkPointResult],
222 b: &[StarkPointResult],
223 ) -> Vec<MpcStarkPointResult> {
224 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
225 if a.is_empty() {
226 return Vec::new();
227 }
228
229 let n = a.len();
230 let fabric = a[0].fabric();
231 let all_ids = a
232 .iter()
233 .map(|v| v.id())
234 .chain(b.iter().map(|b| b.id))
235 .collect_vec();
236
237 let party_id = fabric.party_id();
239 fabric
240 .new_batch_gate_op(all_ids, n , move |mut args| {
241 let lhs_points = args.drain(..n).map(StarkPoint::from).collect_vec();
242 let rhs_points = args.into_iter().map(StarkPoint::from).collect_vec();
243
244 lhs_points
245 .into_iter()
246 .zip(rhs_points.into_iter())
247 .map(|(x, y)| if party_id == PARTY0 { x + y } else { x })
248 .map(ResultValue::Point)
249 .collect_vec()
250 })
251 .into_iter()
252 .map(MpcStarkPointResult::from)
253 .collect_vec()
254 }
255}
256
257impl Sub<&StarkPoint> for &MpcStarkPointResult {
260 type Output = MpcStarkPointResult;
261
262 fn sub(self, rhs: &StarkPoint) -> Self::Output {
264 let rhs = *rhs;
265 let party_id = self.fabric().party_id();
266 self.fabric()
267 .new_gate_op(vec![self.id()], move |args| {
268 let lhs: StarkPoint = args[0].to_owned().into();
269
270 if party_id == PARTY0 {
271 ResultValue::Point(lhs - rhs)
272 } else {
273 ResultValue::Point(lhs)
274 }
275 })
276 .into()
277 }
278}
279impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, StarkPoint);
280
281impl Sub<&StarkPointResult> for &MpcStarkPointResult {
282 type Output = MpcStarkPointResult;
283
284 fn sub(self, rhs: &StarkPointResult) -> Self::Output {
285 let party_id = self.fabric().party_id();
286 self.fabric()
287 .new_gate_op(vec![self.id(), rhs.id()], move |mut args| {
288 let lhs: StarkPoint = args.remove(0).into();
289 let rhs: StarkPoint = args.remove(0).into();
290
291 if party_id == PARTY0 {
292 ResultValue::Point(lhs - rhs)
293 } else {
294 ResultValue::Point(lhs)
295 }
296 })
297 .into()
298 }
299}
300impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, StarkPointResult);
301
302impl Sub<&MpcStarkPointResult> for &MpcStarkPointResult {
303 type Output = MpcStarkPointResult;
304
305 fn sub(self, rhs: &MpcStarkPointResult) -> Self::Output {
306 self.fabric()
307 .new_gate_op(vec![self.id(), rhs.id()], |args| {
308 let lhs: StarkPoint = args[0].to_owned().into();
309 let rhs: StarkPoint = args[1].to_owned().into();
310
311 ResultValue::Point(lhs - rhs)
312 })
313 .into()
314 }
315}
316impl_borrow_variants!(MpcStarkPointResult, Sub, sub, -, MpcStarkPointResult);
317
318impl MpcStarkPointResult {
319 pub fn batch_sub(
321 a: &[MpcStarkPointResult],
322 b: &[MpcStarkPointResult],
323 ) -> Vec<MpcStarkPointResult> {
324 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
325 if a.is_empty() {
326 return Vec::new();
327 }
328
329 let n = a.len();
330 let fabric = a[0].fabric();
331 let all_ids = a.iter().chain(b.iter()).map(|v| v.id()).collect_vec();
332
333 fabric
335 .new_batch_gate_op(all_ids, n , move |args| {
336 let points = args.into_iter().map(StarkPoint::from).collect_vec();
337 let (a, b) = points.split_at(n);
338
339 a.iter()
340 .zip(b.iter())
341 .map(|(x, y)| x - y)
342 .map(ResultValue::Point)
343 .collect_vec()
344 })
345 .into_iter()
346 .map(MpcStarkPointResult::from)
347 .collect_vec()
348 }
349
350 pub fn batch_sub_public(
352 a: &[MpcStarkPointResult],
353 b: &[StarkPointResult],
354 ) -> Vec<MpcStarkPointResult> {
355 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
356 if a.is_empty() {
357 return Vec::new();
358 }
359
360 let n = a.len();
361 let fabric = a[0].fabric();
362 let all_ids = a
363 .iter()
364 .map(|v| v.id())
365 .chain(b.iter().map(|b| b.id))
366 .collect_vec();
367
368 let party_id = fabric.party_id();
370 fabric
371 .new_batch_gate_op(all_ids, n , move |mut args| {
372 let lhs_points = args.drain(..n).map(StarkPoint::from).collect_vec();
373 let rhs_points = args.into_iter().map(StarkPoint::from).collect_vec();
374
375 lhs_points
376 .into_iter()
377 .zip(rhs_points.into_iter())
378 .map(|(x, y)| if party_id == PARTY0 { x - y } else { x })
379 .map(ResultValue::Point)
380 .collect_vec()
381 })
382 .into_iter()
383 .map(MpcStarkPointResult::from)
384 .collect_vec()
385 }
386}
387
388impl Neg for &MpcStarkPointResult {
391 type Output = MpcStarkPointResult;
392
393 fn neg(self) -> Self::Output {
394 self.fabric()
395 .new_gate_op(vec![self.id()], |mut args| {
396 let mpc_val: StarkPoint = args.remove(0).into();
397 ResultValue::Point(-mpc_val)
398 })
399 .into()
400 }
401}
402impl_borrow_variants!(MpcStarkPointResult, Neg, neg, -);
403
404impl MpcStarkPointResult {
405 pub fn batch_neg(values: &[MpcStarkPointResult]) -> Vec<MpcStarkPointResult> {
407 if values.is_empty() {
408 return Vec::new();
409 }
410
411 let n = values.len();
412 let fabric = values[0].fabric();
413 let all_ids = values.iter().map(|v| v.id()).collect_vec();
414
415 fabric
417 .new_batch_gate_op(all_ids, n , move |args| {
418 let points = args.into_iter().map(StarkPoint::from).collect_vec();
419
420 points
421 .into_iter()
422 .map(|x| -x)
423 .map(ResultValue::Point)
424 .collect_vec()
425 })
426 .into_iter()
427 .map(MpcStarkPointResult::from)
428 .collect_vec()
429 }
430}
431
432impl Mul<&Scalar> for &MpcStarkPointResult {
435 type Output = MpcStarkPointResult;
436
437 fn mul(self, rhs: &Scalar) -> Self::Output {
438 let rhs = *rhs;
439 self.fabric()
440 .new_gate_op(vec![self.id()], move |args| {
441 let lhs: StarkPoint = args[0].to_owned().into();
442 ResultValue::Point(lhs * rhs)
443 })
444 .into()
445 }
446}
447impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, Scalar);
448impl_commutative!(MpcStarkPointResult, Mul, mul, *, Scalar);
449
450impl Mul<&ScalarResult> for &MpcStarkPointResult {
451 type Output = MpcStarkPointResult;
452
453 fn mul(self, rhs: &ScalarResult) -> Self::Output {
454 self.fabric()
455 .new_gate_op(vec![self.id(), rhs.id()], |mut args| {
456 let lhs: StarkPoint = args.remove(0).into();
457 let rhs: Scalar = args.remove(0).into();
458
459 ResultValue::Point(lhs * rhs)
460 })
461 .into()
462 }
463}
464impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, ScalarResult);
465impl_commutative!(MpcStarkPointResult, Mul, mul, *, ScalarResult);
466
467impl Mul<&MpcScalarResult> for &MpcStarkPointResult {
468 type Output = MpcStarkPointResult;
469
470 fn mul(self, rhs: &MpcScalarResult) -> Self::Output {
472 let generator = StarkPoint::generator();
473 let (a, b, c) = self.fabric().next_beaver_triple();
474
475 let masked_rhs = rhs - &a;
477 let masked_lhs = self - (&generator * &b);
478
479 #[allow(non_snake_case)]
480 let eG_open = masked_lhs.open();
481 let d_open = masked_rhs.open();
482
483 &d_open * &eG_open + &d_open * &(&generator * &b) + &a * eG_open + &c * generator
485 }
486}
487impl_borrow_variants!(MpcStarkPointResult, Mul, mul, *, MpcScalarResult);
488impl_commutative!(MpcStarkPointResult, Mul, mul, *, MpcScalarResult);
489
490impl MpcStarkPointResult {
491 #[allow(non_snake_case)]
493 pub fn batch_mul(a: &[MpcScalarResult], b: &[MpcStarkPointResult]) -> Vec<MpcStarkPointResult> {
494 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
495 if a.is_empty() {
496 return Vec::new();
497 }
498
499 let n = a.len();
500 let fabric = a[0].fabric();
501
502 let (beaver_a, beaver_b, beaver_c) = fabric.next_beaver_triple_batch(n);
504 let beaver_b_gen = MpcStarkPointResult::batch_mul_generator(&beaver_b);
505
506 let masked_rhs = MpcScalarResult::batch_sub(a, &beaver_a);
507 let masked_lhs = MpcStarkPointResult::batch_sub(b, &beaver_b_gen);
508
509 let eG_open = MpcStarkPointResult::open_batch(&masked_lhs);
510 let d_open = MpcScalarResult::open_batch(&masked_rhs);
511
512 let deG = StarkPointResult::batch_mul(&d_open, &eG_open);
514 let dbG = MpcStarkPointResult::batch_mul_public(&d_open, &beaver_b_gen);
515 let aeG = StarkPointResult::batch_mul_shared(&beaver_a, &eG_open);
516 let cG = MpcStarkPointResult::batch_mul_generator(&beaver_c);
517
518 let de_db_G = MpcStarkPointResult::batch_add_public(&dbG, &deG);
519 let ae_c_G = MpcStarkPointResult::batch_add(&aeG, &cG);
520
521 MpcStarkPointResult::batch_add(&de_db_G, &ae_c_G)
522 }
523
524 pub fn batch_mul_public(
526 a: &[ScalarResult],
527 b: &[MpcStarkPointResult],
528 ) -> Vec<MpcStarkPointResult> {
529 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
530 if a.is_empty() {
531 return Vec::new();
532 }
533
534 let n = a.len();
535 let fabric = a[0].fabric();
536 let all_ids = a
537 .iter()
538 .map(|v| v.id())
539 .chain(b.iter().map(|b| b.id()))
540 .collect_vec();
541
542 fabric
544 .new_batch_gate_op(all_ids, n , move |mut args| {
545 let scalars = args.drain(..n).map(Scalar::from).collect_vec();
546 let points = args.into_iter().map(StarkPoint::from).collect_vec();
547
548 scalars
549 .into_iter()
550 .zip(points.into_iter())
551 .map(|(x, y)| x * y)
552 .map(ResultValue::Point)
553 .collect_vec()
554 })
555 .into_iter()
556 .map(MpcStarkPointResult::from)
557 .collect_vec()
558 }
559
560 pub fn batch_mul_generator(a: &[MpcScalarResult]) -> Vec<MpcStarkPointResult> {
562 if a.is_empty() {
563 return Vec::new();
564 }
565
566 let n = a.len();
567 let fabric = a[0].fabric();
568 let all_ids = a.iter().map(|v| v.id()).collect_vec();
569
570 fabric
572 .new_batch_gate_op(all_ids, n , move |args| {
573 let scalars = args.into_iter().map(Scalar::from).collect_vec();
574 let generator = StarkPoint::generator();
575
576 scalars
577 .into_iter()
578 .map(|x| x * generator)
579 .map(ResultValue::Point)
580 .collect_vec()
581 })
582 .into_iter()
583 .map(MpcStarkPointResult::from)
584 .collect_vec()
585 }
586}