use crate::channel::reliable::{ReceiveChannelReliable, SendChannelReliable};
use crate::channel::unreliable::{ReceiveChannelUnreliable, SendChannelUnreliable};
use crate::channel::{ChannelConfig, DefaultChannel, SendType};
use crate::connection_stats::ConnectionStats;
use crate::error::DisconnectReason;
use crate::packet::{Packet, Payload};
use bytes::Bytes;
use octets::OctetsMut;
use std::collections::BTreeMap;
use std::ops::Range;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub available_bytes_per_tick: u64,
pub server_channels_config: Vec<ChannelConfig>,
pub client_channels_config: Vec<ChannelConfig>,
}
impl ConnectionConfig {
pub fn from_channels(server: Vec<ChannelConfig>, client: Vec<ChannelConfig>) -> Self {
Self {
available_bytes_per_tick: 60_000,
server_channels_config: server,
client_channels_config: client,
}
}
pub fn from_shared_channels(channels: Vec<ChannelConfig>) -> Self {
Self::from_channels(channels.clone(), channels)
}
pub fn test() -> Self {
Self::from_shared_channels(DefaultChannel::config())
}
pub fn downgrade_to_unreliable(&mut self) {
self.server_channels_config.iter_mut().for_each(|c| match c.send_type {
SendType::Unreliable { .. } => (),
_ => {
c.send_type = SendType::Unreliable {
ordered_reliable_substrate: true,
};
}
});
self.client_channels_config.iter_mut().for_each(|c| match c.send_type {
SendType::Unreliable { .. } => (),
_ => {
c.send_type = SendType::Unreliable {
ordered_reliable_substrate: true,
};
}
});
}
}
#[derive(Debug, Clone)]
struct PacketSent {
sent_at: Duration,
info: PacketSentInfo,
}
#[derive(Debug, Clone)]
enum PacketSentInfo {
None,
ReliableMessages {
channel_id: u8,
message_ids: Vec<u64>,
},
ReliableSliceMessage {
channel_id: u8,
message_id: u64,
slice_index: usize,
},
Ack {
largest_acked_packet: u64,
},
}
#[derive(Debug)]
enum ChannelOrder {
Reliable(u8),
Unreliable(u8),
}
#[derive(Debug)]
enum SendChannel {
Empty,
Unreliable(SendChannelUnreliable),
Reliable(SendChannelReliable),
}
#[derive(Debug)]
enum ReceiveChannel {
Empty,
Unreliable(ReceiveChannelUnreliable),
Reliable(ReceiveChannelReliable),
}
pub struct NetworkInfo {
pub rtt: f64,
pub packet_loss: f64,
pub bytes_sent_per_second: f64,
pub bytes_received_per_second: f64,
}
#[derive(Debug)]
pub enum RenetConnectionStatus {
Connected,
Connecting,
Disconnected { reason: DisconnectReason },
}
#[derive(Debug)]
#[cfg_attr(feature = "bevy", derive(bevy_ecs::resource::Resource))]
pub struct RenetClient {
has_reliable_socket: bool,
packet_sequence: u64,
current_time: Duration,
sent_packets: BTreeMap<u64, PacketSent>,
pending_acks: Vec<Range<u64>>,
channel_send_order: Vec<ChannelOrder>,
send_channels: Vec<SendChannel>,
receive_channels: Vec<ReceiveChannel>,
stats: ConnectionStats,
available_bytes_per_tick: u64,
connection_status: RenetConnectionStatus,
rtt: f64,
}
impl RenetClient {
pub fn new(mut config: ConnectionConfig, has_reliable_socket: bool) -> Self {
if has_reliable_socket {
config.downgrade_to_unreliable();
}
Self::from_channels(
has_reliable_socket,
config.available_bytes_per_tick,
config.client_channels_config,
config.server_channels_config,
)
}
pub(crate) fn new_from_server(mut config: ConnectionConfig, has_reliable_socket: bool) -> Self {
if has_reliable_socket {
config.downgrade_to_unreliable();
}
Self::from_channels(
has_reliable_socket,
config.available_bytes_per_tick,
config.server_channels_config,
config.client_channels_config,
)
}
fn from_channels(
has_reliable_socket: bool,
available_bytes_per_tick: u64,
send_channels_config: Vec<ChannelConfig>,
receive_channels_config: Vec<ChannelConfig>,
) -> Self {
let max_send_channel = send_channels_config.iter().map(|c| c.channel_id).max().unwrap_or_default();
let max_receive_channel = receive_channels_config.iter().map(|c| c.channel_id).max().unwrap_or_default();
let mut send_channels = Vec::new();
send_channels.resize_with(max_send_channel as usize + 1, || SendChannel::Empty);
let mut channel_send_order: Vec<ChannelOrder> = Vec::with_capacity(send_channels_config.len());
for channel_config in send_channels_config.iter() {
let send_channel = &mut send_channels[channel_config.channel_id as usize];
assert!(
matches!(send_channel, SendChannel::Empty),
"already exists send channel {}",
channel_config.channel_id
);
match channel_config.send_type {
SendType::Unreliable {
ordered_reliable_substrate,
} => {
channel_send_order.push(ChannelOrder::Unreliable(channel_config.channel_id));
let channel = SendChannelUnreliable::new(
channel_config.channel_id,
channel_config.max_memory_usage_bytes,
ordered_reliable_substrate,
);
*send_channel = SendChannel::Unreliable(channel);
}
SendType::ReliableOrdered { resend_time } | SendType::ReliableUnordered { resend_time } => {
channel_send_order.push(ChannelOrder::Reliable(channel_config.channel_id));
let channel = SendChannelReliable::new(channel_config.channel_id, resend_time, channel_config.max_memory_usage_bytes);
*send_channel = SendChannel::Reliable(channel);
}
}
}
let mut receive_channels = Vec::new();
receive_channels.resize_with(max_receive_channel as usize + 1, || ReceiveChannel::Empty);
for channel_config in receive_channels_config.iter() {
let receive_channel = &mut receive_channels[channel_config.channel_id as usize];
assert!(
matches!(receive_channel, ReceiveChannel::Empty),
"already exists receive channel {}",
channel_config.channel_id
);
match channel_config.send_type {
SendType::Unreliable { .. } => {
let channel = ReceiveChannelUnreliable::new(channel_config.channel_id, channel_config.max_memory_usage_bytes);
*receive_channel = ReceiveChannel::Unreliable(channel);
}
SendType::ReliableOrdered { .. } => {
let channel = ReceiveChannelReliable::new(channel_config.max_memory_usage_bytes, true);
*receive_channel = ReceiveChannel::Reliable(channel);
}
SendType::ReliableUnordered { .. } => {
let channel = ReceiveChannelReliable::new(channel_config.max_memory_usage_bytes, false);
*receive_channel = ReceiveChannel::Reliable(channel);
}
}
}
Self {
has_reliable_socket,
packet_sequence: 0,
current_time: Duration::ZERO,
sent_packets: BTreeMap::new(),
pending_acks: Vec::new(),
channel_send_order,
send_channels,
receive_channels,
stats: ConnectionStats::new(),
rtt: 0.0,
available_bytes_per_tick,
connection_status: RenetConnectionStatus::Connecting,
}
}
pub fn has_reliable_socket(&self) -> bool {
self.has_reliable_socket
}
pub fn rtt(&self) -> f64 {
self.rtt
}
pub fn packet_loss(&self) -> f64 {
self.stats.packet_loss()
}
pub fn bytes_sent_per_sec(&self) -> f64 {
self.stats.bytes_sent_per_second(self.current_time)
}
pub fn bytes_received_per_sec(&self) -> f64 {
self.stats.bytes_received_per_second(self.current_time)
}
pub fn network_info(&self) -> NetworkInfo {
NetworkInfo {
rtt: self.rtt,
packet_loss: self.stats.packet_loss(),
bytes_sent_per_second: self.stats.bytes_sent_per_second(self.current_time),
bytes_received_per_second: self.stats.bytes_received_per_second(self.current_time),
}
}
#[inline]
pub fn is_connected(&self) -> bool {
matches!(self.connection_status, RenetConnectionStatus::Connected)
}
#[inline]
pub fn is_connecting(&self) -> bool {
matches!(self.connection_status, RenetConnectionStatus::Connecting)
}
#[inline]
pub fn is_disconnected(&self) -> bool {
matches!(self.connection_status, RenetConnectionStatus::Disconnected { .. })
}
pub fn disconnect_reason(&self) -> Option<DisconnectReason> {
if let RenetConnectionStatus::Disconnected { reason } = self.connection_status {
Some(reason)
} else {
None
}
}
pub fn set_connected(&mut self) {
if !self.is_disconnected() {
self.connection_status = RenetConnectionStatus::Connected;
}
}
pub fn set_connecting(&mut self) {
if !self.is_disconnected() {
self.connection_status = RenetConnectionStatus::Connecting;
}
}
pub fn disconnect(&mut self) {
self.disconnect_with_reason(DisconnectReason::DisconnectedByClient);
}
pub fn disconnect_due_to_transport(&mut self) {
self.disconnect_with_reason(DisconnectReason::Transport);
}
pub fn channel_available_memory<I: Into<u8>>(&self, channel_id: I) -> usize {
let channel_id = channel_id.into();
match self.send_channels.get(channel_id as usize) {
None | Some(SendChannel::Empty) => {
panic!("Called 'channel_available_memory' with invalid channel {channel_id}");
}
Some(SendChannel::Reliable(reliable_channel)) => reliable_channel.available_memory(),
Some(SendChannel::Unreliable(unreliable_channel)) => unreliable_channel.available_memory(),
}
}
pub fn can_send_message<I: Into<u8>>(&self, channel_id: I, size_bytes: usize) -> bool {
let channel_id = channel_id.into();
match self.send_channels.get(channel_id as usize) {
None | Some(SendChannel::Empty) => {
panic!("Called 'can_send_message' with invalid channel {channel_id}");
}
Some(SendChannel::Reliable(reliable_channel)) => reliable_channel.can_send_message(size_bytes),
Some(SendChannel::Unreliable(unreliable_channel)) => unreliable_channel.can_send_message(size_bytes),
}
}
pub fn send_message<I: Into<u8>, B: Into<Bytes>>(&mut self, channel_id: I, message: B) {
if self.is_disconnected() {
return;
}
let channel_id = channel_id.into();
match self.send_channels.get_mut(channel_id as usize) {
None | Some(SendChannel::Empty) => {
panic!("Called 'send_message' with invalid channel {channel_id}");
}
Some(SendChannel::Reliable(reliable_channel)) => {
if let Err(error) = reliable_channel.send_message(message.into()) {
self.disconnect_with_reason(DisconnectReason::SendChannelError { channel_id, error });
}
}
Some(SendChannel::Unreliable(unreliable_channel)) => {
unreliable_channel.send_message(message.into());
}
}
}
pub fn receive_message<I: Into<u8>>(&mut self, channel_id: I) -> Option<Bytes> {
if self.is_disconnected() {
return None;
}
let channel_id = channel_id.into();
match self.receive_channels.get_mut(channel_id as usize) {
None | Some(ReceiveChannel::Empty) => {
panic!("Called 'receive_message' with invalid channel {channel_id}");
}
Some(ReceiveChannel::Reliable(reliable_channel)) => reliable_channel.receive_message(),
Some(ReceiveChannel::Unreliable(unreliable_channel)) => unreliable_channel.receive_message(),
}
}
pub fn update(&mut self, duration: Duration) {
self.current_time += duration;
self.stats.update(self.current_time);
for unreliable_channel in self.receive_channels.iter_mut() {
let ReceiveChannel::Unreliable(unreliable_channel) = unreliable_channel else {
continue;
};
unreliable_channel.discard_incomplete_old_slices(self.current_time);
}
let mut lost_packets: Vec<u64> = Vec::new();
for (&sequence, sent_packet) in self.sent_packets.iter() {
const DISCARD_AFTER: Duration = Duration::from_secs(3);
if self.current_time - sent_packet.sent_at >= DISCARD_AFTER {
lost_packets.push(sequence);
} else {
break;
}
}
for sequence in lost_packets.iter() {
self.sent_packets.remove(sequence);
}
}
pub fn process_packet(&mut self, packet: &[u8]) {
if self.is_disconnected() {
return;
}
self.stats.received_packet(packet.len() as u64);
let mut octets = octets::Octets::with_slice(packet);
let packet = match Packet::from_bytes(&mut octets) {
Err(err) => {
self.disconnect_with_reason(DisconnectReason::PacketDeserialization(err));
return;
}
Ok(packet) => packet,
};
self.add_pending_ack(packet.sequence());
match packet {
Packet::SmallReliable { channel_id, messages, .. } => {
let Some(ReceiveChannel::Reliable(channel)) = self.receive_channels.get_mut(channel_id as usize) else {
self.disconnect_with_reason(DisconnectReason::ReceivedInvalidChannelId(channel_id));
return;
};
for (message_id, message) in messages {
if let Err(error) = channel.process_message(message, message_id) {
self.disconnect_with_reason(DisconnectReason::ReceiveChannelError { channel_id, error });
return;
}
}
}
Packet::SmallUnreliable { channel_id, messages, .. } => {
let Some(ReceiveChannel::Unreliable(channel)) = self.receive_channels.get_mut(channel_id as usize) else {
self.disconnect_with_reason(DisconnectReason::ReceivedInvalidChannelId(channel_id));
return;
};
for message in messages {
channel.process_message(message);
}
}
Packet::ReliableSlice { channel_id, slice, .. } => {
let Some(ReceiveChannel::Reliable(channel)) = self.receive_channels.get_mut(channel_id as usize) else {
self.disconnect_with_reason(DisconnectReason::ReceivedInvalidChannelId(channel_id));
return;
};
if let Err(error) = channel.process_slice(slice) {
self.disconnect_with_reason(DisconnectReason::ReceiveChannelError { channel_id, error });
}
}
Packet::UnreliableSlice { channel_id, slice, .. } => {
let Some(ReceiveChannel::Unreliable(channel)) = self.receive_channels.get_mut(channel_id as usize) else {
self.disconnect_with_reason(DisconnectReason::ReceivedInvalidChannelId(channel_id));
return;
};
if let Err(error) = channel.process_slice(slice, self.current_time) {
self.disconnect_with_reason(DisconnectReason::ReceiveChannelError { channel_id, error });
}
}
Packet::Ack { ack_ranges, .. } => {
let mut new_acks: Vec<u64> = Vec::new();
for range in ack_ranges {
for (&sequence, _) in self.sent_packets.range(range) {
new_acks.push(sequence)
}
}
for packet_sequence in new_acks {
let sent_packet = self.sent_packets.remove(&packet_sequence).unwrap();
self.stats.acked_packet(sent_packet.sent_at, self.current_time);
let rtt = (self.current_time - sent_packet.sent_at).as_secs_f64();
if self.rtt < f64::EPSILON {
self.rtt = rtt;
} else {
self.rtt = self.rtt * 0.875 + rtt * 0.125;
}
match sent_packet.info {
PacketSentInfo::ReliableMessages { channel_id, message_ids } => {
let SendChannel::Reliable(channel) = self.send_channels.get_mut(channel_id as usize).unwrap() else {
panic!("Acked packet has invalid channel {channel_id}");
};
for message_id in message_ids {
channel.process_message_ack(message_id);
}
}
PacketSentInfo::ReliableSliceMessage {
channel_id,
message_id,
slice_index,
} => {
let SendChannel::Reliable(channel) = self.send_channels.get_mut(channel_id as usize).unwrap() else {
panic!("Acked packet has invalid channel {channel_id}");
};
channel.process_slice_message_ack(message_id, slice_index);
}
PacketSentInfo::Ack { largest_acked_packet } => {
self.acked_largest(largest_acked_packet);
}
PacketSentInfo::None => {}
}
}
}
}
}
pub fn get_packets_to_send(&mut self) -> Vec<Payload> {
let mut packets: Vec<Packet> = vec![];
if self.is_disconnected() {
return vec![];
}
let mut available_bytes = self.available_bytes_per_tick;
for order in self.channel_send_order.iter() {
match order {
ChannelOrder::Reliable(channel_id) => {
let SendChannel::Reliable(channel) = self.send_channels.get_mut(*channel_id as usize).unwrap() else {
panic!("Packet to send has invalid channel {channel_id}");
};
packets.append(&mut channel.get_packets_to_send(&mut self.packet_sequence, &mut available_bytes, self.current_time));
}
ChannelOrder::Unreliable(channel_id) => {
let SendChannel::Unreliable(channel) = self.send_channels.get_mut(*channel_id as usize).unwrap() else {
panic!("Packet to send has invalid channel {channel_id}");
};
packets.append(&mut channel.get_packets_to_send(&mut self.packet_sequence, &mut available_bytes));
}
}
}
if !self.pending_acks.is_empty() {
let ack_packet = Packet::Ack {
sequence: self.packet_sequence,
ack_ranges: self.pending_acks.clone(),
};
self.packet_sequence += 1;
packets.push(ack_packet);
}
let sent_at = self.current_time;
for packet in packets.iter() {
match packet {
Packet::SmallReliable {
sequence,
channel_id,
messages,
} => {
self.sent_packets.insert(
*sequence,
PacketSent {
sent_at,
info: PacketSentInfo::ReliableMessages {
channel_id: *channel_id,
message_ids: messages.iter().map(|(id, _)| *id).collect(),
},
},
);
}
Packet::ReliableSlice {
sequence,
channel_id,
slice,
} => {
self.sent_packets.insert(
*sequence,
PacketSent {
sent_at,
info: PacketSentInfo::ReliableSliceMessage {
channel_id: *channel_id,
message_id: slice.message_id,
slice_index: slice.slice_index,
},
},
);
}
Packet::SmallUnreliable { sequence, .. } => {
self.sent_packets.insert(
*sequence,
PacketSent {
sent_at,
info: PacketSentInfo::None,
},
);
}
Packet::UnreliableSlice { sequence, .. } => {
self.sent_packets.insert(
*sequence,
PacketSent {
sent_at,
info: PacketSentInfo::None,
},
);
}
Packet::Ack { sequence, ack_ranges } => {
let last_range = ack_ranges.last().unwrap();
let largest_acked_packet = last_range.end - 1;
self.sent_packets.insert(
*sequence,
PacketSent {
sent_at,
info: PacketSentInfo::Ack { largest_acked_packet },
},
);
}
}
}
let mut buffer = [0u8; 1400];
let mut serialized_packets = Vec::with_capacity(packets.len());
let mut bytes_sent: u64 = 0;
for packet in packets {
let mut oct = OctetsMut::with_slice(&mut buffer);
let len = match packet.to_bytes(&mut oct) {
Err(err) => {
self.disconnect_with_reason(DisconnectReason::PacketSerialization(err));
return vec![];
}
Ok(len) => len,
};
bytes_sent += len as u64;
serialized_packets.push(buffer[..len].to_vec());
}
self.stats.sent_packets(serialized_packets.len() as u64, bytes_sent);
serialized_packets
}
fn add_pending_ack(&mut self, sequence: u64) {
if self.pending_acks.is_empty() {
self.pending_acks.push(sequence..sequence + 1);
return;
}
for index in 0..self.pending_acks.len() {
let range = &mut self.pending_acks[index];
if range.contains(&sequence) {
return;
}
if range.start == sequence + 1 {
range.start = sequence;
return;
} else if range.end == sequence {
range.end = sequence + 1;
let next_index = index + 1;
if next_index < self.pending_acks.len() && self.pending_acks[index].end == self.pending_acks[next_index].start {
self.pending_acks[index].end = self.pending_acks[next_index].end;
self.pending_acks.remove(next_index);
}
return;
} else if self.pending_acks[index].start > sequence + 1 {
self.pending_acks.insert(index, sequence..sequence + 1);
return;
}
}
self.pending_acks.push(sequence..sequence + 1);
if self.pending_acks.len() > 64 {
self.pending_acks.remove(0);
}
}
fn acked_largest(&mut self, largest_ack: u64) {
while !self.pending_acks.is_empty() {
let range: &mut Range<u64> = &mut self.pending_acks[0];
if largest_ack < range.start {
return;
}
if range.end <= largest_ack {
self.pending_acks.remove(0);
continue;
}
range.start = largest_ack + 1;
if range.is_empty() {
self.pending_acks.remove(0);
}
return;
}
}
pub(crate) fn disconnect_with_reason(&mut self, reason: DisconnectReason) {
if !self.is_disconnected() {
self.connection_status = RenetConnectionStatus::Disconnected { reason };
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pending_acks() {
let mut connection = RenetClient::new(ConnectionConfig::test(), false);
connection.add_pending_ack(3);
assert_eq!(connection.pending_acks, vec![3..4]);
connection.add_pending_ack(4);
assert_eq!(connection.pending_acks, vec![3..5]);
connection.add_pending_ack(2);
assert_eq!(connection.pending_acks, vec![2..5]);
connection.add_pending_ack(0);
assert_eq!(connection.pending_acks, vec![0..1, 2..5]);
connection.add_pending_ack(7);
assert_eq!(connection.pending_acks, vec![0..1, 2..5, 7..8]);
connection.add_pending_ack(1);
assert_eq!(connection.pending_acks, vec![0..5, 7..8]);
connection.add_pending_ack(5);
assert_eq!(connection.pending_acks, vec![0..6, 7..8]);
connection.add_pending_ack(6);
assert_eq!(connection.pending_acks, vec![0..8]);
}
#[test]
fn ack_pending_acks() {
let mut connection = RenetClient::new(ConnectionConfig::test(), false);
for i in 0..10 {
connection.add_pending_ack(i);
}
assert_eq!(connection.pending_acks, vec![0..10]);
connection.acked_largest(0);
assert_eq!(connection.pending_acks, vec![1..10]);
connection.acked_largest(3);
assert_eq!(connection.pending_acks, vec![4..10]);
connection.add_pending_ack(0);
assert_eq!(connection.pending_acks, vec![0..1, 4..10]);
connection.acked_largest(5);
assert_eq!(connection.pending_acks, vec![6..10]);
connection.add_pending_ack(0);
assert_eq!(connection.pending_acks, vec![0..1, 6..10]);
connection.acked_largest(10);
assert_eq!(connection.pending_acks, vec![]);
}
#[test]
fn discard_old_packets() {
let mut connection = RenetClient::new(ConnectionConfig::test(), false);
let message: Bytes = vec![5; 5].into();
connection.send_message(0, message);
connection.get_packets_to_send();
assert_eq!(connection.sent_packets.len(), 1);
connection.update(Duration::from_secs(1));
assert_eq!(connection.sent_packets.len(), 1);
connection.update(Duration::from_secs(4));
assert_eq!(connection.sent_packets.len(), 0);
}
}