round_based/rounds_router/
simple_store.rs

1//! Simple implementation of `MessagesStore`
2
3use alloc::{vec, vec::Vec};
4use core::iter;
5
6use crate::{Incoming, MessageType, MsgId, PartyIndex};
7
8use super::MessagesStore;
9
10/// Simple implementation of [MessagesStore] that waits for all parties to send a message
11///
12/// Round is considered complete when the store received a message from every party. Note that the
13/// store will ignore all the messages such as `msg.sender == local_party_index`.
14///
15/// Once round is complete, it outputs [`RoundMsgs`].
16///
17/// ## Example
18/// ```rust
19/// # use round_based::rounds_router::{MessagesStore, simple_store::RoundInput};
20/// # use round_based::{Incoming, MessageType};
21/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
22/// let mut input = RoundInput::<&'static str>::broadcast(1, 3);
23/// input.add_message(Incoming{
24///     id: 0,
25///     sender: 0,
26///     msg_type: MessageType::Broadcast,
27///     msg: "first party message",
28/// })?;
29/// input.add_message(Incoming{
30///     id: 1,
31///     sender: 2,
32///     msg_type: MessageType::Broadcast,
33///     msg: "third party message",
34/// })?;
35/// assert!(!input.wants_more());
36///
37/// let output = input.output().unwrap();
38/// assert_eq!(output.clone().into_vec_without_me(), ["first party message", "third party message"]);
39/// assert_eq!(
40///     output.clone().into_vec_including_me("my msg"),
41///     ["first party message", "my msg", "third party message"]
42/// );
43/// # Ok(()) }
44/// ```
45#[derive(Debug, Clone)]
46pub struct RoundInput<M> {
47    i: PartyIndex,
48    n: u16,
49    messages_ids: Vec<MsgId>,
50    messages: Vec<Option<M>>,
51    left_messages: u16,
52    expected_msg_type: MessageType,
53}
54
55/// List of received messages
56#[derive(Debug, Clone)]
57pub struct RoundMsgs<M> {
58    i: PartyIndex,
59    ids: Vec<MsgId>,
60    messages: Vec<M>,
61}
62
63impl<M> RoundInput<M> {
64    /// Constructs new messages store
65    ///
66    /// Takes index of local party `i` and amount of parties `n`
67    ///
68    /// ## Panics
69    /// Panics if `n` is less than 2 or `i` is not in the range `[0; n)`.
70    pub fn new(i: PartyIndex, n: u16, msg_type: MessageType) -> Self {
71        assert!(n >= 2);
72        assert!(i < n);
73
74        Self {
75            i,
76            n,
77            messages_ids: vec![0; usize::from(n) - 1],
78            messages: iter::repeat_with(|| None)
79                .take(usize::from(n) - 1)
80                .collect(),
81            left_messages: n - 1,
82            expected_msg_type: msg_type,
83        }
84    }
85
86    /// Construct a new store for broadcast messages
87    ///
88    /// The same as `RoundInput::new(i, n, MessageType::Broadcast)`
89    pub fn broadcast(i: PartyIndex, n: u16) -> Self {
90        Self::new(i, n, MessageType::Broadcast)
91    }
92
93    /// Construct a new store for p2p messages
94    ///
95    /// The same as `RoundInput::new(i, n, MessageType::P2P)`
96    pub fn p2p(i: PartyIndex, n: u16) -> Self {
97        Self::new(i, n, MessageType::P2P)
98    }
99
100    fn is_expected_type_of_msg(&self, msg_type: MessageType) -> bool {
101        self.expected_msg_type == msg_type
102    }
103}
104
105impl<M> MessagesStore for RoundInput<M>
106where
107    M: 'static,
108{
109    type Msg = M;
110    type Output = RoundMsgs<M>;
111    type Error = RoundInputError;
112
113    fn add_message(&mut self, msg: Incoming<Self::Msg>) -> Result<(), Self::Error> {
114        if !self.is_expected_type_of_msg(msg.msg_type) {
115            return Err(RoundInputError::MismatchedMessageType {
116                msg_id: msg.id,
117                expected: self.expected_msg_type,
118                actual: msg.msg_type,
119            });
120        }
121        if msg.sender == self.i {
122            // Ignore own messages
123            return Ok(());
124        }
125
126        let index = usize::from(if msg.sender < self.i {
127            msg.sender
128        } else {
129            msg.sender - 1
130        });
131
132        match self.messages.get_mut(index) {
133            Some(vacant @ None) => {
134                *vacant = Some(msg.msg);
135                self.messages_ids[index] = msg.id;
136                self.left_messages -= 1;
137                Ok(())
138            }
139            Some(Some(_)) => Err(RoundInputError::AttemptToOverwriteReceivedMsg {
140                msgs_ids: [self.messages_ids[index], msg.id],
141                sender: msg.sender,
142            }),
143            None => Err(RoundInputError::SenderIndexOutOfRange {
144                msg_id: msg.id,
145                sender: msg.sender,
146                n: self.n,
147            }),
148        }
149    }
150
151    fn wants_more(&self) -> bool {
152        self.left_messages > 0
153    }
154
155    fn output(self) -> Result<Self::Output, Self> {
156        if self.left_messages > 0 {
157            Err(self)
158        } else {
159            Ok(RoundMsgs {
160                i: self.i,
161                ids: self.messages_ids,
162                messages: self.messages.into_iter().flatten().collect(),
163            })
164        }
165    }
166}
167
168impl<M> RoundMsgs<M> {
169    /// Returns vec of `n-1` received messages
170    ///
171    /// Messages appear in the list in ascending order of sender index. E.g. for n=4 and local party index i=2,
172    /// the list would look like: `[{msg from i=0}, {msg from i=1}, {msg from i=3}]`.
173    pub fn into_vec_without_me(self) -> Vec<M> {
174        self.messages
175    }
176
177    /// Returns vec of received messages plus party's own message
178    ///
179    /// Similar to `into_vec_without_me`, but inserts `my_msg` at position `i` in resulting list. Thus, i-th
180    /// message in the list was received from i-th party.
181    pub fn into_vec_including_me(mut self, my_msg: M) -> Vec<M> {
182        self.messages.insert(usize::from(self.i), my_msg);
183        self.messages
184    }
185
186    /// Returns iterator over messages
187    pub fn iter(&self) -> impl Iterator<Item = &M> {
188        self.messages.iter()
189    }
190
191    /// Returns iterator over received messages plus party's own message
192    ///
193    /// Similar to [`.iter()`](Self::iter), but inserts `my_msg` at position `i`. Thus, i-th message in the
194    /// iterator is the message received from party `i`.
195    pub fn iter_including_me<'m>(&'m self, my_msg: &'m M) -> impl Iterator<Item = &'m M> {
196        self.messages
197            .iter()
198            .take(usize::from(self.i))
199            .chain(iter::once(my_msg))
200            .chain(self.messages.iter().skip(usize::from(self.i)))
201    }
202
203    /// Returns iterator over received messages plus party's own message
204    pub fn into_iter_including_me(self, my_msg: M) -> impl Iterator<Item = M> {
205        struct InsertsAfter<T, It> {
206            offset: usize,
207            inner: It,
208            item: Option<T>,
209        }
210        impl<T, It: Iterator<Item = T>> Iterator for InsertsAfter<T, It> {
211            type Item = T;
212            fn next(&mut self) -> Option<Self::Item> {
213                if self.offset == 0 {
214                    match self.item.take() {
215                        Some(x) => Some(x),
216                        None => self.inner.next(),
217                    }
218                } else {
219                    self.offset -= 1;
220                    self.inner.next()
221                }
222            }
223        }
224        InsertsAfter {
225            offset: usize::from(self.i),
226            inner: self.messages.into_iter(),
227            item: Some(my_msg),
228        }
229    }
230
231    /// Returns iterator over messages with sender indexes
232    ///
233    /// Iterator yields `(sender_index, msg_id, message)`
234    pub fn into_iter_indexed(self) -> impl Iterator<Item = (PartyIndex, MsgId, M)> {
235        let parties_indexes = (0..self.i).chain(self.i + 1..);
236        parties_indexes
237            .zip(self.ids)
238            .zip(self.messages)
239            .map(|((party_ind, msg_id), msg)| (party_ind, msg_id, msg))
240    }
241
242    /// Returns iterator over messages with sender indexes
243    ///
244    /// Iterator yields `(sender_index, msg_id, &message)`
245    pub fn iter_indexed(&self) -> impl Iterator<Item = (PartyIndex, MsgId, &M)> {
246        let parties_indexes = (0..self.i).chain(self.i + 1..);
247        parties_indexes
248            .zip(&self.ids)
249            .zip(&self.messages)
250            .map(|((party_ind, msg_id), msg)| (party_ind, *msg_id, msg))
251    }
252}
253
254/// Error explaining why `RoundInput` wasn't able to process a message
255#[derive(Debug, thiserror::Error)]
256pub enum RoundInputError {
257    /// Party sent two messages in one round
258    ///
259    /// `msgs_ids` are ids of conflicting messages
260    #[error("party {sender} tried to overwrite message")]
261    AttemptToOverwriteReceivedMsg {
262        /// IDs of conflicting messages
263        msgs_ids: [MsgId; 2],
264        /// Index of party who sent two messages in one round
265        sender: PartyIndex,
266    },
267    /// Unknown sender
268    ///
269    /// This error is thrown when index of sender is not in `[0; n)` where `n` is number of
270    /// parties involved in the protocol (provided in [`RoundInput::new`])
271    #[error("sender index is out of range: sender={sender}, n={n}")]
272    SenderIndexOutOfRange {
273        /// Message ID
274        msg_id: MsgId,
275        /// Sender index
276        sender: PartyIndex,
277        /// Number of parties
278        n: u16,
279    },
280    /// Received message type doesn't match expectations
281    ///
282    /// For instance, this error is returned when it's expected to receive broadcast message,
283    /// but party sent p2p message instead (which is rough protocol violation).
284    #[error("expected message {expected:?}, got {actual:?}")]
285    MismatchedMessageType {
286        /// Message ID
287        msg_id: MsgId,
288        /// Expected type of message
289        expected: MessageType,
290        /// Actual type of message
291        actual: MessageType,
292    },
293}
294
295#[cfg(test)]
296mod tests {
297    use alloc::vec::Vec;
298    use matches::assert_matches;
299
300    use crate::rounds_router::store::MessagesStore;
301    use crate::{Incoming, MessageType};
302
303    use super::{RoundInput, RoundInputError};
304
305    #[derive(Debug, Clone, PartialEq)]
306    pub struct Msg(u16);
307
308    #[test]
309    fn store_outputs_received_messages() {
310        let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
311
312        let msgs = (0..5)
313            .map(|s| Incoming {
314                id: s.into(),
315                sender: s,
316                msg_type: MessageType::P2P,
317                msg: Msg(10 + s),
318            })
319            .filter(|incoming| incoming.sender != 3)
320            .collect::<Vec<_>>();
321
322        for msg in &msgs {
323            assert!(store.wants_more());
324            store.add_message(msg.clone()).unwrap();
325        }
326
327        assert!(!store.wants_more());
328        let received = store.output().unwrap();
329
330        // without me
331        let msgs: Vec<_> = msgs.into_iter().map(|msg| msg.msg).collect();
332        assert_eq!(received.clone().into_vec_without_me(), msgs);
333
334        // including me
335        let received = received.into_vec_including_me(Msg(13));
336        assert_eq!(received[0..3], msgs[0..3]);
337        assert_eq!(received[3], Msg(13));
338        assert_eq!(received[4..5], msgs[3..4]);
339    }
340
341    #[test]
342    fn store_returns_error_if_sender_index_is_out_of_range() {
343        let mut store = RoundInput::new(3, 5, MessageType::P2P);
344        let error = store
345            .add_message(Incoming {
346                id: 0,
347                sender: 5,
348                msg_type: MessageType::P2P,
349                msg: Msg(123),
350            })
351            .unwrap_err();
352        assert_matches!(
353            error,
354            RoundInputError::SenderIndexOutOfRange { msg_id, sender, n } if msg_id == 0 && sender == 5 && n == 5
355        );
356    }
357
358    #[test]
359    fn store_returns_error_if_incoming_msg_overwrites_already_received_one() {
360        let mut store = RoundInput::new(0, 3, MessageType::P2P);
361        store
362            .add_message(Incoming {
363                id: 0,
364                sender: 1,
365                msg_type: MessageType::P2P,
366                msg: Msg(11),
367            })
368            .unwrap();
369        let error = store
370            .add_message(Incoming {
371                id: 1,
372                sender: 1,
373                msg_type: MessageType::P2P,
374                msg: Msg(112),
375            })
376            .unwrap_err();
377        assert_matches!(error, RoundInputError::AttemptToOverwriteReceivedMsg { msgs_ids, sender } if msgs_ids[0] == 0 && msgs_ids[1] == 1 && sender == 1);
378        store
379            .add_message(Incoming {
380                id: 2,
381                sender: 2,
382                msg_type: MessageType::P2P,
383                msg: Msg(22),
384            })
385            .unwrap();
386
387        let output = store.output().unwrap().into_vec_without_me();
388        assert_eq!(output, [Msg(11), Msg(22)]);
389    }
390
391    #[test]
392    fn store_returns_error_if_tried_to_output_before_receiving_enough_messages() {
393        let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
394
395        let msgs = (0..5)
396            .map(|s| Incoming {
397                id: s.into(),
398                sender: s,
399                msg_type: MessageType::P2P,
400                msg: Msg(10 + s),
401            })
402            .filter(|incoming| incoming.sender != 3);
403
404        for msg in msgs {
405            assert!(store.wants_more());
406            store = store.output().unwrap_err();
407
408            store.add_message(msg).unwrap();
409        }
410
411        let _ = store.output().unwrap();
412    }
413
414    #[test]
415    fn store_returns_error_if_message_type_mismatched() {
416        let mut store = RoundInput::<Msg>::p2p(3, 5);
417        let err = store
418            .add_message(Incoming {
419                id: 0,
420                sender: 0,
421                msg_type: MessageType::Broadcast,
422                msg: Msg(1),
423            })
424            .unwrap_err();
425        assert_matches!(
426            err,
427            RoundInputError::MismatchedMessageType {
428                msg_id: 0,
429                expected: MessageType::P2P,
430                actual: MessageType::Broadcast
431            }
432        );
433
434        let mut store = RoundInput::<Msg>::broadcast(3, 5);
435        let err = store
436            .add_message(Incoming {
437                id: 0,
438                sender: 0,
439                msg_type: MessageType::P2P,
440                msg: Msg(1),
441            })
442            .unwrap_err();
443        assert_matches!(
444            err,
445            RoundInputError::MismatchedMessageType {
446                msg_id: 0,
447                expected: MessageType::Broadcast,
448                actual: MessageType::P2P,
449            }
450        );
451        for sender in 0u16..5 {
452            store
453                .add_message(Incoming {
454                    id: 0,
455                    sender,
456                    msg_type: MessageType::Broadcast,
457                    msg: Msg(1),
458                })
459                .unwrap();
460        }
461
462        let mut store = RoundInput::<Msg>::broadcast(3, 5);
463        let err = store
464            .add_message(Incoming {
465                id: 0,
466                sender: 0,
467                msg_type: MessageType::P2P,
468                msg: Msg(1),
469            })
470            .unwrap_err();
471        assert_matches!(
472            err,
473            RoundInputError::MismatchedMessageType {
474                msg_id: 0,
475                expected: MessageType::Broadcast,
476                actual,
477            } if actual == MessageType::P2P
478        );
479        store
480            .add_message(Incoming {
481                id: 0,
482                sender: 0,
483                msg_type: MessageType::Broadcast,
484                msg: Msg(1),
485            })
486            .unwrap();
487    }
488
489    #[test]
490    fn into_iter_including_me() {
491        let me = -10_isize;
492        let messages = alloc::vec![1, 2, 3];
493
494        let me_first = super::RoundMsgs {
495            i: 0,
496            ids: alloc::vec![1, 2, 3],
497            messages: messages.clone(),
498        };
499        let all = me_first.into_iter_including_me(me).collect::<Vec<_>>();
500        assert_eq!(all, [-10, 1, 2, 3]);
501
502        let me_second = super::RoundMsgs {
503            i: 1,
504            ids: alloc::vec![0, 2, 3],
505            messages: messages.clone(),
506        };
507        let all = me_second.into_iter_including_me(me).collect::<Vec<_>>();
508        assert_eq!(all, [1, -10, 2, 3]);
509
510        let me_last = super::RoundMsgs {
511            i: 3,
512            ids: alloc::vec![0, 1, 2],
513            messages: messages.clone(),
514        };
515        let all = me_last.into_iter_including_me(me).collect::<Vec<_>>();
516        assert_eq!(all, [1, 2, 3, -10]);
517    }
518}