mpc_stark/fabric/
result.rs

1//! Defines the abstractions over the result of an MPC operation, this can be a network
2//! operation, a simple local computation, or a more complex operation like a
3//! Beaver multiplication
4
5use std::{
6    fmt::{Debug, Formatter, Result as FmtResult},
7    marker::PhantomData,
8    pin::Pin,
9    sync::{Arc, RwLock},
10    task::{Context, Poll, Waker},
11};
12
13use futures::Future;
14
15use crate::{
16    algebra::{scalar::Scalar, stark_curve::StarkPoint},
17    network::NetworkPayload,
18    Shared,
19};
20
21use super::MpcFabric;
22
23/// Error message when a result buffer lock is poisoned
24pub(crate) const ERR_RESULT_BUFFER_POISONED: &str = "result buffer lock poisoned";
25
26// ---------------------
27// | Result Value Type |
28// ---------------------
29
30/// An identifier for a result
31pub type ResultId = usize;
32
33/// The result of an MPC operation
34#[derive(Clone, Debug)]
35pub struct OpResult {
36    /// The ID of the result's output
37    pub id: ResultId,
38    /// The result's value
39    pub value: ResultValue,
40}
41
42/// The value of a result
43#[derive(Clone)]
44pub enum ResultValue {
45    /// A byte value
46    Bytes(Vec<u8>),
47    /// A scalar value
48    Scalar(Scalar),
49    /// A batch of scalars
50    ScalarBatch(Vec<Scalar>),
51    /// A point on the curve
52    Point(StarkPoint),
53    /// A batch of points on the curve
54    PointBatch(Vec<StarkPoint>),
55}
56
57impl Debug for ResultValue {
58    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
59        match self {
60            ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(),
61            ResultValue::Scalar(scalar) => f.debug_tuple("Scalar").field(scalar).finish(),
62            ResultValue::ScalarBatch(scalars) => {
63                f.debug_tuple("ScalarBatch").field(scalars).finish()
64            }
65            ResultValue::Point(point) => f.debug_tuple("Point").field(point).finish(),
66            ResultValue::PointBatch(points) => f.debug_tuple("PointBatch").field(points).finish(),
67        }
68    }
69}
70
71impl From<NetworkPayload> for ResultValue {
72    fn from(value: NetworkPayload) -> Self {
73        match value {
74            NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes),
75            NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar),
76            NetworkPayload::ScalarBatch(scalars) => ResultValue::ScalarBatch(scalars),
77            NetworkPayload::Point(point) => ResultValue::Point(point),
78            NetworkPayload::PointBatch(points) => ResultValue::PointBatch(points),
79        }
80    }
81}
82
83impl From<ResultValue> for NetworkPayload {
84    fn from(value: ResultValue) -> Self {
85        match value {
86            ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
87            ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
88            ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
89            ResultValue::Point(point) => NetworkPayload::Point(point),
90            ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
91        }
92    }
93}
94
95// -- Coercive Casts to Concrete Types -- //
96impl From<ResultValue> for Vec<u8> {
97    fn from(value: ResultValue) -> Self {
98        match value {
99            ResultValue::Bytes(bytes) => bytes,
100            _ => panic!("Cannot cast {:?} to bytes", value),
101        }
102    }
103}
104
105impl From<ResultValue> for Scalar {
106    fn from(value: ResultValue) -> Self {
107        match value {
108            ResultValue::Scalar(scalar) => scalar,
109            _ => panic!("Cannot cast {:?} to scalar", value),
110        }
111    }
112}
113
114impl From<&ResultValue> for Scalar {
115    fn from(value: &ResultValue) -> Self {
116        match value {
117            ResultValue::Scalar(scalar) => *scalar,
118            _ => panic!("Cannot cast {:?} to scalar", value),
119        }
120    }
121}
122
123impl From<ResultValue> for Vec<Scalar> {
124    fn from(value: ResultValue) -> Self {
125        match value {
126            ResultValue::ScalarBatch(scalars) => scalars,
127            _ => panic!("Cannot cast {:?} to scalar batch", value),
128        }
129    }
130}
131
132impl From<ResultValue> for StarkPoint {
133    fn from(value: ResultValue) -> Self {
134        match value {
135            ResultValue::Point(point) => point,
136            _ => panic!("Cannot cast {:?} to point", value),
137        }
138    }
139}
140
141impl From<&ResultValue> for StarkPoint {
142    fn from(value: &ResultValue) -> Self {
143        match value {
144            ResultValue::Point(point) => *point,
145            _ => panic!("Cannot cast {:?} to point", value),
146        }
147    }
148}
149
150impl From<ResultValue> for Vec<StarkPoint> {
151    fn from(value: ResultValue) -> Self {
152        match value {
153            ResultValue::PointBatch(points) => points,
154            _ => panic!("Cannot cast {:?} to point batch", value),
155        }
156    }
157}
158
159// ---------------
160// | Handle Type |
161// ---------------
162
163/// A handle to the result of the execution of an MPC computation graph
164///
165/// This handle acts as a pointer to a possible incomplete partial result, and
166/// `await`-ing it will block the task until the graph has evaluated up to that point
167///
168/// This allows for construction of the graph concurrently with execution, giving the
169/// fabric the opportunity to schedule all results onto the network optimistically
170#[derive(Clone, Debug)]
171pub struct ResultHandle<T: From<ResultValue>> {
172    /// The id of the result
173    pub(crate) id: ResultId,
174    /// The buffer that the result will be written to when it becomes available
175    pub(crate) result_buffer: Shared<Option<ResultValue>>,
176    /// The underlying fabric
177    pub(crate) fabric: MpcFabric,
178    /// A phantom for the type of the result
179    phantom: PhantomData<T>,
180}
181
182impl<T: From<ResultValue>> ResultHandle<T> {
183    /// Get the id of the result
184    pub fn id(&self) -> ResultId {
185        self.id
186    }
187
188    /// Borrow the fabric that this result is allocated within
189    pub fn fabric(&self) -> &MpcFabric {
190        &self.fabric
191    }
192}
193
194impl<T: From<ResultValue>> ResultHandle<T> {
195    /// Constructor
196    pub(crate) fn new(id: ResultId, fabric: MpcFabric) -> Self {
197        Self {
198            id,
199            result_buffer: Arc::new(RwLock::new(None)),
200            fabric,
201            phantom: PhantomData,
202        }
203    }
204
205    /// Get the ids that this result represents, awaiting these IDs is awaiting this result
206    pub fn op_ids(&self) -> Vec<ResultId> {
207        vec![self.id]
208    }
209}
210
211/// A struct describing an async task that is waiting on a result
212pub struct ResultWaiter {
213    /// The id of the result that the task is waiting on
214    pub result_id: ResultId,
215    /// The buffer that the result will be written to when it becomes available
216    pub result_buffer: Shared<Option<ResultValue>>,
217    /// The waker of the task
218    pub waker: Waker,
219}
220
221impl Debug for ResultWaiter {
222    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
223        f.debug_struct("ResultWaiter")
224            .field("id", &self.result_id)
225            .finish()
226    }
227}
228
229impl<T: From<ResultValue> + Debug> Future for ResultHandle<T> {
230    type Output = T;
231
232    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233        // Lock the result buffer
234        let locked_result = self.result_buffer.read().expect(ERR_RESULT_BUFFER_POISONED);
235
236        // If the result is ready, return it, otherwise register the current context's waker
237        // with the `Executor`
238        match locked_result.clone() {
239            Some(res) => Poll::Ready(res.into()),
240            None => {
241                let waiter = ResultWaiter {
242                    result_id: self.id,
243                    result_buffer: self.result_buffer.clone(),
244                    waker: cx.waker().clone(),
245                };
246
247                self.fabric.register_waiter(waiter);
248                Poll::Pending
249            }
250        }
251    }
252}