1use std::{
6 fmt::{Debug, Formatter, Result as FmtResult},
7 marker::PhantomData,
8 pin::Pin,
9 sync::{Arc, RwLock},
10 task::{Context, Poll, Waker},
11};
12
13use ark_ec::CurveGroup;
14use futures::Future;
15
16use crate::{
17 algebra::{CurvePoint, Scalar},
18 network::NetworkPayload,
19 Shared,
20};
21
22use super::MpcFabric;
23
24pub(crate) const ERR_RESULT_BUFFER_POISONED: &str = "result buffer lock poisoned";
26
27pub type ResultId = usize;
33
34#[derive(Clone, Debug)]
36pub struct OpResult<C: CurveGroup> {
37 pub id: ResultId,
39 pub value: ResultValue<C>,
41}
42
43#[derive(Clone)]
45pub enum ResultValue<C: CurveGroup> {
46 Bytes(Vec<u8>),
48 Scalar(Scalar<C>),
50 ScalarBatch(Vec<Scalar<C>>),
52 Point(CurvePoint<C>),
54 PointBatch(Vec<CurvePoint<C>>),
56}
57
58impl<C: CurveGroup> Debug for ResultValue<C> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
60 match self {
61 ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(),
62 ResultValue::Scalar(scalar) => f.debug_tuple("Scalar").field(scalar).finish(),
63 ResultValue::ScalarBatch(scalars) => {
64 f.debug_tuple("ScalarBatch").field(scalars).finish()
65 },
66 ResultValue::Point(point) => f.debug_tuple("Point").field(point).finish(),
67 ResultValue::PointBatch(points) => f.debug_tuple("PointBatch").field(points).finish(),
68 }
69 }
70}
71
72impl<C: CurveGroup> From<NetworkPayload<C>> for ResultValue<C> {
73 fn from(value: NetworkPayload<C>) -> Self {
74 match value {
75 NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes),
76 NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar),
77 NetworkPayload::ScalarBatch(scalars) => ResultValue::ScalarBatch(scalars),
78 NetworkPayload::Point(point) => ResultValue::Point(point),
79 NetworkPayload::PointBatch(points) => ResultValue::PointBatch(points),
80 }
81 }
82}
83
84impl<C: CurveGroup> From<ResultValue<C>> for NetworkPayload<C> {
85 fn from(value: ResultValue<C>) -> Self {
86 match value {
87 ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
88 ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
89 ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
90 ResultValue::Point(point) => NetworkPayload::Point(point),
91 ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
92 }
93 }
94}
95
96impl<C: CurveGroup> From<ResultValue<C>> for Vec<u8> {
98 fn from(value: ResultValue<C>) -> Self {
99 match value {
100 ResultValue::Bytes(bytes) => bytes,
101 _ => panic!("Cannot cast {:?} to bytes", value),
102 }
103 }
104}
105
106impl<C: CurveGroup> From<ResultValue<C>> for Scalar<C> {
107 fn from(value: ResultValue<C>) -> Self {
108 match value {
109 ResultValue::Scalar(scalar) => scalar,
110 _ => panic!("Cannot cast {:?} to scalar", value),
111 }
112 }
113}
114
115impl<C: CurveGroup> From<&ResultValue<C>> for Scalar<C> {
116 fn from(value: &ResultValue<C>) -> Self {
117 match value {
118 ResultValue::Scalar(scalar) => *scalar,
119 _ => panic!("Cannot cast {:?} to scalar", value),
120 }
121 }
122}
123
124impl<C: CurveGroup> From<ResultValue<C>> for Vec<Scalar<C>> {
125 fn from(value: ResultValue<C>) -> Self {
126 match value {
127 ResultValue::ScalarBatch(scalars) => scalars,
128 _ => panic!("Cannot cast {:?} to scalar batch", value),
129 }
130 }
131}
132
133impl<C: CurveGroup> From<ResultValue<C>> for CurvePoint<C> {
134 fn from(value: ResultValue<C>) -> Self {
135 match value {
136 ResultValue::Point(point) => point,
137 _ => panic!("Cannot cast {:?} to point", value),
138 }
139 }
140}
141
142impl<C: CurveGroup> From<&ResultValue<C>> for CurvePoint<C> {
143 fn from(value: &ResultValue<C>) -> Self {
144 match value {
145 ResultValue::Point(point) => *point,
146 _ => panic!("Cannot cast {:?} to point", value),
147 }
148 }
149}
150
151impl<C: CurveGroup> From<ResultValue<C>> for Vec<CurvePoint<C>> {
152 fn from(value: ResultValue<C>) -> Self {
153 match value {
154 ResultValue::PointBatch(points) => points,
155 _ => panic!("Cannot cast {:?} to point batch", value),
156 }
157 }
158}
159
160#[derive(Clone, Debug)]
174pub struct ResultHandle<C: CurveGroup, T: From<ResultValue<C>>> {
175 pub(crate) id: ResultId,
177 pub(crate) result_buffer: Shared<Option<ResultValue<C>>>,
179 pub(crate) fabric: MpcFabric<C>,
181 phantom: PhantomData<T>,
183}
184
185impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
186 pub fn id(&self) -> ResultId {
188 self.id
189 }
190
191 pub fn fabric(&self) -> &MpcFabric<C> {
193 &self.fabric
194 }
195}
196
197impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
198 pub(crate) fn new(id: ResultId, fabric: MpcFabric<C>) -> Self {
200 Self {
201 id,
202 result_buffer: Arc::new(RwLock::new(None)),
203 fabric,
204 phantom: PhantomData,
205 }
206 }
207
208 pub fn op_ids(&self) -> Vec<ResultId> {
211 vec![self.id]
212 }
213}
214
215pub struct ResultWaiter<C: CurveGroup> {
217 pub result_id: ResultId,
219 pub result_buffer: Shared<Option<ResultValue<C>>>,
221 pub waker: Waker,
223}
224
225impl<C: CurveGroup> Debug for ResultWaiter<C> {
226 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
227 f.debug_struct("ResultWaiter")
228 .field("id", &self.result_id)
229 .finish()
230 }
231}
232
233impl<C: CurveGroup, T: From<ResultValue<C>> + Debug> Future for ResultHandle<C, T> {
234 type Output = T;
235
236 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 let locked_result = self.result_buffer.read().expect(ERR_RESULT_BUFFER_POISONED);
239
240 match locked_result.clone() {
243 Some(res) => Poll::Ready(res.into()),
244 None => {
245 let waiter = ResultWaiter {
246 result_id: self.id,
247 result_buffer: self.result_buffer.clone(),
248 waker: cx.waker().clone(),
249 };
250
251 self.fabric.register_waiter(waiter);
252 Poll::Pending
253 },
254 }
255 }
256}