use crate::*;
#[derive(Default, Debug, Clone, Copy)]
pub(crate) struct MessageHandle {
pub(crate) id: MessageId,
pub(crate) frag_index: Option<u16>,
pub(crate) channel: u8,
}
impl MessageHandle {
pub fn id(&self) -> MessageId {
self.id
}
pub fn parent_id(&self) -> Option<MessageId> {
self.frag_index
.map(|frag_index| MessageId(self.id.0.wrapping_sub(frag_index)))
}
}
pub(crate) struct MessageDispatcher {
next_message_id: u16,
sent_frag_map: SentFragMap,
messages_in_packets: SequenceBuffer<Vec<MessageHandle>>,
message_reassembler: MessageReassembler,
message_inbox: smallmap::Map<u8, Vec<ReceivedMessage>>,
ack_inbox: smallmap::Map<u8, Vec<MessageId>>,
}
impl MessageDispatcher {
pub(crate) fn new(config: &PicklebackConfig) -> Self {
Self {
next_message_id: 0,
sent_frag_map: SentFragMap::with_capacity(config.sent_frag_map_size),
messages_in_packets: SequenceBuffer::with_capacity(config.received_packets_buffer_size),
message_reassembler: MessageReassembler::default(),
message_inbox: smallmap::Map::default(),
ack_inbox: smallmap::Map::default(),
}
}
pub(crate) fn process_received_message(&mut self, msg: Message) {
trace!("Dispatcher::process_received_message: {msg:?}");
let received_msg = if msg.fragment().is_none() {
Some(ReceivedMessage::new_single(msg))
} else {
self.message_reassembler.add_fragment(msg)
};
if let Some(msg) = received_msg {
trace!("✅ Adding msg to inbox");
self.message_inbox
.entry(msg.channel())
.or_default()
.push(msg);
}
}
pub(crate) fn drain_received_messages(
&mut self,
channel: u8,
) -> std::vec::Drain<'_, ReceivedMessage> {
self.message_inbox.entry(channel).or_default().drain(..)
}
pub(crate) fn drain_message_acks(&mut self, channel: u8) -> std::vec::Drain<'_, MessageId> {
self.ack_inbox.entry(channel).or_default().drain(..)
}
pub(crate) fn set_packet_message_handles(
&mut self,
packet_handle: PacketId,
message_handles: Vec<MessageHandle>,
) -> Result<(), PicklebackError> {
trace!(">>> {packet_handle:?} CONTAINS msg ids: {message_handles:?}");
self.messages_in_packets
.insert(packet_handle.0, message_handles)?;
Ok(())
}
pub(crate) fn acked_packet(&mut self, packet_handle: PacketId, channel_list: &mut ChannelList) {
if let Some(msg_handles) = self.messages_in_packets.remove(packet_handle.0) {
trace!("Acked packet: {packet_handle:?} --> acked msgs: {msg_handles:?}");
for msg_handle in &msg_handles {
channel_list
.get_mut(msg_handle.channel)
.unwrap()
.message_ack_received(msg_handle);
if let Some(parent_id) = msg_handle.parent_id() {
if self
.sent_frag_map
.ack_fragment_message(parent_id, msg_handle.id())
{
self.ack_inbox
.entry(msg_handle.channel)
.or_default()
.push(parent_id);
} else {
trace!("got fragment ack for parent {parent_id:?}, but not all yet {msg_handle:?} ");
}
} else {
self.ack_inbox
.entry(msg_handle.channel)
.or_default()
.push(msg_handle.id());
}
}
}
}
pub(crate) fn add_message_to_channel(
&mut self,
pool: &BufPool,
channel: &mut Channel,
payload: &[u8],
) -> Result<MessageId, PicklebackError> {
if payload.len() <= 1024 {
let id = self.next_message_id();
channel.enqueue_message(pool, id, payload, Fragmented::No);
Ok(id)
} else {
self.add_large_message_to_channel(pool, channel, payload)
}
}
fn add_large_message_to_channel(
&mut self,
pool: &BufPool,
channel: &mut Channel,
payload: &[u8],
) -> Result<MessageId, PicklebackError> {
assert!(payload.len() > 1024);
let full_payload_size = payload.len();
let remainder = if full_payload_size % 1024 > 0 { 1 } else { 0 };
let num_fragments = ((full_payload_size / 1024) + remainder) as u16;
let mut frag_ids = Vec::new();
let mut id = self.next_message_id();
let parent_id = id;
for index in 0..num_fragments {
let payload_size = if index == num_fragments - 1 {
full_payload_size as u16 - (num_fragments - 1) * 1024
} else {
1024_u16
};
if index > 0 {
id = self.next_message_id();
}
frag_ids.push(id);
trace!("Adding frag msg {id:?} frag:{index}/{num_fragments}");
let fragment = Fragment {
index,
num_fragments,
parent_id,
};
let start = index as usize * 1024;
let end = start + payload_size as usize;
let frag_payload = &payload[start..end];
channel.enqueue_message(pool, id, frag_payload, Fragmented::Yes(fragment));
}
self.sent_frag_map
.insert_fragmented_message(parent_id, frag_ids)?;
Ok(parent_id)
}
fn next_message_id(&mut self) -> MessageId {
let ret = self.next_message_id;
self.next_message_id = self.next_message_id.wrapping_add(1);
MessageId(ret)
}
}
#[derive(Default, Clone, PartialEq)]
pub(crate) enum FragAckStatus {
#[default]
Unknown,
Complete,
Partial(Vec<MessageId>),
}
pub struct SentFragMap {
m: SequenceBuffer<FragAckStatus>,
}
impl SentFragMap {
pub(crate) fn with_capacity(size: usize) -> Self {
Self {
m: SequenceBuffer::with_capacity(size),
}
}
pub(crate) fn insert_fragmented_message(
&mut self,
id: MessageId,
fragment_ids: Vec<MessageId>,
) -> Result<(), PicklebackError> {
match self.m.insert(id.0, FragAckStatus::Partial(fragment_ids)) {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
pub fn ack_fragment_message(&mut self, parent_id: MessageId, fragment_id: MessageId) -> bool {
let Some(entry) = self.m.get_mut(parent_id.0) else {
return false;
};
let ret = match entry {
FragAckStatus::Complete => {
trace!("Message {parent_id:?} already completed arrived.");
false
}
FragAckStatus::Unknown => {
warn!("Message {parent_id:?} unknown to frag map");
false
}
FragAckStatus::Partial(ref mut remaining) => {
remaining.retain(|id| *id != fragment_id);
info!("Remaining fragment indexs for parent {parent_id:?}, fragment_id={fragment_id:?} = {remaining:?}");
remaining.is_empty()
}
};
if ret {
self.m.insert(parent_id.0, FragAckStatus::Complete).unwrap();
trace!("Message fully acked, all fragments accounted for {parent_id:?}");
}
ret
}
}