mpc_bench/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs, unused_imports)]
3
4use comm::{Channels, NetworkDescription};
5use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
6use std::fmt::Debug;
7
8use statistics::{AggregatedStats, Timings};
9
10/// Communication module, allows parties to send and receive messages.
11pub mod comm;
12
13/// Statistics module, allows parties to track timings and bandwidth costs.
14pub mod statistics;
15
16/// A `Party` that takes part in a protocol. The party will receive a unique `id` when it is running the protocol, as well as
17/// communication channels to and from all the other parties. A party keeps track of its own stats.
18pub trait Party {
19    /// The input type of this party. It must be the same for all parties in a given protocol (but it could be e.g. an enum or Option).
20    type Input: Send;
21    /// The output type of this party. It must be the same for all parties in a given protocol (but it could be e.g. an enum or Option)
22    type Output: Debug + Send;
23
24    /// Gets the name of this party. By default, this is 'Party {id}'.
25    fn get_name(&self, id: usize) -> String {
26        format!("Party {}", id)
27    }
28
29    /// Runs the code for this party in the given protocol. The `id` starts from 0.
30    fn run(
31        &mut self,
32        id: usize,
33        n_parties: usize,
34        input: &Self::Input,
35        channels: &mut Channels,
36        timings: &mut Timings,
37    ) -> Self::Output;
38}
39
40/// MPC protocols are described by the `Protocol` trait for a given `Party` type that can be sent accross threads. An implementation should hold the protocol-specific parameters.
41pub trait Protocol
42where
43    Self: Debug,
44{
45    /// The type of the parties participating in the Protocol.
46    type Party: Party + Send;
47
48    /// Sets up `n_parties` according to this parameterization of the Protocol.
49    fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party>;
50
51    /// Generates each party's potentially random input for this parameterization of the Protocol.
52    fn generate_inputs(&self, n_parties: usize) -> Vec<<Self::Party as Party>::Input>;
53
54    /// Validates the outputs of one run of the Protocol. If false, `evaluate` will print a warning.
55    fn validate_outputs(
56        &self,
57        _inputs: &[<Self::Party as Party>::Input],
58        _outputs: &[<Self::Party as Party>::Output],
59    ) -> bool {
60        true
61    }
62
63    /// Evaluates multiple `repetitions` of the protocol with this parameterization of the Protocol.
64    fn evaluate<N: NetworkDescription>(
65        &self,
66        experiment_name: String,
67        n_parties: usize,
68        network_description: &N,
69        repetitions: usize,
70    ) -> AggregatedStats {
71        let mut parties = self.setup_parties(n_parties);
72        debug_assert_eq!(parties.len(), n_parties);
73
74        let mut stats = AggregatedStats::new(
75            experiment_name,
76            parties
77                .iter()
78                .enumerate()
79                .map(|(id, party)| party.get_name(id))
80                .collect(),
81        );
82
83        for _ in 0..repetitions {
84            let mut inputs = self.generate_inputs(n_parties);
85            debug_assert_eq!(inputs.len(), n_parties);
86
87            let mut channels = network_description.instantiate(n_parties);
88            debug_assert_eq!(channels.len(), n_parties);
89
90            let mut party_timings: Vec<Timings> = (0..n_parties).map(|_| Timings::new()).collect();
91
92            let outputs: Vec<_> = parties
93                .par_iter_mut()
94                .enumerate()
95                .zip(inputs.par_iter_mut())
96                .zip(channels.par_iter_mut())
97                .zip(party_timings.par_iter_mut())
98                .map(|((((id, party), input), channel), s)| {
99                    let total_timer = s.create_timer("Total");
100                    let output = party.run(id, n_parties, input, channel, s);
101                    s.stop_timer(total_timer);
102                    output
103                })
104                .collect();
105
106            if !self.validate_outputs(&inputs, &outputs) {
107                #[cfg(feature = "verbose")]
108                println!(
109                    "The outputs are invalid:\n{:?} ...for these parameters:\n{:?}",
110                    outputs, self
111                );
112                // TODO: Mark invalid in stats
113            }
114
115            // TODO: Incorporate communication costs
116            stats.incorporate_party_stats(party_timings);
117        }
118
119        stats
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use std::time::{Duration, Instant};
126
127    use crate::{
128        comm::{Channels, FullMesh},
129        Party, Protocol, Timings,
130    };
131
132    struct ExampleParty;
133
134    impl Party for ExampleParty {
135        type Input = usize;
136        type Output = usize;
137
138        fn run(
139            &mut self,
140            id: usize,
141            n_parties: usize,
142            input: &Self::Input,
143            channels: &mut Channels,
144            stats: &mut Timings,
145        ) -> Self::Output {
146            println!("Hi! I am {}/{}", id, n_parties - 1);
147
148            let sending_timer = stats.create_timer("Sending");
149            for i in (id + 1)..n_parties {
150                channels.send(&vec![id as u8], &i);
151            }
152            stats.stop_timer(sending_timer);
153
154            let receiving_timer = stats.create_timer("Receiving");
155            for j in 0..id {
156                println!(
157                    "I am {}/{} and I received a message from {}",
158                    id,
159                    n_parties - 1,
160                    channels.receive(&j).collect::<Vec<_>>()[0]
161                );
162            }
163            stats.stop_timer(receiving_timer);
164
165            id + input
166        }
167    }
168
169    #[derive(Debug)]
170    struct ExampleProtocol;
171
172    impl Protocol for ExampleProtocol {
173        type Party = ExampleParty;
174
175        fn setup_parties(&self, n_parties: usize) -> Vec<Self::Party> {
176            (0..n_parties).map(|_| ExampleParty).collect()
177        }
178
179        fn generate_inputs(&self, n_parties: usize) -> Vec<usize> {
180            (0..n_parties).map(|_| 10).collect()
181        }
182
183        fn validate_outputs(
184            &self,
185            inputs: &[<Self::Party as Party>::Input],
186            outputs: &[<Self::Party as Party>::Output],
187        ) -> bool {
188            for i in 0..outputs.len() {
189                if outputs[i] != (inputs[i] + i) {
190                    return false;
191                }
192            }
193
194            true
195        }
196    }
197
198    #[test]
199    fn it_works() {
200        let example = ExampleProtocol;
201        let network = FullMesh::new();
202        let stats = example.evaluate("Experiment".to_string(), 5, &network, 1);
203
204        println!("stats: {:?}", stats);
205        // FIXME: All rows are aggregated instead of party-by-party
206        stats.summarize_timings().print();
207
208        //stats.output_party_csv(3, "test.csv");
209    }
210
211    #[test]
212    fn takes_longer() {
213        let example = ExampleProtocol;
214
215        let start = Instant::now();
216        let network = FullMesh::new();
217        let _ = example.evaluate("Experiment".to_string(), 5, &network, 1);
218        let duration_1 = start.elapsed();
219
220        let start = Instant::now();
221        let network = FullMesh::new_with_overhead(Duration::from_secs(1), 1.);
222        let stats = example.evaluate("Experiment (w/ overhead)".to_string(), 5, &network, 1);
223        let duration_2 = start.elapsed();
224
225        assert!(duration_2 > duration_1);
226        assert!(duration_2 > Duration::from_secs(5));
227
228        stats.summarize_timings().print();
229    }
230}