round_based/round/
simple_store.rs

1//! Simple implementation of [`RoundStore`]
2
3use alloc::{vec, vec::Vec};
4use core::iter;
5
6use crate::{Incoming, MessageType, MsgId, PartyIndex};
7
8use super::{RoundInfo, RoundStore};
9
10/// Simple implementation of [`RoundStore`] that waits for all parties to send a message
11///
12/// Round is considered complete when the store received a message from every party. Note that the
13/// store will ignore all the messages such as `msg.sender == local_party_index`.
14///
15/// Once round is complete, it outputs [`RoundMsgs`].
16///
17/// ## Example
18/// ```rust
19/// use round_based::{Incoming, MessageType};
20/// use round_based::round::{RoundStore, RoundInput};
21///
22/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
23/// let mut input = RoundInput::<&'static str>::broadcast(1, 3);
24/// input.add_message(Incoming{
25///     id: 0,
26///     sender: 0,
27///     msg_type: MessageType::Broadcast { reliable: false },
28///     msg: "first party message",
29/// })?;
30/// input.add_message(Incoming{
31///     id: 1,
32///     sender: 2,
33///     msg_type: MessageType::Broadcast { reliable: false },
34///     msg: "third party message",
35/// })?;
36/// assert!(!input.wants_more());
37///
38/// let output = input.output().unwrap();
39/// assert_eq!(
40///     output.clone().into_vec_without_me(),
41///     ["first party message", "third party message"]
42/// );
43/// assert_eq!(
44///     output.into_vec_including_me("my msg"),
45///     ["first party message", "my msg", "third party message"]
46/// );
47/// # Ok(()) }
48/// ```
49#[derive(Debug, Clone)]
50pub struct RoundInput<M> {
51    i: PartyIndex,
52    n: u16,
53    messages_ids: Vec<MsgId>,
54    messages: Vec<Option<M>>,
55    left_messages: u16,
56    expected_msg_type: MessageType,
57}
58
59/// List of received messages
60#[derive(Debug, Clone)]
61pub struct RoundMsgs<M> {
62    i: PartyIndex,
63    ids: Vec<MsgId>,
64    messages: Vec<M>,
65}
66
67impl<M> RoundInput<M> {
68    /// Constructs new messages store
69    ///
70    /// Takes index of local party `i` and amount of parties `n`
71    ///
72    /// ## Panics
73    /// Panics if `n` is less than 2 or `i` is not in the range `[0; n)`.
74    pub fn new(i: PartyIndex, n: u16, msg_type: MessageType) -> Self {
75        assert!(n >= 2);
76        assert!(i < n);
77
78        Self {
79            i,
80            n,
81            messages_ids: vec![0; usize::from(n) - 1],
82            messages: iter::repeat_with(|| None)
83                .take(usize::from(n) - 1)
84                .collect(),
85            left_messages: n - 1,
86            expected_msg_type: msg_type,
87        }
88    }
89
90    /// Construct a new store for broadcast messages
91    ///
92    /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: false })`
93    pub fn broadcast(i: PartyIndex, n: u16) -> Self {
94        Self::new(i, n, MessageType::Broadcast { reliable: false })
95    }
96
97    /// Construct a new store for reliable broadcast messages
98    ///
99    /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: true })`
100    pub fn reliable_broadcast(i: PartyIndex, n: u16) -> Self {
101        Self::new(i, n, MessageType::Broadcast { reliable: true })
102    }
103
104    /// Construct a new store for p2p messages
105    ///
106    /// The same as `RoundInput::new(i, n, MessageType::P2P)`
107    pub fn p2p(i: PartyIndex, n: u16) -> Self {
108        Self::new(i, n, MessageType::P2P)
109    }
110
111    fn is_expected_type_of_msg(&self, actual_msg_type: MessageType) -> bool {
112        matches!(
113            (self.expected_msg_type, actual_msg_type),
114            (MessageType::P2P, MessageType::P2P)
115                | (
116                    MessageType::Broadcast { reliable: false },
117                    MessageType::Broadcast { .. }
118                )
119                | (
120                    MessageType::Broadcast { reliable: true },
121                    MessageType::Broadcast { reliable: true },
122                )
123        )
124    }
125}
126
127impl<M> RoundInfo for RoundInput<M>
128where
129    M: 'static,
130{
131    type Msg = M;
132    type Output = RoundMsgs<M>;
133    type Error = RoundInputError;
134}
135impl<M> RoundStore for RoundInput<M>
136where
137    M: 'static,
138{
139    fn add_message(&mut self, msg: Incoming<Self::Msg>) -> Result<(), Self::Error> {
140        if !self.is_expected_type_of_msg(msg.msg_type) {
141            return Err(RoundInputError::MismatchedMessageType {
142                msg_id: msg.id,
143                expected: self.expected_msg_type,
144                actual: msg.msg_type,
145            });
146        }
147        if msg.sender == self.i {
148            // Ignore own messages
149            return Ok(());
150        }
151
152        let index = usize::from(if msg.sender < self.i {
153            msg.sender
154        } else {
155            msg.sender - 1
156        });
157
158        match self.messages.get_mut(index) {
159            Some(vacant @ None) => {
160                *vacant = Some(msg.msg);
161                self.messages_ids[index] = msg.id;
162                self.left_messages -= 1;
163                Ok(())
164            }
165            Some(Some(_)) => Err(RoundInputError::AttemptToOverwriteReceivedMsg {
166                msgs_ids: [self.messages_ids[index], msg.id],
167                sender: msg.sender,
168            }),
169            None => Err(RoundInputError::SenderIndexOutOfRange {
170                msg_id: msg.id,
171                sender: msg.sender,
172                n: self.n,
173            }),
174        }
175    }
176
177    fn wants_more(&self) -> bool {
178        self.left_messages > 0
179    }
180
181    fn output(self) -> Result<Self::Output, Self> {
182        if self.left_messages > 0 {
183            Err(self)
184        } else {
185            Ok(RoundMsgs {
186                i: self.i,
187                ids: self.messages_ids,
188                messages: self.messages.into_iter().flatten().collect(),
189            })
190        }
191    }
192
193    fn read_any_prop(&self, property: &mut dyn core::any::Any) {
194        if let Some(p) =
195            property.downcast_mut::<Option<crate::round::props::RequiresReliableBroadcast>>()
196        {
197            *p = Some(crate::round::props::RequiresReliableBroadcast(matches!(
198                self.expected_msg_type,
199                MessageType::Broadcast { reliable: true }
200            )));
201        }
202    }
203}
204
205impl<M> RoundMsgs<M> {
206    /// Returns vec of `n-1` received messages
207    ///
208    /// Messages appear in the list in ascending order of sender index. E.g. for n=4 and local party index i=2,
209    /// the list would look like: `[{msg from i=0}, {msg from i=1}, {msg from i=3}]`.
210    pub fn into_vec_without_me(self) -> Vec<M> {
211        self.messages
212    }
213
214    /// Returns vec of received messages plus party's own message
215    ///
216    /// Similar to `into_vec_without_me`, but inserts `my_msg` at position `i` in resulting list. Thus, i-th
217    /// message in the list was received from i-th party.
218    pub fn into_vec_including_me(mut self, my_msg: M) -> Vec<M> {
219        self.messages.insert(usize::from(self.i), my_msg);
220        self.messages
221    }
222
223    /// Returns iterator over messages
224    pub fn iter(&self) -> impl Iterator<Item = &M> {
225        self.messages.iter()
226    }
227
228    /// Returns iterator over received messages plus party's own message
229    ///
230    /// Similar to [`.iter()`](Self::iter), but inserts `my_msg` at position `i`. Thus, i-th message in the
231    /// iterator is the message received from party `i`.
232    pub fn iter_including_me<'m>(&'m self, my_msg: &'m M) -> impl Iterator<Item = &'m M> {
233        self.messages
234            .iter()
235            .take(usize::from(self.i))
236            .chain(iter::once(my_msg))
237            .chain(self.messages.iter().skip(usize::from(self.i)))
238    }
239
240    /// Returns iterator over received messages plus party's own message
241    pub fn into_iter_including_me(self, my_msg: M) -> impl Iterator<Item = M> {
242        struct InsertsAfter<T, It> {
243            offset: usize,
244            inner: It,
245            item: Option<T>,
246        }
247        impl<T, It: Iterator<Item = T>> Iterator for InsertsAfter<T, It> {
248            type Item = T;
249            fn next(&mut self) -> Option<Self::Item> {
250                if self.offset == 0 {
251                    match self.item.take() {
252                        Some(x) => Some(x),
253                        None => self.inner.next(),
254                    }
255                } else {
256                    self.offset -= 1;
257                    self.inner.next()
258                }
259            }
260        }
261        InsertsAfter {
262            offset: usize::from(self.i),
263            inner: self.messages.into_iter(),
264            item: Some(my_msg),
265        }
266    }
267
268    /// Returns iterator over messages with sender indexes
269    ///
270    /// Iterator yields `(sender_index, msg_id, message)`
271    pub fn into_iter_indexed(self) -> impl Iterator<Item = (PartyIndex, MsgId, M)> {
272        let parties_indexes = (0..self.i).chain(self.i + 1..);
273        parties_indexes
274            .zip(self.ids)
275            .zip(self.messages)
276            .map(|((party_ind, msg_id), msg)| (party_ind, msg_id, msg))
277    }
278
279    /// Returns iterator over messages with sender indexes
280    ///
281    /// Iterator yields `(sender_index, msg_id, &message)`
282    pub fn iter_indexed(&self) -> impl Iterator<Item = (PartyIndex, MsgId, &M)> {
283        let parties_indexes = (0..self.i).chain(self.i + 1..);
284        parties_indexes
285            .zip(&self.ids)
286            .zip(&self.messages)
287            .map(|((party_ind, msg_id), msg)| (party_ind, *msg_id, msg))
288    }
289}
290
291/// Error explaining why [`RoundInput`] wasn't able to process a message
292#[derive(Debug, thiserror::Error)]
293pub enum RoundInputError {
294    /// Party sent two messages in one round
295    ///
296    /// `msgs_ids` are ids of conflicting messages
297    #[error("party {sender} tried to overwrite message")]
298    AttemptToOverwriteReceivedMsg {
299        /// IDs of conflicting messages
300        msgs_ids: [MsgId; 2],
301        /// Index of party who sent two messages in one round
302        sender: PartyIndex,
303    },
304    /// Unknown sender
305    ///
306    /// This error is thrown when index of sender is not in `[0; n)` where `n` is number of
307    /// parties involved in the protocol (provided in [`RoundInput::new`])
308    #[error("sender index is out of range: sender={sender}, n={n}")]
309    SenderIndexOutOfRange {
310        /// Message ID
311        msg_id: MsgId,
312        /// Sender index
313        sender: PartyIndex,
314        /// Number of parties
315        n: u16,
316    },
317    /// Received message type doesn't match expectations
318    ///
319    /// For instance, this error is returned when it's expected to receive broadcast message,
320    /// but party sent p2p message instead (which is rough protocol violation).
321    #[error("expected message {expected:?}, got {actual:?}")]
322    MismatchedMessageType {
323        /// Message ID
324        msg_id: MsgId,
325        /// Expected type of message
326        expected: MessageType,
327        /// Actual type of message
328        actual: MessageType,
329    },
330}
331
332/// p2p round
333///
334/// Alias to [`RoundInput::p2p`]
335pub fn p2p<M>(i: u16, n: u16) -> RoundInput<M> {
336    RoundInput::p2p(i, n)
337}
338/// Broadcast round
339///
340/// Alias to [`RoundInput::broadcast`]
341pub fn broadcast<M>(i: u16, n: u16) -> RoundInput<M> {
342    RoundInput::broadcast(i, n)
343}
344/// Reliable broadcast round
345///
346/// Alias to [`RoundInput::broadcast`]
347pub fn reliable_broadcast<M>(i: u16, n: u16) -> RoundInput<M> {
348    RoundInput::reliable_broadcast(i, n)
349}
350
351#[cfg(test)]
352mod tests {
353    use alloc::vec::Vec;
354    use matches::assert_matches;
355
356    use crate::round::RoundStore;
357    use crate::{Incoming, MessageType};
358
359    use super::{RoundInput, RoundInputError};
360
361    #[derive(Debug, Clone, PartialEq)]
362    pub struct Msg(u16);
363
364    #[test]
365    fn store_outputs_received_messages() {
366        let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
367
368        let msgs = (0..5)
369            .map(|s| Incoming {
370                id: s.into(),
371                sender: s,
372                msg_type: MessageType::P2P,
373                msg: Msg(10 + s),
374            })
375            .filter(|incoming| incoming.sender != 3)
376            .collect::<Vec<_>>();
377
378        for msg in &msgs {
379            assert!(store.wants_more());
380            store.add_message(msg.clone()).unwrap();
381        }
382
383        assert!(!store.wants_more());
384        let received = store.output().unwrap();
385
386        // without me
387        let msgs: Vec<_> = msgs.into_iter().map(|msg| msg.msg).collect();
388        assert_eq!(received.clone().into_vec_without_me(), msgs);
389
390        // including me
391        let received = received.into_vec_including_me(Msg(13));
392        assert_eq!(received[0..3], msgs[0..3]);
393        assert_eq!(received[3], Msg(13));
394        assert_eq!(received[4..5], msgs[3..4]);
395    }
396
397    #[test]
398    fn store_returns_error_if_sender_index_is_out_of_range() {
399        let mut store = RoundInput::new(3, 5, MessageType::P2P);
400        let error = store
401            .add_message(Incoming {
402                id: 0,
403                sender: 5,
404                msg_type: MessageType::P2P,
405                msg: Msg(123),
406            })
407            .unwrap_err();
408        assert_matches!(
409            error,
410            RoundInputError::SenderIndexOutOfRange { msg_id, sender, n } if msg_id == 0 && sender == 5 && n == 5
411        );
412    }
413
414    #[test]
415    fn store_returns_error_if_incoming_msg_overwrites_already_received_one() {
416        let mut store = RoundInput::new(0, 3, MessageType::P2P);
417        store
418            .add_message(Incoming {
419                id: 0,
420                sender: 1,
421                msg_type: MessageType::P2P,
422                msg: Msg(11),
423            })
424            .unwrap();
425        let error = store
426            .add_message(Incoming {
427                id: 1,
428                sender: 1,
429                msg_type: MessageType::P2P,
430                msg: Msg(112),
431            })
432            .unwrap_err();
433        assert_matches!(error, RoundInputError::AttemptToOverwriteReceivedMsg { msgs_ids, sender } if msgs_ids[0] == 0 && msgs_ids[1] == 1 && sender == 1);
434        store
435            .add_message(Incoming {
436                id: 2,
437                sender: 2,
438                msg_type: MessageType::P2P,
439                msg: Msg(22),
440            })
441            .unwrap();
442
443        let output = store.output().unwrap().into_vec_without_me();
444        assert_eq!(output, [Msg(11), Msg(22)]);
445    }
446
447    #[test]
448    fn store_returns_error_if_tried_to_output_before_receiving_enough_messages() {
449        let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
450
451        let msgs = (0..5)
452            .map(|s| Incoming {
453                id: s.into(),
454                sender: s,
455                msg_type: MessageType::P2P,
456                msg: Msg(10 + s),
457            })
458            .filter(|incoming| incoming.sender != 3);
459
460        for msg in msgs {
461            assert!(store.wants_more());
462            store = store.output().unwrap_err();
463
464            store.add_message(msg).unwrap();
465        }
466
467        let _ = store.output().unwrap();
468    }
469
470    #[test]
471    fn store_returns_error_if_message_type_mismatched() {
472        let mut store = RoundInput::<Msg>::p2p(3, 5);
473        for reliable in [true, false] {
474            let err = store
475                .add_message(Incoming {
476                    id: 0,
477                    sender: 0,
478                    msg_type: MessageType::Broadcast { reliable },
479                    msg: Msg(1),
480                })
481                .unwrap_err();
482            assert_matches!(
483                err,
484                RoundInputError::MismatchedMessageType {
485                    msg_id: 0,
486                    expected: MessageType::P2P,
487                    actual: MessageType::Broadcast { reliable: r }
488                } if r == reliable
489            );
490        }
491
492        let mut store = RoundInput::<Msg>::broadcast(3, 5);
493        let err = store
494            .add_message(Incoming {
495                id: 0,
496                sender: 0,
497                msg_type: MessageType::P2P,
498                msg: Msg(1),
499            })
500            .unwrap_err();
501        assert_matches!(
502            err,
503            RoundInputError::MismatchedMessageType {
504                msg_id: 0,
505                expected: MessageType::Broadcast { reliable: false },
506                actual: MessageType::P2P,
507            }
508        );
509
510        let mut store = RoundInput::<Msg>::reliable_broadcast(3, 5);
511        let err = store
512            .add_message(Incoming {
513                id: 0,
514                sender: 0,
515                msg_type: MessageType::P2P,
516                msg: Msg(1),
517            })
518            .unwrap_err();
519        assert_matches!(
520            err,
521            RoundInputError::MismatchedMessageType {
522                msg_id: 0,
523                expected: MessageType::Broadcast { reliable: true },
524                actual: MessageType::P2P,
525            }
526        );
527        let err = store
528            .add_message(Incoming {
529                id: 0,
530                sender: 0,
531                msg_type: MessageType::Broadcast { reliable: false },
532                msg: Msg(1),
533            })
534            .unwrap_err();
535        assert_matches!(
536            err,
537            RoundInputError::MismatchedMessageType {
538                msg_id: 0,
539                expected: MessageType::Broadcast { reliable: true },
540                actual: MessageType::Broadcast { reliable: false },
541            }
542        );
543    }
544
545    #[test]
546    fn non_reliable_broadcast_round_accepts_reliable_broadcast_messages() {
547        let mut store = RoundInput::<Msg>::broadcast(3, 5);
548        store
549            .add_message(Incoming {
550                id: 0,
551                sender: 0,
552                msg_type: MessageType::Broadcast { reliable: true },
553                msg: Msg(1),
554            })
555            .unwrap();
556    }
557
558    #[test]
559    fn into_iter_including_me() {
560        let me = -10_isize;
561        let messages = alloc::vec![1, 2, 3];
562
563        let me_first = super::RoundMsgs {
564            i: 0,
565            ids: alloc::vec![1, 2, 3],
566            messages: messages.clone(),
567        };
568        let all = me_first.into_iter_including_me(me).collect::<Vec<_>>();
569        assert_eq!(all, [-10, 1, 2, 3]);
570
571        let me_second = super::RoundMsgs {
572            i: 1,
573            ids: alloc::vec![0, 2, 3],
574            messages: messages.clone(),
575        };
576        let all = me_second.into_iter_including_me(me).collect::<Vec<_>>();
577        assert_eq!(all, [1, -10, 2, 3]);
578
579        let me_last = super::RoundMsgs {
580            i: 3,
581            ids: alloc::vec![0, 1, 2],
582            messages: messages.clone(),
583        };
584        let all = me_last.into_iter_including_me(me).collect::<Vec<_>>();
585        assert_eq!(all, [1, 2, 3, -10]);
586    }
587}