use std::{
collections::{BTreeMap, VecDeque},
time::Duration,
};
use bytes::Bytes;
use crate::{
channel::SliceConstructor,
error::ChannelError,
packet::{Packet, Slice, SLICE_SIZE},
};
#[derive(Debug)]
pub struct SendChannelUnreliable {
channel_id: u8,
unreliable_messages: VecDeque<Bytes>,
sliced_message_id: u64,
max_memory_usage_bytes: usize,
memory_usage_bytes: usize,
}
#[derive(Debug)]
pub struct ReceiveChannelUnreliable {
channel_id: u8,
messages: VecDeque<Bytes>,
slices: BTreeMap<u64, SliceConstructor>,
slices_last_received: BTreeMap<u64, Duration>,
max_memory_usage_bytes: usize,
memory_usage_bytes: usize,
}
impl SendChannelUnreliable {
pub fn new(channel_id: u8, max_memory_usage_bytes: usize) -> Self {
Self {
channel_id,
unreliable_messages: VecDeque::new(),
sliced_message_id: 0,
max_memory_usage_bytes,
memory_usage_bytes: 0,
}
}
pub fn can_send_message(&self, size_bytes: usize) -> bool {
size_bytes + self.memory_usage_bytes <= self.max_memory_usage_bytes
}
pub fn available_memory(&self) -> usize {
self.max_memory_usage_bytes - self.memory_usage_bytes
}
pub fn get_packets_to_send(&mut self, packet_sequence: &mut u64, available_bytes: &mut u64) -> Vec<Packet> {
let mut packets: Vec<Packet> = vec![];
let mut small_messages: Vec<Bytes> = vec![];
let mut small_messages_bytes = 0;
while let Some(message) = self.unreliable_messages.pop_front() {
self.memory_usage_bytes -= message.len();
if *available_bytes < message.len() as u64 {
continue;
}
*available_bytes -= message.len() as u64;
if message.len() > SLICE_SIZE {
let num_slices = message.len().div_ceil(SLICE_SIZE);
for slice_index in 0..num_slices {
let start = slice_index * SLICE_SIZE;
let end = if slice_index == num_slices - 1 { message.len() } else { (slice_index + 1) * SLICE_SIZE };
let payload = message.slice(start..end);
let slice = Slice {
message_id: self.sliced_message_id,
slice_index,
num_slices,
payload,
};
packets.push(Packet::UnreliableSlice {
sequence: *packet_sequence,
channel_id: self.channel_id,
slice,
});
*packet_sequence += 1;
}
self.sliced_message_id += 1;
} else {
let serialized_size = message.len() + octets::varint_len(message.len() as u64);
if small_messages_bytes + serialized_size > SLICE_SIZE {
packets.push(Packet::SmallUnreliable {
sequence: *packet_sequence,
channel_id: self.channel_id,
messages: std::mem::take(&mut small_messages),
});
*packet_sequence += 1;
small_messages_bytes = 0;
}
small_messages_bytes += serialized_size;
small_messages.push(message);
}
}
if !small_messages.is_empty() {
packets.push(Packet::SmallUnreliable {
sequence: *packet_sequence,
channel_id: self.channel_id,
messages: std::mem::take(&mut small_messages),
});
*packet_sequence += 1;
}
packets
}
pub fn send_message(&mut self, message: Bytes) {
if self.memory_usage_bytes + message.len() > self.max_memory_usage_bytes {
log::warn!(
"dropped unreliable message sent because channel {} is memory limited",
self.channel_id
);
return;
}
let num_fragments = message.len() / SLICE_SIZE;
if num_fragments > 20 {
log::warn!(
"Sending an unreliable message with {num_fragments} fragments, messages with this many fragments are susceptible to packet loss. \
Consider breaking your message into smaller ones or using a reliable channel");
}
self.memory_usage_bytes += message.len();
self.unreliable_messages.push_back(message);
}
}
impl ReceiveChannelUnreliable {
pub fn new(channel_id: u8, max_memory_usage_bytes: usize) -> Self {
Self {
channel_id,
slices: BTreeMap::new(),
slices_last_received: BTreeMap::new(),
messages: VecDeque::new(),
memory_usage_bytes: 0,
max_memory_usage_bytes,
}
}
pub fn process_message(&mut self, message: Bytes) {
if self.memory_usage_bytes + message.len() > self.max_memory_usage_bytes {
log::warn!(
"dropped unreliable message received because channel {} is memory limited",
self.channel_id
);
return;
}
self.memory_usage_bytes += message.len();
self.messages.push_back(message);
}
pub fn process_slice(&mut self, slice: Slice, current_time: Duration) -> Result<(), ChannelError> {
if !self.slices.contains_key(&slice.message_id) {
let message_len = slice.num_slices * SLICE_SIZE;
if self.memory_usage_bytes + message_len > self.max_memory_usage_bytes {
log::warn!(
"dropped unreliable slice message received because channel {} is memory limited",
self.channel_id
);
return Ok(());
}
self.memory_usage_bytes += message_len;
}
let slice_constructor = self
.slices
.entry(slice.message_id)
.or_insert_with(|| SliceConstructor::new(slice.message_id, slice.num_slices));
if let Some(message) = slice_constructor.process_slice(slice.slice_index, &slice.payload)? {
self.slices.remove(&slice.message_id);
self.slices_last_received.remove(&slice.message_id);
self.memory_usage_bytes -= slice.num_slices * SLICE_SIZE;
self.memory_usage_bytes += message.len();
self.messages.push_back(message);
} else {
self.slices_last_received.insert(slice.message_id, current_time);
}
Ok(())
}
pub fn discard_incomplete_old_slices(&mut self, current_time: Duration) {
let mut lost_messages: Vec<u64> = Vec::new();
for (&message_id, last_received) in self.slices_last_received.iter() {
const DISCARD_AFTER: Duration = Duration::from_secs(3);
if current_time - *last_received >= DISCARD_AFTER {
lost_messages.push(message_id);
} else {
break;
}
}
for message_id in lost_messages.iter() {
self.slices_last_received.remove(message_id);
let slice = self.slices.remove(message_id).expect("discarded slice should exist");
self.memory_usage_bytes -= slice.num_slices * SLICE_SIZE;
}
}
pub fn receive_message(&mut self) -> Option<Bytes> {
if let Some(message) = self.messages.pop_front() {
self.memory_usage_bytes -= message.len();
return Some(message);
};
None
}
}
#[cfg(test)]
mod tests {
use octets::OctetsMut;
use super::*;
#[test]
fn small_packet() {
let max_memory: usize = 10000;
let mut available_bytes = u64::MAX;
let mut sequence: u64 = 0;
let mut recv = ReceiveChannelUnreliable::new(0, max_memory);
let mut send = SendChannelUnreliable::new(0, max_memory);
let message1 = vec![1, 2, 3];
let message2 = vec![3, 4, 5];
send.send_message(message1.clone().into());
send.send_message(message2.clone().into());
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
for packet in packets {
let Packet::SmallUnreliable { messages, .. } = packet else {
unreachable!();
};
for message in messages {
recv.process_message(message);
}
}
let new_message1 = recv.receive_message().unwrap();
let new_message2 = recv.receive_message().unwrap();
assert!(recv.receive_message().is_none());
assert_eq!(message1, new_message1);
assert_eq!(message2, new_message2);
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert!(packets.is_empty());
}
#[test]
fn slice_packet() {
let max_memory: usize = 10000;
let mut available_bytes = u64::MAX;
let mut sequence: u64 = 0;
let current_time = Duration::ZERO;
let mut recv = ReceiveChannelUnreliable::new(0, max_memory);
let mut send = SendChannelUnreliable::new(0, max_memory);
let message = vec![5; SLICE_SIZE * 3];
send.send_message(message.clone().into());
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
for packet in packets {
let Packet::UnreliableSlice { slice, .. } = packet else {
unreachable!();
};
recv.process_slice(slice, current_time).unwrap();
}
let new_message = recv.receive_message().unwrap();
assert!(recv.receive_message().is_none());
assert_eq!(message, new_message);
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert!(packets.is_empty());
}
#[test]
fn max_memory() {
let mut sequence: u64 = 0;
let mut available_bytes = u64::MAX;
let mut recv = ReceiveChannelUnreliable::new(0, 50);
let mut send = SendChannelUnreliable::new(0, 40);
let message = vec![5; 50];
send.send_message(message.clone().into());
send.send_message(message.into());
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
for packet in packets {
let Packet::SmallUnreliable { messages, .. } = packet else {
unreachable!();
};
assert_eq!(messages.len(), 1);
for message in messages {
recv.process_message(message);
}
}
assert!(recv.receive_message().is_none());
}
#[test]
fn available_bytes() {
let mut sequence: u64 = 0;
let mut send = SendChannelUnreliable::new(0, usize::MAX);
let message: Bytes = vec![0u8; 100].into();
send.send_message(message.clone());
let mut available_bytes: u64 = 50;
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert_eq!(packets.len(), 0);
let mut available_bytes: u64 = u64::MAX;
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert_eq!(packets.len(), 0);
send.send_message(message.clone());
send.send_message(message);
let mut available_bytes: u64 = 100;
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert_eq!(packets.len(), 1);
let mut available_bytes: u64 = u64::MAX;
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert_eq!(packets.len(), 0);
}
#[test]
fn small_packet_max_size() {
let mut sequence: u64 = 0;
let mut available_bytes = u64::MAX;
let mut send = SendChannelUnreliable::new(0, usize::MAX);
let message: Bytes = vec![0, 1, 2, 3].into();
for _ in 0..400 {
send.send_message(message.clone());
}
let packets = send.get_packets_to_send(&mut sequence, &mut available_bytes);
assert_eq!(packets.len(), 2);
let mut buffer = [0u8; 1400];
for packet in packets {
let mut oct = OctetsMut::with_slice(&mut buffer);
let len = packet.to_bytes(&mut oct).unwrap();
assert!(len < 1300);
}
}
}