use std::{borrow::Borrow, time::Duration};
use crate::{
message::{MessageTag, MsgHdr, MsgId},
relay::{MessageSendError, Relay},
};
use super::traits::ProtocolParticipant;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum RoundMode {
Broadcast,
P2P,
}
impl RoundMode {
fn receiver(self, me: usize) -> Option<usize> {
match self {
Self::Broadcast => None,
Self::P2P => Some(me),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct MessageRound {
pending: Vec<PendingMessage>,
ttl: Duration,
}
#[derive(Clone, Debug)]
struct PendingMessage {
id: MsgId,
sender: usize,
}
impl MessageRound {
pub fn new<P: ProtocolParticipant>(setup: &P, tag: MessageTag) -> Self {
Self::broadcast(setup, tag)
}
pub fn broadcast<P: ProtocolParticipant>(
setup: &P,
tag: MessageTag,
) -> Self {
Self::from_parties(
setup,
tag,
setup.all_other_parties(),
RoundMode::Broadcast,
)
}
pub fn p2p<P: ProtocolParticipant>(setup: &P, tag: MessageTag) -> Self {
Self::from_parties(
setup,
tag,
setup.all_other_parties(),
RoundMode::P2P,
)
}
pub fn from_parties<P, I, T>(
setup: &P,
tag: MessageTag,
parties: I,
mode: RoundMode,
) -> Self
where
P: ProtocolParticipant,
I: IntoIterator<Item = T>,
T: Borrow<usize>,
{
let my_party_index = setup.participant_index();
let receiver = mode.receiver(my_party_index);
let mut pending = parties
.into_iter()
.map(|sender_index| *sender_index.borrow())
.filter(|sender_index| *sender_index != my_party_index)
.map(|sender| PendingMessage {
id: setup.msg_id_from(sender, receiver, tag),
sender,
})
.collect::<Vec<_>>();
pending.sort_unstable_by(|a, b| a.id.as_slice().cmp(b.id.as_slice()));
pending.dedup_by(|a, b| a.id == b.id);
Self {
pending,
ttl: setup.message_ttl(),
}
}
fn pending_position(&self, id: &MsgId) -> Result<usize, usize> {
self.pending.binary_search_by(|pending| {
pending.id.as_slice().cmp(id.as_slice())
})
}
pub fn is_pending(&self, id: &MsgId) -> bool {
self.pending_position(id).is_ok()
}
pub fn pending_sender(&self, id: &MsgId) -> Option<usize> {
self.pending_position(id)
.ok()
.map(|idx| self.pending[idx].sender)
}
pub fn pending_sender_message(&self, msg: &[u8]) -> Option<usize> {
<&MsgHdr>::try_from(msg)
.ok()
.and_then(|hdr| self.pending_sender(hdr.id()))
}
pub fn mark_received(&mut self, id: &MsgId) -> bool {
self.mark_received_with_sender(id).is_some()
}
pub fn mark_received_with_sender(&mut self, id: &MsgId) -> Option<usize> {
self.pending_position(id)
.ok()
.map(|idx| self.pending.remove(idx).sender)
}
pub fn mark_received_message(&mut self, msg: &[u8]) -> bool {
self.mark_received_message_with_sender(msg).is_some()
}
pub fn mark_received_message_with_sender(
&mut self,
msg: &[u8],
) -> Option<usize> {
<&MsgId>::try_from(msg)
.ok()
.and_then(|id| self.mark_received_with_sender(id))
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
pub fn is_complete(&self) -> bool {
self.pending.is_empty()
}
pub async fn ask_pending<R: Relay>(
&self,
relay: &R,
) -> Result<usize, MessageSendError> {
let count = self.pending.len();
for pending in &self.pending {
relay.ask(&pending.id, self.ttl).await?;
}
Ok(count)
}
}