use core::fmt;
use std::{collections::HashMap, error};
use ::serde::Serialize;
use crate::compat::CSCurve;
#[derive(Debug)]
pub enum ProtocolError {
AssertionFailed(String),
Other(Box<dyn error::Error + Send + Sync>),
}
impl fmt::Display for ProtocolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProtocolError::Other(e) => write!(f, "{}", e),
ProtocolError::AssertionFailed(e) => write!(f, "assertion failed {}", e),
}
}
}
impl error::Error for ProtocolError {}
impl From<Box<dyn error::Error + Send + Sync>> for ProtocolError {
fn from(e: Box<dyn error::Error + Send + Sync>) -> Self {
Self::Other(e)
}
}
#[derive(Debug)]
pub enum InitializationError {
BadParameters(String),
}
impl fmt::Display for InitializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InitializationError::BadParameters(s) => write!(f, "bad parameters: {}", s),
}
}
}
impl error::Error for InitializationError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Hash)]
pub struct Participant(u32);
impl Participant {
pub fn bytes(&self) -> [u8; 4] {
self.0.to_le_bytes()
}
pub fn scalar<C: CSCurve>(&self) -> C::Scalar {
C::Scalar::from(self.0 as u64 + 1)
}
}
impl From<Participant> for u32 {
fn from(p: Participant) -> Self {
p.0
}
}
impl From<u32> for Participant {
fn from(x: u32) -> Self {
Participant(x)
}
}
pub type MessageData = Vec<u8>;
#[derive(Debug, Clone)]
pub enum Action<T> {
Wait,
SendMany(MessageData),
SendPrivate(Participant, MessageData),
Return(T),
}
pub trait Protocol {
type Output;
fn poke(&mut self) -> Result<Action<Self::Output>, ProtocolError>;
fn message(&mut self, from: Participant, data: MessageData);
}
pub fn run_protocol<T>(
mut ps: Vec<(Participant, Box<dyn Protocol<Output = T>>)>,
) -> Result<Vec<(Participant, T)>, ProtocolError> {
let indices: HashMap<Participant, usize> =
ps.iter().enumerate().map(|(i, (p, _))| (*p, i)).collect();
let size = ps.len();
let mut out = Vec::with_capacity(size);
while out.len() < size {
for i in 0..size {
while {
let action = ps[i].1.poke()?;
match action {
Action::Wait => false,
Action::SendMany(m) => {
for j in 0..size {
if i == j {
continue;
}
let from = ps[i].0;
ps[j].1.message(from, m.clone());
}
true
}
Action::SendPrivate(to, m) => {
let from = ps[i].0;
ps[indices[&to]].1.message(from, m);
true
}
Action::Return(r) => {
out.push((ps[i].0, r));
false
}
}
} {}
}
}
Ok(out)
}
pub(crate) fn run_two_party_protocol<T0: std::fmt::Debug, T1: std::fmt::Debug>(
p0: Participant,
p1: Participant,
prot0: &mut dyn Protocol<Output = T0>,
prot1: &mut dyn Protocol<Output = T1>,
) -> Result<(T0, T1), ProtocolError> {
let mut active0 = true;
let mut out0 = None;
let mut out1 = None;
while out0.is_none() || out1.is_none() {
if active0 {
let action = prot0.poke()?;
match action {
Action::Wait => active0 = false,
Action::SendMany(m) => prot1.message(p0, m),
Action::SendPrivate(to, m) if to == p1 => {
prot1.message(p0, m);
}
Action::Return(out) => out0 = Some(out),
_ => {}
}
} else {
let action = prot1.poke()?;
match action {
Action::Wait => active0 = true,
Action::SendMany(m) => prot0.message(p1, m),
Action::SendPrivate(to, m) if to == p0 => {
prot0.message(p1, m);
}
Action::Return(out) => out1 = Some(out),
_ => {}
}
}
}
Ok((out0.unwrap(), out1.unwrap()))
}
pub(crate) mod internal;