1#![doc = include_str!("../README.md")]
2#![warn(missing_docs, unused_imports)]
3
4use comm::{Channels, NetworkDescription};
5use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
6use std::fmt::Debug;
7
8use statistics::{AggregatedStats, Timings};
9
10pub mod comm;
12
13pub mod statistics;
15
16pub trait Party {
19 type Input: Send;
21 type Output: Debug + Send;
23
24 fn get_name(&self, id: usize) -> String {
26 format!("Party {}", id)
27 }
28
29 fn run(
31 &mut self,
32 id: usize,
33 n_parties: usize,
34 input: &Self::Input,
35 channels: &mut Channels,
36 timings: &mut Timings,
37 ) -> Self::Output;
38}
39
40pub trait Protocol
42where
43 Self: Debug,
44{
45 type Party: Party + Send;
47
48 fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party>;
50
51 fn generate_inputs(&self, n_parties: usize) -> Vec<<Self::Party as Party>::Input>;
53
54 fn validate_outputs(
56 &self,
57 _inputs: &[<Self::Party as Party>::Input],
58 _outputs: &[<Self::Party as Party>::Output],
59 ) -> bool {
60 true
61 }
62
63 fn evaluate<N: NetworkDescription>(
65 &self,
66 experiment_name: String,
67 n_parties: usize,
68 network_description: &N,
69 repetitions: usize,
70 ) -> AggregatedStats {
71 let mut parties = self.setup_parties(n_parties);
72 debug_assert_eq!(parties.len(), n_parties);
73
74 let mut stats = AggregatedStats::new(
75 experiment_name,
76 parties
77 .iter()
78 .enumerate()
79 .map(|(id, party)| party.get_name(id))
80 .collect(),
81 );
82
83 for _ in 0..repetitions {
84 let mut inputs = self.generate_inputs(n_parties);
85 debug_assert_eq!(inputs.len(), n_parties);
86
87 let mut channels = network_description.instantiate(n_parties);
88 debug_assert_eq!(channels.len(), n_parties);
89
90 let mut party_timings: Vec<Timings> = (0..n_parties).map(|_| Timings::new()).collect();
91
92 let outputs: Vec<_> = parties
93 .par_iter_mut()
94 .enumerate()
95 .zip(inputs.par_iter_mut())
96 .zip(channels.par_iter_mut())
97 .zip(party_timings.par_iter_mut())
98 .map(|((((id, party), input), channel), s)| {
99 let total_timer = s.create_timer("Total");
100 let output = party.run(id, n_parties, input, channel, s);
101 s.stop_timer(total_timer);
102 output
103 })
104 .collect();
105
106 if !self.validate_outputs(&inputs, &outputs) {
107 #[cfg(feature = "verbose")]
108 println!(
109 "The outputs are invalid:\n{:?} ...for these parameters:\n{:?}",
110 outputs, self
111 );
112 }
114
115 stats.incorporate_party_stats(party_timings);
117 }
118
119 stats
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use std::time::{Duration, Instant};
126
127 use crate::{
128 comm::{Channels, FullMesh},
129 Party, Protocol, Timings,
130 };
131
132 struct ExampleParty;
133
134 impl Party for ExampleParty {
135 type Input = usize;
136 type Output = usize;
137
138 fn run(
139 &mut self,
140 id: usize,
141 n_parties: usize,
142 input: &Self::Input,
143 channels: &mut Channels,
144 stats: &mut Timings,
145 ) -> Self::Output {
146 println!("Hi! I am {}/{}", id, n_parties - 1);
147
148 let sending_timer = stats.create_timer("Sending");
149 for i in (id + 1)..n_parties {
150 channels.send(&vec![id as u8], &i);
151 }
152 stats.stop_timer(sending_timer);
153
154 let receiving_timer = stats.create_timer("Receiving");
155 for j in 0..id {
156 println!(
157 "I am {}/{} and I received a message from {}",
158 id,
159 n_parties - 1,
160 channels.receive(&j).collect::<Vec<_>>()[0]
161 );
162 }
163 stats.stop_timer(receiving_timer);
164
165 id + input
166 }
167 }
168
169 #[derive(Debug)]
170 struct ExampleProtocol;
171
172 impl Protocol for ExampleProtocol {
173 type Party = ExampleParty;
174
175 fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party> {
176 (0..n_parties).map(|_| ExampleParty).collect()
177 }
178
179 fn generate_inputs(&self, n_parties: usize) -> Vec<usize> {
180 (0..n_parties).map(|_| 10).collect()
181 }
182
183 fn validate_outputs(
184 &self,
185 inputs: &[<Self::Party as Party>::Input],
186 outputs: &[<Self::Party as Party>::Output],
187 ) -> bool {
188 for i in 0..outputs.len() {
189 if outputs[i] != (inputs[i] + i) {
190 return false;
191 }
192 }
193
194 true
195 }
196 }
197
198 #[test]
199 fn it_works() {
200 let example = ExampleProtocol;
201 let network = FullMesh::new();
202 let stats = example.evaluate("Experiment".to_string(), 5, &network, 1);
203
204 println!("stats: {:?}", stats);
205 stats.summarize_timings().print();
207
208 }
210
211 #[test]
212 fn takes_longer() {
213 let example = ExampleProtocol;
214
215 let start = Instant::now();
216 let network = FullMesh::new();
217 let _ = example.evaluate("Experiment".to_string(), 5, &network, 1);
218 let duration_1 = start.elapsed();
219
220 let start = Instant::now();
221 let network = FullMesh::new_with_overhead(Duration::from_secs(1), 1.);
222 let stats = example.evaluate("Experiment (w/ overhead)".to_string(), 5, &network, 1);
223 let duration_2 = start.elapsed();
224
225 assert!(duration_2 > duration_1);
226 assert!(duration_2 > Duration::from_secs(5));
227
228 stats.summarize_timings().print();
229 }
230}