ark_mpc/
fabric.rs

1//! Defines an MPC fabric for the protocol
2//!
3//! The fabric essentially acts as a dependency injection layer. That is, the
4//! MpcFabric creates and manages dependencies needed to allocate network
5//! values. This provides a cleaner interface for consumers of the library; i.e.
6//! clients do not have to hold onto references of the network layer or the
7//! beaver sources to allocate values.
8
9mod executor;
10mod network_sender;
11mod result;
12
13use ark_ec::CurveGroup;
14#[cfg(feature = "benchmarks")]
15pub use executor::{Executor, ExecutorMessage};
16#[cfg(not(feature = "benchmarks"))]
17use executor::{Executor, ExecutorMessage};
18use rand::thread_rng;
19pub use result::{ResultHandle, ResultId, ResultValue};
20
21use futures::executor::block_on;
22use tracing::log;
23
24use crossbeam::queue::SegQueue;
25use kanal::Sender as KanalSender;
26use std::{
27    fmt::{Debug, Formatter, Result as FmtResult},
28    sync::{
29        atomic::{AtomicUsize, Ordering},
30        Arc, Mutex,
31    },
32};
33use tokio::sync::broadcast::{self, Sender as BroadcastSender};
34
35use itertools::Itertools;
36
37use crate::{
38    algebra::{
39        AuthenticatedPointResult, AuthenticatedScalarResult, BatchCurvePointResult,
40        BatchScalarResult, CurvePoint, CurvePointResult, MpcPointResult, MpcScalarResult, Scalar,
41        ScalarResult,
42    },
43    beaver::SharedValueSource,
44    network::{MpcNetwork, NetworkOutbound, NetworkPayload, PartyId},
45    PARTY0,
46};
47
48use self::{
49    network_sender::NetworkSender,
50    result::{OpResult, ResultWaiter},
51};
52
53/// The result id that is hardcoded to zero
54const RESULT_ZERO: ResultId = 0;
55/// The result id that is hardcoded to one
56const RESULT_ONE: ResultId = 1;
57/// The result id that is hardcoded to the curve identity point
58const RESULT_IDENTITY: ResultId = 2;
59
60/// The number of constant results allocated in the fabric, i.e. those defined
61/// above
62const N_CONSTANT_RESULTS: usize = 3;
63
64/// The default size hint to give the fabric for buffer pre-allocation
65const DEFAULT_SIZE_HINT: usize = 10_000;
66
67/// A type alias for the identifier used for a gate
68pub type OperationId = usize;
69
70/// An operation within the network, describes the arguments and function to
71/// evaluate once the arguments are ready
72///
73/// `N` represents the number of results that this operation outputs
74#[derive(Clone)]
75pub struct Operation<C: CurveGroup> {
76    /// Identifier of the result that this operation emits
77    id: OperationId,
78    /// The result ID of the first result in the outputs
79    result_id: ResultId,
80    /// The number of outputs this operation produces
81    output_arity: usize,
82    /// The number of arguments that are still in-flight for this operation
83    inflight_args: usize,
84    /// The IDs of the inputs to this operation
85    args: Vec<ResultId>,
86    /// The type of the operation
87    op_type: OperationType<C>,
88}
89
90impl<C: CurveGroup> Operation<C> {
91    /// Get the result IDs for an operation
92    pub fn result_ids(&self) -> Vec<ResultId> {
93        (self.result_id..self.result_id + self.output_arity).collect_vec()
94    }
95}
96
97impl<C: CurveGroup> Debug for Operation<C> {
98    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
99        write!(f, "Operation {}", self.id)
100    }
101}
102
103/// Defines the different types of operations available in the computation graph
104pub enum OperationType<C: CurveGroup> {
105    /// A gate operation; may be evaluated locally given its ready inputs
106    Gate {
107        /// The function to apply to the inputs
108        function: Box<dyn FnOnce(Vec<ResultValue<C>>) -> ResultValue<C> + Send + Sync>,
109    },
110    /// A gate operation that has output arity greater than one
111    ///
112    /// We separate this out to avoid vector allocation for result values of
113    /// arity one
114    GateBatch {
115        /// The function to apply to the inputs
116        #[allow(clippy::type_complexity)]
117        function: Box<dyn FnOnce(Vec<ResultValue<C>>) -> Vec<ResultValue<C>> + Send + Sync>,
118    },
119    /// A network operation, requires that a value be sent over the network
120    Network {
121        /// The function to apply to the inputs to derive a Network payload
122        function: Box<dyn FnOnce(Vec<ResultValue<C>>) -> NetworkPayload<C> + Send + Sync>,
123    },
124}
125
126/// A clone implementation, never concretely called but used as a Marker type to
127/// allow pre-allocating buffer space for `Operation`s
128impl<C: CurveGroup> Clone for OperationType<C> {
129    fn clone(&self) -> Self {
130        panic!("cannot clone `OperationType`")
131    }
132}
133
134impl<C: CurveGroup> Debug for OperationType<C> {
135    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
136        match self {
137            OperationType::Gate { .. } => write!(f, "Gate"),
138            OperationType::GateBatch { .. } => write!(f, "GateBatch"),
139            OperationType::Network { .. } => write!(f, "Network"),
140        }
141    }
142}
143
144/// A fabric for the MPC protocol, defines a dependency injection layer that
145/// dynamically schedules circuit gate evaluations onto the network to be
146/// executed
147///
148/// The fabric does not block on gate evaluations, but instead returns a handle
149/// to a future result that may be polled to obtain the materialized result.
150/// This allows the application layer to continue using the fabric, scheduling
151/// more gates to be evaluated and maximally exploiting gate-level parallelism
152/// within the circuit
153#[derive(Clone)]
154pub struct MpcFabric<C: CurveGroup> {
155    /// The inner fabric
156    #[cfg(not(feature = "benchmarks"))]
157    inner: Arc<FabricInner<C>>,
158    /// The inner fabric, accessible publicly for benchmark mocking
159    #[cfg(feature = "benchmarks")]
160    pub inner: Arc<FabricInner<C>>,
161    /// The local party's share of the global MAC key
162    ///
163    /// The parties collectively hold an additive sharing of the global key
164    ///
165    /// We wrap in a reference counting structure to avoid recursive type issues
166    #[cfg(not(feature = "benchmarks"))]
167    mac_key: Option<Arc<MpcScalarResult<C>>>,
168    /// The MAC key, accessible publicly for benchmark mocking
169    #[cfg(feature = "benchmarks")]
170    pub mac_key: Option<Arc<MpcScalarResult<C>>>,
171    /// The channel on which shutdown messages are sent to blocking workers
172    #[cfg(not(feature = "benchmarks"))]
173    shutdown: BroadcastSender<()>,
174    /// The shutdown channel, made publicly available for benchmark mocking
175    #[cfg(feature = "benchmarks")]
176    pub shutdown: BroadcastSender<()>,
177}
178
179impl<C: CurveGroup> Debug for MpcFabric<C> {
180    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
181        write!(f, "MpcFabric")
182    }
183}
184
185/// The inner component of the fabric, allows the constructor to allocate
186/// executor and network sender objects at the same level as the fabric
187#[derive(Clone)]
188pub struct FabricInner<C: CurveGroup> {
189    /// The ID of the local party in the MPC execution
190    party_id: u64,
191    /// The next identifier to assign to a result
192    next_result_id: Arc<AtomicUsize>,
193    /// The next identifier to assign to an operation
194    next_op_id: Arc<AtomicUsize>,
195    /// A sender to the executor
196    execution_queue: Arc<SegQueue<ExecutorMessage<C>>>,
197    /// The underlying queue to the network
198    outbound_queue: KanalSender<NetworkOutbound<C>>,
199    /// The underlying shared randomness source
200    beaver_source: Arc<Mutex<Box<dyn SharedValueSource<C>>>>,
201}
202
203impl<C: CurveGroup> Debug for FabricInner<C> {
204    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
205        write!(f, "FabricInner")
206    }
207}
208
209impl<C: CurveGroup> FabricInner<C> {
210    /// Constructor
211    pub fn new<S: 'static + SharedValueSource<C>>(
212        party_id: u64,
213        execution_queue: Arc<SegQueue<ExecutorMessage<C>>>,
214        outbound_queue: KanalSender<NetworkOutbound<C>>,
215        beaver_source: S,
216    ) -> Self {
217        // Allocate a zero and a one as well as the curve identity in the fabric to
218        // begin, for convenience
219        let zero = ResultValue::Scalar(Scalar::zero());
220        let one = ResultValue::Scalar(Scalar::one());
221        let identity = ResultValue::Point(CurvePoint::identity());
222
223        for initial_result in vec![
224            OpResult {
225                id: RESULT_ZERO,
226                value: zero,
227            },
228            OpResult {
229                id: RESULT_ONE,
230                value: one,
231            },
232            OpResult {
233                id: RESULT_IDENTITY,
234                value: identity,
235            },
236        ]
237        .into_iter()
238        {
239            execution_queue.push(ExecutorMessage::Result(initial_result));
240        }
241
242        let next_result_id = Arc::new(AtomicUsize::new(N_CONSTANT_RESULTS));
243        let next_op_id = Arc::new(AtomicUsize::new(0));
244
245        Self {
246            party_id,
247            next_result_id,
248            next_op_id,
249            execution_queue,
250            outbound_queue,
251            beaver_source: Arc::new(Mutex::new(Box::new(beaver_source))),
252        }
253    }
254
255    /// Register a waiter on a result    
256    pub(crate) fn register_waiter(&self, waiter: ResultWaiter<C>) {
257        self.execution_queue
258            .push(ExecutorMessage::NewWaiter(waiter));
259    }
260
261    /// Shutdown the inner fabric, by sending a shutdown message to the executor
262    pub(crate) fn shutdown(&self) {
263        self.execution_queue.push(ExecutorMessage::Shutdown)
264    }
265
266    /// -----------
267    /// | Getters |
268    /// -----------
269
270    /// Increment the operation counter and return the existing value
271    fn new_result_id(&self) -> ResultId {
272        self.next_result_id.fetch_add(1, Ordering::Relaxed)
273    }
274
275    /// Increment the operation counter and return the existing value
276    fn new_op_id(&self) -> OperationId {
277        self.next_op_id.fetch_add(1, Ordering::Relaxed)
278    }
279
280    /// Get the hardcoded zero value in the fabric
281    pub(crate) fn zero(&self) -> ResultId {
282        RESULT_ZERO
283    }
284
285    /// Get the hardcoded one value in the fabric
286    pub(crate) fn one(&self) -> ResultId {
287        RESULT_ONE
288    }
289
290    /// Get the hardcoded curve identity value in the fabric
291    pub(crate) fn curve_identity(&self) -> ResultId {
292        RESULT_IDENTITY
293    }
294
295    // ------------------------
296    // | Low Level Allocation |
297    // ------------------------
298
299    /// Allocate a new plaintext value in the fabric
300    pub(crate) fn allocate_value(&self, value: ResultValue<C>) -> ResultId {
301        // Forward the result to the executor
302        let id = self.new_result_id();
303        self.execution_queue
304            .push(ExecutorMessage::Result(OpResult { id, value }));
305
306        id
307    }
308
309    /// Allocate a secret shared value in the network
310    pub(crate) fn allocate_shared_value(
311        &self,
312        my_share: ResultValue<C>,
313        their_share: ResultValue<C>,
314    ) -> ResultId {
315        // Forward the local party's share to the executor
316        let id = self.new_result_id();
317        self.execution_queue.push(ExecutorMessage::Result(OpResult {
318            id,
319            value: my_share,
320        }));
321
322        // Send the counterparty their share
323        if let Err(e) = self.outbound_queue.send(NetworkOutbound {
324            result_id: id,
325            payload: their_share.into(),
326        }) {
327            log::error!("error sending share to counterparty: {e:?}");
328        }
329
330        id
331    }
332
333    /// Receive a value from a network operation initiated by a peer
334    ///
335    /// The peer will already send the value with the corresponding ID, so all
336    /// that is needed is to allocate a slot in the result buffer for the
337    /// receipt
338    pub(crate) fn receive_value(&self) -> ResultId {
339        self.new_result_id()
340    }
341
342    // --------------
343    // | Operations |
344    // --------------
345
346    /// Allocate a new in-flight gate operation in the fabric
347    pub(crate) fn new_op(
348        &self,
349        args: Vec<ResultId>,
350        output_arity: usize,
351        op_type: OperationType<C>,
352    ) -> Vec<ResultId> {
353        if matches!(op_type, OperationType::Gate { .. }) {
354            assert_eq!(output_arity, 1, "gate operations must have arity 1");
355        }
356
357        // Allocate IDs for the results
358        let ids = (0..output_arity)
359            .map(|_| self.new_result_id())
360            .collect_vec();
361
362        // Build the operation
363        let op = Operation {
364            id: self.new_op_id(),
365            result_id: ids[0],
366            output_arity,
367            args,
368            inflight_args: 0,
369            op_type,
370        };
371
372        // Forward the op to the executor
373        self.execution_queue.push(ExecutorMessage::Op(op));
374        ids
375    }
376}
377
378impl<C: CurveGroup> MpcFabric<C> {
379    /// Constructor
380    pub fn new<N: 'static + MpcNetwork<C>, S: 'static + SharedValueSource<C>>(
381        network: N,
382        beaver_source: S,
383    ) -> Self {
384        Self::new_with_size_hint(DEFAULT_SIZE_HINT, network, beaver_source)
385    }
386
387    /// Constructor that takes an additional size hint, indicating how much
388    /// buffer space the fabric should allocate for results. The size is
389    /// given in number of gates
390    pub fn new_with_size_hint<N: 'static + MpcNetwork<C>, S: 'static + SharedValueSource<C>>(
391        size_hint: usize,
392        network: N,
393        beaver_source: S,
394    ) -> Self {
395        // Build communication primitives
396        let execution_queue = Arc::new(SegQueue::new());
397
398        let (outbound_sender, outbound_receiver) = kanal::unbounded_async();
399        let (shutdown_sender, shutdown_receiver) = broadcast::channel(1 /* capacity */);
400
401        // Build a fabric
402        let fabric = FabricInner::new(
403            network.party_id(),
404            execution_queue.clone(),
405            outbound_sender.to_sync(),
406            beaver_source,
407        );
408
409        // Start a network sender and operator executor
410        let network_sender = NetworkSender::new(
411            outbound_receiver,
412            execution_queue.clone(),
413            network,
414            shutdown_receiver,
415        );
416        tokio::task::spawn_blocking(move || block_on(network_sender.run()));
417
418        let executor = Executor::new(size_hint, execution_queue, fabric.clone());
419        tokio::task::spawn_blocking(move || executor.run());
420
421        // Create the fabric and fill in the MAC key after
422        let mut self_ = Self {
423            inner: Arc::new(fabric.clone()),
424            shutdown: shutdown_sender,
425            mac_key: None,
426        };
427
428        // Sample a MAC key from the pre-shared values in the beaver source
429        let mac_key_id = fabric.allocate_value(ResultValue::Scalar(
430            fabric
431                .beaver_source
432                .lock()
433                .expect("beaver source poisoned")
434                .next_shared_value(),
435        ));
436        let mac_key = MpcScalarResult::new_shared(ResultHandle::new(mac_key_id, self_.clone()));
437
438        // Set the MAC key
439        self_.mac_key.replace(Arc::new(mac_key));
440
441        self_
442    }
443
444    /// Get the party ID of the local party
445    pub fn party_id(&self) -> PartyId {
446        self.inner.party_id
447    }
448
449    /// Shutdown the fabric and the threads it has spawned
450    pub fn shutdown(self) {
451        log::debug!("shutting down fabric");
452        self.inner.shutdown();
453        self.shutdown
454            .send(())
455            .expect("error sending shutdown signal");
456    }
457
458    /// Register a waiter on a result
459    pub fn register_waiter(&self, waiter: ResultWaiter<C>) {
460        self.inner.register_waiter(waiter);
461    }
462
463    /// Immutably borrow the MAC key
464    pub(crate) fn borrow_mac_key(&self) -> &MpcScalarResult<C> {
465        // Unwrap is safe, the constructor sets the MAC key
466        self.mac_key.as_ref().unwrap()
467    }
468
469    // ------------------------
470    // | Constants Allocation |
471    // ------------------------
472
473    /// Get the hardcoded zero wire as a raw `ScalarResult`
474    pub fn zero(&self) -> ScalarResult<C> {
475        ResultHandle::new(self.inner.zero(), self.clone())
476    }
477
478    /// Get the shared zero value as an `MpcScalarResult`
479    fn zero_shared(&self) -> MpcScalarResult<C> {
480        MpcScalarResult::new_shared(self.zero())
481    }
482
483    /// Get the hardcoded zero wire as an `AuthenticatedScalarResult`
484    ///
485    /// Both parties hold the share 0 directly in this case
486    pub fn zero_authenticated(&self) -> AuthenticatedScalarResult<C> {
487        let zero_value = self.zero();
488        let share_value = self.zero_shared();
489        let mac_value = self.zero_shared();
490
491        AuthenticatedScalarResult {
492            share: share_value,
493            mac: mac_value,
494            public_modifier: zero_value,
495        }
496    }
497
498    /// Get a batch of references to the zero wire as an
499    /// `AuthenticatedScalarResult`
500    pub fn zeros_authenticated(&self, n: usize) -> Vec<AuthenticatedScalarResult<C>> {
501        let val = self.zero_authenticated();
502        (0..n).map(|_| val.clone()).collect_vec()
503    }
504
505    /// Get the hardcoded one wire as a raw `ScalarResult`
506    pub fn one(&self) -> ScalarResult<C> {
507        ResultHandle::new(self.inner.one(), self.clone())
508    }
509
510    /// Get the hardcoded shared one wire as an `MpcScalarResult`
511    fn one_shared(&self) -> MpcScalarResult<C> {
512        MpcScalarResult::new_shared(self.one())
513    }
514
515    /// Get the hardcoded one wire as an `AuthenticatedScalarResult`
516    ///
517    /// Party 0 holds the value zero and party 1 holds the value one
518    pub fn one_authenticated(&self) -> AuthenticatedScalarResult<C> {
519        if self.party_id() == PARTY0 {
520            let zero_value = self.zero();
521            let share_value = self.zero_shared();
522            let mac_value = self.zero_shared();
523
524            AuthenticatedScalarResult {
525                share: share_value,
526                mac: mac_value,
527                public_modifier: zero_value,
528            }
529        } else {
530            let zero_value = self.zero();
531            let share_value = self.one_shared();
532            let mac_value = self.borrow_mac_key().clone();
533
534            AuthenticatedScalarResult {
535                share: share_value,
536                mac: mac_value,
537                public_modifier: zero_value,
538            }
539        }
540    }
541
542    /// Get a batch of references to the one wire as an
543    /// `AuthenticatedScalarResult`
544    pub fn ones_authenticated(&self, n: usize) -> Vec<AuthenticatedScalarResult<C>> {
545        let val = self.one_authenticated();
546        (0..n).map(|_| val.clone()).collect_vec()
547    }
548
549    /// Get the hardcoded curve identity wire as a raw `CurvePointResult`
550    pub fn curve_identity(&self) -> CurvePointResult<C> {
551        ResultHandle::new(self.inner.curve_identity(), self.clone())
552    }
553
554    /// Get the hardcoded shared curve identity wire as an `MpcPointResult`
555    fn curve_identity_shared(&self) -> MpcPointResult<C> {
556        MpcPointResult::new_shared(self.curve_identity())
557    }
558
559    /// Get the hardcoded curve identity wire as an `AuthenticatedPointResult`
560    ///
561    /// Both parties hold the identity point directly in this case
562    pub fn curve_identity_authenticated(&self) -> AuthenticatedPointResult<C> {
563        let identity_val = self.curve_identity();
564        let share_value = self.curve_identity_shared();
565        let mac_value = self.curve_identity_shared();
566
567        AuthenticatedPointResult {
568            share: share_value,
569            mac: mac_value,
570            public_modifier: identity_val,
571        }
572    }
573
574    // -------------------
575    // | Wire Allocation |
576    // -------------------
577
578    /// Allocate a shared value in the fabric
579    fn allocate_shared_value<T: From<ResultValue<C>>>(
580        &self,
581        my_share: ResultValue<C>,
582        their_share: ResultValue<C>,
583    ) -> ResultHandle<C, T> {
584        let id = self.inner.allocate_shared_value(my_share, their_share);
585        ResultHandle::new(id, self.clone())
586    }
587
588    /// Share a `Scalar` value with the counterparty
589    pub fn share_scalar<T: Into<Scalar<C>>>(
590        &self,
591        val: T,
592        sender: PartyId,
593    ) -> AuthenticatedScalarResult<C> {
594        let scalar: ScalarResult<C> = if self.party_id() == sender {
595            let scalar_val = val.into();
596            let mut rng = thread_rng();
597            let random = Scalar::random(&mut rng);
598
599            let (my_share, their_share) = (scalar_val - random, random);
600            self.allocate_shared_value(
601                ResultValue::Scalar(my_share),
602                ResultValue::Scalar(their_share),
603            )
604        } else {
605            self.receive_value()
606        };
607
608        AuthenticatedScalarResult::new_shared(scalar)
609    }
610
611    /// Share a batch of `Scalar` values with the counterparty
612    pub fn batch_share_scalar<T: Into<Scalar<C>>>(
613        &self,
614        vals: Vec<T>,
615        sender: PartyId,
616    ) -> Vec<AuthenticatedScalarResult<C>> {
617        let n = vals.len();
618        let shares: BatchScalarResult<C> = if self.party_id() == sender {
619            let vals = vals.into_iter().map(|val| val.into()).collect_vec();
620            let mut rng = thread_rng();
621
622            let peer_shares = (0..vals.len())
623                .map(|_| Scalar::random(&mut rng))
624                .collect_vec();
625            let my_shares = vals
626                .iter()
627                .zip(peer_shares.iter())
628                .map(|(val, share)| val - share)
629                .collect_vec();
630
631            self.allocate_shared_value(
632                ResultValue::ScalarBatch(my_shares),
633                ResultValue::ScalarBatch(peer_shares),
634            )
635        } else {
636            self.receive_value()
637        };
638
639        AuthenticatedScalarResult::new_shared_from_batch_result(shares, n)
640    }
641
642    /// Share a `CurvePoint` value with the counterparty
643    pub fn share_point(&self, val: CurvePoint<C>, sender: PartyId) -> AuthenticatedPointResult<C> {
644        let point: CurvePointResult<C> = if self.party_id() == sender {
645            // As mentioned in https://eprint.iacr.org/2009/226.pdf
646            // it is okay to sample a random point by sampling a random `Scalar` and
647            // multiplying by the generator in the case that the discrete log of
648            // the output may be leaked with respect to the generator. Leaking
649            // the discrete log (i.e. the random `Scalar`) is okay when it is
650            // used to generate secret shares
651            let mut rng = thread_rng();
652            let random = Scalar::random(&mut rng);
653            let random_point = random * CurvePoint::generator();
654
655            let (my_share, their_share) = (val - random_point, random_point);
656            self.allocate_shared_value(
657                ResultValue::Point(my_share),
658                ResultValue::Point(their_share),
659            )
660        } else {
661            self.receive_value()
662        };
663
664        AuthenticatedPointResult::new_shared(point)
665    }
666
667    /// Share a batch of `CurvePoint`s with the counterparty
668    pub fn batch_share_point(
669        &self,
670        vals: Vec<CurvePoint<C>>,
671        sender: PartyId,
672    ) -> Vec<AuthenticatedPointResult<C>> {
673        let n = vals.len();
674        let shares: BatchCurvePointResult<C> = if self.party_id() == sender {
675            let mut rng = thread_rng();
676            let generator = CurvePoint::generator();
677            let peer_shares = (0..vals.len())
678                .map(|_| {
679                    let discrete_log = Scalar::random(&mut rng);
680                    discrete_log * generator
681                })
682                .collect_vec();
683            let my_shares = vals
684                .iter()
685                .zip(peer_shares.iter())
686                .map(|(val, share)| val - share)
687                .collect_vec();
688
689            self.allocate_shared_value(
690                ResultValue::PointBatch(my_shares),
691                ResultValue::PointBatch(peer_shares),
692            )
693        } else {
694            self.receive_value()
695        };
696
697        AuthenticatedPointResult::new_shared_from_batch_result(shares, n)
698    }
699
700    /// Allocate a public value in the fabric
701    pub fn allocate_scalar<T: Into<Scalar<C>>>(&self, value: T) -> ScalarResult<C> {
702        let id = self.inner.allocate_value(ResultValue::Scalar(value.into()));
703        ResultHandle::new(id, self.clone())
704    }
705
706    /// Allocate a batch of scalars in the fabric
707    pub fn allocate_scalars<T: Into<Scalar<C>>>(&self, values: Vec<T>) -> Vec<ScalarResult<C>> {
708        values
709            .into_iter()
710            .map(|value| self.allocate_scalar(value))
711            .collect_vec()
712    }
713
714    /// Allocate a scalar as a secret share of an already shared value
715    pub fn allocate_preshared_scalar<T: Into<Scalar<C>>>(
716        &self,
717        value: T,
718    ) -> AuthenticatedScalarResult<C> {
719        let allocated = self.allocate_scalar(value);
720        AuthenticatedScalarResult::new_shared(allocated)
721    }
722
723    /// Allocate a batch of scalars as secret shares of already shared values
724    pub fn batch_allocate_preshared_scalar<T: Into<Scalar<C>>>(
725        &self,
726        values: Vec<T>,
727    ) -> Vec<AuthenticatedScalarResult<C>> {
728        let values = self.allocate_scalars(values);
729        AuthenticatedScalarResult::new_shared_batch(&values)
730    }
731
732    /// Allocate a public curve point in the fabric
733    pub fn allocate_point(&self, value: CurvePoint<C>) -> CurvePointResult<C> {
734        let id = self.inner.allocate_value(ResultValue::Point(value));
735        ResultHandle::new(id, self.clone())
736    }
737
738    /// Allocate a batch of points in the fabric
739    pub fn allocate_points(&self, values: Vec<CurvePoint<C>>) -> Vec<CurvePointResult<C>> {
740        values
741            .into_iter()
742            .map(|value| self.allocate_point(value))
743            .collect_vec()
744    }
745
746    /// Send a value to the peer, placing the identity in the local result
747    /// buffer at the send ID
748    pub fn send_value<T: From<ResultValue<C>> + Into<NetworkPayload<C>>>(
749        &self,
750        value: ResultHandle<C, T>,
751    ) -> ResultHandle<C, T> {
752        self.new_network_op(vec![value.id], |mut args| args.remove(0).into())
753    }
754
755    /// Send a batch of values to the counterparty
756    pub fn send_values<T>(&self, values: &[ResultHandle<C, T>]) -> ResultHandle<C, Vec<T>>
757    where
758        T: From<ResultValue<C>>,
759        Vec<T>: Into<NetworkPayload<C>> + From<ResultValue<C>>,
760    {
761        let ids = values.iter().map(|v| v.id).collect_vec();
762        self.new_network_op(ids, |args| {
763            let payload: Vec<T> = args.into_iter().map(|val| val.into()).collect();
764            payload.into()
765        })
766    }
767
768    /// Receive a value from the peer
769    pub fn receive_value<T: From<ResultValue<C>>>(&self) -> ResultHandle<C, T> {
770        let id = self.inner.receive_value();
771        ResultHandle::new(id, self.clone())
772    }
773
774    /// Exchange a value with the peer, i.e. send then receive or receive then
775    /// send based on the party ID
776    ///
777    /// Returns a handle to the received value, which will be different for
778    /// different parties
779    pub fn exchange_value<T: From<ResultValue<C>> + Into<NetworkPayload<C>>>(
780        &self,
781        value: ResultHandle<C, T>,
782    ) -> ResultHandle<C, T> {
783        if self.party_id() == PARTY0 {
784            // Party 0 sends first then receives
785            self.send_value(value);
786            self.receive_value()
787        } else {
788            // Party 1 receives first then sends
789            let handle = self.receive_value();
790            self.send_value(value);
791            handle
792        }
793    }
794
795    /// Exchange a batch of values with the peer, i.e. send then receive or
796    /// receive then send based on party ID
797    pub fn exchange_values<T>(&self, values: &[ResultHandle<C, T>]) -> ResultHandle<C, Vec<T>>
798    where
799        T: From<ResultValue<C>>,
800        Vec<T>: From<ResultValue<C>> + Into<NetworkPayload<C>>,
801    {
802        if self.party_id() == PARTY0 {
803            self.send_values(values);
804            self.receive_value()
805        } else {
806            let handle = self.receive_value();
807            self.send_values(values);
808            handle
809        }
810    }
811
812    /// Share a public value with the counterparty
813    pub fn share_plaintext<T>(&self, value: T, sender: PartyId) -> ResultHandle<C, T>
814    where
815        T: 'static + From<ResultValue<C>> + Into<NetworkPayload<C>> + Send + Sync,
816    {
817        if self.party_id() == sender {
818            self.new_network_op(vec![], move |_args| value.into())
819        } else {
820            self.receive_value()
821        }
822    }
823
824    /// Share a batch of public values with the counterparty
825    pub fn batch_share_plaintext<T>(
826        &self,
827        values: Vec<T>,
828        sender: PartyId,
829    ) -> ResultHandle<C, Vec<T>>
830    where
831        T: 'static + From<ResultValue<C>> + Send + Sync,
832        Vec<T>: Into<NetworkPayload<C>> + From<ResultValue<C>>,
833    {
834        self.share_plaintext(values, sender)
835    }
836
837    // -------------------
838    // | Gate Definition |
839    // -------------------
840
841    /// Construct a new gate operation in the fabric, i.e. one that can be
842    /// evaluated immediate given its inputs
843    pub fn new_gate_op<F, T>(&self, args: Vec<ResultId>, function: F) -> ResultHandle<C, T>
844    where
845        F: 'static + FnOnce(Vec<ResultValue<C>>) -> ResultValue<C> + Send + Sync,
846        T: From<ResultValue<C>>,
847    {
848        let function = Box::new(function);
849        let id = self.inner.new_op(
850            args,
851            1, // output_arity
852            OperationType::Gate { function },
853        )[0];
854        ResultHandle::new(id, self.clone())
855    }
856
857    /// Construct a new batch gate operation in the fabric, i.e. one that can be
858    /// evaluated to return an array of results
859    ///
860    /// The array must be sized so that the fabric knows how many results to
861    /// allocate buffer space for ahead of execution
862    pub fn new_batch_gate_op<F, T>(
863        &self,
864        args: Vec<ResultId>,
865        output_arity: usize,
866        function: F,
867    ) -> Vec<ResultHandle<C, T>>
868    where
869        F: 'static + FnOnce(Vec<ResultValue<C>>) -> Vec<ResultValue<C>> + Send + Sync,
870        T: From<ResultValue<C>>,
871    {
872        let function = Box::new(function);
873        let ids = self
874            .inner
875            .new_op(args, output_arity, OperationType::GateBatch { function });
876        ids.into_iter()
877            .map(|id| ResultHandle::new(id, self.clone()))
878            .collect_vec()
879    }
880
881    /// Construct a new network operation in the fabric, i.e. one that requires
882    /// a value to be sent over the channel
883    pub fn new_network_op<F, T>(&self, args: Vec<ResultId>, function: F) -> ResultHandle<C, T>
884    where
885        F: 'static + FnOnce(Vec<ResultValue<C>>) -> NetworkPayload<C> + Send + Sync,
886        T: From<ResultValue<C>>,
887    {
888        let function = Box::new(function);
889        let id = self.inner.new_op(
890            args,
891            1, // output_arity
892            OperationType::Network { function },
893        )[0];
894        ResultHandle::new(id, self.clone())
895    }
896
897    // -----------------
898    // | Beaver Source |
899    // -----------------
900
901    /// Sample the next beaver triplet from the beaver source
902    pub fn next_beaver_triple(
903        &self,
904    ) -> (MpcScalarResult<C>, MpcScalarResult<C>, MpcScalarResult<C>) {
905        // Sample the triple and allocate it in the fabric, the counterparty will do the
906        // same
907        let (a, b, c) = self
908            .inner
909            .beaver_source
910            .lock()
911            .expect("beaver source poisoned")
912            .next_triplet();
913
914        let a_val = self.allocate_scalar(a);
915        let b_val = self.allocate_scalar(b);
916        let c_val = self.allocate_scalar(c);
917
918        (
919            MpcScalarResult::new_shared(a_val),
920            MpcScalarResult::new_shared(b_val),
921            MpcScalarResult::new_shared(c_val),
922        )
923    }
924
925    /// Sample a batch of beaver triples
926    #[allow(clippy::type_complexity)]
927    pub fn next_beaver_triple_batch(
928        &self,
929        n: usize,
930    ) -> (
931        Vec<MpcScalarResult<C>>,
932        Vec<MpcScalarResult<C>>,
933        Vec<MpcScalarResult<C>>,
934    ) {
935        let (a_vals, b_vals, c_vals) = self
936            .inner
937            .beaver_source
938            .lock()
939            .expect("beaver source poisoned")
940            .next_triplet_batch(n);
941
942        let a_vals = self
943            .allocate_scalars(a_vals)
944            .into_iter()
945            .map(MpcScalarResult::new_shared)
946            .collect_vec();
947        let b_vals = self
948            .allocate_scalars(b_vals)
949            .into_iter()
950            .map(MpcScalarResult::new_shared)
951            .collect_vec();
952        let c_vals = self
953            .allocate_scalars(c_vals)
954            .into_iter()
955            .map(MpcScalarResult::new_shared)
956            .collect_vec();
957
958        (a_vals, b_vals, c_vals)
959    }
960
961    /// Sample the next beaver triplet with MACs from the beaver source
962    ///
963    /// TODO: Authenticate these values either here or in the pre-processing
964    /// phase as per the SPDZ paper
965    pub fn next_authenticated_triple(
966        &self,
967    ) -> (
968        AuthenticatedScalarResult<C>,
969        AuthenticatedScalarResult<C>,
970        AuthenticatedScalarResult<C>,
971    ) {
972        let (a, b, c) = self
973            .inner
974            .beaver_source
975            .lock()
976            .expect("beaver source poisoned")
977            .next_triplet();
978
979        let a_val = self.allocate_scalar(a);
980        let b_val = self.allocate_scalar(b);
981        let c_val = self.allocate_scalar(c);
982
983        (
984            AuthenticatedScalarResult::new_shared(a_val),
985            AuthenticatedScalarResult::new_shared(b_val),
986            AuthenticatedScalarResult::new_shared(c_val),
987        )
988    }
989
990    /// Sample the next batch of beaver triples as `AuthenticatedScalar`s
991    #[allow(clippy::type_complexity)]
992    pub fn next_authenticated_triple_batch(
993        &self,
994        n: usize,
995    ) -> (
996        Vec<AuthenticatedScalarResult<C>>,
997        Vec<AuthenticatedScalarResult<C>>,
998        Vec<AuthenticatedScalarResult<C>>,
999    ) {
1000        let (a_vals, b_vals, c_vals) = self
1001            .inner
1002            .beaver_source
1003            .lock()
1004            .expect("beaver source poisoned")
1005            .next_triplet_batch(n);
1006
1007        let a_allocated = self.allocate_scalars(a_vals);
1008        let b_allocated = self.allocate_scalars(b_vals);
1009        let c_allocated = self.allocate_scalars(c_vals);
1010
1011        (
1012            AuthenticatedScalarResult::new_shared_batch(&a_allocated),
1013            AuthenticatedScalarResult::new_shared_batch(&b_allocated),
1014            AuthenticatedScalarResult::new_shared_batch(&c_allocated),
1015        )
1016    }
1017
1018    /// Sample a batch of random shared values from the beaver source
1019    pub fn random_shared_scalars(&self, n: usize) -> Vec<ScalarResult<C>> {
1020        let values_raw = self
1021            .inner
1022            .beaver_source
1023            .lock()
1024            .expect("beaver source poisoned")
1025            .next_shared_value_batch(n);
1026
1027        // Wrap the values in a result handle
1028        values_raw
1029            .into_iter()
1030            .map(|value| self.allocate_scalar(value))
1031            .collect_vec()
1032    }
1033
1034    /// Sample a batch of random shared values from the beaver source and
1035    /// allocate them as `AuthenticatedScalars`
1036    pub fn random_shared_scalars_authenticated(
1037        &self,
1038        n: usize,
1039    ) -> Vec<AuthenticatedScalarResult<C>> {
1040        let values_raw = self
1041            .inner
1042            .beaver_source
1043            .lock()
1044            .expect("beaver source poisoned")
1045            .next_shared_value_batch(n);
1046
1047        // Wrap the values in an authenticated wrapper
1048        values_raw
1049            .into_iter()
1050            .map(|value| {
1051                let value = self.allocate_scalar(value);
1052                AuthenticatedScalarResult::new_shared(value)
1053            })
1054            .collect_vec()
1055    }
1056
1057    /// Sample a pair of values that are multiplicative inverses of one another
1058    pub fn random_inverse_pair(
1059        &self,
1060    ) -> (AuthenticatedScalarResult<C>, AuthenticatedScalarResult<C>) {
1061        let (l, r) = self
1062            .inner
1063            .beaver_source
1064            .lock()
1065            .unwrap()
1066            .next_shared_inverse_pair();
1067        (
1068            AuthenticatedScalarResult::new_shared(self.allocate_scalar(l)),
1069            AuthenticatedScalarResult::new_shared(self.allocate_scalar(r)),
1070        )
1071    }
1072
1073    /// Sample a batch of values that are multiplicative inverses of one another
1074    pub fn random_inverse_pairs(
1075        &self,
1076        n: usize,
1077    ) -> (
1078        Vec<AuthenticatedScalarResult<C>>,
1079        Vec<AuthenticatedScalarResult<C>>,
1080    ) {
1081        let (left, right) = self
1082            .inner
1083            .beaver_source
1084            .lock()
1085            .unwrap()
1086            .next_shared_inverse_pair_batch(n);
1087
1088        let left_right = left.into_iter().chain(right).collect_vec();
1089        let allocated_left_right = self.allocate_scalars(left_right);
1090        let authenticated_left_right =
1091            AuthenticatedScalarResult::new_shared_batch(&allocated_left_right);
1092
1093        // Split left and right
1094        let (left, right) = authenticated_left_right.split_at(n);
1095        (left.to_vec(), right.to_vec())
1096    }
1097
1098    /// Sample a random shared bit from the beaver source
1099    pub fn random_shared_bit(&self) -> AuthenticatedScalarResult<C> {
1100        let bit = self
1101            .inner
1102            .beaver_source
1103            .lock()
1104            .expect("beaver source poisoned")
1105            .next_shared_bit();
1106
1107        let bit = self.allocate_scalar(bit);
1108        AuthenticatedScalarResult::new_shared(bit)
1109    }
1110
1111    /// Sample a batch of random shared bits from the beaver source
1112    pub fn random_shared_bits(&self, n: usize) -> Vec<AuthenticatedScalarResult<C>> {
1113        let bits = self
1114            .inner
1115            .beaver_source
1116            .lock()
1117            .expect("beaver source poisoned")
1118            .next_shared_bit_batch(n);
1119
1120        let bits = self.allocate_scalars(bits);
1121        AuthenticatedScalarResult::new_shared_batch(&bits)
1122    }
1123}
1124
1125#[cfg(test)]
1126mod test {
1127    use crate::{algebra::Scalar, test_helpers::execute_mock_mpc, PARTY0};
1128
1129    /// Tests a linear circuit of very large depth
1130    #[tokio::test]
1131    async fn test_deep_circuit() {
1132        const DEPTH: usize = 1_000_000;
1133        let (res, _) = execute_mock_mpc(|fabric| async move {
1134            // Perform an operation that takes time, so that further operations will enqueue
1135            // behind it
1136            let mut res = fabric.share_plaintext(Scalar::from(1u8), PARTY0);
1137            for _ in 0..DEPTH {
1138                res = res + fabric.one();
1139            }
1140
1141            res.await
1142        })
1143        .await;
1144
1145        assert_eq!(res, Scalar::from(DEPTH + 1));
1146    }
1147}