1use alloc::{vec, vec::Vec};
4use core::iter;
5
6use crate::{Incoming, MessageType, MsgId, PartyIndex};
7
8use super::MessagesStore;
9
10#[derive(Debug, Clone)]
46pub struct RoundInput<M> {
47 i: PartyIndex,
48 n: u16,
49 messages_ids: Vec<MsgId>,
50 messages: Vec<Option<M>>,
51 left_messages: u16,
52 expected_msg_type: MessageType,
53}
54
55#[derive(Debug, Clone)]
57pub struct RoundMsgs<M> {
58 i: PartyIndex,
59 ids: Vec<MsgId>,
60 messages: Vec<M>,
61}
62
63impl<M> RoundInput<M> {
64 pub fn new(i: PartyIndex, n: u16, msg_type: MessageType) -> Self {
71 assert!(n >= 2);
72 assert!(i < n);
73
74 Self {
75 i,
76 n,
77 messages_ids: vec![0; usize::from(n) - 1],
78 messages: iter::repeat_with(|| None)
79 .take(usize::from(n) - 1)
80 .collect(),
81 left_messages: n - 1,
82 expected_msg_type: msg_type,
83 }
84 }
85
86 pub fn broadcast(i: PartyIndex, n: u16) -> Self {
90 Self::new(i, n, MessageType::Broadcast)
91 }
92
93 pub fn p2p(i: PartyIndex, n: u16) -> Self {
97 Self::new(i, n, MessageType::P2P)
98 }
99
100 fn is_expected_type_of_msg(&self, msg_type: MessageType) -> bool {
101 self.expected_msg_type == msg_type
102 }
103}
104
105impl<M> MessagesStore for RoundInput<M>
106where
107 M: 'static,
108{
109 type Msg = M;
110 type Output = RoundMsgs<M>;
111 type Error = RoundInputError;
112
113 fn add_message(&mut self, msg: Incoming<Self::Msg>) -> Result<(), Self::Error> {
114 if !self.is_expected_type_of_msg(msg.msg_type) {
115 return Err(RoundInputError::MismatchedMessageType {
116 msg_id: msg.id,
117 expected: self.expected_msg_type,
118 actual: msg.msg_type,
119 });
120 }
121 if msg.sender == self.i {
122 return Ok(());
124 }
125
126 let index = usize::from(if msg.sender < self.i {
127 msg.sender
128 } else {
129 msg.sender - 1
130 });
131
132 match self.messages.get_mut(index) {
133 Some(vacant @ None) => {
134 *vacant = Some(msg.msg);
135 self.messages_ids[index] = msg.id;
136 self.left_messages -= 1;
137 Ok(())
138 }
139 Some(Some(_)) => Err(RoundInputError::AttemptToOverwriteReceivedMsg {
140 msgs_ids: [self.messages_ids[index], msg.id],
141 sender: msg.sender,
142 }),
143 None => Err(RoundInputError::SenderIndexOutOfRange {
144 msg_id: msg.id,
145 sender: msg.sender,
146 n: self.n,
147 }),
148 }
149 }
150
151 fn wants_more(&self) -> bool {
152 self.left_messages > 0
153 }
154
155 fn output(self) -> Result<Self::Output, Self> {
156 if self.left_messages > 0 {
157 Err(self)
158 } else {
159 Ok(RoundMsgs {
160 i: self.i,
161 ids: self.messages_ids,
162 messages: self.messages.into_iter().flatten().collect(),
163 })
164 }
165 }
166}
167
168impl<M> RoundMsgs<M> {
169 pub fn into_vec_without_me(self) -> Vec<M> {
174 self.messages
175 }
176
177 pub fn into_vec_including_me(mut self, my_msg: M) -> Vec<M> {
182 self.messages.insert(usize::from(self.i), my_msg);
183 self.messages
184 }
185
186 pub fn iter(&self) -> impl Iterator<Item = &M> {
188 self.messages.iter()
189 }
190
191 pub fn iter_including_me<'m>(&'m self, my_msg: &'m M) -> impl Iterator<Item = &'m M> {
196 self.messages
197 .iter()
198 .take(usize::from(self.i))
199 .chain(iter::once(my_msg))
200 .chain(self.messages.iter().skip(usize::from(self.i)))
201 }
202
203 pub fn into_iter_including_me(self, my_msg: M) -> impl Iterator<Item = M> {
205 struct InsertsAfter<T, It> {
206 offset: usize,
207 inner: It,
208 item: Option<T>,
209 }
210 impl<T, It: Iterator<Item = T>> Iterator for InsertsAfter<T, It> {
211 type Item = T;
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.offset == 0 {
214 match self.item.take() {
215 Some(x) => Some(x),
216 None => self.inner.next(),
217 }
218 } else {
219 self.offset -= 1;
220 self.inner.next()
221 }
222 }
223 }
224 InsertsAfter {
225 offset: usize::from(self.i),
226 inner: self.messages.into_iter(),
227 item: Some(my_msg),
228 }
229 }
230
231 pub fn into_iter_indexed(self) -> impl Iterator<Item = (PartyIndex, MsgId, M)> {
235 let parties_indexes = (0..self.i).chain(self.i + 1..);
236 parties_indexes
237 .zip(self.ids)
238 .zip(self.messages)
239 .map(|((party_ind, msg_id), msg)| (party_ind, msg_id, msg))
240 }
241
242 pub fn iter_indexed(&self) -> impl Iterator<Item = (PartyIndex, MsgId, &M)> {
246 let parties_indexes = (0..self.i).chain(self.i + 1..);
247 parties_indexes
248 .zip(&self.ids)
249 .zip(&self.messages)
250 .map(|((party_ind, msg_id), msg)| (party_ind, *msg_id, msg))
251 }
252}
253
254#[derive(Debug, thiserror::Error)]
256pub enum RoundInputError {
257 #[error("party {sender} tried to overwrite message")]
261 AttemptToOverwriteReceivedMsg {
262 msgs_ids: [MsgId; 2],
264 sender: PartyIndex,
266 },
267 #[error("sender index is out of range: sender={sender}, n={n}")]
272 SenderIndexOutOfRange {
273 msg_id: MsgId,
275 sender: PartyIndex,
277 n: u16,
279 },
280 #[error("expected message {expected:?}, got {actual:?}")]
285 MismatchedMessageType {
286 msg_id: MsgId,
288 expected: MessageType,
290 actual: MessageType,
292 },
293}
294
295#[cfg(test)]
296mod tests {
297 use alloc::vec::Vec;
298 use matches::assert_matches;
299
300 use crate::rounds_router::store::MessagesStore;
301 use crate::{Incoming, MessageType};
302
303 use super::{RoundInput, RoundInputError};
304
305 #[derive(Debug, Clone, PartialEq)]
306 pub struct Msg(u16);
307
308 #[test]
309 fn store_outputs_received_messages() {
310 let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
311
312 let msgs = (0..5)
313 .map(|s| Incoming {
314 id: s.into(),
315 sender: s,
316 msg_type: MessageType::P2P,
317 msg: Msg(10 + s),
318 })
319 .filter(|incoming| incoming.sender != 3)
320 .collect::<Vec<_>>();
321
322 for msg in &msgs {
323 assert!(store.wants_more());
324 store.add_message(msg.clone()).unwrap();
325 }
326
327 assert!(!store.wants_more());
328 let received = store.output().unwrap();
329
330 let msgs: Vec<_> = msgs.into_iter().map(|msg| msg.msg).collect();
332 assert_eq!(received.clone().into_vec_without_me(), msgs);
333
334 let received = received.into_vec_including_me(Msg(13));
336 assert_eq!(received[0..3], msgs[0..3]);
337 assert_eq!(received[3], Msg(13));
338 assert_eq!(received[4..5], msgs[3..4]);
339 }
340
341 #[test]
342 fn store_returns_error_if_sender_index_is_out_of_range() {
343 let mut store = RoundInput::new(3, 5, MessageType::P2P);
344 let error = store
345 .add_message(Incoming {
346 id: 0,
347 sender: 5,
348 msg_type: MessageType::P2P,
349 msg: Msg(123),
350 })
351 .unwrap_err();
352 assert_matches!(
353 error,
354 RoundInputError::SenderIndexOutOfRange { msg_id, sender, n } if msg_id == 0 && sender == 5 && n == 5
355 );
356 }
357
358 #[test]
359 fn store_returns_error_if_incoming_msg_overwrites_already_received_one() {
360 let mut store = RoundInput::new(0, 3, MessageType::P2P);
361 store
362 .add_message(Incoming {
363 id: 0,
364 sender: 1,
365 msg_type: MessageType::P2P,
366 msg: Msg(11),
367 })
368 .unwrap();
369 let error = store
370 .add_message(Incoming {
371 id: 1,
372 sender: 1,
373 msg_type: MessageType::P2P,
374 msg: Msg(112),
375 })
376 .unwrap_err();
377 assert_matches!(error, RoundInputError::AttemptToOverwriteReceivedMsg { msgs_ids, sender } if msgs_ids[0] == 0 && msgs_ids[1] == 1 && sender == 1);
378 store
379 .add_message(Incoming {
380 id: 2,
381 sender: 2,
382 msg_type: MessageType::P2P,
383 msg: Msg(22),
384 })
385 .unwrap();
386
387 let output = store.output().unwrap().into_vec_without_me();
388 assert_eq!(output, [Msg(11), Msg(22)]);
389 }
390
391 #[test]
392 fn store_returns_error_if_tried_to_output_before_receiving_enough_messages() {
393 let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
394
395 let msgs = (0..5)
396 .map(|s| Incoming {
397 id: s.into(),
398 sender: s,
399 msg_type: MessageType::P2P,
400 msg: Msg(10 + s),
401 })
402 .filter(|incoming| incoming.sender != 3);
403
404 for msg in msgs {
405 assert!(store.wants_more());
406 store = store.output().unwrap_err();
407
408 store.add_message(msg).unwrap();
409 }
410
411 let _ = store.output().unwrap();
412 }
413
414 #[test]
415 fn store_returns_error_if_message_type_mismatched() {
416 let mut store = RoundInput::<Msg>::p2p(3, 5);
417 let err = store
418 .add_message(Incoming {
419 id: 0,
420 sender: 0,
421 msg_type: MessageType::Broadcast,
422 msg: Msg(1),
423 })
424 .unwrap_err();
425 assert_matches!(
426 err,
427 RoundInputError::MismatchedMessageType {
428 msg_id: 0,
429 expected: MessageType::P2P,
430 actual: MessageType::Broadcast
431 }
432 );
433
434 let mut store = RoundInput::<Msg>::broadcast(3, 5);
435 let err = store
436 .add_message(Incoming {
437 id: 0,
438 sender: 0,
439 msg_type: MessageType::P2P,
440 msg: Msg(1),
441 })
442 .unwrap_err();
443 assert_matches!(
444 err,
445 RoundInputError::MismatchedMessageType {
446 msg_id: 0,
447 expected: MessageType::Broadcast,
448 actual: MessageType::P2P,
449 }
450 );
451 for sender in 0u16..5 {
452 store
453 .add_message(Incoming {
454 id: 0,
455 sender,
456 msg_type: MessageType::Broadcast,
457 msg: Msg(1),
458 })
459 .unwrap();
460 }
461
462 let mut store = RoundInput::<Msg>::broadcast(3, 5);
463 let err = store
464 .add_message(Incoming {
465 id: 0,
466 sender: 0,
467 msg_type: MessageType::P2P,
468 msg: Msg(1),
469 })
470 .unwrap_err();
471 assert_matches!(
472 err,
473 RoundInputError::MismatchedMessageType {
474 msg_id: 0,
475 expected: MessageType::Broadcast,
476 actual,
477 } if actual == MessageType::P2P
478 );
479 store
480 .add_message(Incoming {
481 id: 0,
482 sender: 0,
483 msg_type: MessageType::Broadcast,
484 msg: Msg(1),
485 })
486 .unwrap();
487 }
488
489 #[test]
490 fn into_iter_including_me() {
491 let me = -10_isize;
492 let messages = alloc::vec![1, 2, 3];
493
494 let me_first = super::RoundMsgs {
495 i: 0,
496 ids: alloc::vec![1, 2, 3],
497 messages: messages.clone(),
498 };
499 let all = me_first.into_iter_including_me(me).collect::<Vec<_>>();
500 assert_eq!(all, [-10, 1, 2, 3]);
501
502 let me_second = super::RoundMsgs {
503 i: 1,
504 ids: alloc::vec![0, 2, 3],
505 messages: messages.clone(),
506 };
507 let all = me_second.into_iter_including_me(me).collect::<Vec<_>>();
508 assert_eq!(all, [1, -10, 2, 3]);
509
510 let me_last = super::RoundMsgs {
511 i: 3,
512 ids: alloc::vec![0, 1, 2],
513 messages: messages.clone(),
514 };
515 let all = me_last.into_iter_including_me(me).collect::<Vec<_>>();
516 assert_eq!(all, [1, 2, 3, -10]);
517 }
518}