ark_mpc/fabric/
result.rs

1//! Defines the abstractions over the result of an MPC operation, this can be a
2//! network operation, a simple local computation, or a more complex operation
3//! like a Beaver multiplication
4
5use std::{
6    fmt::{Debug, Formatter, Result as FmtResult},
7    marker::PhantomData,
8    pin::Pin,
9    sync::{Arc, RwLock},
10    task::{Context, Poll, Waker},
11};
12
13use ark_ec::CurveGroup;
14use futures::Future;
15
16use crate::{
17    algebra::{CurvePoint, Scalar},
18    network::NetworkPayload,
19    Shared,
20};
21
22use super::MpcFabric;
23
24/// Error message when a result buffer lock is poisoned
25pub(crate) const ERR_RESULT_BUFFER_POISONED: &str = "result buffer lock poisoned";
26
27// ---------------------
28// | Result Value Type |
29// ---------------------
30
31/// An identifier for a result
32pub type ResultId = usize;
33
34/// The result of an MPC operation
35#[derive(Clone, Debug)]
36pub struct OpResult<C: CurveGroup> {
37    /// The ID of the result's output
38    pub id: ResultId,
39    /// The result's value
40    pub value: ResultValue<C>,
41}
42
43/// The value of a result
44#[derive(Clone)]
45pub enum ResultValue<C: CurveGroup> {
46    /// A byte value
47    Bytes(Vec<u8>),
48    /// A scalar value
49    Scalar(Scalar<C>),
50    /// A batch of scalars
51    ScalarBatch(Vec<Scalar<C>>),
52    /// A point on the curve
53    Point(CurvePoint<C>),
54    /// A batch of points on the curve
55    PointBatch(Vec<CurvePoint<C>>),
56}
57
58impl<C: CurveGroup> Debug for ResultValue<C> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
60        match self {
61            ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(),
62            ResultValue::Scalar(scalar) => f.debug_tuple("Scalar").field(scalar).finish(),
63            ResultValue::ScalarBatch(scalars) => {
64                f.debug_tuple("ScalarBatch").field(scalars).finish()
65            },
66            ResultValue::Point(point) => f.debug_tuple("Point").field(point).finish(),
67            ResultValue::PointBatch(points) => f.debug_tuple("PointBatch").field(points).finish(),
68        }
69    }
70}
71
72impl<C: CurveGroup> From<NetworkPayload<C>> for ResultValue<C> {
73    fn from(value: NetworkPayload<C>) -> Self {
74        match value {
75            NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes),
76            NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar),
77            NetworkPayload::ScalarBatch(scalars) => ResultValue::ScalarBatch(scalars),
78            NetworkPayload::Point(point) => ResultValue::Point(point),
79            NetworkPayload::PointBatch(points) => ResultValue::PointBatch(points),
80        }
81    }
82}
83
84impl<C: CurveGroup> From<ResultValue<C>> for NetworkPayload<C> {
85    fn from(value: ResultValue<C>) -> Self {
86        match value {
87            ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
88            ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
89            ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
90            ResultValue::Point(point) => NetworkPayload::Point(point),
91            ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
92        }
93    }
94}
95
96// -- Coercive Casts to Concrete Types -- //
97impl<C: CurveGroup> From<ResultValue<C>> for Vec<u8> {
98    fn from(value: ResultValue<C>) -> Self {
99        match value {
100            ResultValue::Bytes(bytes) => bytes,
101            _ => panic!("Cannot cast {:?} to bytes", value),
102        }
103    }
104}
105
106impl<C: CurveGroup> From<ResultValue<C>> for Scalar<C> {
107    fn from(value: ResultValue<C>) -> Self {
108        match value {
109            ResultValue::Scalar(scalar) => scalar,
110            _ => panic!("Cannot cast {:?} to scalar", value),
111        }
112    }
113}
114
115impl<C: CurveGroup> From<&ResultValue<C>> for Scalar<C> {
116    fn from(value: &ResultValue<C>) -> Self {
117        match value {
118            ResultValue::Scalar(scalar) => *scalar,
119            _ => panic!("Cannot cast {:?} to scalar", value),
120        }
121    }
122}
123
124impl<C: CurveGroup> From<ResultValue<C>> for Vec<Scalar<C>> {
125    fn from(value: ResultValue<C>) -> Self {
126        match value {
127            ResultValue::ScalarBatch(scalars) => scalars,
128            _ => panic!("Cannot cast {:?} to scalar batch", value),
129        }
130    }
131}
132
133impl<C: CurveGroup> From<ResultValue<C>> for CurvePoint<C> {
134    fn from(value: ResultValue<C>) -> Self {
135        match value {
136            ResultValue::Point(point) => point,
137            _ => panic!("Cannot cast {:?} to point", value),
138        }
139    }
140}
141
142impl<C: CurveGroup> From<&ResultValue<C>> for CurvePoint<C> {
143    fn from(value: &ResultValue<C>) -> Self {
144        match value {
145            ResultValue::Point(point) => *point,
146            _ => panic!("Cannot cast {:?} to point", value),
147        }
148    }
149}
150
151impl<C: CurveGroup> From<ResultValue<C>> for Vec<CurvePoint<C>> {
152    fn from(value: ResultValue<C>) -> Self {
153        match value {
154            ResultValue::PointBatch(points) => points,
155            _ => panic!("Cannot cast {:?} to point batch", value),
156        }
157    }
158}
159
160// ---------------
161// | Handle Type |
162// ---------------
163
164/// A handle to the result of the execution of an MPC computation graph
165///
166/// This handle acts as a pointer to a possible incomplete partial result, and
167/// `await`-ing it will block the task until the graph has evaluated up to that
168/// point
169///
170/// This allows for construction of the graph concurrently with execution,
171/// giving the fabric the opportunity to schedule all results onto the network
172/// optimistically
173#[derive(Clone, Debug)]
174pub struct ResultHandle<C: CurveGroup, T: From<ResultValue<C>>> {
175    /// The id of the result
176    pub(crate) id: ResultId,
177    /// The buffer that the result will be written to when it becomes available
178    pub(crate) result_buffer: Shared<Option<ResultValue<C>>>,
179    /// The underlying fabric
180    pub(crate) fabric: MpcFabric<C>,
181    /// A phantom for the type of the result
182    phantom: PhantomData<T>,
183}
184
185impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
186    /// Get the id of the result
187    pub fn id(&self) -> ResultId {
188        self.id
189    }
190
191    /// Borrow the fabric that this result is allocated within
192    pub fn fabric(&self) -> &MpcFabric<C> {
193        &self.fabric
194    }
195}
196
197impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
198    /// Constructor
199    pub(crate) fn new(id: ResultId, fabric: MpcFabric<C>) -> Self {
200        Self {
201            id,
202            result_buffer: Arc::new(RwLock::new(None)),
203            fabric,
204            phantom: PhantomData,
205        }
206    }
207
208    /// Get the ids that this result represents, awaiting these IDs is awaiting
209    /// this result
210    pub fn op_ids(&self) -> Vec<ResultId> {
211        vec![self.id]
212    }
213}
214
215/// A struct describing an async task that is waiting on a result
216pub struct ResultWaiter<C: CurveGroup> {
217    /// The id of the result that the task is waiting on
218    pub result_id: ResultId,
219    /// The buffer that the result will be written to when it becomes available
220    pub result_buffer: Shared<Option<ResultValue<C>>>,
221    /// The waker of the task
222    pub waker: Waker,
223}
224
225impl<C: CurveGroup> Debug for ResultWaiter<C> {
226    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
227        f.debug_struct("ResultWaiter")
228            .field("id", &self.result_id)
229            .finish()
230    }
231}
232
233impl<C: CurveGroup, T: From<ResultValue<C>> + Debug> Future for ResultHandle<C, T> {
234    type Output = T;
235
236    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        // Lock the result buffer
238        let locked_result = self.result_buffer.read().expect(ERR_RESULT_BUFFER_POISONED);
239
240        // If the result is ready, return it, otherwise register the current context's
241        // waker with the `Executor`
242        match locked_result.clone() {
243            Some(res) => Poll::Ready(res.into()),
244            None => {
245                let waiter = ResultWaiter {
246                    result_id: self.id,
247                    result_buffer: self.result_buffer.clone(),
248                    waker: cx.waker().clone(),
249                };
250
251                self.fabric.register_waiter(waiter);
252                Poll::Pending
253            },
254        }
255    }
256}