1use alloc::{vec, vec::Vec};
4use core::iter;
5
6use crate::{Incoming, MessageType, MsgId, PartyIndex};
7
8use super::{RoundInfo, RoundStore};
9
10#[derive(Debug, Clone)]
50pub struct RoundInput<M> {
51 i: PartyIndex,
52 n: u16,
53 messages_ids: Vec<MsgId>,
54 messages: Vec<Option<M>>,
55 left_messages: u16,
56 expected_msg_type: MessageType,
57}
58
59#[derive(Debug, Clone)]
61pub struct RoundMsgs<M> {
62 i: PartyIndex,
63 ids: Vec<MsgId>,
64 messages: Vec<M>,
65}
66
67impl<M> RoundInput<M> {
68 pub fn new(i: PartyIndex, n: u16, msg_type: MessageType) -> Self {
75 assert!(n >= 2);
76 assert!(i < n);
77
78 Self {
79 i,
80 n,
81 messages_ids: vec![0; usize::from(n) - 1],
82 messages: iter::repeat_with(|| None)
83 .take(usize::from(n) - 1)
84 .collect(),
85 left_messages: n - 1,
86 expected_msg_type: msg_type,
87 }
88 }
89
90 pub fn broadcast(i: PartyIndex, n: u16) -> Self {
94 Self::new(i, n, MessageType::Broadcast { reliable: false })
95 }
96
97 pub fn reliable_broadcast(i: PartyIndex, n: u16) -> Self {
101 Self::new(i, n, MessageType::Broadcast { reliable: true })
102 }
103
104 pub fn p2p(i: PartyIndex, n: u16) -> Self {
108 Self::new(i, n, MessageType::P2P)
109 }
110
111 fn is_expected_type_of_msg(&self, actual_msg_type: MessageType) -> bool {
112 matches!(
113 (self.expected_msg_type, actual_msg_type),
114 (MessageType::P2P, MessageType::P2P)
115 | (
116 MessageType::Broadcast { reliable: false },
117 MessageType::Broadcast { .. }
118 )
119 | (
120 MessageType::Broadcast { reliable: true },
121 MessageType::Broadcast { reliable: true },
122 )
123 )
124 }
125}
126
127impl<M> RoundInfo for RoundInput<M>
128where
129 M: 'static,
130{
131 type Msg = M;
132 type Output = RoundMsgs<M>;
133 type Error = RoundInputError;
134}
135impl<M> RoundStore for RoundInput<M>
136where
137 M: 'static,
138{
139 fn add_message(&mut self, msg: Incoming<Self::Msg>) -> Result<(), Self::Error> {
140 if !self.is_expected_type_of_msg(msg.msg_type) {
141 return Err(RoundInputError::MismatchedMessageType {
142 msg_id: msg.id,
143 expected: self.expected_msg_type,
144 actual: msg.msg_type,
145 });
146 }
147 if msg.sender == self.i {
148 return Ok(());
150 }
151
152 let index = usize::from(if msg.sender < self.i {
153 msg.sender
154 } else {
155 msg.sender - 1
156 });
157
158 match self.messages.get_mut(index) {
159 Some(vacant @ None) => {
160 *vacant = Some(msg.msg);
161 self.messages_ids[index] = msg.id;
162 self.left_messages -= 1;
163 Ok(())
164 }
165 Some(Some(_)) => Err(RoundInputError::AttemptToOverwriteReceivedMsg {
166 msgs_ids: [self.messages_ids[index], msg.id],
167 sender: msg.sender,
168 }),
169 None => Err(RoundInputError::SenderIndexOutOfRange {
170 msg_id: msg.id,
171 sender: msg.sender,
172 n: self.n,
173 }),
174 }
175 }
176
177 fn wants_more(&self) -> bool {
178 self.left_messages > 0
179 }
180
181 fn output(self) -> Result<Self::Output, Self> {
182 if self.left_messages > 0 {
183 Err(self)
184 } else {
185 Ok(RoundMsgs {
186 i: self.i,
187 ids: self.messages_ids,
188 messages: self.messages.into_iter().flatten().collect(),
189 })
190 }
191 }
192
193 fn read_any_prop(&self, property: &mut dyn core::any::Any) {
194 if let Some(p) =
195 property.downcast_mut::<Option<crate::round::props::RequiresReliableBroadcast>>()
196 {
197 *p = Some(crate::round::props::RequiresReliableBroadcast(matches!(
198 self.expected_msg_type,
199 MessageType::Broadcast { reliable: true }
200 )));
201 }
202 }
203}
204
205impl<M> RoundMsgs<M> {
206 pub fn into_vec_without_me(self) -> Vec<M> {
211 self.messages
212 }
213
214 pub fn into_vec_including_me(mut self, my_msg: M) -> Vec<M> {
219 self.messages.insert(usize::from(self.i), my_msg);
220 self.messages
221 }
222
223 pub fn iter(&self) -> impl Iterator<Item = &M> {
225 self.messages.iter()
226 }
227
228 pub fn iter_including_me<'m>(&'m self, my_msg: &'m M) -> impl Iterator<Item = &'m M> {
233 self.messages
234 .iter()
235 .take(usize::from(self.i))
236 .chain(iter::once(my_msg))
237 .chain(self.messages.iter().skip(usize::from(self.i)))
238 }
239
240 pub fn into_iter_including_me(self, my_msg: M) -> impl Iterator<Item = M> {
242 struct InsertsAfter<T, It> {
243 offset: usize,
244 inner: It,
245 item: Option<T>,
246 }
247 impl<T, It: Iterator<Item = T>> Iterator for InsertsAfter<T, It> {
248 type Item = T;
249 fn next(&mut self) -> Option<Self::Item> {
250 if self.offset == 0 {
251 match self.item.take() {
252 Some(x) => Some(x),
253 None => self.inner.next(),
254 }
255 } else {
256 self.offset -= 1;
257 self.inner.next()
258 }
259 }
260 }
261 InsertsAfter {
262 offset: usize::from(self.i),
263 inner: self.messages.into_iter(),
264 item: Some(my_msg),
265 }
266 }
267
268 pub fn into_iter_indexed(self) -> impl Iterator<Item = (PartyIndex, MsgId, M)> {
272 let parties_indexes = (0..self.i).chain(self.i + 1..);
273 parties_indexes
274 .zip(self.ids)
275 .zip(self.messages)
276 .map(|((party_ind, msg_id), msg)| (party_ind, msg_id, msg))
277 }
278
279 pub fn iter_indexed(&self) -> impl Iterator<Item = (PartyIndex, MsgId, &M)> {
283 let parties_indexes = (0..self.i).chain(self.i + 1..);
284 parties_indexes
285 .zip(&self.ids)
286 .zip(&self.messages)
287 .map(|((party_ind, msg_id), msg)| (party_ind, *msg_id, msg))
288 }
289}
290
291#[derive(Debug, thiserror::Error)]
293pub enum RoundInputError {
294 #[error("party {sender} tried to overwrite message")]
298 AttemptToOverwriteReceivedMsg {
299 msgs_ids: [MsgId; 2],
301 sender: PartyIndex,
303 },
304 #[error("sender index is out of range: sender={sender}, n={n}")]
309 SenderIndexOutOfRange {
310 msg_id: MsgId,
312 sender: PartyIndex,
314 n: u16,
316 },
317 #[error("expected message {expected:?}, got {actual:?}")]
322 MismatchedMessageType {
323 msg_id: MsgId,
325 expected: MessageType,
327 actual: MessageType,
329 },
330}
331
332pub fn p2p<M>(i: u16, n: u16) -> RoundInput<M> {
336 RoundInput::p2p(i, n)
337}
338pub fn broadcast<M>(i: u16, n: u16) -> RoundInput<M> {
342 RoundInput::broadcast(i, n)
343}
344pub fn reliable_broadcast<M>(i: u16, n: u16) -> RoundInput<M> {
348 RoundInput::reliable_broadcast(i, n)
349}
350
351#[cfg(test)]
352mod tests {
353 use alloc::vec::Vec;
354 use matches::assert_matches;
355
356 use crate::round::RoundStore;
357 use crate::{Incoming, MessageType};
358
359 use super::{RoundInput, RoundInputError};
360
361 #[derive(Debug, Clone, PartialEq)]
362 pub struct Msg(u16);
363
364 #[test]
365 fn store_outputs_received_messages() {
366 let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
367
368 let msgs = (0..5)
369 .map(|s| Incoming {
370 id: s.into(),
371 sender: s,
372 msg_type: MessageType::P2P,
373 msg: Msg(10 + s),
374 })
375 .filter(|incoming| incoming.sender != 3)
376 .collect::<Vec<_>>();
377
378 for msg in &msgs {
379 assert!(store.wants_more());
380 store.add_message(msg.clone()).unwrap();
381 }
382
383 assert!(!store.wants_more());
384 let received = store.output().unwrap();
385
386 let msgs: Vec<_> = msgs.into_iter().map(|msg| msg.msg).collect();
388 assert_eq!(received.clone().into_vec_without_me(), msgs);
389
390 let received = received.into_vec_including_me(Msg(13));
392 assert_eq!(received[0..3], msgs[0..3]);
393 assert_eq!(received[3], Msg(13));
394 assert_eq!(received[4..5], msgs[3..4]);
395 }
396
397 #[test]
398 fn store_returns_error_if_sender_index_is_out_of_range() {
399 let mut store = RoundInput::new(3, 5, MessageType::P2P);
400 let error = store
401 .add_message(Incoming {
402 id: 0,
403 sender: 5,
404 msg_type: MessageType::P2P,
405 msg: Msg(123),
406 })
407 .unwrap_err();
408 assert_matches!(
409 error,
410 RoundInputError::SenderIndexOutOfRange { msg_id, sender, n } if msg_id == 0 && sender == 5 && n == 5
411 );
412 }
413
414 #[test]
415 fn store_returns_error_if_incoming_msg_overwrites_already_received_one() {
416 let mut store = RoundInput::new(0, 3, MessageType::P2P);
417 store
418 .add_message(Incoming {
419 id: 0,
420 sender: 1,
421 msg_type: MessageType::P2P,
422 msg: Msg(11),
423 })
424 .unwrap();
425 let error = store
426 .add_message(Incoming {
427 id: 1,
428 sender: 1,
429 msg_type: MessageType::P2P,
430 msg: Msg(112),
431 })
432 .unwrap_err();
433 assert_matches!(error, RoundInputError::AttemptToOverwriteReceivedMsg { msgs_ids, sender } if msgs_ids[0] == 0 && msgs_ids[1] == 1 && sender == 1);
434 store
435 .add_message(Incoming {
436 id: 2,
437 sender: 2,
438 msg_type: MessageType::P2P,
439 msg: Msg(22),
440 })
441 .unwrap();
442
443 let output = store.output().unwrap().into_vec_without_me();
444 assert_eq!(output, [Msg(11), Msg(22)]);
445 }
446
447 #[test]
448 fn store_returns_error_if_tried_to_output_before_receiving_enough_messages() {
449 let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
450
451 let msgs = (0..5)
452 .map(|s| Incoming {
453 id: s.into(),
454 sender: s,
455 msg_type: MessageType::P2P,
456 msg: Msg(10 + s),
457 })
458 .filter(|incoming| incoming.sender != 3);
459
460 for msg in msgs {
461 assert!(store.wants_more());
462 store = store.output().unwrap_err();
463
464 store.add_message(msg).unwrap();
465 }
466
467 let _ = store.output().unwrap();
468 }
469
470 #[test]
471 fn store_returns_error_if_message_type_mismatched() {
472 let mut store = RoundInput::<Msg>::p2p(3, 5);
473 for reliable in [true, false] {
474 let err = store
475 .add_message(Incoming {
476 id: 0,
477 sender: 0,
478 msg_type: MessageType::Broadcast { reliable },
479 msg: Msg(1),
480 })
481 .unwrap_err();
482 assert_matches!(
483 err,
484 RoundInputError::MismatchedMessageType {
485 msg_id: 0,
486 expected: MessageType::P2P,
487 actual: MessageType::Broadcast { reliable: r }
488 } if r == reliable
489 );
490 }
491
492 let mut store = RoundInput::<Msg>::broadcast(3, 5);
493 let err = store
494 .add_message(Incoming {
495 id: 0,
496 sender: 0,
497 msg_type: MessageType::P2P,
498 msg: Msg(1),
499 })
500 .unwrap_err();
501 assert_matches!(
502 err,
503 RoundInputError::MismatchedMessageType {
504 msg_id: 0,
505 expected: MessageType::Broadcast { reliable: false },
506 actual: MessageType::P2P,
507 }
508 );
509
510 let mut store = RoundInput::<Msg>::reliable_broadcast(3, 5);
511 let err = store
512 .add_message(Incoming {
513 id: 0,
514 sender: 0,
515 msg_type: MessageType::P2P,
516 msg: Msg(1),
517 })
518 .unwrap_err();
519 assert_matches!(
520 err,
521 RoundInputError::MismatchedMessageType {
522 msg_id: 0,
523 expected: MessageType::Broadcast { reliable: true },
524 actual: MessageType::P2P,
525 }
526 );
527 let err = store
528 .add_message(Incoming {
529 id: 0,
530 sender: 0,
531 msg_type: MessageType::Broadcast { reliable: false },
532 msg: Msg(1),
533 })
534 .unwrap_err();
535 assert_matches!(
536 err,
537 RoundInputError::MismatchedMessageType {
538 msg_id: 0,
539 expected: MessageType::Broadcast { reliable: true },
540 actual: MessageType::Broadcast { reliable: false },
541 }
542 );
543 }
544
545 #[test]
546 fn non_reliable_broadcast_round_accepts_reliable_broadcast_messages() {
547 let mut store = RoundInput::<Msg>::broadcast(3, 5);
548 store
549 .add_message(Incoming {
550 id: 0,
551 sender: 0,
552 msg_type: MessageType::Broadcast { reliable: true },
553 msg: Msg(1),
554 })
555 .unwrap();
556 }
557
558 #[test]
559 fn into_iter_including_me() {
560 let me = -10_isize;
561 let messages = alloc::vec![1, 2, 3];
562
563 let me_first = super::RoundMsgs {
564 i: 0,
565 ids: alloc::vec![1, 2, 3],
566 messages: messages.clone(),
567 };
568 let all = me_first.into_iter_including_me(me).collect::<Vec<_>>();
569 assert_eq!(all, [-10, 1, 2, 3]);
570
571 let me_second = super::RoundMsgs {
572 i: 1,
573 ids: alloc::vec![0, 2, 3],
574 messages: messages.clone(),
575 };
576 let all = me_second.into_iter_including_me(me).collect::<Vec<_>>();
577 assert_eq!(all, [1, -10, 2, 3]);
578
579 let me_last = super::RoundMsgs {
580 i: 3,
581 ids: alloc::vec![0, 1, 2],
582 messages: messages.clone(),
583 };
584 let all = me_last.into_iter_including_me(me).collect::<Vec<_>>();
585 assert_eq!(all, [1, 2, 3, -10]);
586 }
587}