#![doc = include_str!("../README.md")]
#![warn(missing_docs, unused_imports)]
use comm::{Channels, NetworkDescription};
use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use std::fmt::Debug;
use statistics::{AggregatedStats, Timings};
pub mod comm;
pub mod statistics;
pub trait Party {
type Input: Send;
type Output: Debug + Send;
fn get_name(&self, id: usize) -> String {
format!("Party {}", id)
}
fn run(
&mut self,
id: usize,
n_parties: usize,
input: &Self::Input,
channels: &mut Channels,
timings: &mut Timings,
) -> Self::Output;
}
pub trait Protocol
where
Self: Debug,
{
type Party: Party + Send;
fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party>;
fn generate_inputs(&self, n_parties: usize) -> Vec<<Self::Party as Party>::Input>;
fn validate_outputs(
&self,
_inputs: &[<Self::Party as Party>::Input],
_outputs: &[<Self::Party as Party>::Output],
) -> bool {
true
}
fn evaluate<N: NetworkDescription>(
&self,
experiment_name: String,
n_parties: usize,
network_description: &N,
repetitions: usize,
) -> AggregatedStats {
let mut parties = self.setup_parties(n_parties);
debug_assert_eq!(parties.len(), n_parties);
let mut stats = AggregatedStats::new(
experiment_name,
parties
.iter()
.enumerate()
.map(|(id, party)| party.get_name(id))
.collect(),
);
for _ in 0..repetitions {
let mut inputs = self.generate_inputs(n_parties);
debug_assert_eq!(inputs.len(), n_parties);
let mut channels = network_description.instantiate(n_parties);
debug_assert_eq!(channels.len(), n_parties);
let mut party_timings: Vec<Timings> = (0..n_parties).map(|_| Timings::new()).collect();
let outputs: Vec<_> = parties
.par_iter_mut()
.enumerate()
.zip(inputs.par_iter_mut())
.zip(channels.par_iter_mut())
.zip(party_timings.par_iter_mut())
.map(|((((id, party), input), channel), s)| {
let total_timer = s.create_timer("Total");
let output = party.run(id, n_parties, input, channel, s);
s.stop_timer(total_timer);
output
})
.collect();
if !self.validate_outputs(&inputs, &outputs) {
#[cfg(feature = "verbose")]
println!(
"The outputs are invalid:\n{:?} ...for these parameters:\n{:?}",
outputs, self
);
}
stats.incorporate_party_stats(party_timings);
}
stats
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use crate::{
comm::{Channels, FullMesh},
Party, Protocol, Timings,
};
struct ExampleParty;
impl Party for ExampleParty {
type Input = usize;
type Output = usize;
fn run(
&mut self,
id: usize,
n_parties: usize,
input: &Self::Input,
channels: &mut Channels,
stats: &mut Timings,
) -> Self::Output {
println!("Hi! I am {}/{}", id, n_parties - 1);
let sending_timer = stats.create_timer("Sending");
for i in (id + 1)..n_parties {
channels.send(&vec![id as u8], &i);
}
stats.stop_timer(sending_timer);
let receiving_timer = stats.create_timer("Receiving");
for j in 0..id {
println!(
"I am {}/{} and I received a message from {}",
id,
n_parties - 1,
channels.receive(&j).collect::<Vec<_>>()[0]
);
}
stats.stop_timer(receiving_timer);
id + input
}
}
#[derive(Debug)]
struct ExampleProtocol;
impl Protocol for ExampleProtocol {
type Party = ExampleParty;
fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party> {
(0..n_parties).map(|_| ExampleParty).collect()
}
fn generate_inputs(&self, n_parties: usize) -> Vec<usize> {
(0..n_parties).map(|_| 10).collect()
}
fn validate_outputs(
&self,
inputs: &[<Self::Party as Party>::Input],
outputs: &[<Self::Party as Party>::Output],
) -> bool {
for i in 0..outputs.len() {
if outputs[i] != (inputs[i] + i) {
return false;
}
}
true
}
}
#[test]
fn it_works() {
let example = ExampleProtocol;
let network = FullMesh::new();
let stats = example.evaluate("Experiment".to_string(), 5, &network, 1);
println!("stats: {:?}", stats);
stats.summarize_timings().print();
}
#[test]
fn takes_longer() {
let example = ExampleProtocol;
let start = Instant::now();
let network = FullMesh::new();
let _ = example.evaluate("Experiment".to_string(), 5, &network, 1);
let duration_1 = start.elapsed();
let start = Instant::now();
let network = FullMesh::new_with_overhead(Duration::from_secs(1), 1.);
let stats = example.evaluate("Experiment (w/ overhead)".to_string(), 5, &network, 1);
let duration_2 = start.elapsed();
assert!(duration_2 > duration_1);
assert!(duration_2 > Duration::from_secs(5));
stats.summarize_timings().print();
}
}