use alloc::{vec, vec::Vec};
use core::iter;
use crate::{Incoming, MessageType, MsgId, PartyIndex};
use super::MessagesStore;
#[derive(Debug, Clone)]
pub struct RoundInput<M> {
i: PartyIndex,
n: u16,
messages_ids: Vec<MsgId>,
messages: Vec<Option<M>>,
left_messages: u16,
expected_msg_type: MessageType,
}
#[derive(Debug, Clone)]
pub struct RoundMsgs<M> {
i: PartyIndex,
ids: Vec<MsgId>,
messages: Vec<M>,
}
impl<M> RoundInput<M> {
pub fn new(i: PartyIndex, n: u16, msg_type: MessageType) -> Self {
assert!(n >= 2);
assert!(i < n);
Self {
i,
n,
messages_ids: vec![0; usize::from(n) - 1],
messages: iter::repeat_with(|| None)
.take(usize::from(n) - 1)
.collect(),
left_messages: n - 1,
expected_msg_type: msg_type,
}
}
pub fn broadcast(i: PartyIndex, n: u16) -> Self {
Self::new(i, n, MessageType::Broadcast)
}
pub fn p2p(i: PartyIndex, n: u16) -> Self {
Self::new(i, n, MessageType::P2P)
}
fn is_expected_type_of_msg(&self, msg_type: MessageType) -> bool {
self.expected_msg_type == msg_type
}
}
impl<M> MessagesStore for RoundInput<M>
where
M: 'static,
{
type Msg = M;
type Output = RoundMsgs<M>;
type Error = RoundInputError;
fn add_message(&mut self, msg: Incoming<Self::Msg>) -> Result<(), Self::Error> {
if !self.is_expected_type_of_msg(msg.msg_type) {
return Err(RoundInputError::MismatchedMessageType {
msg_id: msg.id,
expected: self.expected_msg_type,
actual: msg.msg_type,
});
}
if msg.sender == self.i {
return Ok(());
}
let index = usize::from(if msg.sender < self.i {
msg.sender
} else {
msg.sender - 1
});
match self.messages.get_mut(index) {
Some(vacant @ None) => {
*vacant = Some(msg.msg);
self.messages_ids[index] = msg.id;
self.left_messages -= 1;
Ok(())
}
Some(Some(_)) => Err(RoundInputError::AttemptToOverwriteReceivedMsg {
msgs_ids: [self.messages_ids[index], msg.id],
sender: msg.sender,
}),
None => Err(RoundInputError::SenderIndexOutOfRange {
msg_id: msg.id,
sender: msg.sender,
n: self.n,
}),
}
}
fn wants_more(&self) -> bool {
self.left_messages > 0
}
fn output(self) -> Result<Self::Output, Self> {
if self.left_messages > 0 {
Err(self)
} else {
Ok(RoundMsgs {
i: self.i,
ids: self.messages_ids,
messages: self.messages.into_iter().flatten().collect(),
})
}
}
}
impl<M> RoundMsgs<M> {
pub fn into_vec_without_me(self) -> Vec<M> {
self.messages
}
pub fn into_vec_including_me(mut self, my_msg: M) -> Vec<M> {
self.messages.insert(usize::from(self.i), my_msg);
self.messages
}
pub fn iter(&self) -> impl Iterator<Item = &M> {
self.messages.iter()
}
pub fn iter_including_me<'m>(&'m self, my_msg: &'m M) -> impl Iterator<Item = &'m M> {
self.messages
.iter()
.take(usize::from(self.i))
.chain(iter::once(my_msg))
.chain(self.messages.iter().skip(usize::from(self.i)))
}
pub fn into_iter_including_me(self, my_msg: M) -> impl Iterator<Item = M> {
struct InsertsAfter<T, It> {
offset: usize,
inner: It,
item: Option<T>,
}
impl<T, It: Iterator<Item = T>> Iterator for InsertsAfter<T, It> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.offset == 0 {
match self.item.take() {
Some(x) => Some(x),
None => self.inner.next(),
}
} else {
self.offset -= 1;
self.inner.next()
}
}
}
InsertsAfter {
offset: usize::from(self.i),
inner: self.messages.into_iter(),
item: Some(my_msg),
}
}
pub fn into_iter_indexed(self) -> impl Iterator<Item = (PartyIndex, MsgId, M)> {
let parties_indexes = (0..self.i).chain(self.i + 1..);
parties_indexes
.zip(self.ids)
.zip(self.messages)
.map(|((party_ind, msg_id), msg)| (party_ind, msg_id, msg))
}
pub fn iter_indexed(&self) -> impl Iterator<Item = (PartyIndex, MsgId, &M)> {
let parties_indexes = (0..self.i).chain(self.i + 1..);
parties_indexes
.zip(&self.ids)
.zip(&self.messages)
.map(|((party_ind, msg_id), msg)| (party_ind, *msg_id, msg))
}
}
#[derive(Debug, thiserror::Error)]
pub enum RoundInputError {
#[error("party {sender} tried to overwrite message")]
AttemptToOverwriteReceivedMsg {
msgs_ids: [MsgId; 2],
sender: PartyIndex,
},
#[error("sender index is out of range: sender={sender}, n={n}")]
SenderIndexOutOfRange {
msg_id: MsgId,
sender: PartyIndex,
n: u16,
},
#[error("expected message {expected:?}, got {actual:?}")]
MismatchedMessageType {
msg_id: MsgId,
expected: MessageType,
actual: MessageType,
},
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use matches::assert_matches;
use crate::rounds_router::store::MessagesStore;
use crate::{Incoming, MessageType};
use super::{RoundInput, RoundInputError};
#[derive(Debug, Clone, PartialEq)]
pub struct Msg(u16);
#[test]
fn store_outputs_received_messages() {
let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
let msgs = (0..5)
.map(|s| Incoming {
id: s.into(),
sender: s,
msg_type: MessageType::P2P,
msg: Msg(10 + s),
})
.filter(|incoming| incoming.sender != 3)
.collect::<Vec<_>>();
for msg in &msgs {
assert!(store.wants_more());
store.add_message(msg.clone()).unwrap();
}
assert!(!store.wants_more());
let received = store.output().unwrap();
let msgs: Vec<_> = msgs.into_iter().map(|msg| msg.msg).collect();
assert_eq!(received.clone().into_vec_without_me(), msgs);
let received = received.into_vec_including_me(Msg(13));
assert_eq!(received[0..3], msgs[0..3]);
assert_eq!(received[3], Msg(13));
assert_eq!(received[4..5], msgs[3..4]);
}
#[test]
fn store_returns_error_if_sender_index_is_out_of_range() {
let mut store = RoundInput::new(3, 5, MessageType::P2P);
let error = store
.add_message(Incoming {
id: 0,
sender: 5,
msg_type: MessageType::P2P,
msg: Msg(123),
})
.unwrap_err();
assert_matches!(
error,
RoundInputError::SenderIndexOutOfRange { msg_id, sender, n } if msg_id == 0 && sender == 5 && n == 5
);
}
#[test]
fn store_returns_error_if_incoming_msg_overwrites_already_received_one() {
let mut store = RoundInput::new(0, 3, MessageType::P2P);
store
.add_message(Incoming {
id: 0,
sender: 1,
msg_type: MessageType::P2P,
msg: Msg(11),
})
.unwrap();
let error = store
.add_message(Incoming {
id: 1,
sender: 1,
msg_type: MessageType::P2P,
msg: Msg(112),
})
.unwrap_err();
assert_matches!(error, RoundInputError::AttemptToOverwriteReceivedMsg { msgs_ids, sender } if msgs_ids[0] == 0 && msgs_ids[1] == 1 && sender == 1);
store
.add_message(Incoming {
id: 2,
sender: 2,
msg_type: MessageType::P2P,
msg: Msg(22),
})
.unwrap();
let output = store.output().unwrap().into_vec_without_me();
assert_eq!(output, [Msg(11), Msg(22)]);
}
#[test]
fn store_returns_error_if_tried_to_output_before_receiving_enough_messages() {
let mut store = RoundInput::<Msg>::new(3, 5, MessageType::P2P);
let msgs = (0..5)
.map(|s| Incoming {
id: s.into(),
sender: s,
msg_type: MessageType::P2P,
msg: Msg(10 + s),
})
.filter(|incoming| incoming.sender != 3);
for msg in msgs {
assert!(store.wants_more());
store = store.output().unwrap_err();
store.add_message(msg).unwrap();
}
let _ = store.output().unwrap();
}
#[test]
fn store_returns_error_if_message_type_mismatched() {
let mut store = RoundInput::<Msg>::p2p(3, 5);
let err = store
.add_message(Incoming {
id: 0,
sender: 0,
msg_type: MessageType::Broadcast,
msg: Msg(1),
})
.unwrap_err();
assert_matches!(
err,
RoundInputError::MismatchedMessageType {
msg_id: 0,
expected: MessageType::P2P,
actual: MessageType::Broadcast
}
);
let mut store = RoundInput::<Msg>::broadcast(3, 5);
let err = store
.add_message(Incoming {
id: 0,
sender: 0,
msg_type: MessageType::P2P,
msg: Msg(1),
})
.unwrap_err();
assert_matches!(
err,
RoundInputError::MismatchedMessageType {
msg_id: 0,
expected: MessageType::Broadcast,
actual: MessageType::P2P,
}
);
for sender in 0u16..5 {
store
.add_message(Incoming {
id: 0,
sender,
msg_type: MessageType::Broadcast,
msg: Msg(1),
})
.unwrap();
}
let mut store = RoundInput::<Msg>::broadcast(3, 5);
let err = store
.add_message(Incoming {
id: 0,
sender: 0,
msg_type: MessageType::P2P,
msg: Msg(1),
})
.unwrap_err();
assert_matches!(
err,
RoundInputError::MismatchedMessageType {
msg_id: 0,
expected: MessageType::Broadcast,
actual,
} if actual == MessageType::P2P
);
store
.add_message(Incoming {
id: 0,
sender: 0,
msg_type: MessageType::Broadcast,
msg: Msg(1),
})
.unwrap();
}
#[test]
fn into_iter_including_me() {
let me = -10_isize;
let messages = alloc::vec![1, 2, 3];
let me_first = super::RoundMsgs {
i: 0,
ids: alloc::vec![1, 2, 3],
messages: messages.clone(),
};
let all = me_first.into_iter_including_me(me).collect::<Vec<_>>();
assert_eq!(all, [-10, 1, 2, 3]);
let me_second = super::RoundMsgs {
i: 1,
ids: alloc::vec![0, 2, 3],
messages: messages.clone(),
};
let all = me_second.into_iter_including_me(me).collect::<Vec<_>>();
assert_eq!(all, [1, -10, 2, 3]);
let me_last = super::RoundMsgs {
i: 3,
ids: alloc::vec![0, 1, 2],
messages: messages.clone(),
};
let all = me_last.into_iter_including_me(me).collect::<Vec<_>>();
assert_eq!(all, [1, 2, 3, -10]);
}
}