use std::{
fmt::{Debug, Formatter, Result as FmtResult},
marker::PhantomData,
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll, Waker},
};
use ark_ec::CurveGroup;
use futures::Future;
use crate::{
algebra::{curve::CurvePoint, scalar::Scalar},
network::NetworkPayload,
Shared,
};
use super::MpcFabric;
pub(crate) const ERR_RESULT_BUFFER_POISONED: &str = "result buffer lock poisoned";
pub type ResultId = usize;
#[derive(Clone, Debug)]
pub struct OpResult<C: CurveGroup> {
pub id: ResultId,
pub value: ResultValue<C>,
}
#[derive(Clone)]
pub enum ResultValue<C: CurveGroup> {
Bytes(Vec<u8>),
Scalar(Scalar<C>),
ScalarBatch(Vec<Scalar<C>>),
Point(CurvePoint<C>),
PointBatch(Vec<CurvePoint<C>>),
}
impl<C: CurveGroup> Debug for ResultValue<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(),
ResultValue::Scalar(scalar) => f.debug_tuple("Scalar").field(scalar).finish(),
ResultValue::ScalarBatch(scalars) => {
f.debug_tuple("ScalarBatch").field(scalars).finish()
}
ResultValue::Point(point) => f.debug_tuple("Point").field(point).finish(),
ResultValue::PointBatch(points) => f.debug_tuple("PointBatch").field(points).finish(),
}
}
}
impl<C: CurveGroup> From<NetworkPayload<C>> for ResultValue<C> {
fn from(value: NetworkPayload<C>) -> Self {
match value {
NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes),
NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar),
NetworkPayload::ScalarBatch(scalars) => ResultValue::ScalarBatch(scalars),
NetworkPayload::Point(point) => ResultValue::Point(point),
NetworkPayload::PointBatch(points) => ResultValue::PointBatch(points),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for NetworkPayload<C> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
ResultValue::Point(point) => NetworkPayload::Point(point),
ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for Vec<u8> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::Bytes(bytes) => bytes,
_ => panic!("Cannot cast {:?} to bytes", value),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for Scalar<C> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::Scalar(scalar) => scalar,
_ => panic!("Cannot cast {:?} to scalar", value),
}
}
}
impl<C: CurveGroup> From<&ResultValue<C>> for Scalar<C> {
fn from(value: &ResultValue<C>) -> Self {
match value {
ResultValue::Scalar(scalar) => *scalar,
_ => panic!("Cannot cast {:?} to scalar", value),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for Vec<Scalar<C>> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::ScalarBatch(scalars) => scalars,
_ => panic!("Cannot cast {:?} to scalar batch", value),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for CurvePoint<C> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::Point(point) => point,
_ => panic!("Cannot cast {:?} to point", value),
}
}
}
impl<C: CurveGroup> From<&ResultValue<C>> for CurvePoint<C> {
fn from(value: &ResultValue<C>) -> Self {
match value {
ResultValue::Point(point) => *point,
_ => panic!("Cannot cast {:?} to point", value),
}
}
}
impl<C: CurveGroup> From<ResultValue<C>> for Vec<CurvePoint<C>> {
fn from(value: ResultValue<C>) -> Self {
match value {
ResultValue::PointBatch(points) => points,
_ => panic!("Cannot cast {:?} to point batch", value),
}
}
}
#[derive(Clone, Debug)]
pub struct ResultHandle<C: CurveGroup, T: From<ResultValue<C>>> {
pub(crate) id: ResultId,
pub(crate) result_buffer: Shared<Option<ResultValue<C>>>,
pub(crate) fabric: MpcFabric<C>,
phantom: PhantomData<T>,
}
impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
pub fn id(&self) -> ResultId {
self.id
}
pub fn fabric(&self) -> &MpcFabric<C> {
&self.fabric
}
}
impl<C: CurveGroup, T: From<ResultValue<C>>> ResultHandle<C, T> {
pub(crate) fn new(id: ResultId, fabric: MpcFabric<C>) -> Self {
Self {
id,
result_buffer: Arc::new(RwLock::new(None)),
fabric,
phantom: PhantomData,
}
}
pub fn op_ids(&self) -> Vec<ResultId> {
vec![self.id]
}
}
pub struct ResultWaiter<C: CurveGroup> {
pub result_id: ResultId,
pub result_buffer: Shared<Option<ResultValue<C>>>,
pub waker: Waker,
}
impl<C: CurveGroup> Debug for ResultWaiter<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("ResultWaiter")
.field("id", &self.result_id)
.finish()
}
}
impl<C: CurveGroup, T: From<ResultValue<C>> + Debug> Future for ResultHandle<C, T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let locked_result = self.result_buffer.read().expect(ERR_RESULT_BUFFER_POISONED);
match locked_result.clone() {
Some(res) => Poll::Ready(res.into()),
None => {
let waiter = ResultWaiter {
result_id: self.id,
result_buffer: self.result_buffer.clone(),
waker: cx.waker().clone(),
};
self.fabric.register_waiter(waiter);
Poll::Pending
}
}
}
}