mpc_stark/
fabric.rs

1//! Defines an MPC fabric for the protocol
2//!
3//! The fabric essentially acts as a dependency injection layer. That is, the MpcFabric
4//! creates and manages dependencies needed to allocate network values. This provides a
5//! cleaner interface for consumers of the library; i.e. clients do not have to hold onto
6//! references of the network layer or the beaver sources to allocate values.
7
8mod executor;
9mod network_sender;
10mod result;
11
12#[cfg(feature = "benchmarks")]
13pub use executor::{Executor, ExecutorMessage};
14#[cfg(not(feature = "benchmarks"))]
15use executor::{Executor, ExecutorMessage};
16use rand::thread_rng;
17pub use result::{ResultHandle, ResultId, ResultValue};
18
19use futures::executor::block_on;
20use tracing::log;
21
22use crossbeam::queue::SegQueue;
23use kanal::Sender as KanalSender;
24use std::{
25    fmt::{Debug, Formatter, Result as FmtResult},
26    sync::{
27        atomic::{AtomicUsize, Ordering},
28        Arc, Mutex,
29    },
30};
31use tokio::sync::broadcast::{self, Sender as BroadcastSender};
32
33use itertools::Itertools;
34
35use crate::{
36    algebra::{
37        authenticated_scalar::AuthenticatedScalarResult,
38        authenticated_stark_point::AuthenticatedStarkPointResult,
39        mpc_scalar::MpcScalarResult,
40        mpc_stark_point::MpcStarkPointResult,
41        scalar::{BatchScalarResult, Scalar, ScalarResult},
42        stark_curve::{BatchStarkPointResult, StarkPoint, StarkPointResult},
43    },
44    beaver::SharedValueSource,
45    network::{MpcNetwork, NetworkOutbound, NetworkPayload, PartyId},
46    PARTY0,
47};
48
49use self::{
50    network_sender::NetworkSender,
51    result::{OpResult, ResultWaiter},
52};
53
54/// The result id that is hardcoded to zero
55const RESULT_ZERO: ResultId = 0;
56/// The result id that is hardcoded to one
57const RESULT_ONE: ResultId = 1;
58/// The result id that is hardcoded to the curve identity point
59const RESULT_IDENTITY: ResultId = 2;
60
61/// The number of constant results allocated in the fabric, i.e. those defined above
62const N_CONSTANT_RESULTS: usize = 3;
63
64/// The default size hint to give the fabric for buffer pre-allocation
65const DEFAULT_SIZE_HINT: usize = 10_000;
66
67/// A type alias for the identifier used for a gate
68pub type OperationId = usize;
69
70/// An operation within the network, describes the arguments and function to evaluate
71/// once the arguments are ready
72///
73/// `N` represents the number of results that this operation outputs
74#[derive(Clone)]
75pub struct Operation {
76    /// Identifier of the result that this operation emits
77    id: OperationId,
78    /// The result ID of the first result in the outputs
79    result_id: ResultId,
80    /// The number of outputs this operation produces
81    output_arity: usize,
82    /// The number of arguments that are still in-flight for this operation
83    inflight_args: usize,
84    /// The IDs of the inputs to this operation
85    args: Vec<ResultId>,
86    /// The type of the operation
87    op_type: OperationType,
88}
89
90impl Operation {
91    /// Get the result IDs for an operation
92    pub fn result_ids(&self) -> Vec<ResultId> {
93        (self.result_id..self.result_id + self.output_arity).collect_vec()
94    }
95}
96
97impl Debug for Operation {
98    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
99        write!(f, "Operation {}", self.id)
100    }
101}
102
103/// Defines the different types of operations available in the computation graph
104pub enum OperationType {
105    /// A gate operation; may be evaluated locally given its ready inputs
106    Gate {
107        /// The function to apply to the inputs
108        function: Box<dyn FnOnce(Vec<ResultValue>) -> ResultValue + Send + Sync>,
109    },
110    /// A gate operation that has output arity greater than one
111    ///
112    /// We separate this out to avoid vector allocation for result values of arity one
113    GateBatch {
114        /// The function to apply to the inputs
115        function: Box<dyn FnOnce(Vec<ResultValue>) -> Vec<ResultValue> + Send + Sync>,
116    },
117    /// A network operation, requires that a value be sent over the network
118    Network {
119        /// The function to apply to the inputs to derive a Network payload
120        function: Box<dyn FnOnce(Vec<ResultValue>) -> NetworkPayload + Send + Sync>,
121    },
122}
123
124/// A clone implementation, never concretely called but used as a Marker type to allow
125/// pre-allocating buffer space for `Operation`s
126impl Clone for OperationType {
127    fn clone(&self) -> Self {
128        panic!("cannot clone `OperationType`")
129    }
130}
131
132impl Debug for OperationType {
133    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
134        match self {
135            OperationType::Gate { .. } => write!(f, "Gate"),
136            OperationType::GateBatch { .. } => write!(f, "GateBatch"),
137            OperationType::Network { .. } => write!(f, "Network"),
138        }
139    }
140}
141
142/// A fabric for the MPC protocol, defines a dependency injection layer that dynamically schedules
143/// circuit gate evaluations onto the network to be executed
144///
145/// The fabric does not block on gate evaluations, but instead returns a handle to a future result
146/// that may be polled to obtain the materialized result. This allows the application layer to
147/// continue using the fabric, scheduling more gates to be evaluated and maximally exploiting
148/// gate-level parallelism within the circuit
149#[derive(Clone)]
150pub struct MpcFabric {
151    /// The inner fabric
152    #[cfg(not(feature = "benchmarks"))]
153    inner: Arc<FabricInner>,
154    /// The inner fabric, accessible publicly for benchmark mocking
155    #[cfg(feature = "benchmarks")]
156    pub inner: Arc<FabricInner>,
157    /// The local party's share of the global MAC key
158    ///
159    /// The parties collectively hold an additive sharing of the global key
160    ///
161    /// We wrap in a reference counting structure to avoid recursive type issues
162    #[cfg(not(feature = "benchmarks"))]
163    mac_key: Option<Arc<MpcScalarResult>>,
164    /// The MAC key, accessible publicly for benchmark mocking
165    #[cfg(feature = "benchmarks")]
166    pub mac_key: Option<Arc<MpcScalarResult>>,
167    /// The channel on which shutdown messages are sent to blocking workers
168    #[cfg(not(feature = "benchmarks"))]
169    shutdown: BroadcastSender<()>,
170    /// The shutdown channel, made publicly available for benchmark mocking
171    #[cfg(feature = "benchmarks")]
172    pub shutdown: BroadcastSender<()>,
173}
174
175impl Debug for MpcFabric {
176    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
177        write!(f, "MpcFabric")
178    }
179}
180
181/// The inner component of the fabric, allows the constructor to allocate executor and network
182/// sender objects at the same level as the fabric
183#[derive(Clone)]
184pub struct FabricInner {
185    /// The ID of the local party in the MPC execution
186    party_id: u64,
187    /// The next identifier to assign to a result
188    next_result_id: Arc<AtomicUsize>,
189    /// The next identifier to assign to an operation
190    next_op_id: Arc<AtomicUsize>,
191    /// A sender to the executor
192    execution_queue: Arc<SegQueue<ExecutorMessage>>,
193    /// The underlying queue to the network
194    outbound_queue: KanalSender<NetworkOutbound>,
195    /// The underlying shared randomness source
196    beaver_source: Arc<Mutex<Box<dyn SharedValueSource>>>,
197}
198
199impl Debug for FabricInner {
200    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
201        write!(f, "FabricInner")
202    }
203}
204
205impl FabricInner {
206    /// Constructor
207    pub fn new<S: 'static + SharedValueSource>(
208        party_id: u64,
209        execution_queue: Arc<SegQueue<ExecutorMessage>>,
210        outbound_queue: KanalSender<NetworkOutbound>,
211        beaver_source: S,
212    ) -> Self {
213        // Allocate a zero and a one as well as the curve identity in the fabric to begin,
214        // for convenience
215        let zero = ResultValue::Scalar(Scalar::zero());
216        let one = ResultValue::Scalar(Scalar::one());
217        let identity = ResultValue::Point(StarkPoint::identity());
218
219        for initial_result in vec![
220            OpResult {
221                id: RESULT_ZERO,
222                value: zero,
223            },
224            OpResult {
225                id: RESULT_ONE,
226                value: one,
227            },
228            OpResult {
229                id: RESULT_IDENTITY,
230                value: identity,
231            },
232        ]
233        .into_iter()
234        {
235            execution_queue.push(ExecutorMessage::Result(initial_result));
236        }
237
238        let next_result_id = Arc::new(AtomicUsize::new(N_CONSTANT_RESULTS));
239        let next_op_id = Arc::new(AtomicUsize::new(0));
240
241        Self {
242            party_id,
243            next_result_id,
244            next_op_id,
245            execution_queue,
246            outbound_queue,
247            beaver_source: Arc::new(Mutex::new(Box::new(beaver_source))),
248        }
249    }
250
251    /// Register a waiter on a result    
252    pub(crate) fn register_waiter(&self, waiter: ResultWaiter) {
253        self.execution_queue
254            .push(ExecutorMessage::NewWaiter(waiter));
255    }
256
257    /// Shutdown the inner fabric, by sending a shutdown message to the executor
258    pub(crate) fn shutdown(&self) {
259        self.execution_queue.push(ExecutorMessage::Shutdown)
260    }
261
262    /// -----------
263    /// | Getters |
264    /// -----------
265
266    /// Increment the operation counter and return the existing value
267    fn new_result_id(&self) -> ResultId {
268        self.next_result_id.fetch_add(1, Ordering::Relaxed)
269    }
270
271    /// Increment the operation counter and return the existing value
272    fn new_op_id(&self) -> OperationId {
273        self.next_op_id.fetch_add(1, Ordering::Relaxed)
274    }
275
276    /// Get the hardcoded zero value in the fabric
277    pub(crate) fn zero(&self) -> ResultId {
278        RESULT_ZERO
279    }
280
281    /// Get the hardcoded one value in the fabric
282    pub(crate) fn one(&self) -> ResultId {
283        RESULT_ONE
284    }
285
286    /// Get the hardcoded curve identity value in the fabric
287    pub(crate) fn curve_identity(&self) -> ResultId {
288        RESULT_IDENTITY
289    }
290
291    // ------------------------
292    // | Low Level Allocation |
293    // ------------------------
294
295    /// Allocate a new plaintext value in the fabric
296    pub(crate) fn allocate_value(&self, value: ResultValue) -> ResultId {
297        // Forward the result to the executor
298        let id = self.new_result_id();
299        self.execution_queue
300            .push(ExecutorMessage::Result(OpResult { id, value }));
301
302        id
303    }
304
305    /// Allocate a secret shared value in the network
306    pub(crate) fn allocate_shared_value(
307        &self,
308        my_share: ResultValue,
309        their_share: ResultValue,
310    ) -> ResultId {
311        // Forward the local party's share to the executor
312        let id = self.new_result_id();
313        self.execution_queue.push(ExecutorMessage::Result(OpResult {
314            id,
315            value: my_share,
316        }));
317
318        // Send the counterparty their share
319        if let Err(e) = self.outbound_queue.send(NetworkOutbound {
320            result_id: id,
321            payload: their_share.into(),
322        }) {
323            log::error!("error sending share to counterparty: {e:?}");
324        }
325
326        id
327    }
328
329    /// Receive a value from a network operation initiated by a peer
330    ///
331    /// The peer will already send the value with the corresponding ID, so all that is needed
332    /// is to allocate a slot in the result buffer for the receipt
333    pub(crate) fn receive_value(&self) -> ResultId {
334        self.new_result_id()
335    }
336
337    // --------------
338    // | Operations |
339    // --------------
340
341    /// Allocate a new in-flight gate operation in the fabric
342    pub(crate) fn new_op(
343        &self,
344        args: Vec<ResultId>,
345        output_arity: usize,
346        op_type: OperationType,
347    ) -> Vec<ResultId> {
348        if matches!(op_type, OperationType::Gate { .. }) {
349            assert_eq!(output_arity, 1, "gate operations must have arity 1");
350        }
351
352        // Allocate IDs for the results
353        let ids = (0..output_arity)
354            .map(|_| self.new_result_id())
355            .collect_vec();
356
357        // Build the operation
358        let op = Operation {
359            id: self.new_op_id(),
360            result_id: ids[0],
361            output_arity,
362            args,
363            inflight_args: 0,
364            op_type,
365        };
366
367        // Forward the op to the executor
368        self.execution_queue.push(ExecutorMessage::Op(op));
369        ids
370    }
371}
372
373impl MpcFabric {
374    /// Constructor
375    pub fn new<N: 'static + MpcNetwork, S: 'static + SharedValueSource>(
376        network: N,
377        beaver_source: S,
378    ) -> Self {
379        Self::new_with_size_hint(DEFAULT_SIZE_HINT, network, beaver_source)
380    }
381
382    /// Constructor that takes an additional size hint, indicating how much buffer space
383    /// the fabric should allocate for results. The size is given in number of gates
384    pub fn new_with_size_hint<N: 'static + MpcNetwork, S: 'static + SharedValueSource>(
385        size_hint: usize,
386        network: N,
387        beaver_source: S,
388    ) -> Self {
389        // Build communication primitives
390        let execution_queue = Arc::new(SegQueue::new());
391
392        let (outbound_sender, outbound_receiver) = kanal::unbounded_async();
393        let (shutdown_sender, shutdown_receiver) = broadcast::channel(1 /* capacity */);
394
395        // Build a fabric
396        let fabric = FabricInner::new(
397            network.party_id(),
398            execution_queue.clone(),
399            outbound_sender.to_sync(),
400            beaver_source,
401        );
402
403        // Start a network sender and operator executor
404        let network_sender = NetworkSender::new(
405            outbound_receiver,
406            execution_queue.clone(),
407            network,
408            shutdown_receiver,
409        );
410        tokio::task::spawn_blocking(move || block_on(network_sender.run()));
411
412        let executor = Executor::new(size_hint, execution_queue, fabric.clone());
413        tokio::task::spawn_blocking(move || executor.run());
414
415        // Create the fabric and fill in the MAC key after
416        let mut self_ = Self {
417            inner: Arc::new(fabric.clone()),
418            shutdown: shutdown_sender,
419            mac_key: None,
420        };
421
422        // Sample a MAC key from the pre-shared values in the beaver source
423        let mac_key_id = fabric.allocate_value(ResultValue::Scalar(
424            fabric
425                .beaver_source
426                .lock()
427                .expect("beaver source poisoned")
428                .next_shared_value(),
429        ));
430        let mac_key = MpcScalarResult::new_shared(ResultHandle::new(mac_key_id, self_.clone()));
431
432        // Set the MAC key
433        self_.mac_key.replace(Arc::new(mac_key));
434
435        self_
436    }
437
438    /// Get the party ID of the local party
439    pub fn party_id(&self) -> PartyId {
440        self.inner.party_id
441    }
442
443    /// Shutdown the fabric and the threads it has spawned
444    pub fn shutdown(self) {
445        log::debug!("shutting down fabric");
446        self.inner.shutdown();
447        self.shutdown
448            .send(())
449            .expect("error sending shutdown signal");
450    }
451
452    /// Register a waiter on a result
453    pub fn register_waiter(&self, waiter: ResultWaiter) {
454        self.inner.register_waiter(waiter);
455    }
456
457    /// Immutably borrow the MAC key
458    pub(crate) fn borrow_mac_key(&self) -> &MpcScalarResult {
459        // Unwrap is safe, the constructor sets the MAC key
460        self.mac_key.as_ref().unwrap()
461    }
462
463    // ------------------------
464    // | Constants Allocation |
465    // ------------------------
466
467    /// Get the hardcoded zero wire as a raw `ScalarResult`
468    pub fn zero(&self) -> ScalarResult {
469        ResultHandle::new(self.inner.zero(), self.clone())
470    }
471
472    /// Get the shared zero value as an `MpcScalarResult`
473    fn zero_shared(&self) -> MpcScalarResult {
474        MpcScalarResult::new_shared(self.zero())
475    }
476
477    /// Get the hardcoded zero wire as an `AuthenticatedScalarResult`
478    ///
479    /// Both parties hold the share 0 directly in this case
480    pub fn zero_authenticated(&self) -> AuthenticatedScalarResult {
481        let zero_value = self.zero();
482        let share_value = self.zero_shared();
483        let mac_value = self.zero_shared();
484
485        AuthenticatedScalarResult {
486            share: share_value,
487            mac: mac_value,
488            public_modifier: zero_value,
489        }
490    }
491
492    /// Get a batch of references to the zero wire as an `AuthenticatedScalarResult`
493    pub fn zeros_authenticated(&self, n: usize) -> Vec<AuthenticatedScalarResult> {
494        let val = self.zero_authenticated();
495        (0..n).map(|_| val.clone()).collect_vec()
496    }
497
498    /// Get the hardcoded one wire as a raw `ScalarResult`
499    pub fn one(&self) -> ScalarResult {
500        ResultHandle::new(self.inner.one(), self.clone())
501    }
502
503    /// Get the hardcoded shared one wire as an `MpcScalarResult`
504    fn one_shared(&self) -> MpcScalarResult {
505        MpcScalarResult::new_shared(self.one())
506    }
507
508    /// Get the hardcoded one wire as an `AuthenticatedScalarResult`
509    ///
510    /// Party 0 holds the value zero and party 1 holds the value one
511    pub fn one_authenticated(&self) -> AuthenticatedScalarResult {
512        if self.party_id() == PARTY0 {
513            let zero_value = self.zero();
514            let share_value = self.zero_shared();
515            let mac_value = self.zero_shared();
516
517            AuthenticatedScalarResult {
518                share: share_value,
519                mac: mac_value,
520                public_modifier: zero_value,
521            }
522        } else {
523            let zero_value = self.zero();
524            let share_value = self.one_shared();
525            let mac_value = self.borrow_mac_key().clone();
526
527            AuthenticatedScalarResult {
528                share: share_value,
529                mac: mac_value,
530                public_modifier: zero_value,
531            }
532        }
533    }
534
535    /// Get a batch of references to the one wire as an `AuthenticatedScalarResult`
536    pub fn ones_authenticated(&self, n: usize) -> Vec<AuthenticatedScalarResult> {
537        let val = self.one_authenticated();
538        (0..n).map(|_| val.clone()).collect_vec()
539    }
540
541    /// Get the hardcoded curve identity wire as a raw `StarkPoint`
542    pub fn curve_identity(&self) -> ResultHandle<StarkPoint> {
543        ResultHandle::new(self.inner.curve_identity(), self.clone())
544    }
545
546    /// Get the hardcoded shared curve identity wire as an `MpcStarkPointResult`
547    fn curve_identity_shared(&self) -> MpcStarkPointResult {
548        MpcStarkPointResult::new_shared(self.curve_identity())
549    }
550
551    /// Get the hardcoded curve identity wire as an `AuthenticatedStarkPointResult`
552    ///
553    /// Both parties hold the identity point directly in this case
554    pub fn curve_identity_authenticated(&self) -> AuthenticatedStarkPointResult {
555        let identity_val = self.curve_identity();
556        let share_value = self.curve_identity_shared();
557        let mac_value = self.curve_identity_shared();
558
559        AuthenticatedStarkPointResult {
560            share: share_value,
561            mac: mac_value,
562            public_modifier: identity_val,
563        }
564    }
565
566    // -------------------
567    // | Wire Allocation |
568    // -------------------
569
570    /// Allocate a shared value in the fabric
571    fn allocate_shared_value<T: From<ResultValue>>(
572        &self,
573        my_share: ResultValue,
574        their_share: ResultValue,
575    ) -> ResultHandle<T> {
576        let id = self.inner.allocate_shared_value(my_share, their_share);
577        ResultHandle::new(id, self.clone())
578    }
579
580    /// Share a `Scalar` value with the counterparty
581    pub fn share_scalar<T: Into<Scalar>>(
582        &self,
583        val: T,
584        sender: PartyId,
585    ) -> AuthenticatedScalarResult {
586        let scalar: ScalarResult = if self.party_id() == sender {
587            let scalar_val = val.into();
588            let mut rng = thread_rng();
589            let random = Scalar::random(&mut rng);
590
591            let (my_share, their_share) = (scalar_val - random, random);
592            self.allocate_shared_value(
593                ResultValue::Scalar(my_share),
594                ResultValue::Scalar(their_share),
595            )
596        } else {
597            self.receive_value()
598        };
599
600        AuthenticatedScalarResult::new_shared(scalar)
601    }
602
603    /// Share a batch of `Scalar` values with the counterparty
604    pub fn batch_share_scalar<T: Into<Scalar>>(
605        &self,
606        vals: Vec<T>,
607        sender: PartyId,
608    ) -> Vec<AuthenticatedScalarResult> {
609        let n = vals.len();
610        let shares: BatchScalarResult = if self.party_id() == sender {
611            let vals = vals.into_iter().map(|val| val.into()).collect_vec();
612            let mut rng = thread_rng();
613
614            let peer_shares = (0..vals.len())
615                .map(|_| Scalar::random(&mut rng))
616                .collect_vec();
617            let my_shares = vals
618                .iter()
619                .zip(peer_shares.iter())
620                .map(|(val, share)| val - share)
621                .collect_vec();
622
623            self.allocate_shared_value(
624                ResultValue::ScalarBatch(my_shares),
625                ResultValue::ScalarBatch(peer_shares),
626            )
627        } else {
628            self.receive_value()
629        };
630
631        AuthenticatedScalarResult::new_shared_from_batch_result(shares, n)
632    }
633
634    /// Share a `StarkPoint` value with the counterparty
635    pub fn share_point(&self, val: StarkPoint, sender: PartyId) -> AuthenticatedStarkPointResult {
636        let point: StarkPointResult = if self.party_id() == sender {
637            // As mentioned in https://eprint.iacr.org/2009/226.pdf
638            // it is okay to sample a random point by sampling a random `Scalar` and multiplying
639            // by the generator in the case that the discrete log of the output may be leaked with
640            // respect to the generator. Leaking the discrete log (i.e. the random `Scalar`) is okay
641            // when it is used to generate secret shares
642            let mut rng = thread_rng();
643            let random = Scalar::random(&mut rng);
644            let random_point = random * StarkPoint::generator();
645
646            let (my_share, their_share) = (val - random_point, random_point);
647            self.allocate_shared_value(
648                ResultValue::Point(my_share),
649                ResultValue::Point(their_share),
650            )
651        } else {
652            self.receive_value()
653        };
654
655        AuthenticatedStarkPointResult::new_shared(point)
656    }
657
658    /// Share a batch of `StarkPoint`s with the counterparty
659    pub fn batch_share_point(
660        &self,
661        vals: Vec<StarkPoint>,
662        sender: PartyId,
663    ) -> Vec<AuthenticatedStarkPointResult> {
664        let n = vals.len();
665        let shares: BatchStarkPointResult = if self.party_id() == sender {
666            let mut rng = thread_rng();
667            let generator = StarkPoint::generator();
668            let peer_shares = (0..vals.len())
669                .map(|_| {
670                    let discrete_log = Scalar::random(&mut rng);
671                    discrete_log * generator
672                })
673                .collect_vec();
674            let my_shares = vals
675                .iter()
676                .zip(peer_shares.iter())
677                .map(|(val, share)| val - share)
678                .collect_vec();
679
680            self.allocate_shared_value(
681                ResultValue::PointBatch(my_shares),
682                ResultValue::PointBatch(peer_shares),
683            )
684        } else {
685            self.receive_value()
686        };
687
688        AuthenticatedStarkPointResult::new_shared_from_batch_result(shares, n)
689    }
690
691    /// Allocate a public value in the fabric
692    pub fn allocate_scalar<T: Into<Scalar>>(&self, value: T) -> ResultHandle<Scalar> {
693        let id = self.inner.allocate_value(ResultValue::Scalar(value.into()));
694        ResultHandle::new(id, self.clone())
695    }
696
697    /// Allocate a batch of scalars in the fabric
698    pub fn allocate_scalars<T: Into<Scalar>>(&self, values: Vec<T>) -> Vec<ResultHandle<Scalar>> {
699        values
700            .into_iter()
701            .map(|value| self.allocate_scalar(value))
702            .collect_vec()
703    }
704
705    /// Allocate a scalar as a secret share of an already shared value
706    pub fn allocate_preshared_scalar<T: Into<Scalar>>(
707        &self,
708        value: T,
709    ) -> AuthenticatedScalarResult {
710        let allocated = self.allocate_scalar(value);
711        AuthenticatedScalarResult::new_shared(allocated)
712    }
713
714    /// Allocate a batch of scalars as secret shares of already shared values
715    pub fn batch_allocate_preshared_scalar<T: Into<Scalar>>(
716        &self,
717        values: Vec<T>,
718    ) -> Vec<AuthenticatedScalarResult> {
719        let values = self.allocate_scalars(values);
720        AuthenticatedScalarResult::new_shared_batch(&values)
721    }
722
723    /// Allocate a public curve point in the fabric
724    pub fn allocate_point(&self, value: StarkPoint) -> ResultHandle<StarkPoint> {
725        let id = self.inner.allocate_value(ResultValue::Point(value));
726        ResultHandle::new(id, self.clone())
727    }
728
729    /// Allocate a batch of points in the fabric
730    pub fn allocate_points(&self, values: Vec<StarkPoint>) -> Vec<ResultHandle<StarkPoint>> {
731        values
732            .into_iter()
733            .map(|value| self.allocate_point(value))
734            .collect_vec()
735    }
736
737    /// Send a value to the peer, placing the identity in the local result buffer at the send ID
738    pub fn send_value<T: From<ResultValue> + Into<NetworkPayload>>(
739        &self,
740        value: ResultHandle<T>,
741    ) -> ResultHandle<T> {
742        self.new_network_op(vec![value.id], |mut args| args.remove(0).into())
743    }
744
745    /// Send a batch of values to the counterparty
746    pub fn send_values<T>(&self, values: &[ResultHandle<T>]) -> ResultHandle<Vec<T>>
747    where
748        T: From<ResultValue>,
749        Vec<T>: Into<NetworkPayload> + From<ResultValue>,
750    {
751        let ids = values.iter().map(|v| v.id).collect_vec();
752        self.new_network_op(ids, |args| {
753            let payload: Vec<T> = args.into_iter().map(|val| val.into()).collect();
754            payload.into()
755        })
756    }
757
758    /// Receive a value from the peer
759    pub fn receive_value<T: From<ResultValue>>(&self) -> ResultHandle<T> {
760        let id = self.inner.receive_value();
761        ResultHandle::new(id, self.clone())
762    }
763
764    /// Exchange a value with the peer, i.e. send then receive or receive then send
765    /// based on the party ID
766    ///
767    /// Returns a handle to the received value, which will be different for different parties
768    pub fn exchange_value<T: From<ResultValue> + Into<NetworkPayload>>(
769        &self,
770        value: ResultHandle<T>,
771    ) -> ResultHandle<T> {
772        if self.party_id() == PARTY0 {
773            // Party 0 sends first then receives
774            self.send_value(value);
775            self.receive_value()
776        } else {
777            // Party 1 receives first then sends
778            let handle = self.receive_value();
779            self.send_value(value);
780            handle
781        }
782    }
783
784    /// Exchange a batch of values with the peer, i.e. send then receive or receive then send
785    /// based on party ID
786    pub fn exchange_values<T>(&self, values: &[ResultHandle<T>]) -> ResultHandle<Vec<T>>
787    where
788        T: From<ResultValue>,
789        Vec<T>: From<ResultValue> + Into<NetworkPayload>,
790    {
791        if self.party_id() == PARTY0 {
792            self.send_values(values);
793            self.receive_value()
794        } else {
795            let handle = self.receive_value();
796            self.send_values(values);
797            handle
798        }
799    }
800
801    /// Share a public value with the counterparty
802    pub fn share_plaintext<T>(&self, value: T, sender: PartyId) -> ResultHandle<T>
803    where
804        T: 'static + From<ResultValue> + Into<NetworkPayload> + Send + Sync,
805    {
806        if self.party_id() == sender {
807            self.new_network_op(vec![], move |_args| value.into())
808        } else {
809            self.receive_value()
810        }
811    }
812
813    /// Share a batch of public values with the counterparty
814    pub fn batch_share_plaintext<T>(&self, values: Vec<T>, sender: PartyId) -> ResultHandle<Vec<T>>
815    where
816        T: 'static + From<ResultValue> + Send + Sync,
817        Vec<T>: Into<NetworkPayload> + From<ResultValue>,
818    {
819        self.share_plaintext(values, sender)
820    }
821
822    // -------------------
823    // | Gate Definition |
824    // -------------------
825
826    /// Construct a new gate operation in the fabric, i.e. one that can be evaluated immediate given
827    /// its inputs
828    pub fn new_gate_op<F, T>(&self, args: Vec<ResultId>, function: F) -> ResultHandle<T>
829    where
830        F: 'static + FnOnce(Vec<ResultValue>) -> ResultValue + Send + Sync,
831        T: From<ResultValue>,
832    {
833        let function = Box::new(function);
834        let id = self.inner.new_op(
835            args,
836            1, /* output_arity */
837            OperationType::Gate { function },
838        )[0];
839        ResultHandle::new(id, self.clone())
840    }
841
842    /// Construct a new batch gate operation in the fabric, i.e. one that can be evaluated to return
843    /// an array of results
844    ///
845    /// The array must be sized so that the fabric knows how many results to allocate buffer space for
846    /// ahead of execution
847    pub fn new_batch_gate_op<F, T>(
848        &self,
849        args: Vec<ResultId>,
850        output_arity: usize,
851        function: F,
852    ) -> Vec<ResultHandle<T>>
853    where
854        F: 'static + FnOnce(Vec<ResultValue>) -> Vec<ResultValue> + Send + Sync,
855        T: From<ResultValue>,
856    {
857        let function = Box::new(function);
858        let ids = self
859            .inner
860            .new_op(args, output_arity, OperationType::GateBatch { function });
861        ids.into_iter()
862            .map(|id| ResultHandle::new(id, self.clone()))
863            .collect_vec()
864    }
865
866    /// Construct a new network operation in the fabric, i.e. one that requires a value to be sent
867    /// over the channel
868    pub fn new_network_op<F, T>(&self, args: Vec<ResultId>, function: F) -> ResultHandle<T>
869    where
870        F: 'static + FnOnce(Vec<ResultValue>) -> NetworkPayload + Send + Sync,
871        T: From<ResultValue>,
872    {
873        let function = Box::new(function);
874        let id = self.inner.new_op(
875            args,
876            1, /* output_arity */
877            OperationType::Network { function },
878        )[0];
879        ResultHandle::new(id, self.clone())
880    }
881
882    // -----------------
883    // | Beaver Source |
884    // -----------------
885
886    /// Sample the next beaver triplet from the beaver source
887    pub fn next_beaver_triple(&self) -> (MpcScalarResult, MpcScalarResult, MpcScalarResult) {
888        // Sample the triple and allocate it in the fabric, the counterparty will do the same
889        let (a, b, c) = self
890            .inner
891            .beaver_source
892            .lock()
893            .expect("beaver source poisoned")
894            .next_triplet();
895
896        let a_val = self.allocate_scalar(a);
897        let b_val = self.allocate_scalar(b);
898        let c_val = self.allocate_scalar(c);
899
900        (
901            MpcScalarResult::new_shared(a_val),
902            MpcScalarResult::new_shared(b_val),
903            MpcScalarResult::new_shared(c_val),
904        )
905    }
906
907    /// Sample a batch of beaver triples
908    pub fn next_beaver_triple_batch(
909        &self,
910        n: usize,
911    ) -> (
912        Vec<MpcScalarResult>,
913        Vec<MpcScalarResult>,
914        Vec<MpcScalarResult>,
915    ) {
916        let (a_vals, b_vals, c_vals) = self
917            .inner
918            .beaver_source
919            .lock()
920            .expect("beaver source poisoned")
921            .next_triplet_batch(n);
922
923        let a_vals = self
924            .allocate_scalars(a_vals)
925            .into_iter()
926            .map(MpcScalarResult::new_shared)
927            .collect_vec();
928        let b_vals = self
929            .allocate_scalars(b_vals)
930            .into_iter()
931            .map(MpcScalarResult::new_shared)
932            .collect_vec();
933        let c_vals = self
934            .allocate_scalars(c_vals)
935            .into_iter()
936            .map(MpcScalarResult::new_shared)
937            .collect_vec();
938
939        (a_vals, b_vals, c_vals)
940    }
941
942    /// Sample the next beaver triplet with MACs from the beaver source
943    ///
944    /// TODO: Authenticate these values either here or in the pre-processing phase as per
945    /// the SPDZ paper
946    pub fn next_authenticated_triple(
947        &self,
948    ) -> (
949        AuthenticatedScalarResult,
950        AuthenticatedScalarResult,
951        AuthenticatedScalarResult,
952    ) {
953        let (a, b, c) = self
954            .inner
955            .beaver_source
956            .lock()
957            .expect("beaver source poisoned")
958            .next_triplet();
959
960        let a_val = self.allocate_scalar(a);
961        let b_val = self.allocate_scalar(b);
962        let c_val = self.allocate_scalar(c);
963
964        (
965            AuthenticatedScalarResult::new_shared(a_val),
966            AuthenticatedScalarResult::new_shared(b_val),
967            AuthenticatedScalarResult::new_shared(c_val),
968        )
969    }
970
971    /// Sample the next batch of beaver triples as `AuthenticatedScalar`s
972    pub fn next_authenticated_triple_batch(
973        &self,
974        n: usize,
975    ) -> (
976        Vec<AuthenticatedScalarResult>,
977        Vec<AuthenticatedScalarResult>,
978        Vec<AuthenticatedScalarResult>,
979    ) {
980        let (a_vals, b_vals, c_vals) = self
981            .inner
982            .beaver_source
983            .lock()
984            .expect("beaver source poisoned")
985            .next_triplet_batch(n);
986
987        let a_allocated = self.allocate_scalars(a_vals);
988        let b_allocated = self.allocate_scalars(b_vals);
989        let c_allocated = self.allocate_scalars(c_vals);
990
991        (
992            AuthenticatedScalarResult::new_shared_batch(&a_allocated),
993            AuthenticatedScalarResult::new_shared_batch(&b_allocated),
994            AuthenticatedScalarResult::new_shared_batch(&c_allocated),
995        )
996    }
997
998    /// Sample a batch of random shared values from the beaver source
999    pub fn random_shared_scalars(&self, n: usize) -> Vec<ScalarResult> {
1000        let values_raw = self
1001            .inner
1002            .beaver_source
1003            .lock()
1004            .expect("beaver source poisoned")
1005            .next_shared_value_batch(n);
1006
1007        // Wrap the values in a result handle
1008        values_raw
1009            .into_iter()
1010            .map(|value| self.allocate_scalar(value))
1011            .collect_vec()
1012    }
1013
1014    /// Sample a batch of random shared values from the beaver source and allocate them as `AuthenticatedScalars`
1015    pub fn random_shared_scalars_authenticated(&self, n: usize) -> Vec<AuthenticatedScalarResult> {
1016        let values_raw = self
1017            .inner
1018            .beaver_source
1019            .lock()
1020            .expect("beaver source poisoned")
1021            .next_shared_value_batch(n);
1022
1023        // Wrap the values in an authenticated wrapper
1024        values_raw
1025            .into_iter()
1026            .map(|value| {
1027                let value = self.allocate_scalar(value);
1028                AuthenticatedScalarResult::new_shared(value)
1029            })
1030            .collect_vec()
1031    }
1032
1033    /// Sample a pair of values that are multiplicative inverses of one another
1034    pub fn random_inverse_pair(&self) -> (AuthenticatedScalarResult, AuthenticatedScalarResult) {
1035        let (l, r) = self
1036            .inner
1037            .beaver_source
1038            .lock()
1039            .unwrap()
1040            .next_shared_inverse_pair();
1041        (
1042            AuthenticatedScalarResult::new_shared(self.allocate_scalar(l)),
1043            AuthenticatedScalarResult::new_shared(self.allocate_scalar(r)),
1044        )
1045    }
1046
1047    /// Sample a batch of values that are multiplicative inverses of one another
1048    pub fn random_inverse_pairs(
1049        &self,
1050        n: usize,
1051    ) -> (
1052        Vec<AuthenticatedScalarResult>,
1053        Vec<AuthenticatedScalarResult>,
1054    ) {
1055        let (left, right) = self
1056            .inner
1057            .beaver_source
1058            .lock()
1059            .unwrap()
1060            .next_shared_inverse_pair_batch(n);
1061
1062        let left_right = left.into_iter().chain(right.into_iter()).collect_vec();
1063        let allocated_left_right = self.allocate_scalars(left_right);
1064        let authenticated_left_right =
1065            AuthenticatedScalarResult::new_shared_batch(&allocated_left_right);
1066
1067        // Split left and right
1068        let (left, right) = authenticated_left_right.split_at(n);
1069        (left.to_vec(), right.to_vec())
1070    }
1071
1072    /// Sample a random shared bit from the beaver source
1073    pub fn random_shared_bit(&self) -> AuthenticatedScalarResult {
1074        let bit = self
1075            .inner
1076            .beaver_source
1077            .lock()
1078            .expect("beaver source poisoned")
1079            .next_shared_bit();
1080
1081        let bit = self.allocate_scalar(bit);
1082        AuthenticatedScalarResult::new_shared(bit)
1083    }
1084
1085    /// Sample a batch of random shared bits from the beaver source
1086    pub fn random_shared_bits(&self, n: usize) -> Vec<AuthenticatedScalarResult> {
1087        let bits = self
1088            .inner
1089            .beaver_source
1090            .lock()
1091            .expect("beaver source poisoned")
1092            .next_shared_bit_batch(n);
1093
1094        let bits = self.allocate_scalars(bits);
1095        AuthenticatedScalarResult::new_shared_batch(&bits)
1096    }
1097}