1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
use std::cmp::Ordering;
use std::ops;

use crate::sm::Msg;

use super::store_err::StoreErr;
use super::traits::{MessageContainer, MessageStore};

/// Received broadcast messages from every protocol participant
#[derive(Debug)]
pub struct BroadcastMsgs<B> {
    my_ind: u16,
    msgs: Vec<B>,
}

impl<B> BroadcastMsgs<B>
where
    B: 'static,
{
    /// Turns a container into iterator of messages with parties indexes (1 <= i <= n)
    pub fn into_iter_indexed(self) -> impl Iterator<Item = (u16, B)> {
        let my_ind = usize::from(self.my_ind);
        let ind = move |i| {
            if i < my_ind - 1 {
                i as u16 + 1
            } else {
                i as u16 + 2
            }
        };
        self.msgs
            .into_iter()
            .enumerate()
            .map(move |(i, m)| (ind(i), m))
    }

    /// Turns container into vec of `n-1` messages
    pub fn into_vec(self) -> Vec<B> {
        self.msgs
    }

    /// Turns container into vec of `n` messages (where given message lies at index `party_i-1`)
    pub fn into_vec_including_me(mut self, me: B) -> Vec<B> {
        self.msgs.insert(self.my_ind as usize - 1, me);
        self.msgs
    }
}

impl<B> ops::Index<u16> for BroadcastMsgs<B> {
    type Output = B;

    /// Takes party index i and returns received message (1 <= i <= n)
    ///
    /// ## Panics
    /// Panics if there's no party with index i (or it's your party index)
    fn index(&self, index: u16) -> &Self::Output {
        match Ord::cmp(&index, &(self.my_ind - 1)) {
            Ordering::Less => &self.msgs[usize::from(index)],
            Ordering::Greater => &self.msgs[usize::from(index - 1)],
            Ordering::Equal => panic!("accessing own broadcasted msg"),
        }
    }
}

impl<B> IntoIterator for BroadcastMsgs<B> {
    type Item = B;
    type IntoIter = <Vec<B> as IntoIterator>::IntoIter;

    /// Returns messages in ascending party's index order
    fn into_iter(self) -> Self::IntoIter {
        self.msgs.into_iter()
    }
}

impl<M> MessageContainer for BroadcastMsgs<M> {
    type Store = BroadcastMsgsStore<M>;
}

/// Receives broadcast messages from every protocol participant
pub struct BroadcastMsgsStore<M> {
    party_i: u16,
    msgs: Vec<Option<M>>,
    msgs_left: usize,
}

impl<M> BroadcastMsgsStore<M> {
    /// Constructs store. Takes this party index and total number of parties.
    pub fn new(party_i: u16, parties_n: u16) -> Self {
        let parties_n = usize::from(parties_n);
        Self {
            party_i,
            msgs: std::iter::repeat_with(|| None)
                .take(parties_n - 1)
                .collect(),
            msgs_left: parties_n - 1,
        }
    }

    /// Amount of received messages so far
    pub fn messages_received(&self) -> usize {
        self.msgs.len() - self.msgs_left
    }
    /// Total amount of wanted messages (n-1)
    pub fn messages_total(&self) -> usize {
        self.msgs.len()
    }
}

impl<M> MessageStore for BroadcastMsgsStore<M> {
    type M = M;
    type Err = StoreErr;
    type Output = BroadcastMsgs<M>;

    fn push_msg(&mut self, msg: Msg<Self::M>) -> Result<(), Self::Err> {
        if msg.sender == 0 {
            return Err(StoreErr::UnknownSender { sender: msg.sender });
        }
        if msg.receiver.is_some() {
            return Err(StoreErr::ExpectedBroadcast);
        }
        let party_j = match Ord::cmp(&msg.sender, &self.party_i) {
            Ordering::Less => usize::from(msg.sender),
            Ordering::Greater => usize::from(msg.sender) - 1,
            Ordering::Equal => return Err(StoreErr::ItsFromMe),
        };
        let slot = self
            .msgs
            .get_mut(party_j - 1)
            .ok_or(StoreErr::UnknownSender { sender: msg.sender })?;
        if slot.is_some() {
            return Err(StoreErr::MsgOverwrite);
        }
        *slot = Some(msg.body);
        self.msgs_left -= 1;

        Ok(())
    }

    fn contains_msg_from(&self, sender: u16) -> bool {
        let party_j = match Ord::cmp(&sender, &self.party_i) {
            Ordering::Less => usize::from(sender),
            Ordering::Greater => usize::from(sender) - 1,
            Ordering::Equal => return false,
        };
        match self.msgs.get(party_j - 1) {
            None => false,
            Some(None) => false,
            Some(Some(_)) => true,
        }
    }

    fn wants_more(&self) -> bool {
        self.msgs_left > 0
    }

    fn finish(self) -> Result<Self::Output, Self::Err> {
        if self.msgs_left > 0 {
            return Err(StoreErr::WantsMoreMessages);
        }
        Ok(BroadcastMsgs {
            my_ind: self.party_i,
            msgs: self.msgs.into_iter().map(Option::unwrap).collect(),
        })
    }

    fn blame(&self) -> (u16, Vec<u16>) {
        let ind = |i: u16| -> u16 {
            if i < self.party_i - 1 {
                i + 1
            } else {
                i + 2
            }
        };
        let guilty_parties = self
            .msgs
            .iter()
            .enumerate()
            .flat_map(|(i, m)| {
                if m.is_none() {
                    Some(ind(i as u16))
                } else {
                    None
                }
            })
            .collect();
        (self.msgs_left as u16, guilty_parties)
    }
}