1use ark_ec::CurveGroup;
5use itertools::Itertools;
6
7use crate::{algebra::Scalar, network::PartyId};
8
9pub trait SharedValueSource<C: CurveGroup>: Send + Sync {
15 fn next_shared_bit(&mut self) -> Scalar<C>;
17 fn next_shared_bit_batch(&mut self, num_values: usize) -> Vec<Scalar<C>> {
19 (0..num_values)
20 .map(|_| self.next_shared_bit())
21 .collect_vec()
22 }
23 fn next_shared_value(&mut self) -> Scalar<C>;
25 fn next_shared_value_batch(&mut self, num_values: usize) -> Vec<Scalar<C>> {
27 (0..num_values)
28 .map(|_| self.next_shared_value())
29 .collect_vec()
30 }
31 fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>);
34 fn next_shared_inverse_pair_batch(
36 &mut self,
37 num_pairs: usize,
38 ) -> (Vec<Scalar<C>>, Vec<Scalar<C>>) {
39 (0..num_pairs)
40 .map(|_| self.next_shared_inverse_pair())
41 .unzip()
42 }
43 fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>);
45 #[allow(clippy::type_complexity)]
47 fn next_triplet_batch(
48 &mut self,
49 num_triplets: usize,
50 ) -> (Vec<Scalar<C>>, Vec<Scalar<C>>, Vec<Scalar<C>>) {
51 let mut a_vals = Vec::with_capacity(num_triplets);
52 let mut b_vals = Vec::with_capacity(num_triplets);
53 let mut c_vals = Vec::with_capacity(num_triplets);
54
55 for _ in 0..num_triplets {
56 let (a, b, c) = self.next_triplet();
57 a_vals.push(a);
58 b_vals.push(b);
59 c_vals.push(c);
60 }
61
62 (a_vals, b_vals, c_vals)
63 }
64}
65#[cfg(any(feature = "test_helpers", test))]
68#[derive(Clone, Debug, Default)]
69pub struct PartyIDBeaverSource {
70 party_id: u64,
72}
73
74#[cfg(any(feature = "test_helpers", test))]
75impl PartyIDBeaverSource {
76 pub fn new(party_id: u64) -> Self {
78 Self { party_id }
79 }
80}
81
82#[cfg(any(feature = "test_helpers", test))]
86impl<C: CurveGroup> SharedValueSource<C> for PartyIDBeaverSource {
87 fn next_shared_bit(&mut self) -> Scalar<C> {
88 assert!(self.party_id == 0 || self.party_id == 1);
90 Scalar::from(self.party_id)
91 }
92
93 fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>) {
94 if self.party_id == 0 {
95 (Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64))
96 } else {
97 (Scalar::from(1u64), Scalar::from(0u64), Scalar::from(4u64))
98 }
99 }
100
101 fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>) {
102 (Scalar::from(self.party_id), Scalar::from(self.party_id))
103 }
104
105 fn next_shared_value(&mut self) -> Scalar<C> {
106 Scalar::from(self.party_id)
107 }
108}
109
110#[cfg(any(feature = "test_helpers", test))]
112pub struct ZeroBeaverSource {
113 party_id: PartyId,
115}
116
117#[cfg(any(feature = "test_helpers", test))]
118impl ZeroBeaverSource {
119 pub fn new(party_id: PartyId) -> Self {
121 Self { party_id }
122 }
123}
124
125#[cfg(any(feature = "test_helpers", test))]
126impl<C: CurveGroup> SharedValueSource<C> for ZeroBeaverSource {
127 fn next_shared_bit(&mut self) -> Scalar<C> {
128 Scalar::zero()
129 }
130
131 fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>) {
132 (Scalar::zero(), Scalar::zero(), Scalar::zero())
133 }
134
135 fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>) {
140 let val = Scalar::from(self.party_id);
141 (val, val)
142 }
143
144 fn next_shared_value(&mut self) -> Scalar<C> {
145 Scalar::zero()
146 }
147}