use std::collections::VecDeque;
use std::collections::{BTreeMap, HashSet};
use std::time::Duration;
use bytes::Bytes;
use crossbeam_channel::Receiver;
use tracing::{info, trace};
use crate::channel::builder::ReliableSettings;
use crate::channel::senders::fragment_sender::FragmentSender;
use crate::channel::senders::ChannelSend;
use crate::packet::message::{FragmentData, MessageAck, MessageId, SingleData};
use crate::shared::ping::manager::PingManager;
use crate::shared::tick_manager::TickManager;
use crate::shared::time_manager::{TimeManager, WrappedTime};
pub struct FragmentAck {
data: FragmentData,
acked: bool,
last_sent: Option<WrappedTime>,
}
pub enum UnackedMessage {
Single {
bytes: Bytes,
last_sent: Option<WrappedTime>,
},
Fragmented(Vec<FragmentAck>),
}
pub struct ReliableSender {
reliable_settings: ReliableSettings,
unacked_messages: BTreeMap<MessageId, UnackedMessage>,
next_send_message_id: MessageId,
single_messages_to_send: VecDeque<SingleData>,
fragmented_messages_to_send: VecDeque<FragmentData>,
message_ids_to_send: HashSet<MessageAck>,
fragment_sender: FragmentSender,
current_rtt: Duration,
current_time: WrappedTime,
}
impl ReliableSender {
pub fn new(reliable_settings: ReliableSettings) -> Self {
Self {
reliable_settings,
unacked_messages: Default::default(),
next_send_message_id: MessageId(0),
single_messages_to_send: Default::default(),
fragmented_messages_to_send: Default::default(),
message_ids_to_send: Default::default(),
fragment_sender: FragmentSender::new(),
current_rtt: Duration::default(),
current_time: WrappedTime::default(),
}
}
}
impl ChannelSend for ReliableSender {
fn update(&mut self, time_manager: &TimeManager, ping_manager: &PingManager, _: &TickManager) {
self.current_time = time_manager.current_time();
self.current_rtt = ping_manager.rtt();
}
fn buffer_send(&mut self, message: Bytes) -> Option<MessageId> {
let message_id = self.next_send_message_id;
let unacked_message = if message.len() > self.fragment_sender.fragment_size {
let fragments = self
.fragment_sender
.build_fragments(message_id, None, message);
UnackedMessage::Fragmented(
fragments
.into_iter()
.map(|fragment| FragmentAck {
data: fragment,
acked: false,
last_sent: None,
})
.collect(),
)
} else {
UnackedMessage::Single {
bytes: message,
last_sent: None,
}
};
self.unacked_messages.insert(message_id, unacked_message);
self.next_send_message_id += 1;
Some(message_id)
}
fn send_packet(&mut self) -> (VecDeque<SingleData>, VecDeque<FragmentData>) {
self.message_ids_to_send.clear();
(
std::mem::take(&mut self.single_messages_to_send),
std::mem::take(&mut self.fragmented_messages_to_send),
)
}
fn collect_messages_to_send(&mut self) {
let resend_delay =
chrono::Duration::from_std(self.reliable_settings.resend_delay(self.current_rtt))
.unwrap();
trace!("resend_delay: {:?}", resend_delay);
let should_send = |last_sent: &Option<WrappedTime>| -> bool {
match last_sent {
None => true,
Some(last_sent) => self.current_time - *last_sent > resend_delay,
}
};
for (message_id, unacked_message) in self.unacked_messages.iter_mut() {
match unacked_message {
UnackedMessage::Single {
bytes,
ref mut last_sent,
} => {
if should_send(last_sent) {
let message_info = MessageAck {
message_id: *message_id,
fragment_id: None,
};
if !self.message_ids_to_send.contains(&message_info) {
let message = SingleData::new(Some(*message_id), bytes.clone());
self.single_messages_to_send.push_back(message);
self.message_ids_to_send.insert(message_info);
*last_sent = Some(self.current_time);
}
}
}
UnackedMessage::Fragmented(fragment_acks) => {
fragment_acks
.iter_mut()
.filter(|f| !f.acked && should_send(&f.last_sent))
.for_each(|f| {
let message_info = MessageAck {
message_id: *message_id,
fragment_id: Some(f.data.fragment_id),
};
if !self.message_ids_to_send.contains(&message_info) {
let message = f.data.clone();
self.fragmented_messages_to_send.push_back(message);
self.message_ids_to_send.insert(message_info);
f.last_sent = Some(self.current_time);
}
})
}
}
}
}
fn notify_message_delivered(&mut self, message_ack: &MessageAck) {
if let Some(unacked_message) = self.unacked_messages.get_mut(&message_ack.message_id) {
match unacked_message {
UnackedMessage::Single { .. } => {
if message_ack.fragment_id.is_some() {
panic!(
"Received a message ack for a fragment but message is a single message"
)
}
self.unacked_messages.remove(&message_ack.message_id);
}
UnackedMessage::Fragmented(fragment_acks) => {
let Some(fragment_id) = message_ack.fragment_id else {
panic!("Received a message ack for a single message but message is a fragmented message")
};
if !fragment_acks[fragment_id as usize].acked {
fragment_acks[fragment_id as usize].acked = true;
if fragment_acks.iter().all(|f| f.acked) {
self.unacked_messages.remove(&message_ack.message_id);
}
}
}
}
}
}
fn has_messages_to_send(&self) -> bool {
!self.single_messages_to_send.is_empty() || !self.fragmented_messages_to_send.is_empty()
}
fn subscribe_acks(&mut self) -> Receiver<MessageId> {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use bytes::Bytes;
use crate::channel::builder::ReliableSettings;
use crate::packet::message::SingleData;
use super::*;
#[test]
fn test_reliable_sender_internals() {
let mut sender = ReliableSender::new(ReliableSettings {
rtt_resend_factor: 1.5,
rtt_resend_min_delay: Duration::from_millis(100),
});
sender.current_rtt = Duration::from_millis(100);
sender.current_time = WrappedTime::new(0);
let message1 = Bytes::from("hello");
sender.buffer_send(message1.clone());
assert_eq!(sender.unacked_messages.len(), 1);
assert_eq!(sender.next_send_message_id, MessageId(1));
sender.collect_messages_to_send();
assert_eq!(sender.single_messages_to_send.len(), 1);
sender.current_time += Duration::from_millis(100);
sender.collect_messages_to_send();
assert_eq!(sender.single_messages_to_send.len(), 1);
sender.current_time += Duration::from_millis(200);
sender.collect_messages_to_send();
assert_eq!(sender.single_messages_to_send.len(), 1);
assert_eq!(
sender.single_messages_to_send.front().unwrap(),
&SingleData::new(Some(MessageId(0)), message1.clone())
);
sender.notify_message_delivered(&MessageAck {
message_id: MessageId(0),
fragment_id: None,
});
assert_eq!(sender.unacked_messages.len(), 0);
sender.current_time += Duration::from_millis(200);
assert_eq!(sender.single_messages_to_send.len(), 1);
}
}