use std::{self, collections::HashMap, fmt::Debug, io::Result, net::SocketAddr, time::Instant};
use crossbeam_channel::{self, unbounded, Receiver, Sender};
use log::error;
use crate::{
config::Config, net::Connection, net::ConnectionEventAddress, net::ConnectionMessenger,
};
pub trait DatagramSocket: Debug {
fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> Result<usize>;
fn receive_packet<'a>(&mut self, buffer: &'a mut [u8]) -> Result<(&'a [u8], SocketAddr)>;
fn local_addr(&self) -> Result<SocketAddr>;
fn is_blocking_mode(&self) -> bool;
}
#[derive(Debug)]
struct SocketEventSenderAndConfig<TSocket: DatagramSocket, ReceiveEvent: Debug> {
config: Config,
socket: TSocket,
event_sender: Sender<ReceiveEvent>,
}
impl<TSocket: DatagramSocket, ReceiveEvent: Debug>
SocketEventSenderAndConfig<TSocket, ReceiveEvent>
{
fn new(config: Config, socket: TSocket, event_sender: Sender<ReceiveEvent>) -> Self {
Self {
config,
socket,
event_sender,
}
}
}
impl<TSocket: DatagramSocket, ReceiveEvent: Debug> ConnectionMessenger<ReceiveEvent>
for SocketEventSenderAndConfig<TSocket, ReceiveEvent>
{
fn config(&self) -> &Config {
&self.config
}
fn send_event(&mut self, _address: &SocketAddr, event: ReceiveEvent) {
self.event_sender.send(event).expect("Receiver must exists");
}
fn send_packet(&mut self, address: &SocketAddr, payload: &[u8]) {
if let Err(err) = self.socket.send_packet(address, payload) {
error!("Error occured sending a packet (to {}): {}", address, err)
}
}
}
#[derive(Debug)]
pub struct ConnectionManager<TSocket: DatagramSocket, TConnection: Connection> {
connections: HashMap<SocketAddr, TConnection>,
receive_buffer: Vec<u8>,
user_event_receiver: Receiver<TConnection::SendEvent>,
messenger: SocketEventSenderAndConfig<TSocket, TConnection::ReceiveEvent>,
event_receiver: Receiver<TConnection::ReceiveEvent>,
user_event_sender: Sender<TConnection::SendEvent>,
max_unestablished_connections: u16,
}
impl<TSocket: DatagramSocket, TConnection: Connection> ConnectionManager<TSocket, TConnection> {
pub fn new(socket: TSocket, config: Config) -> Self {
let (event_sender, event_receiver) = unbounded();
let (user_event_sender, user_event_receiver) = unbounded();
let max_unestablished_connections = config.max_unestablished_connections;
ConnectionManager {
receive_buffer: vec![0; config.receive_buffer_max_size],
connections: Default::default(),
user_event_receiver,
messenger: SocketEventSenderAndConfig::new(config, socket, event_sender),
user_event_sender,
event_receiver,
max_unestablished_connections,
}
}
pub fn manual_poll(&mut self, time: Instant) {
let mut unestablished_connections = self.unestablished_connection_count();
let messenger = &mut self.messenger;
loop {
match messenger
.socket
.receive_packet(self.receive_buffer.as_mut())
{
Ok((payload, address)) => {
if let Some(conn) = self.connections.get_mut(&address) {
let was_est = conn.is_established();
conn.process_packet(messenger, payload, time);
if !was_est && conn.is_established() {
unestablished_connections -= 1;
}
} else {
let mut conn = TConnection::create_connection(messenger, address, time);
conn.process_packet(messenger, payload, time);
if unestablished_connections < self.max_unestablished_connections as usize {
self.connections.insert(address, conn);
unestablished_connections += 1;
}
}
}
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
error!("Encountered an error receiving data: {:?}", e);
}
break;
}
}
if messenger.socket.is_blocking_mode() {
break;
}
}
while let Ok(event) = self.user_event_receiver.try_recv() {
let conn = self.connections.entry(event.address()).or_insert_with(|| {
TConnection::create_connection(messenger, event.address(), time)
});
let was_est = conn.is_established();
conn.process_event(messenger, event, time);
if !was_est && conn.is_established() {
unestablished_connections -= 1;
}
}
for conn in self.connections.values_mut() {
conn.update(messenger, time);
}
self.connections
.retain(|_, conn| !conn.should_drop(messenger, time));
}
pub fn event_sender(&self) -> &Sender<TConnection::SendEvent> {
&self.user_event_sender
}
pub fn event_receiver(&self) -> &Receiver<TConnection::ReceiveEvent> {
&self.event_receiver
}
pub fn socket(&self) -> &TSocket {
&self.messenger.socket
}
fn unestablished_connection_count(&self) -> usize {
self.connections
.iter()
.filter(|c| !c.1.is_established())
.count()
}
#[allow(dead_code)]
pub fn socket_mut(&mut self) -> &mut TSocket {
&mut self.messenger.socket
}
#[cfg(test)]
pub fn connections_count(&self) -> usize {
self.connections.len()
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashSet,
net::{SocketAddr, SocketAddrV4},
time::{Duration, Instant},
};
use crate::net::LinkConditioner;
use crate::test_utils::*;
use crate::{Config, Packet, SocketEvent};
const SERVER_ADDR: &str = "127.0.0.1:10001";
const CLIENT_ADDR: &str = "127.0.0.1:10002";
fn client_address() -> SocketAddr {
CLIENT_ADDR.parse().unwrap()
}
fn client_address_n(n: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new("127.0.0.1".parse().unwrap(), 10002 + n))
}
fn server_address() -> SocketAddr {
SERVER_ADDR.parse().unwrap()
}
fn create_server_client_network() -> (FakeSocket, FakeSocket, NetworkEmulator) {
let network = NetworkEmulator::default();
let server = FakeSocket::bind(&network, server_address(), Config::default()).unwrap();
let client = FakeSocket::bind(&network, client_address(), Config::default()).unwrap();
(server, client, network)
}
fn create_server_client(config: Config) -> (FakeSocket, FakeSocket) {
let network = NetworkEmulator::default();
let server = FakeSocket::bind(&network, server_address(), config.clone()).unwrap();
let client = FakeSocket::bind(&network, client_address(), config).unwrap();
(server, client)
}
#[test]
fn using_sender_and_receiver() {
let (mut server, mut client, _) = create_server_client_network();
let sender = client.get_packet_sender();
let receiver = server.get_event_receiver();
sender
.send(Packet::reliable_unordered(
server_address(),
b"Hello world!".to_vec(),
))
.unwrap();
let time = Instant::now();
client.manual_poll(time);
server.manual_poll(time);
if let SocketEvent::Packet(packet) = receiver.recv().unwrap() {
assert_eq![b"Hello world!", packet.payload()];
} else {
panic!["Did not receive a packet when it should"];
}
}
#[test]
fn initial_packet_is_resent() {
let (mut server, mut client, network) = create_server_client_network();
let time = Instant::now();
client
.send(Packet::reliable_unordered(
server_address(),
b"Do not arrive".to_vec(),
))
.unwrap();
client.manual_poll(time);
network.clear_packets(server_address());
for id in 0..u8::max_value() {
client
.send(Packet::reliable_unordered(server_address(), vec![id]))
.unwrap();
server
.send(Packet::reliable_unordered(client_address(), vec![id]))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
return;
}
}
while client.recv().is_some() {}
}
panic!["Did not receive the ignored packet"];
}
#[test]
fn receiving_does_not_allow_denial_of_service() {
let time = Instant::now();
let network = NetworkEmulator::default();
let mut server = FakeSocket::bind(
&network,
server_address(),
Config {
max_unestablished_connections: 2,
..Default::default()
},
)
.unwrap();
for i in 0..3 {
let mut client =
FakeSocket::bind(&network, client_address_n(i), Config::default()).unwrap();
client
.send(Packet::unreliable(
server_address(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
client.manual_poll(time);
}
server.manual_poll(time);
for _ in 0..3 {
assert![server.recv().is_some()];
}
assert![server.recv().is_none()];
assert_eq![2, server.connection_count()];
server
.send(Packet::unreliable(client_address(), vec![1]))
.unwrap();
server.manual_poll(time);
assert_eq![2, server.connection_count()];
}
#[test]
fn initial_sequenced_is_resent() {
let (mut server, mut client, network) = create_server_client_network();
let time = Instant::now();
client
.send(Packet::reliable_sequenced(
server_address(),
b"Do not arrive".to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
network.clear_packets(server_address());
for id in 0..36 {
client
.send(Packet::reliable_sequenced(server_address(), vec![id], None))
.unwrap();
server
.send(Packet::reliable_sequenced(client_address(), vec![id], None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
panic!["Sequenced packet arrived while it should not"];
}
}
while client.recv().is_some() {}
}
}
#[test]
fn initial_ordered_is_resent() {
let (mut server, mut client, network) = create_server_client_network();
let time = Instant::now();
client
.send(Packet::reliable_ordered(
server_address(),
b"Do not arrive".to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
network.clear_packets(server_address());
for id in 0..35 {
client
.send(Packet::reliable_ordered(server_address(), vec![id], None))
.unwrap();
server
.send(Packet::reliable_ordered(client_address(), vec![id], None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
return;
}
}
while client.recv().is_some() {}
}
panic!["Did not receive the ignored packet"];
}
#[test]
fn do_not_duplicate_sequenced_packets_when_received() {
let (mut server, mut client, _) = create_server_client_network();
let time = Instant::now();
for id in 0..100 {
client
.send(Packet::reliable_sequenced(server_address(), vec![id], None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
}
let mut seen = HashSet::new();
while let Some(message) = server.recv() {
match message {
SocketEvent::Connect(_) => {}
SocketEvent::Packet(packet) => {
let byte = packet.payload()[0];
assert![!seen.contains(&byte)];
seen.insert(byte);
}
SocketEvent::Timeout(_) | SocketEvent::Disconnect(_) => {
panic!["This should not happen, as we've not advanced time"];
}
}
}
assert_eq![100, seen.len()];
}
#[test]
fn more_than_65536_sequenced_packets() {
let (mut server, mut client, _) = create_server_client_network();
server
.send(Packet::unreliable(client_address(), vec![0]))
.unwrap();
let time = Instant::now();
for id in 0..65536 + 100 {
client
.send(Packet::unreliable_sequenced(
server_address(),
id.to_string().as_bytes().to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
}
let mut cnt = 0;
while let Some(message) = server.recv() {
match message {
SocketEvent::Connect(_) => {}
SocketEvent::Packet(_) => {
cnt += 1;
}
SocketEvent::Timeout(_) | SocketEvent::Disconnect(_) => {
panic!["This should not happen, as we've not advanced time"];
}
}
}
assert_eq![65536 + 100, cnt];
}
#[test]
fn sequenced_packets_pathological_case() {
let config = Config {
max_packets_in_flight: 100,
..Default::default()
};
let (_, mut client) = create_server_client(config);
let time = Instant::now();
for id in 0..101 {
client
.send(Packet::reliable_sequenced(
server_address(),
id.to_string().as_bytes().to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
while let Some(event) = client.recv() {
match event {
SocketEvent::Timeout(remote_addr) => {
assert_eq![100, id];
assert_eq![remote_addr, server_address()];
return;
}
_ => {
panic!["No other event possible"];
}
}
}
}
panic!["Should have received a timeout event"];
}
#[test]
fn manual_polling_socket() {
let (mut server, mut client, _) = create_server_client_network();
for _ in 0..3 {
client
.send(Packet::unreliable(
server_address(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let time = Instant::now();
client.manual_poll(time);
server.manual_poll(time);
assert!(server.recv().is_some());
assert!(server.recv().is_some());
assert!(server.recv().is_some());
}
#[test]
fn can_send_and_receive() {
let (mut server, mut client, _) = create_server_client_network();
for _ in 0..3 {
client
.send(Packet::unreliable(
server_address(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert!(server.recv().is_some());
assert!(server.recv().is_some());
assert!(server.recv().is_some());
}
#[test]
fn connect_event_occurs() {
let (mut server, mut client, _) = create_server_client_network();
client
.send(Packet::unreliable(server_address(), vec![0, 1, 2]))
.unwrap();
server
.send(Packet::unreliable(client_address(), vec![2, 1, 0]))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert!(matches!(server.recv().unwrap(), SocketEvent::Packet(_)));
assert_eq!(
server.recv().unwrap(),
SocketEvent::Connect(client_address())
);
}
#[test]
fn disconnect_event_occurs() {
let config = Config {
idle_connection_timeout: Duration::from_millis(1),
..Default::default()
};
let (mut server, mut client) = create_server_client(config.clone());
client
.send(Packet::unreliable(server_address(), vec![0, 1, 2]))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(client_address(), vec![0, 1, 2]))
);
server
.send(Packet::unreliable(client_address(), vec![]))
.unwrap();
server.manual_poll(now);
client.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Connect(client_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Connect(server_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(server_address(), vec![]))
);
server.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1));
client.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1));
assert_eq!(server.recv(), None);
assert_eq!(client.recv(), None);
server.manual_poll(now + config.idle_connection_timeout);
client.manual_poll(now + config.idle_connection_timeout);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Timeout(client_address())
);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Disconnect(client_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Timeout(server_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Disconnect(server_address())
);
}
#[test]
fn heartbeats_work() {
let config = Config {
idle_connection_timeout: Duration::from_millis(10),
heartbeat_interval: Some(Duration::from_millis(4)),
..Default::default()
};
let (mut server, mut client) = create_server_client(config.clone());
client
.send(Packet::unreliable(server_address(), vec![0, 1, 2]))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(client_address(), vec![0, 1, 2]))
);
server
.send(Packet::unreliable(client_address(), vec![]))
.unwrap();
server.manual_poll(now);
client.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Connect(client_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Connect(server_address())
);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(server_address(), vec![]))
);
client.manual_poll(now + config.heartbeat_interval.unwrap());
server.manual_poll(now + config.heartbeat_interval.unwrap());
client.manual_poll(now + config.idle_connection_timeout);
server.manual_poll(now + config.idle_connection_timeout);
assert_eq!(client.recv(), None);
assert_eq!(server.recv(), None);
}
#[test]
fn multiple_sends_should_start_sending_dropped() {
let (mut server, mut client, _) = create_server_client_network();
let now = Instant::now();
for i in 0..35 {
client
.send(Packet::unreliable(server_address(), vec![i]))
.unwrap();
client.manual_poll(now);
}
let mut events = Vec::new();
loop {
server.manual_poll(now);
if let Some(event) = server.recv() {
events.push(event);
} else {
break;
}
}
assert_eq!(events.len(), 35);
server
.send(Packet::unreliable(client_address(), vec![0]))
.unwrap();
server.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Connect(client_address())
);
loop {
client.manual_poll(now);
if client.recv().is_some() {
break;
}
}
events.clear();
client
.send(Packet::unreliable(server_address(), vec![35]))
.unwrap();
client.manual_poll(now);
loop {
server.manual_poll(now);
if let Some(event) = server.recv() {
events.push(event);
break;
}
}
let sent_events: Vec<u8> = events
.iter()
.flat_map(|e| match e {
SocketEvent::Packet(p) => Some(p.payload()[0]),
_ => None,
})
.collect();
assert_eq!(sent_events, vec![35]);
}
#[test]
fn really_bad_network_keeps_chugging_along() {
let (mut server, mut client, _) = create_server_client_network();
let time = Instant::now();
let link_conditioner = {
let mut lc = LinkConditioner::new();
lc.set_packet_loss(0.9);
Some(lc)
};
client.set_link_conditioner(link_conditioner.clone());
server.set_link_conditioner(link_conditioner);
let mut set = HashSet::new();
let mut send_many_packets = |dummy: Option<u8>| {
for id in 0..100 {
client
.send(Packet::reliable_unordered(
server_address(),
vec![dummy.unwrap_or(id)],
))
.unwrap();
server
.send(Packet::reliable_unordered(client_address(), vec![255]))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while client.recv().is_some() {}
while let Some(event) = server.recv() {
match event {
SocketEvent::Packet(pkt) => {
set.insert(pkt.payload()[0]);
}
SocketEvent::Timeout(_) | SocketEvent::Disconnect(_) => {
panic!["Unable to time out, time has not advanced"]
}
SocketEvent::Connect(_) => {}
}
}
}
set.len()
};
send_many_packets(None);
send_many_packets(Some(255));
send_many_packets(Some(255));
send_many_packets(Some(255));
assert_eq![101, send_many_packets(Some(255))];
}
#[test]
fn fragmented_ordered_gets_acked() {
let config = Config {
fragment_size: 10,
..Default::default()
};
let (mut server, mut client) = create_server_client(config);
let time = Instant::now();
let dummy = vec![0];
client
.send(Packet::unreliable(server_address(), dummy.clone()))
.unwrap();
client.manual_poll(time);
server
.send(Packet::unreliable(client_address(), dummy.clone()))
.unwrap();
server.manual_poll(time);
let exceeds = b"Fragmented string".to_vec();
client
.send(Packet::reliable_ordered(server_address(), exceeds, None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
server.manual_poll(time);
server
.send(Packet::reliable_ordered(
client_address(),
dummy.clone(),
None,
))
.unwrap();
client
.send(Packet::unreliable(server_address(), dummy.clone()))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
for _ in 0..4 {
assert![server.recv().is_some()];
}
assert![server.recv().is_none()];
for _ in 0..34 {
client
.send(Packet::reliable_ordered(
server_address(),
dummy.clone(),
None,
))
.unwrap();
client.manual_poll(time);
server
.send(Packet::reliable_ordered(
client_address(),
dummy.clone(),
None,
))
.unwrap();
server.manual_poll(time);
assert![client.recv().is_some()];
match server.recv() {
Some(SocketEvent::Packet(pkt)) => {
assert_eq![dummy, pkt.payload()];
}
_ => {
panic!["Did not receive expected dummy packet"];
}
}
}
}
#[quickcheck_macros::quickcheck]
fn do_not_panic_on_arbitrary_packets(bytes: Vec<u8>) {
use crate::net::DatagramSocket;
let network = NetworkEmulator::default();
let mut server = FakeSocket::bind(&network, server_address(), Config::default()).unwrap();
let mut client_socket = network.new_socket(client_address()).unwrap();
client_socket
.send_packet(&server_address(), &bytes)
.unwrap();
let time = Instant::now();
server.manual_poll(time);
}
}