mpc_stark/fabric/
result.rs1use std::{
6 fmt::{Debug, Formatter, Result as FmtResult},
7 marker::PhantomData,
8 pin::Pin,
9 sync::{Arc, RwLock},
10 task::{Context, Poll, Waker},
11};
12
13use futures::Future;
14
15use crate::{
16 algebra::{scalar::Scalar, stark_curve::StarkPoint},
17 network::NetworkPayload,
18 Shared,
19};
20
21use super::MpcFabric;
22
23pub(crate) const ERR_RESULT_BUFFER_POISONED: &str = "result buffer lock poisoned";
25
26pub type ResultId = usize;
32
33#[derive(Clone, Debug)]
35pub struct OpResult {
36 pub id: ResultId,
38 pub value: ResultValue,
40}
41
42#[derive(Clone)]
44pub enum ResultValue {
45 Bytes(Vec<u8>),
47 Scalar(Scalar),
49 ScalarBatch(Vec<Scalar>),
51 Point(StarkPoint),
53 PointBatch(Vec<StarkPoint>),
55}
56
57impl Debug for ResultValue {
58 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
59 match self {
60 ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(),
61 ResultValue::Scalar(scalar) => f.debug_tuple("Scalar").field(scalar).finish(),
62 ResultValue::ScalarBatch(scalars) => {
63 f.debug_tuple("ScalarBatch").field(scalars).finish()
64 }
65 ResultValue::Point(point) => f.debug_tuple("Point").field(point).finish(),
66 ResultValue::PointBatch(points) => f.debug_tuple("PointBatch").field(points).finish(),
67 }
68 }
69}
70
71impl From<NetworkPayload> for ResultValue {
72 fn from(value: NetworkPayload) -> Self {
73 match value {
74 NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes),
75 NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar),
76 NetworkPayload::ScalarBatch(scalars) => ResultValue::ScalarBatch(scalars),
77 NetworkPayload::Point(point) => ResultValue::Point(point),
78 NetworkPayload::PointBatch(points) => ResultValue::PointBatch(points),
79 }
80 }
81}
82
83impl From<ResultValue> for NetworkPayload {
84 fn from(value: ResultValue) -> Self {
85 match value {
86 ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
87 ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
88 ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
89 ResultValue::Point(point) => NetworkPayload::Point(point),
90 ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
91 }
92 }
93}
94
95impl From<ResultValue> for Vec<u8> {
97 fn from(value: ResultValue) -> Self {
98 match value {
99 ResultValue::Bytes(bytes) => bytes,
100 _ => panic!("Cannot cast {:?} to bytes", value),
101 }
102 }
103}
104
105impl From<ResultValue> for Scalar {
106 fn from(value: ResultValue) -> Self {
107 match value {
108 ResultValue::Scalar(scalar) => scalar,
109 _ => panic!("Cannot cast {:?} to scalar", value),
110 }
111 }
112}
113
114impl From<&ResultValue> for Scalar {
115 fn from(value: &ResultValue) -> Self {
116 match value {
117 ResultValue::Scalar(scalar) => *scalar,
118 _ => panic!("Cannot cast {:?} to scalar", value),
119 }
120 }
121}
122
123impl From<ResultValue> for Vec<Scalar> {
124 fn from(value: ResultValue) -> Self {
125 match value {
126 ResultValue::ScalarBatch(scalars) => scalars,
127 _ => panic!("Cannot cast {:?} to scalar batch", value),
128 }
129 }
130}
131
132impl From<ResultValue> for StarkPoint {
133 fn from(value: ResultValue) -> Self {
134 match value {
135 ResultValue::Point(point) => point,
136 _ => panic!("Cannot cast {:?} to point", value),
137 }
138 }
139}
140
141impl From<&ResultValue> for StarkPoint {
142 fn from(value: &ResultValue) -> Self {
143 match value {
144 ResultValue::Point(point) => *point,
145 _ => panic!("Cannot cast {:?} to point", value),
146 }
147 }
148}
149
150impl From<ResultValue> for Vec<StarkPoint> {
151 fn from(value: ResultValue) -> Self {
152 match value {
153 ResultValue::PointBatch(points) => points,
154 _ => panic!("Cannot cast {:?} to point batch", value),
155 }
156 }
157}
158
159#[derive(Clone, Debug)]
171pub struct ResultHandle<T: From<ResultValue>> {
172 pub(crate) id: ResultId,
174 pub(crate) result_buffer: Shared<Option<ResultValue>>,
176 pub(crate) fabric: MpcFabric,
178 phantom: PhantomData<T>,
180}
181
182impl<T: From<ResultValue>> ResultHandle<T> {
183 pub fn id(&self) -> ResultId {
185 self.id
186 }
187
188 pub fn fabric(&self) -> &MpcFabric {
190 &self.fabric
191 }
192}
193
194impl<T: From<ResultValue>> ResultHandle<T> {
195 pub(crate) fn new(id: ResultId, fabric: MpcFabric) -> Self {
197 Self {
198 id,
199 result_buffer: Arc::new(RwLock::new(None)),
200 fabric,
201 phantom: PhantomData,
202 }
203 }
204
205 pub fn op_ids(&self) -> Vec<ResultId> {
207 vec![self.id]
208 }
209}
210
211pub struct ResultWaiter {
213 pub result_id: ResultId,
215 pub result_buffer: Shared<Option<ResultValue>>,
217 pub waker: Waker,
219}
220
221impl Debug for ResultWaiter {
222 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
223 f.debug_struct("ResultWaiter")
224 .field("id", &self.result_id)
225 .finish()
226 }
227}
228
229impl<T: From<ResultValue> + Debug> Future for ResultHandle<T> {
230 type Output = T;
231
232 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 let locked_result = self.result_buffer.read().expect(ERR_RESULT_BUFFER_POISONED);
235
236 match locked_result.clone() {
239 Some(res) => Poll::Ready(res.into()),
240 None => {
241 let waiter = ResultWaiter {
242 result_id: self.id,
243 result_buffer: self.result_buffer.clone(),
244 waker: cx.waker().clone(),
245 };
246
247 self.fabric.register_waiter(waiter);
248 Poll::Pending
249 }
250 }
251 }
252}