use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec};
use core::future::Future;
use crate::{state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing};
#[cfg(feature = "sim-async")]
pub mod async_env;
pub struct SimResult<T>(pub Vec<T>);
impl<T, E> SimResult<Result<T, E>>
where
E: core::fmt::Debug,
{
pub fn expect_ok(self) -> SimResult<T> {
let mut oks = Vec::with_capacity(self.0.len());
let mut errs = Vec::with_capacity(self.0.len());
for (res, i) in self.0.into_iter().zip(0u16..) {
match res {
Ok(res) => oks.push(res),
Err(res) => errs.push((i, res)),
}
}
if !errs.is_empty() {
let mut msg = alloc::format!(
"Simulation output didn't match expectations.\n\
Expected: all parties succeed\n\
Actual : {success} parties succeeded, {failed} parties returned an error\n\
Failures:\n",
success = oks.len(),
failed = errs.len(),
);
for (i, err) in errs {
msg += &alloc::format!("- Party {i}: {err:?}\n");
}
panic!("{msg}");
}
SimResult(oks)
}
}
impl<T> SimResult<T>
where
T: PartialEq + core::fmt::Debug,
{
pub fn expect_eq(mut self) -> T {
let Some(first) = self.0.first() else {
panic!("simulation contained zero parties");
};
if !self.0[1..].iter().all(|i| i == first) {
let mut msg = alloc::string::String::from(
"Simulation output didn't match expectations.\n\
Expected: all parties return the same output\n\
Actual : some of the parties returned a different output\n\
Outputs :\n",
);
let mut clusters: Vec<(&T, Vec<usize>)> = Vec::new();
for (i, value) in self.0.iter().enumerate() {
match clusters
.iter_mut()
.find(|(cluster_value, _)| *cluster_value == value)
.map(|(_, indexes)| indexes)
{
Some(indexes) => indexes.push(i),
None => clusters.push((value, alloc::vec![i])),
}
}
for (value, parties) in &clusters {
if parties.len() == 1 {
msg += "- Party ";
} else {
msg += "- Parties "
}
for (i, is_first) in parties
.iter()
.zip(core::iter::once(true).chain(core::iter::repeat(false)))
{
if !is_first {
msg += ", "
}
msg += &i.to_string();
}
msg += &alloc::format!(": {value:?}\n");
}
panic!("{msg}")
}
self.0
.pop()
.expect("we checked that the list contains at least one element")
}
}
impl<T> SimResult<T> {
pub fn into_vec(self) -> Vec<T> {
self.0
}
}
impl<T> IntoIterator for SimResult<T> {
type Item = T;
type IntoIter = alloc::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<T> core::ops::Deref for SimResult<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> From<Vec<T>> for SimResult<T> {
fn from(list: Vec<T>) -> Self {
Self(list)
}
}
impl<T> From<SimResult<T>> for Vec<T> {
fn from(res: SimResult<T>) -> Self {
res.0
}
}
pub struct Simulation<'a, O, M> {
parties: Vec<Party<'a, O, M>>,
}
enum Party<'a, O, M> {
Active {
party: Box<dyn crate::state_machine::StateMachine<Output = O, Msg = M> + 'a>,
wants_one_more_msg: bool,
},
Finished(O),
}
impl<'a, O, M> Simulation<'a, O, M>
where
M: Clone + 'static,
{
pub fn empty() -> Self {
Self {
parties: Vec::new(),
}
}
pub fn with_capacity(n: u16) -> Self {
Self {
parties: Vec::with_capacity(n.into()),
}
}
pub fn from_async_fn<F>(
n: u16,
mut init: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
) -> Self
where
F: core::future::Future<Output = O> + 'a,
{
let mut sim = Self::with_capacity(n);
for i in 0..n {
sim.add_async_party(|party| init(i, party))
}
sim
}
pub fn from_fn<S>(n: u16, mut init: impl FnMut(u16) -> S) -> Self
where
S: crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
{
let mut sim = Self::with_capacity(n);
for i in 0..n {
sim.add_party(init(i));
}
sim
}
pub fn add_party(
&mut self,
party: impl crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
) {
self.parties.push(Party::Active {
party: Box::new(party),
wants_one_more_msg: false,
})
}
pub fn add_async_party<F>(&mut self, party: impl FnOnce(crate::state_machine::MpcParty<M>) -> F)
where
F: core::future::Future<Output = O> + 'a,
{
self.parties.push(Party::Active {
party: Box::new(crate::state_machine::wrap_protocol(party)),
wants_one_more_msg: false,
})
}
pub fn parties_amount(&self) -> usize {
self.parties.len()
}
pub fn run(mut self) -> Result<SimResult<O>, SimError> {
let mut messages_queue = MessagesQueue::new(self.parties.len());
let mut parties_left = self.parties.len();
while parties_left > 0 {
'next_party: for (i, party_state) in (0..).zip(&mut self.parties) {
'this_party: loop {
let Party::Active {
party,
wants_one_more_msg,
} = party_state
else {
continue 'next_party;
};
if *wants_one_more_msg {
if let Some(message) = messages_queue.recv_next_msg(i) {
party
.received_msg(message)
.map_err(|_| Reason::SaveIncomingMsg)?;
*wants_one_more_msg = false;
} else {
continue 'next_party;
}
}
match party.proceed() {
ProceedResult::SendMsg(msg) => {
messages_queue.send_message(i, msg)?;
continue 'this_party;
}
ProceedResult::NeedsOneMoreMessage => {
*wants_one_more_msg = true;
continue 'this_party;
}
ProceedResult::Output(out) => {
*party_state = Party::Finished(out);
parties_left -= 1;
continue 'next_party;
}
ProceedResult::Yielded => {
continue 'this_party;
}
ProceedResult::Error(err) => {
return Err(Reason::ExecutionError(err).into());
}
}
}
}
}
Ok(SimResult(
self.parties
.into_iter()
.map(|party| match party {
Party::Active { .. } => {
unreachable!("there must be no active parties when `parties_left == 0`")
}
Party::Finished(out) => out,
})
.collect(),
))
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct SimError(#[from] Reason);
#[derive(Debug, thiserror::Error)]
enum Reason {
#[error("save incoming message")]
SaveIncomingMsg,
#[error("execution error")]
ExecutionError(#[source] crate::state_machine::ExecutionError),
#[error("party #{sender} tried to send a message to non existing party #{recipient}")]
UnknownRecipient { sender: u16, recipient: u16 },
}
struct MessagesQueue<M> {
queue: Vec<VecDeque<Incoming<M>>>,
next_id: u64,
}
impl<M: Clone> MessagesQueue<M> {
fn new(n: usize) -> Self {
Self {
queue: alloc::vec![VecDeque::new(); n],
next_id: 0,
}
}
fn send_message(&mut self, sender: u16, msg: Outgoing<M>) -> Result<(), SimError> {
match msg.recipient {
MessageDestination::AllParties => {
let mut msg_ids = self.next_id..;
for (destination, msg_id) in (0..)
.zip(&mut self.queue)
.filter(|(recipient_index, _)| *recipient_index != sender)
.map(|(_, msg)| msg)
.zip(msg_ids.by_ref())
{
destination.push_back(Incoming {
id: msg_id,
sender,
msg_type: MessageType::Broadcast,
msg: msg.msg.clone(),
})
}
self.next_id = msg_ids.next().unwrap();
}
MessageDestination::OneParty(destination) => {
let next_id = self.next_id;
self.next_id += 1;
self.queue
.get_mut(usize::from(destination))
.ok_or(Reason::UnknownRecipient {
sender,
recipient: destination,
})?
.push_back(Incoming {
id: next_id,
sender,
msg_type: MessageType::P2P,
msg: msg.msg,
})
}
}
Ok(())
}
fn recv_next_msg(&mut self, recipient: u16) -> Option<Incoming<M>> {
self.queue[usize::from(recipient)].pop_front()
}
}
pub fn run<M, F>(
n: u16,
mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
) -> Result<SimResult<F::Output>, SimError>
where
M: Clone + 'static,
F: Future,
{
run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| {
party_start(i, party)
})
}
pub fn run_with_setup<S, M, F>(
setups: impl IntoIterator<Item = S>,
mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>, S) -> F,
) -> Result<SimResult<F::Output>, SimError>
where
M: Clone + 'static,
F: Future,
{
let mut sim = Simulation::empty();
for (setup, i) in setups.into_iter().zip(0u16..) {
let party = crate::state_machine::wrap_protocol(|party| party_start(i, party, setup));
sim.add_party(party);
}
sim.run()
}
#[cfg(test)]
mod tests {
mod expect_eq {
use crate::sim::SimResult;
#[test]
fn all_eq() {
let res = SimResult::from(alloc::vec!["same string", "same string", "same string"])
.expect_eq();
assert_eq!(res, "same string")
}
#[test]
#[should_panic]
fn empty_res() {
SimResult::from(alloc::vec![]).expect_eq()
}
#[test]
#[should_panic]
fn not_eq() {
SimResult::from(alloc::vec![
"one result",
"one result",
"another result",
"one result",
"and something else",
])
.expect_eq();
}
}
mod expect_ok {
use crate::sim::SimResult;
#[test]
fn all_ok() {
let res = SimResult::<Result<i32, core::convert::Infallible>>::from(alloc::vec![
Ok(0),
Ok(1),
Ok(2)
])
.expect_ok()
.into_vec();
assert_eq!(res, [0, 1, 2]);
}
#[test]
#[should_panic]
fn not_ok() {
SimResult::from(alloc::vec![
Ok(0),
Err("i couldn't do what you asked :("),
Ok(2),
Ok(3),
Err("sorry I was pooping, what did you want?")
])
.expect_ok();
}
}
}