use crate::either::Either::{Left, Right};
use crate::{
config::Config,
error::{ErrorKind, Result},
net::{connection::ActiveConnections, events::SocketEvent, link_conditioner::LinkConditioner},
packet::{DeliveryGuarantee, Outgoing, Packet},
};
use crossbeam_channel::{self, unbounded, Receiver, SendError, Sender, TryRecvError};
use log::error;
use std::{
self, io,
net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket},
thread::{sleep, yield_now},
time::{Duration, Instant},
};
#[derive(Debug)]
pub struct Socket {
socket: UdpSocket,
config: Config,
connections: ActiveConnections,
recv_buffer: Vec<u8>,
link_conditioner: Option<LinkConditioner>,
event_sender: Sender<SocketEvent>,
packet_receiver: Receiver<Packet>,
receiver: Receiver<SocketEvent>,
sender: Sender<Packet>,
}
enum UdpSocketState {
MaybeEmpty,
MaybeMore,
}
impl Socket {
pub fn bind<A: ToSocketAddrs>(addresses: A) -> Result<Self> {
Socket::bind_with_config(addresses, Config::default())
}
pub fn bind_any() -> Result<Self> {
Self::bind_any_with_config(Config::default())
}
pub fn bind_any_with_config(config: Config) -> Result<Self> {
let loopback = Ipv4Addr::new(127, 0, 0, 1);
let address = SocketAddrV4::new(loopback, 0);
let socket = UdpSocket::bind(address)?;
Self::bind_internal(socket, config)
}
pub fn bind_with_config<A: ToSocketAddrs>(addresses: A, config: Config) -> Result<Self> {
let socket = UdpSocket::bind(addresses)?;
Self::bind_internal(socket, config)
}
fn bind_internal(socket: UdpSocket, config: Config) -> Result<Self> {
socket.set_nonblocking(!config.blocking_mode)?;
let (event_sender, event_receiver) = unbounded();
let (packet_sender, packet_receiver) = unbounded();
Ok(Socket {
recv_buffer: vec![0; config.receive_buffer_max_size],
socket,
config,
connections: ActiveConnections::new(),
link_conditioner: None,
event_sender,
packet_receiver,
sender: packet_sender,
receiver: event_receiver,
})
}
pub fn get_packet_sender(&mut self) -> Sender<Packet> {
self.sender.clone()
}
pub fn get_event_receiver(&mut self) -> Receiver<SocketEvent> {
self.receiver.clone()
}
pub fn send(&mut self, packet: Packet) -> Result<()> {
match self.sender.send(packet) {
Ok(_) => Ok(()),
Err(error) => Err(ErrorKind::SendError(SendError(SocketEvent::Packet(
error.0,
)))),
}
}
pub fn recv(&mut self) -> Option<SocketEvent> {
match self.receiver.try_recv() {
Ok(pkt) => Some(pkt),
Err(TryRecvError::Empty) => None,
Err(TryRecvError::Disconnected) => panic!["This can never happen"],
}
}
pub fn start_polling(&mut self) {
self.start_polling_with_duration(Some(Duration::from_millis(1)))
}
pub fn start_polling_with_duration(&mut self, sleep_duration: Option<Duration>) {
loop {
self.manual_poll(Instant::now());
match sleep_duration {
None => yield_now(),
Some(duration) => sleep(duration),
};
}
}
pub fn manual_poll(&mut self, time: Instant) {
loop {
match self.recv_from(time) {
Ok(UdpSocketState::MaybeMore) => continue,
Ok(UdpSocketState::MaybeEmpty) => break,
Err(e) => error!("Encountered an error receiving data: {:?}", e),
}
}
while let Ok(p) = self.packet_receiver.try_recv() {
if let Err(e) = self.send_to(p, time) {
match e {
ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
_ => error!("There was an error sending packet: {:?}", e),
}
}
}
if let Err(e) = self.handle_idle_clients(time) {
error!("Encountered an error when sending TimeoutEvent: {:?}", e);
}
self.handle_dead_clients().expect("Internal laminar error");
if let Some(heartbeat_interval) = self.config.heartbeat_interval {
if let Err(e) = self.send_heartbeat_packets(heartbeat_interval, time) {
match e {
ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
_ => error!("There was an error sending a heartbeat packet: {:?}", e),
}
}
}
}
pub fn set_link_conditioner(&mut self, link_conditioner: Option<LinkConditioner>) {
self.link_conditioner = link_conditioner;
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.socket.local_addr()?)
}
fn handle_dead_clients(&mut self) -> Result<()> {
let dead_addresses = self.connections.dead_connections();
for address in dead_addresses {
self.connections.remove_connection(&address);
self.event_sender.send(SocketEvent::Timeout(address))?;
}
Ok(())
}
fn handle_idle_clients(&mut self, time: Instant) -> Result<()> {
let idle_addresses = self
.connections
.idle_connections(self.config.idle_connection_timeout, time);
for address in idle_addresses {
self.connections.remove_connection(&address);
self.event_sender.send(SocketEvent::Timeout(address))?;
}
Ok(())
}
fn send_heartbeat_packets(
&mut self,
heartbeat_interval: Duration,
time: Instant,
) -> Result<usize> {
let heartbeat_packets_and_addrs = self
.connections
.heartbeat_required_connections(heartbeat_interval, time)
.map(|connection| {
(
connection.create_and_process_heartbeat(time),
connection.remote_address,
)
})
.collect::<Vec<_>>();
let mut bytes_sent = 0;
for (heartbeat_packet, address) in heartbeat_packets_and_addrs {
if self.should_send_packet() {
bytes_sent += self.send_packet(&address, &heartbeat_packet.contents())?;
}
}
Ok(bytes_sent)
}
fn send_to(&mut self, packet: Packet, time: Instant) -> Result<usize> {
let connection =
self.connections
.get_or_insert_connection(packet.addr(), &self.config, time);
let dropped = connection.gather_dropped_packets();
let mut processed_packets: Vec<Outgoing> = dropped
.iter()
.flat_map(|waiting_packet| {
connection.process_outgoing(
&waiting_packet.payload,
DeliveryGuarantee::Reliable,
waiting_packet.ordering_guarantee,
waiting_packet.item_identifier,
time,
)
})
.collect();
let processed_packet = connection.process_outgoing(
packet.payload(),
packet.delivery_guarantee(),
packet.order_guarantee(),
None,
time,
)?;
processed_packets.push(processed_packet);
let mut bytes_sent = 0;
for processed_packet in processed_packets {
if self.should_send_packet() {
match processed_packet {
Outgoing::Packet(outgoing) => {
bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?;
}
Outgoing::Fragments(packets) => {
for outgoing in packets {
bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?;
}
}
}
}
}
Ok(bytes_sent)
}
fn recv_from(&mut self, time: Instant) -> Result<UdpSocketState> {
match self.socket.recv_from(&mut self.recv_buffer) {
Ok((recv_len, address)) => {
if recv_len == 0 {
return Err(ErrorKind::ReceivedDataToShort)?;
}
let received_payload = &self.recv_buffer[..recv_len];
if !self.connections.exists(&address) {
self.event_sender.send(SocketEvent::Connect(address))?;
}
let connection =
self.connections
.get_or_create_connection(address, &self.config, time);
match connection {
Left(existing) => {
existing.process_incoming(received_payload, &self.event_sender, time)?;
}
Right(mut anonymous) => {
anonymous.process_incoming(received_payload, &self.event_sender, time)?;
}
}
}
Err(e) => {
if e.kind() != io::ErrorKind::WouldBlock {
error!("Encountered an error receiving data: {:?}", e);
return Err(e.into());
} else {
return Ok(UdpSocketState::MaybeEmpty);
}
}
}
if self.config.blocking_mode {
Ok(UdpSocketState::MaybeEmpty)
} else {
Ok(UdpSocketState::MaybeMore)
}
}
fn send_packet(&self, addr: &SocketAddr, payload: &[u8]) -> Result<usize> {
let bytes_sent = self.socket.send_to(payload, addr)?;
Ok(bytes_sent)
}
fn should_send_packet(&mut self) -> bool {
if let Some(link_conditioner) = &mut self.link_conditioner {
link_conditioner.should_send()
} else {
true
}
}
#[cfg(test)]
fn connection_count(&self) -> usize {
self.connections.count()
}
#[cfg(test)]
fn forget_all_incoming_packets(&mut self) {
std::thread::sleep(std::time::Duration::from_millis(100));
self.socket.set_nonblocking(true);
loop {
match self.socket.recv_from(&mut self.recv_buffer) {
Ok((recv_len, _address)) => {
if recv_len == 0 {
panic!("Received data too short");
}
&self.recv_buffer[..recv_len];
}
Err(e) => {
if e.kind() != io::ErrorKind::WouldBlock {
panic!("Encountered an error receiving data: {:?}", e);
} else {
self.socket.set_nonblocking(!self.config.blocking_mode);
return;
}
}
}
}
self.socket.set_nonblocking(!self.config.blocking_mode);
}
}
#[cfg(test)]
mod tests {
use crate::{
net::constants::{ACKED_PACKET_HEADER, FRAGMENT_HEADER_SIZE, STANDARD_HEADER_SIZE},
Config, LinkConditioner, Packet, Socket, SocketEvent,
};
use std::collections::HashSet;
use std::net::{SocketAddr, UdpSocket};
use std::time::{Duration, Instant};
#[test]
fn binding_to_any() {
assert![Socket::bind_any().is_ok()];
assert![Socket::bind_any_with_config(Config::default()).is_ok()];
}
#[test]
fn blocking_sender_and_receiver() {
let cfg = Config::default();
let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap();
let mut server = Socket::bind_any_with_config(Config {
blocking_mode: true,
..cfg
})
.unwrap();
let server_addr = server.local_addr().unwrap();
let client_addr = client.local_addr().unwrap();
let time = Instant::now();
client
.send(Packet::unreliable(
server_addr,
b"Hello world!".iter().cloned().collect::<Vec<_>>(),
))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
assert_eq![SocketEvent::Connect(client_addr), server.recv().unwrap()];
if let SocketEvent::Packet(packet) = server.recv().unwrap() {
assert_eq![b"Hello world!", packet.payload()];
} else {
panic!["Did not receive a packet when it should"];
}
}
#[test]
fn using_sender_and_receiver() {
let server_addr = "127.0.0.1:12310".parse::<SocketAddr>().unwrap();
let client_addr = "127.0.0.1:12311".parse::<SocketAddr>().unwrap();
let mut server = Socket::bind(server_addr).unwrap();
let mut client = Socket::bind(client_addr).unwrap();
let time = Instant::now();
let sender = client.get_packet_sender();
let receiver = server.get_event_receiver();
sender
.send(Packet::reliable_unordered(
server_addr,
b"Hello world!".iter().cloned().collect::<Vec<_>>(),
))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
assert_eq![Ok(SocketEvent::Connect(client_addr)), receiver.recv()];
if let SocketEvent::Packet(packet) = receiver.recv().unwrap() {
assert_eq![b"Hello world!", packet.payload()];
} else {
panic!["Did not receive a packet when it should"];
}
}
#[test]
fn initial_packet_is_resent() {
let mut server = Socket::bind("127.0.0.1:12335".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12336".parse::<SocketAddr>().unwrap()).unwrap();
let time = Instant::now();
client
.send(Packet::reliable_unordered(
"127.0.0.1:12335".parse::<SocketAddr>().unwrap(),
b"Do not arrive".iter().cloned().collect::<Vec<_>>(),
))
.unwrap();
client.manual_poll(time);
server.forget_all_incoming_packets();
for id in 0..u8::max_value() {
client
.send(create_test_packet(id, "127.0.0.1:12335"))
.unwrap();
server
.send(create_test_packet(id, "127.0.0.1:12336"))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
return;
}
}
while let Some(_) = client.recv() {}
}
panic!["Did not receive the ignored packet"];
}
#[test]
fn receiving_does_not_allow_denial_of_service() {
let mut server = Socket::bind("127.0.0.1:12337".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12338".parse::<SocketAddr>().unwrap()).unwrap();
for _ in 0..3 {
client
.send(Packet::unreliable(
"127.0.0.1:12337".parse::<SocketAddr>().unwrap(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let time = Instant::now();
client.manual_poll(time);
server.manual_poll(time);
for _ in 0..6 {
assert![server.recv().is_some()];
}
assert![server.recv().is_none()];
assert_eq![0, server.connection_count()];
server
.send(Packet::unreliable(
"127.0.0.1:12338".parse::<SocketAddr>().unwrap(),
vec![1],
))
.unwrap();
server.manual_poll(time);
assert_eq![1, server.connection_count()];
}
#[test]
fn initial_sequenced_is_resent() {
let mut server = Socket::bind("127.0.0.1:12329".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12330".parse::<SocketAddr>().unwrap()).unwrap();
let time = Instant::now();
client
.send(Packet::reliable_sequenced(
"127.0.0.1:12329".parse::<SocketAddr>().unwrap(),
b"Do not arrive".iter().cloned().collect::<Vec<_>>(),
None,
))
.unwrap();
client.manual_poll(time);
server.forget_all_incoming_packets();
for id in 0..36 {
client
.send(create_sequenced_packet(id, "127.0.0.1:12329"))
.unwrap();
server
.send(create_sequenced_packet(id, "127.0.0.1:12330"))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
panic!["Sequenced packet arrived while it should not"];
}
}
while let Some(_) = client.recv() {}
}
}
#[test]
fn initial_ordered_is_resent() {
let mut server = Socket::bind("127.0.0.1:12333".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12334".parse::<SocketAddr>().unwrap()).unwrap();
let time = Instant::now();
client
.send(Packet::reliable_ordered(
"127.0.0.1:12333".parse::<SocketAddr>().unwrap(),
b"Do not arrive".iter().cloned().collect::<Vec<_>>(),
None,
))
.unwrap();
client.manual_poll(time);
server.forget_all_incoming_packets();
for id in 0..35 {
client
.send(create_ordered_packet(id, "127.0.0.1:12333"))
.unwrap();
server
.send(create_ordered_packet(id, "127.0.0.1:12334"))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(SocketEvent::Packet(pkt)) = server.recv() {
if pkt.payload() == b"Do not arrive" {
return;
}
}
while let Some(_) = client.recv() {}
}
panic!["Did not receive the ignored packet"];
}
#[test]
fn do_not_duplicate_sequenced_packets_when_received() {
let mut config = Config::default();
let mut client = Socket::bind_any_with_config(config.clone()).unwrap();
config.blocking_mode = true;
let mut server = Socket::bind_any_with_config(config).unwrap();
let server_addr = server.local_addr().unwrap();
let client_addr = client.local_addr().unwrap();
let time = Instant::now();
for id in 0..100 {
client
.send(Packet::reliable_sequenced(server_addr, vec![id], None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
}
let mut seen = HashSet::new();
while let Some(message) = server.recv() {
match message {
SocketEvent::Connect(_) => {}
SocketEvent::Packet(packet) => {
let byte = packet.payload()[0];
assert![!seen.contains(&byte)];
seen.insert(byte);
}
SocketEvent::Timeout(_) => {
panic!["This should not happen, as we've not advanced time"];
}
}
}
assert_eq![100, seen.len()];
}
#[test]
fn more_than_65536_sequenced_packets() {
let mut config = Config::default();
let mut client = Socket::bind_any_with_config(config.clone()).unwrap();
config.blocking_mode = true;
let mut server = Socket::bind_any_with_config(config).unwrap();
let server_addr = server.local_addr().unwrap();
let client_addr = client.local_addr().unwrap();
server
.send(Packet::unreliable(client_addr, vec![0]))
.unwrap();
let time = Instant::now();
for id in 0..65536 + 100 {
client
.send(Packet::unreliable_sequenced(
server_addr,
id.to_string().as_bytes().to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
}
let mut cnt = 0;
while let Some(message) = server.recv() {
match message {
SocketEvent::Connect(_) => {}
SocketEvent::Packet(packet) => {
cnt += 1;
}
SocketEvent::Timeout(_) => {
panic!["This should not happen, as we've not advanced time"];
}
}
}
assert_eq![65536 + 100, cnt];
}
#[test]
fn sequenced_packets_pathological_case() {
let mut config = Config::default();
config.max_packets_in_flight = 100;
let mut client = Socket::bind_any_with_config(config.clone()).unwrap();
config.blocking_mode = true;
let mut server = Socket::bind_any_with_config(config).unwrap();
let server_addr = server.local_addr().unwrap();
let time = Instant::now();
for id in 0..101 {
client
.send(Packet::reliable_sequenced(
server_addr,
id.to_string().as_bytes().to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
while let Some(event) = client.recv() {
match event {
SocketEvent::Timeout(remote_addr) => {
assert_eq![100, id];
assert_eq![remote_addr, server_addr];
return;
}
_ => {
panic!["No other event possible"];
}
}
}
}
panic!["Should have received a timeout event"];
}
#[test]
fn manual_polling_socket() {
let mut server = Socket::bind("127.0.0.1:12339".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12340".parse::<SocketAddr>().unwrap()).unwrap();
for _ in 0..3 {
client
.send(Packet::unreliable(
"127.0.0.1:12339".parse::<SocketAddr>().unwrap(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let time = Instant::now();
client.manual_poll(time);
server.manual_poll(time);
assert!(server.recv().is_some());
assert!(server.recv().is_some());
assert!(server.recv().is_some());
}
#[test]
fn can_send_and_receive() {
let mut server = Socket::bind("127.0.0.1:12342".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12341".parse::<SocketAddr>().unwrap()).unwrap();
for _ in 0..3 {
client
.send(Packet::unreliable(
"127.0.0.1:12342".parse::<SocketAddr>().unwrap(),
vec![1, 2, 3, 4, 5, 6, 7, 8, 9],
))
.unwrap();
}
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert!(server.recv().is_some());
assert!(server.recv().is_some());
assert!(server.recv().is_some());
}
#[test]
fn sending_large_unreliable_packet_should_fail() {
let mut server = Socket::bind("127.0.0.1:12370".parse::<SocketAddr>().unwrap()).unwrap();
assert_eq!(
server
.send_to(
Packet::unreliable("127.0.0.1:12360".parse().unwrap(), vec![1; 5000]),
Instant::now(),
)
.is_err(),
true
);
}
#[test]
fn send_returns_right_size() {
let mut server = Socket::bind("127.0.0.1:12371".parse::<SocketAddr>().unwrap()).unwrap();
assert_eq!(
server
.send_to(
Packet::unreliable("127.0.0.1:12361".parse().unwrap(), vec![1; 1024]),
Instant::now(),
)
.unwrap(),
1024 + STANDARD_HEADER_SIZE as usize
);
}
#[test]
fn fragmentation_send_returns_right_size() {
let mut server = Socket::bind("127.0.0.1:12372".parse::<SocketAddr>().unwrap()).unwrap();
let fragment_packet_size = STANDARD_HEADER_SIZE + FRAGMENT_HEADER_SIZE;
assert_eq!(
server
.send_to(
Packet::reliable_unordered("127.0.0.1:12362".parse().unwrap(), vec![1; 4000]),
Instant::now(),
)
.unwrap(),
4000 + (fragment_packet_size * 4 + ACKED_PACKET_HEADER) as usize
);
}
#[test]
fn connect_event_occurs() {
let mut server = Socket::bind("127.0.0.1:12345".parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind("127.0.0.1:12344".parse::<SocketAddr>().unwrap()).unwrap();
client
.send(Packet::unreliable(
"127.0.0.1:12345".parse().unwrap(),
vec![0, 1, 2],
))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert_eq!(
server.recv().unwrap(),
SocketEvent::Connect("127.0.0.1:12344".parse().unwrap())
);
}
#[test]
fn disconnect_event_occurs() {
let mut config = Config::default();
config.idle_connection_timeout = Duration::from_millis(1);
let server_addr = "127.0.0.1:12347".parse::<SocketAddr>().unwrap();
let client_addr = "127.0.0.1:12346".parse::<SocketAddr>().unwrap();
let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap();
let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap();
client
.send(Packet::unreliable(server_addr, vec![0, 1, 2]))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr));
assert_eq!(
server.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2]))
);
server
.send(Packet::unreliable(client_addr, vec![]))
.unwrap();
server.manual_poll(now);
client.manual_poll(now);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(server_addr, vec![]))
);
server.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1));
client.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1));
assert_eq!(server.recv(), None);
assert_eq!(client.recv(), None);
server.manual_poll(now + config.idle_connection_timeout);
client.manual_poll(now + config.idle_connection_timeout);
assert_eq!(server.recv().unwrap(), SocketEvent::Timeout(client_addr));
assert_eq!(client.recv().unwrap(), SocketEvent::Timeout(server_addr));
}
#[test]
fn heartbeats_work() {
let mut config = Config::default();
config.idle_connection_timeout = Duration::from_millis(10);
config.heartbeat_interval = Some(Duration::from_millis(4));
let server_addr = "127.0.0.1:12351".parse::<SocketAddr>().unwrap();
let client_addr = "127.0.0.1:12352".parse::<SocketAddr>().unwrap();
let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap();
let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap();
client
.send(Packet::unreliable(server_addr, vec![0, 1, 2]))
.unwrap();
let now = Instant::now();
client.manual_poll(now);
server.manual_poll(now);
assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr));
assert_eq!(
server.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2]))
);
server
.send(Packet::unreliable(client_addr, vec![]))
.unwrap();
server.manual_poll(now);
client.manual_poll(now);
assert_eq!(
client.recv().unwrap(),
SocketEvent::Packet(Packet::unreliable(server_addr, vec![]))
);
client.manual_poll(now + config.heartbeat_interval.unwrap());
server.manual_poll(now + config.heartbeat_interval.unwrap());
client.manual_poll(now + config.idle_connection_timeout);
server.manual_poll(now + config.idle_connection_timeout);
assert_eq!(client.recv(), None);
assert_eq!(server.recv(), None);
}
fn create_test_packet(id: u8, addr: &str) -> Packet {
let payload = vec![id];
Packet::reliable_unordered(addr.parse().unwrap(), payload)
}
fn create_ordered_packet(id: u8, addr: &str) -> Packet {
let payload = vec![id];
Packet::reliable_ordered(addr.parse().unwrap(), payload, None)
}
fn create_sequenced_packet(id: u8, addr: &str) -> Packet {
let payload = vec![id];
Packet::reliable_sequenced(addr.parse().unwrap(), payload, None)
}
#[test]
fn multiple_sends_should_start_sending_dropped() {
const LOCAL_ADDR: &str = "127.0.0.1:13000";
const REMOTE_ADDR: &str = "127.0.0.1:14000";
let mut server = Socket::bind(REMOTE_ADDR.parse::<SocketAddr>().unwrap()).unwrap();
let mut client = Socket::bind(LOCAL_ADDR.parse::<SocketAddr>().unwrap()).unwrap();
let now = Instant::now();
for i in 0..35 {
client.send(create_test_packet(i, REMOTE_ADDR)).unwrap();
client.manual_poll(now);
}
let mut events = Vec::new();
loop {
server.manual_poll(now);
if let Some(event) = server.recv() {
events.push(event);
} else {
break;
}
}
assert_eq!(events.len(), 70);
server.send(create_test_packet(0, LOCAL_ADDR)).unwrap();
server.manual_poll(now);
loop {
client.manual_poll(now);
if client.recv().is_some() {
break;
}
}
events.clear();
client.send(create_test_packet(35, REMOTE_ADDR)).unwrap();
client.manual_poll(now);
loop {
server.manual_poll(now);
if let Some(event) = server.recv() {
events.push(event);
break;
}
}
let sent_events: Vec<u8> = events
.iter()
.flat_map(|e| match e {
SocketEvent::Packet(p) => Some(p.payload()[0]),
_ => None,
})
.collect();
assert_eq!(sent_events, vec![35]);
}
#[quickcheck_macros::quickcheck]
fn do_not_panic_on_arbitrary_packets(bytes: Vec<u8>) {
let receiver = "127.0.0.1:12332".parse::<SocketAddr>().unwrap();
let sender = "127.0.0.1:12331".parse::<SocketAddr>().unwrap();
let mut server = Socket::bind(receiver).unwrap();
let client = UdpSocket::bind(sender).unwrap();
client.send_to(&bytes, receiver).unwrap();
let time = Instant::now();
server.manual_poll(time);
}
#[test]
fn really_bad_network_keeps_chugging_along() {
let server_addr = "127.0.0.1:12320".parse::<SocketAddr>().unwrap();
let client_addr = "127.0.0.1:12321".parse::<SocketAddr>().unwrap();
let mut server = Socket::bind(server_addr).unwrap();
let mut client = Socket::bind(client_addr).unwrap();
let time = Instant::now();
let link_conditioner = {
let mut lc = LinkConditioner::new();
lc.set_packet_loss(0.9);
Some(lc)
};
client.set_link_conditioner(link_conditioner.clone());
server.set_link_conditioner(link_conditioner);
let mut set = HashSet::new();
let mut send_many_packets = |dummy: Option<u8>| {
for id in 0..100 {
client
.send(Packet::reliable_unordered(
server_addr,
vec![dummy.unwrap_or(id)],
))
.unwrap();
server
.send(Packet::reliable_unordered(client_addr, vec![255]))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
while let Some(_) = client.recv() {}
while let Some(event) = server.recv() {
match event {
SocketEvent::Packet(pkt) => {
set.insert(pkt.payload()[0]);
}
SocketEvent::Timeout(_) => {
panic!["Unable to time out, time has not advanced"]
}
SocketEvent::Connect(_) => {}
}
}
}
return set.len();
};
send_many_packets(None);
send_many_packets(Some(255));
send_many_packets(Some(255));
send_many_packets(Some(255));
assert_eq![101, send_many_packets(Some(255))];
}
#[test]
fn local_addr() {
let port = 40000;
let socket =
Socket::bind(format!("127.0.0.1:{}", port).parse::<SocketAddr>().unwrap()).unwrap();
assert_eq!(port, socket.local_addr().unwrap().port());
}
#[test]
fn ordered_16_bit_overflow() {
let mut cfg = Config::default();
let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap();
let client_addr = client.local_addr().unwrap();
cfg.blocking_mode = false;
let mut server = Socket::bind_any_with_config(cfg).unwrap();
let server_addr = server.local_addr().unwrap();
let time = Instant::now();
let mut last_payload = String::new();
for idx in 0..100_000u64 {
client
.send(Packet::reliable_ordered(
server_addr,
idx.to_string().as_bytes().to_vec(),
None,
))
.unwrap();
client.manual_poll(time);
while let Some(_) = client.recv() {}
server
.send(Packet::reliable_ordered(client_addr, vec![123], None))
.unwrap();
server.manual_poll(time);
while let Some(msg) = server.recv() {
match msg {
SocketEvent::Packet(pkt) => {
last_payload = std::str::from_utf8(pkt.payload()).unwrap().to_string();
}
_ => {}
}
}
}
assert_eq!["99999", last_payload];
}
#[test]
fn fragmented_ordered_gets_acked() {
let mut cfg = Config::default();
cfg.fragment_size = 10;
let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap();
let client_addr = client.local_addr().unwrap();
cfg.blocking_mode = true;
let mut server = Socket::bind_any_with_config(cfg).unwrap();
let server_addr = server.local_addr().unwrap();
let time = Instant::now();
let dummy = vec![0];
client
.send(Packet::unreliable(server_addr, dummy.clone()))
.unwrap();
client.manual_poll(time);
server
.send(Packet::unreliable(client_addr, dummy.clone()))
.unwrap();
server.manual_poll(time);
let exceeds = b"Fragmented string".to_vec();
client
.send(Packet::reliable_ordered(server_addr, exceeds, None))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
server.manual_poll(time);
server
.send(Packet::reliable_ordered(client_addr, dummy.clone(), None))
.unwrap();
client
.send(Packet::unreliable(server_addr, dummy.clone()))
.unwrap();
client.manual_poll(time);
server.manual_poll(time);
for _ in 0..4 {
assert![server.recv().is_some()];
}
assert![server.recv().is_none()];
for _ in 0..34 {
client
.send(Packet::reliable_ordered(server_addr, dummy.clone(), None))
.unwrap();
client.manual_poll(time);
server
.send(Packet::reliable_ordered(client_addr, dummy.clone(), None))
.unwrap();
server.manual_poll(time);
assert![client.recv().is_some()];
match server.recv() {
Some(SocketEvent::Packet(pkt)) => {
assert_eq![dummy, pkt.payload()];
}
_ => {
panic!["Did not receive expected dummy packet"];
}
}
}
}
}