use crate::error::SocketError;
use crate::packet::*;
use crate::time::*;
use crate::util::*;
use async_std::task;
use async_std::{
io,
net::{SocketAddr, ToSocketAddrs, UdpSocket},
};
use futures::future::BoxFuture;
use futures::ready;
use log::debug;
use std::cmp::{max, min};
use std::collections::VecDeque;
use std::io::{ErrorKind, Result};
use std::task::Poll;
use std::time::{Duration, Instant};
pub(crate) const BUF_SIZE: usize = 1500;
const GAIN: f64 = 1.0;
const ALLOWED_INCREASE: u32 = 1;
const TARGET: f64 = 100_000.0; const MSS: u32 = 1400;
const MIN_CWND: u32 = 2;
const INIT_CWND: u32 = 2;
const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; const MIN_CONGESTION_TIMEOUT: u64 = 500; const MAX_CONGESTION_TIMEOUT: u64 = 60_000; const BASE_HISTORY: usize = 10; const MAX_SYN_RETRIES: u32 = 5; const MAX_RETRANSMISSION_RETRIES: u32 = 5; const WINDOW_SIZE: u32 = 1024 * 1024;
const PRE_SEND_TIMEOUT: u32 = 500_000;
const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
enum SocketState {
New,
Connected,
SynSent,
FinSent,
ResetReceived,
Closed,
}
#[derive(Debug)]
struct DelayDifferenceSample {
received_at: Timestamp,
difference: Delay,
}
async fn take_address<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
addr.to_socket_addrs()
.await
.and_then(|mut it| it.next().ok_or_else(|| SocketError::InvalidAddress.into()))
}
#[derive(Debug)]
pub struct UtpSocket {
socket: UdpSocket,
connected_to: SocketAddr,
sender_connection_id: u16,
receiver_connection_id: u16,
seq_nr: u16,
ack_nr: u16,
state: SocketState,
incoming_buffer: Vec<Packet>,
send_window: Vec<Packet>,
unsent_queue: VecDeque<Packet>,
duplicate_ack_count: u32,
last_acked: u16,
last_acked_timestamp: Timestamp,
last_dropped: u16,
rtt: i32,
rtt_variance: i32,
pending_data: Vec<u8>,
curr_window: u32,
remote_wnd_size: u32,
base_delays: VecDeque<Delay>,
current_delays: Vec<DelayDifferenceSample>,
their_delay: Delay,
last_rollover: Timestamp,
congestion_timeout: u64,
cwnd: u32,
pub max_retransmission_retries: u32,
}
impl UtpSocket {
fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
let (receiver_id, sender_id) = generate_sequential_identifiers();
UtpSocket {
socket: s,
connected_to: src,
receiver_connection_id: receiver_id,
sender_connection_id: sender_id,
seq_nr: 1,
ack_nr: 0,
state: SocketState::New,
incoming_buffer: Vec::new(),
send_window: Vec::new(),
unsent_queue: VecDeque::new(),
duplicate_ack_count: 0,
last_acked: 0,
last_acked_timestamp: Timestamp::default(),
last_dropped: 0,
rtt: 0,
rtt_variance: 0,
pending_data: Vec::new(),
curr_window: 0,
remote_wnd_size: 0,
current_delays: Vec::new(),
base_delays: VecDeque::with_capacity(BASE_HISTORY),
their_delay: Delay::default(),
last_rollover: Timestamp::default(),
congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
cwnd: INIT_CWND * MSS,
max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
}
}
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
let addr = take_address(addr).await?;
let socket = UdpSocket::bind(addr).await?;
Ok(UtpSocket::from_raw_parts(socket, addr))
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.local_addr()
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
if self.state == SocketState::Connected || self.state == SocketState::FinSent {
Ok(self.connected_to)
} else {
Err(SocketError::NotConnected.into())
}
}
pub async fn connect<A: ToSocketAddrs>(other: A) -> Result<UtpSocket> {
let addr = take_address(other).await?;
let my_addr = match addr {
SocketAddr::V4(_) => "0.0.0.0:0",
SocketAddr::V6(_) => "[::]:0",
};
let mut socket = UtpSocket::bind(my_addr).await?;
socket.connected_to = addr;
let mut packet = Packet::new();
packet.set_type(PacketType::Syn);
packet.set_connection_id(socket.receiver_connection_id);
packet.set_seq_nr(socket.seq_nr);
let mut len = 0;
let mut buf = [0u8; BUF_SIZE];
let mut syn_timeout = Duration::from_millis(socket.congestion_timeout);
for _ in 0..MAX_SYN_RETRIES {
packet.set_timestamp(now_microseconds());
debug!("Connecting to {}", socket.connected_to);
socket
.socket
.send_to(packet.as_ref(), socket.connected_to)
.await?;
socket.state = SocketState::SynSent;
debug!("sent {:?}", packet);
match io::timeout(syn_timeout, socket.socket.recv_from(&mut buf)).await {
Ok((read, src)) => {
socket.connected_to = src;
len = read;
break;
}
Err(ref e)
if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
{
debug!("Timed out, retrying");
syn_timeout *= 2;
continue;
}
Err(e) => return Err(e),
};
}
let addr = socket.connected_to;
let packet = Packet::try_from(&buf[..len])?;
debug!("received {:?}", packet);
socket.handle_packet(&packet, addr).await?;
debug!("connected to: {}", socket.connected_to);
Ok(socket)
}
pub async fn close(&mut self) -> Result<()> {
if self.state == SocketState::Closed
|| self.state == SocketState::New
|| self.state == SocketState::SynSent
{
return Ok(());
}
self.flush().await?;
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("sent {:?}", packet);
self.state = SocketState::FinSent;
let mut buf = [0u8; BUF_SIZE];
while self.state != SocketState::Closed {
self.recv(&mut buf).await?;
}
Ok(())
}
pub async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
let read = self.flush_incoming_buffer(buf);
if read > 0 {
return Ok((read, self.connected_to));
}
if self.state == SocketState::ResetReceived {
return Err(SocketError::ConnectionReset.into());
}
loop {
if self.state == SocketState::Closed {
return Ok((0, self.connected_to));
}
match self.recv(buf).await {
Ok((0, _src)) => continue,
Ok(x) => return Ok(x),
Err(e) => return Err(e),
}
}
}
async fn recv(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
let mut b = [0; BUF_SIZE + HEADER_SIZE];
let start = Instant::now();
let (read, src);
let mut retries = 0;
loop {
if retries >= self.max_retransmission_retries {
self.state = SocketState::Closed;
return Err(SocketError::ConnectionTimedOut.into());
}
let timeout = if self.state != SocketState::New {
debug!("setting read timeout of {} ms", self.congestion_timeout);
Some(Duration::from_millis(self.congestion_timeout))
} else {
None
};
let response = match timeout {
Some(timeout) => io::timeout(timeout, self.socket.recv_from(&mut b)).await,
None => self.socket.recv_from(&mut b).await,
};
match response {
Ok((r, s)) => {
read = r;
src = s;
break;
}
Err(ref e)
if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
{
debug!("recv_from timed out");
self.handle_receive_timeout().await?;
}
Err(e) => return Err(e),
};
let elapsed = start.elapsed();
let elapsed_ms = elapsed.as_secs() * 1000 + elapsed.subsec_millis() as u64;
debug!("{} ms elapsed", elapsed_ms);
retries += 1;
}
let packet = match Packet::try_from(&b[..read]) {
Ok(packet) => packet,
Err(e) => {
debug!("{}", e);
debug!("Ignoring invalid packet");
return Ok((0, self.connected_to));
}
};
debug!("received {:?}", packet);
if let Some(mut pkt) = self.handle_packet(&packet, src).await? {
pkt.set_wnd_size(WINDOW_SIZE);
self.socket.send_to(pkt.as_ref(), src).await?;
debug!("sent {:?}", pkt);
}
if packet.get_type() == PacketType::Data
&& packet.seq_nr().wrapping_sub(self.last_dropped) > 0
{
self.insert_into_buffer(packet);
}
let read = self.flush_incoming_buffer(buf);
Ok((read, src))
}
async fn handle_receive_timeout(&mut self) -> Result<()> {
self.congestion_timeout *= 2;
self.cwnd = MSS;
debug!(
"self.send_window: {:?}",
self.send_window
.iter()
.map(Packet::seq_nr)
.collect::<Vec<u16>>()
);
if self.send_window.is_empty() {
if self.state == SocketState::FinSent {
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("resent FIN: {:?}", packet);
} else if self.state != SocketState::New {
debug!("sending fast resend request");
self.send_fast_resend_request().await;
}
} else {
let packet = &mut self.send_window[0];
packet.set_timestamp(now_microseconds());
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("resent {:?}", packet);
}
Ok(())
}
fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
let mut resp = Packet::new();
resp.set_type(t);
let self_t_micro = now_microseconds();
let other_t_micro = original.timestamp();
let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
resp.set_timestamp(self_t_micro);
resp.set_timestamp_difference(time_difference);
resp.set_connection_id(self.sender_connection_id);
resp.set_seq_nr(self.seq_nr);
resp.set_ack_nr(self.ack_nr);
resp
}
fn advance_incoming_buffer(&mut self) -> Option<Packet> {
if !self.incoming_buffer.is_empty() {
let packet = self.incoming_buffer.remove(0);
debug!("Removed packet from incoming buffer: {:?}", packet);
self.ack_nr = packet.seq_nr();
self.last_dropped = self.ack_nr;
Some(packet)
} else {
None
}
}
fn flush_incoming_buffer(&mut self, buf: &mut [u8]) -> usize {
fn unsafe_copy(src: &[u8], dst: &mut [u8]) -> usize {
let max_len = min(src.len(), dst.len());
unsafe {
use std::ptr::copy;
copy(src.as_ptr(), dst.as_mut_ptr(), max_len);
}
max_len
}
if !self.pending_data.is_empty() {
let flushed = unsafe_copy(&self.pending_data[..], buf);
if flushed == self.pending_data.len() {
self.pending_data.clear();
self.advance_incoming_buffer();
} else {
self.pending_data = self.pending_data[flushed..].to_vec();
}
return flushed;
}
if !self.incoming_buffer.is_empty()
&& (self.ack_nr == self.incoming_buffer[0].seq_nr()
|| self.ack_nr + 1 == self.incoming_buffer[0].seq_nr())
{
let flushed = unsafe_copy(&self.incoming_buffer[0].payload(), buf);
if flushed == self.incoming_buffer[0].payload().len() {
self.advance_incoming_buffer();
} else {
self.pending_data = self.incoming_buffer[0].payload()[flushed..].to_vec();
}
return flushed;
}
0
}
pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
if self.state == SocketState::Closed {
return Err(SocketError::ConnectionClosed.into());
}
let total_length = buf.len();
for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_connection_id(self.sender_connection_id);
self.unsent_queue.push_back(packet);
self.seq_nr = self.seq_nr.wrapping_add(1);
}
self.send().await?;
Ok(total_length)
}
pub async fn flush(&mut self) -> Result<()> {
let mut buf = [0u8; BUF_SIZE];
while !self.send_window.is_empty() {
debug!("packets in send window: {}", self.send_window.len());
self.recv(&mut buf).await?;
}
Ok(())
}
async fn send(&mut self) -> Result<()> {
while let Some(mut packet) = self.unsent_queue.pop_front() {
self.send_packet(&mut packet).await?;
self.curr_window += packet.len() as u32;
self.send_window.push(packet);
}
Ok(())
}
async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
debug!("current window: {}", self.send_window.len());
let max_inflight = min(self.cwnd, self.remote_wnd_size);
let max_inflight = max(MIN_CWND * MSS, max_inflight);
let now = now_microseconds();
while self.curr_window >= max_inflight && now_microseconds() - now < PRE_SEND_TIMEOUT.into()
{
debug!("self.curr_window: {}", self.curr_window);
debug!("max_inflight: {}", max_inflight);
debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count);
debug!("now_microseconds() - now = {}", now_microseconds() - now);
let mut buf = [0u8; BUF_SIZE];
self.recv(&mut buf).await?;
}
debug!(
"out: now_microseconds() - now = {}",
now_microseconds() - now
);
let distance_a = packet.seq_nr().wrapping_sub(self.last_acked);
let distance_b = self.last_acked.wrapping_sub(packet.seq_nr());
if distance_a > distance_b {
debug!("Packet already acknowledged, skipping...");
return Ok(());
}
packet.set_timestamp(now_microseconds());
packet.set_timestamp_difference(self.their_delay);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("sent {:?}", packet);
Ok(())
}
fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
if self.base_delays.is_empty() || now - self.last_rollover > MAX_BASE_DELAY_AGE {
self.last_rollover = now;
if self.base_delays.len() == BASE_HISTORY {
self.base_delays.pop_front();
}
self.base_delays.push_back(base_delay);
} else {
let last_idx = self.base_delays.len() - 1;
if base_delay < self.base_delays[last_idx] {
self.base_delays[last_idx] = base_delay;
}
}
}
fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
let rtt = (self.rtt as i64 * 100).into();
while !self.current_delays.is_empty() && now - self.current_delays[0].received_at > rtt {
self.current_delays.remove(0);
}
self.current_delays.push(DelayDifferenceSample {
received_at: now,
difference: v,
});
}
fn update_congestion_timeout(&mut self, current_delay: i32) {
let delta = self.rtt - current_delay;
self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
self.rtt += (current_delay - self.rtt) / 8;
self.congestion_timeout = max(
(self.rtt + self.rtt_variance * 4) as u64,
MIN_CONGESTION_TIMEOUT,
);
self.congestion_timeout = min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
debug!("current_delay: {}", current_delay);
debug!("delta: {}", delta);
debug!("self.rtt_variance: {}", self.rtt_variance);
debug!("self.rtt: {}", self.rtt);
debug!("self.congestion_timeout: {}", self.congestion_timeout);
}
fn filtered_current_delay(&self) -> Delay {
let input = self.current_delays.iter().map(|delay| &delay.difference);
(ewma(input, 0.333) as i64).into()
}
fn min_base_delay(&self) -> Delay {
self.base_delays.iter().min().cloned().unwrap_or_default()
}
fn build_selective_ack(&self) -> Vec<u8> {
let stashed = self
.incoming_buffer
.iter()
.filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
.map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
.map(|diff| (diff / 8, diff % 8));
let mut sack = Vec::new();
for (byte, bit) in stashed {
while byte >= sack.len() || sack.len() % 4 != 0 {
sack.push(0u8);
}
sack[byte] |= 1 << bit;
}
sack
}
async fn send_fast_resend_request(&self) {
for _ in 0..3 {
let mut packet = Packet::new();
packet.set_type(PacketType::State);
let self_t_micro = now_microseconds();
packet.set_timestamp(self_t_micro);
packet.set_timestamp_difference(self.their_delay);
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
let _ = self
.socket
.send_to(packet.as_ref(), self.connected_to)
.await;
}
}
async fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
match self
.send_window
.iter()
.position(|pkt| pkt.seq_nr() == lost_packet_nr)
{
None => debug!("Packet {} not found", lost_packet_nr),
Some(position) => {
debug!("self.send_window.len(): {}", self.send_window.len());
debug!("position: {}", position);
let mut packet = self.send_window[position].clone();
let _ = self.send_packet(&mut packet).await;
}
}
debug!("---> END resend_lost_packet <---");
}
fn advance_send_window(&mut self) {
if let Some(position) = self
.send_window
.iter()
.position(|packet| packet.seq_nr() == self.last_acked)
{
for _ in 0..position + 1 {
let packet = self.send_window.remove(0);
self.curr_window -= packet.len() as u32;
}
}
debug!("self.curr_window: {}", self.curr_window);
}
async fn handle_packet(&mut self, packet: &Packet, src: SocketAddr) -> Result<Option<Packet>> {
debug!("({:?}, {:?})", self.state, packet.get_type());
if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
self.ack_nr = packet.seq_nr();
}
if packet.get_type() != PacketType::Syn
&& self.state != SocketState::SynSent
&& !(packet.connection_id() == self.sender_connection_id
|| packet.connection_id() == self.receiver_connection_id)
{
return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
}
self.remote_wnd_size = packet.wnd_size();
debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
let now = now_microseconds();
self.their_delay = abs_diff(now, packet.timestamp());
debug!("self.their_delay: {}", self.their_delay);
match (self.state, packet.get_type()) {
(SocketState::New, PacketType::Syn) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr = rand::random();
self.receiver_connection_id = packet.connection_id() + 1;
self.sender_connection_id = packet.connection_id();
self.state = SocketState::Connected;
self.last_dropped = self.ack_nr;
Ok(Some(self.prepare_reply(packet, PacketType::State)))
}
(_, PacketType::Syn) => Ok(Some(self.prepare_reply(packet, PacketType::Reset))),
(SocketState::SynSent, PacketType::State) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr += 1;
self.state = SocketState::Connected;
self.last_acked = packet.ack_nr();
self.last_acked_timestamp = now_microseconds();
Ok(None)
}
(SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
(SocketState::Connected, PacketType::Data)
| (SocketState::FinSent, PacketType::Data) => Ok(self.handle_data_packet(packet)),
(SocketState::Connected, PacketType::State) => {
self.handle_state_packet(packet).await;
Ok(None)
}
(SocketState::Connected, PacketType::Fin) | (SocketState::FinSent, PacketType::Fin) => {
if packet.ack_nr() < self.seq_nr {
debug!("FIN received but there are missing acknowledgements for sent packets");
}
let mut reply = self.prepare_reply(packet, PacketType::State);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!(
"current ack_nr ({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr()
);
let sack = self.build_selective_ack();
if !sack.is_empty() {
reply.set_sack(sack);
}
}
self.state = SocketState::Closed;
Ok(Some(reply))
}
(SocketState::Closed, PacketType::Fin) => {
Ok(Some(self.prepare_reply(packet, PacketType::State)))
}
(SocketState::FinSent, PacketType::State) => {
if packet.ack_nr() == self.seq_nr {
self.state = SocketState::Closed;
} else {
self.handle_state_packet(packet);
}
Ok(None)
}
(_, PacketType::Reset) => {
self.state = SocketState::ResetReceived;
Err(SocketError::ConnectionReset.into())
}
(state, ty) => {
let message = format!("Unimplemented handling for ({:?},{:?})", state, ty);
debug!("{}", message);
Err(SocketError::Other(message).into())
}
}
}
fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
let packet_type = if self.state == SocketState::FinSent {
PacketType::Fin
} else {
PacketType::State
};
let mut reply = self.prepare_reply(packet, packet_type);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!(
"current ack_nr ({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr()
);
let sack = self.build_selective_ack();
if !sack.is_empty() {
reply.set_sack(sack);
}
}
Some(reply)
}
fn queuing_delay(&self) -> Delay {
let filtered_current_delay = self.filtered_current_delay();
let min_base_delay = self.min_base_delay();
let queuing_delay = filtered_current_delay - min_base_delay;
debug!("filtered_current_delay: {}", filtered_current_delay);
debug!("min_base_delay: {}", min_base_delay);
debug!("queuing_delay: {}", queuing_delay);
queuing_delay
}
fn update_congestion_window(&mut self, off_target: f64, bytes_newly_acked: u32) {
let flightsize = self.curr_window;
let cwnd_increase = GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
let cwnd_increase = cwnd_increase / self.cwnd as f64;
debug!("cwnd_increase: {}", cwnd_increase);
self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
self.cwnd = min(self.cwnd, max_allowed_cwnd);
self.cwnd = max(self.cwnd, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
}
#[async_recursion::async_recursion]
async fn handle_state_packet(&mut self, packet: &Packet) {
if packet.ack_nr() == self.last_acked {
self.duplicate_ack_count += 1;
} else {
self.last_acked = packet.ack_nr();
self.last_acked_timestamp = now_microseconds();
self.duplicate_ack_count = 1;
}
if let Some(index) = self
.send_window
.iter()
.position(|p| packet.ack_nr() == p.seq_nr())
{
let bytes_newly_acked = self
.send_window
.iter()
.take(index + 1)
.fold(0, |acc, p| acc + p.len());
let now = now_microseconds();
let our_delay = now - self.send_window[index].timestamp();
debug!("our_delay: {}", our_delay);
self.update_base_delay(our_delay, now);
self.update_current_delay(our_delay, now);
let off_target: f64 = (TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
debug!("off_target: {}", off_target);
self.update_congestion_window(off_target, bytes_newly_acked as u32);
let rtt = u32::from(our_delay - self.queuing_delay()) / 1000; self.update_congestion_timeout(rtt as i32);
}
let mut packet_loss_detected: bool =
!self.send_window.is_empty() && self.duplicate_ack_count == 3;
for extension in packet.extensions() {
if extension.get_type() == ExtensionType::SelectiveAck {
if extension.iter().count_ones() >= 3 {
self.resend_lost_packet(packet.ack_nr() + 1).await;
packet_loss_detected = true;
}
if let Some(last_seq_nr) = self.send_window.last().map(Packet::seq_nr) {
let lost_packets = extension
.iter()
.enumerate()
.filter(|&(_, received)| !received)
.map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
.take_while(|&seq_nr| seq_nr < last_seq_nr);
for seq_nr in lost_packets {
debug!("SACK: packet {} lost", seq_nr);
self.resend_lost_packet(seq_nr).await;
packet_loss_detected = true;
}
}
} else {
debug!("Unknown extension {:?}, ignoring", extension.get_type());
}
}
if !self.send_window.is_empty()
&& self.duplicate_ack_count == 3
&& !packet
.extensions()
.any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
{
self.resend_lost_packet(packet.ack_nr() + 1).await;
}
if packet_loss_detected {
debug!("packet loss detected, halving congestion window");
self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
}
self.advance_send_window();
}
fn insert_into_buffer(&mut self, packet: Packet) {
if self
.incoming_buffer
.last()
.map_or(false, |p| packet.seq_nr() > p.seq_nr())
{
self.incoming_buffer.push(packet);
} else {
let i = self
.incoming_buffer
.iter()
.filter(|p| p.seq_nr() < packet.seq_nr())
.count();
if self
.incoming_buffer
.get(i)
.map_or(true, |p| p.seq_nr() != packet.seq_nr())
{
self.incoming_buffer.insert(i, packet);
}
}
}
}
impl Drop for UtpSocket {
fn drop(&mut self) {
task::block_on(async {
drop(self.close().await);
});
}
}
pub struct UtpListener {
socket: UdpSocket,
}
impl UtpListener {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpListener> {
let socket = UdpSocket::bind(addr).await?;
Ok(UtpListener { socket })
}
pub async fn accept(&self) -> Result<(UtpSocket, SocketAddr)> {
let mut buf = [0; BUF_SIZE];
let (nread, src) = self.socket.recv_from(&mut buf).await?;
let packet = Packet::try_from(&buf[..nread])?;
if packet.get_type() != PacketType::Syn {
let message = format!("Expected SYN packet, got {:?} instead", packet.get_type());
return Err(SocketError::Other(message).into());
}
let local_addr = self.socket.local_addr()?;
let inner_socket = match local_addr {
SocketAddr::V4(_) => UdpSocket::bind("0.0.0.0:0"),
SocketAddr::V6(_) => UdpSocket::bind("[::]:0"),
}
.await?;
let mut socket = UtpSocket::from_raw_parts(inner_socket, src);
if let Ok(Some(reply)) = socket.handle_packet(&packet, src).await {
socket
.socket
.send_to(reply.as_ref(), src)
.await
.and(Ok((socket, src)))
} else {
Err(SocketError::Other("Reached unreachable statement".to_owned()).into())
}
}
pub fn incoming(&self) -> Incoming<'_> {
Incoming {
listener: self,
accept: None,
}
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.local_addr()
}
}
type AcceptFuture<'a> = Option<BoxFuture<'a, io::Result<(UtpSocket, SocketAddr)>>>;
pub struct Incoming<'a> {
listener: &'a UtpListener,
accept: AcceptFuture<'a>,
}
impl<'a> futures::Stream for Incoming<'a> {
type Item = Result<(UtpSocket, SocketAddr)>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
if self.accept.is_none() {
self.accept = Some(Box::pin(self.listener.accept()));
}
if let Some(f) = &mut self.accept {
let res = ready!(f.as_mut().poll(cx));
self.accept = None;
return Poll::Ready(Some(res));
}
}
}
}
#[cfg(test)]
mod test {
use crate::packet::*;
use crate::socket::{take_address, SocketState, UtpListener, UtpSocket, BUF_SIZE};
use crate::time::now_microseconds;
use async_std::task;
use rand;
use std::io::ErrorKind;
use std::net::ToSocketAddrs;
macro_rules! iotry {
($e:expr) => {
match $e.await {
Ok(e) => e,
Err(e) => panic!("{:?}", e),
}
};
}
fn next_test_port() -> u16 {
use std::sync::atomic::{AtomicUsize, Ordering};
static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
const BASE_PORT: u16 = 9600;
BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
}
fn next_test_ip4<'a>() -> (&'a str, u16) {
("127.0.0.1", next_test_port())
}
fn next_test_ip6<'a>() -> (&'a str, u16) {
("::1", next_test_port())
}
#[async_std::test]
async fn test_socket_ipv4() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
assert_eq!(
client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap()
);
iotry!(client.close());
drop(client);
});
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf).await {
e => println!("{:?}", e),
}
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Closed);
drop(server);
child.await;
}
#[async_std::test]
async fn test_socket_ipv6() {
let server_addr = next_test_ip6();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
assert_eq!(
client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap()
);
iotry!(client.close());
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf).await {
e => println!("{:?}", e),
}
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Closed);
drop(server);
child.await;
}
#[async_std::test]
async fn test_recvfrom_on_closed_socket() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert!(client.close().await.is_ok());
});
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv_from(&mut buf).await;
assert_eq!(server.state, SocketState::Closed);
match server.recv_from(&mut buf).await {
Ok((0, _src)) => {}
e => panic!("Expected Ok(0), got {:?}", e),
}
assert_eq!(server.state, SocketState::Closed);
child.await;
}
#[async_std::test]
async fn test_sendto_on_closed_socket() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let (_read, _src) = iotry!(server.recv_from(&mut buf));
assert_eq!(server.state, SocketState::Closed);
match server.send_to(&buf).await {
Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
}
child.await;
}
#[async_std::test]
async fn test_acks_on_socket() {
use std::sync::mpsc::channel;
let server_addr = next_test_ip4();
let (tx, rx) = channel();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = task::spawn(async move {
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv(&mut buf).await;
tx.send(server.seq_nr).unwrap();
iotry!(server.recv_from(&mut buf));
drop(server);
});
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
let sender_seq_nr = rx.recv().unwrap();
let ack_nr = client.ack_nr;
assert_eq!(ack_nr, sender_seq_nr);
assert!(client.close().await.is_ok());
assert_eq!(client.ack_nr, ack_nr);
drop(client);
child.await;
}
#[async_std::test]
async fn test_handle_packet() {
let initial_connection_id: u16 = rand::random();
let sender_connection_id = initial_connection_id + 1;
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(response.connection_id(), packet.connection_id());
assert_eq!(response.ack_nr(), packet.seq_nr());
assert!(response.payload().is_empty());
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(response.connection_id(), initial_connection_id);
assert_eq!(response.connection_id(), packet.connection_id() - 1);
assert_eq!(response.ack_nr(), packet.seq_nr());
assert!(response.payload().is_empty());
assert_eq!(response.seq_nr(), old_response.seq_nr());
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Fin);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
assert_eq!(response.seq_nr(), old_response.seq_nr());
assert_eq!(response.ack_nr(), packet.seq_nr());
}
#[async_std::test]
async fn test_response_to_keepalive_ack() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
socket.state = SocketState::Closed;
}
#[async_std::test]
async fn test_response_to_wrong_connection_id() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
assert_eq!(response.unwrap().get_type(), PacketType::State);
let new_connection_id = initial_connection_id.wrapping_mul(2);
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(new_connection_id);
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::Reset);
assert_eq!(response.ack_nr(), packet.seq_nr());
socket.state = SocketState::Closed;
}
#[async_std::test]
async fn test_unordered_packets() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
let old_packet = packet;
let old_response = response;
let mut window: Vec<Packet> = Vec::new();
let mut packet = Packet::with_payload(&[1, 2, 3]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
window.push(packet);
let mut packet = Packet::with_payload(&[4, 5, 6]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 2);
packet.set_ack_nr(old_response.seq_nr());
window.push(packet);
let response = socket.handle_packet(&window[1], client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.ack_nr() != window[1].seq_nr());
let response = socket.handle_packet(&window[0], client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
socket.state = SocketState::Closed;
}
#[async_std::test]
async fn test_response_to_triple_ack() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert_eq!(LEN, data.len());
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&d[..]));
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv(&mut buf));
let data_packet = match server.socket.recv_from(&mut buf).await {
Ok((read, _src)) => Packet::try_from(&buf[..read]).unwrap(),
Err(e) => panic!("{}", e),
};
assert_eq!(data_packet.get_type(), PacketType::Data);
assert_eq!(&data_packet.payload(), &data.as_slice());
assert_eq!(data_packet.payload().len(), data.len());
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(data_packet.seq_nr() - 1);
packet.set_connection_id(server.sender_connection_id);
for _ in 0..3u8 {
iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
}
let client_addr = server.connected_to;
match server.socket.recv_from(&mut buf).await {
Ok((0, _)) => panic!("Received 0 bytes from socket"),
Ok((read, _src)) => {
let packet = Packet::try_from(&buf[..read]).unwrap();
assert_eq!(packet.get_type(), PacketType::Data);
assert_eq!(packet.seq_nr(), data_packet.seq_nr());
assert_eq!(packet.payload(), data_packet.payload());
let response = server.handle_packet(&packet, client_addr).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
iotry!(server
.socket
.send_to(response.as_ref(), server.connected_to));
}
Err(e) => panic!("{}", e),
}
iotry!(server.recv_from(&mut buf));
child.await;
}
#[async_std::test]
async fn test_socket_timeout_request() {
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 512;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert_eq!(server.state, SocketState::New);
assert_eq!(client.state, SocketState::New);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(client.connected_to, server_addr);
iotry!(client.send_to(&d[..]));
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
server.recv(&mut buf).await.unwrap();
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Connected);
iotry!(server.socket.recv_from(&mut buf));
server.congestion_timeout = 50;
loop {
let response = server.recv_from(&mut buf).await;
match response {
Ok((0, _)) => continue,
Ok(_) => break,
Err(e) => panic!("{}", e),
}
}
drop(server);
child.await;
}
#[async_std::test]
async fn test_sorted_buffer_insertion() {
let server_addr = next_test_ip4();
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_seq_nr(1);
assert!(socket.incoming_buffer.is_empty());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 1);
packet.set_seq_nr(2);
packet.set_timestamp(128.into());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 2);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
packet.set_seq_nr(3);
packet.set_timestamp(256.into());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
packet.set_seq_nr(2);
packet.set_timestamp(456.into());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
}
#[async_std::test]
async fn test_duplicate_packet_handling() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
assert_eq!(client.state, SocketState::New);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
let mut packet = Packet::with_payload(&[1, 2, 3]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
for _ in 0..2 {
packet.set_timestamp(now_microseconds());
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
}
client.seq_nr += 1;
for _ in 0..1 {
let mut buf = [0u8; BUF_SIZE];
iotry!(client.socket.recv_from(&mut buf));
}
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv(&mut buf));
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Connected);
let expected: Vec<u8> = vec![1, 2, 3];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{:?}", e),
}
}
assert_eq!(received.len(), expected.len());
assert_eq!(received, expected);
child.await;
}
#[async_std::test]
async fn test_correct_packet_loss() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024 * 10;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
let chunks = to_send[..].chunks(BUF_SIZE);
let dst = client.connected_to;
for (index, chunk) in chunks.enumerate() {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
packet.set_connection_id(client.sender_connection_id);
packet.set_timestamp(now_microseconds());
if index % 2 == 0 {
iotry!(client.socket.send_to(packet.as_ref(), dst));
}
client.curr_window += packet.len() as u32;
client.send_window.push(packet);
client.seq_nr += 1;
}
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
child.await;
}
#[async_std::test]
async fn test_tolerance_to_small_buffers() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut read = Vec::new();
while server.state != SocketState::Closed {
let mut small_buffer = [0; 512];
match server.recv_from(&mut small_buffer).await {
Ok((0, _src)) => break,
Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(read.len(), data.len());
assert_eq!(read, data);
child.await;
}
#[async_std::test]
async fn test_sequence_number_rollover() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::bind(client_addr));
client.seq_nr = ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
assert!(client.seq_nr < 50);
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
child.await;
}
#[async_std::test]
async fn test_drop_unused_socket() {
let server_addr = next_test_ip4();
let server = iotry!(UtpSocket::bind(server_addr));
drop(server);
}
#[async_std::test]
async fn test_invalid_packet_on_connect() {
use async_std::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let child = task::spawn(async move {
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf).await {
Ok((_len, client_addr)) => {
iotry!(server.send_to(&[], client_addr));
}
_ => panic!(),
}
});
match UtpSocket::connect(server_addr).await {
Err(ref e) if e.kind() == ErrorKind::Other => (), Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
Ok(_) => panic!("Expected Err, got Ok"),
}
child.await;
}
#[async_std::test]
async fn test_receive_unexpected_reply_type_on_connect() {
use async_std::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let child = task::spawn(async move {
let mut buf = [0u8; BUF_SIZE];
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
match server.recv_from(&mut buf).await {
Ok((_len, client_addr)) => {
iotry!(server.send_to(packet.as_ref(), client_addr));
}
_ => panic!(),
}
});
match UtpSocket::connect(server_addr).await {
Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), Err(e) => panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e),
Ok(_) => panic!("Expected Err, got Ok"),
}
child.await;
}
#[async_std::test]
async fn test_receiving_syn_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = task::spawn(async move {
let mut buf = [0; BUF_SIZE];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok(_) => (),
Err(e) => panic!("{:?}", e),
}
}
});
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
let mut buf = [0u8; BUF_SIZE];
match client.socket.recv_from(&mut buf).await {
Ok((len, _src)) => {
let reply = Packet::try_from(&buf[..len]).ok().unwrap();
assert_eq!(reply.get_type(), PacketType::Reset);
}
Err(e) => panic!("{:?}", e),
}
iotry!(client.close());
child.await;
}
#[async_std::test]
async fn test_receiving_reset_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = task::spawn(async move {
let client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Reset);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
let mut buf = [0u8; BUF_SIZE];
match client.socket.recv_from(&mut buf).await {
Ok((_len, _src)) => (),
Err(e) => panic!("{:?}", e),
}
});
let mut buf = [0u8; BUF_SIZE];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok(_) => (),
Err(ref e) if e.kind() == ErrorKind::ConnectionReset => return,
Err(e) => panic!("{:?}", e),
}
}
child.await;
panic!("Should have received Reset");
}
#[cfg(not(windows))]
#[async_std::test]
async fn test_premature_fin() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv(&mut buf));
let mut packet = Packet::new();
packet.set_connection_id(server.sender_connection_id);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(server.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
iotry!(server.socket.send_to(packet.as_ref(), client_addr));
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
child.await;
}
#[async_std::test]
async fn test_base_delay_calculation() {
let minute_in_microseconds = 60 * 10i64.pow(6);
let samples = vec![
(0, 10),
(1, 8),
(2, 12),
(3, 7),
(minute_in_microseconds + 1, 11),
(minute_in_microseconds + 2, 19),
(minute_in_microseconds + 3, 9),
];
let addr = next_test_ip4();
let mut socket = UtpSocket::bind(addr).await.unwrap();
for (timestamp, delay) in samples {
socket.update_base_delay(delay.into(), ((timestamp + delay) as u32).into());
}
let expected = vec![7i64, 9i64]
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
assert_eq!(expected, actual);
assert_eq!(
socket.min_base_delay(),
expected.iter().min().cloned().unwrap_or_default()
);
}
#[async_std::test]
async fn test_local_addr() {
let addr = next_test_ip4();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let socket = UtpSocket::bind(addr).await.unwrap();
assert!(socket.local_addr().is_ok());
assert_eq!(socket.local_addr().unwrap(), addr);
}
#[async_std::test]
async fn test_listener_local_addr() {
let addr = next_test_ip4();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let listener = UtpListener::bind(addr).await.unwrap();
assert!(listener.local_addr().is_ok());
assert_eq!(listener.local_addr().unwrap(), addr);
}
#[async_std::test]
async fn test_peer_addr() {
use std::sync::mpsc::channel;
let addr = next_test_ip4();
let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
let mut server = UtpSocket::bind(server_addr).await.unwrap();
let (tx, rx) = channel();
assert!(server.peer_addr().is_err());
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut buf = [0; 1024];
tx.send(client.local_addr()).unwrap();
iotry!(client.recv_from(&mut buf));
});
let mut buf = [0; 1024];
iotry!(server.recv(&mut buf));
assert!(server.peer_addr().is_ok());
let client_addr = rx.recv().unwrap().unwrap();
assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
iotry!(server.close());
assert!(server.peer_addr().is_err());
child.await;
}
#[async_std::test]
async fn test_take_address() {
assert!(take_address("0.0.0.0:0").await.is_ok());
assert!(take_address("[::]:0").await.is_ok());
assert!(take_address(("0.0.0.0", 0)).await.is_ok());
assert!(take_address(("::", 0)).await.is_ok());
assert!(take_address(("1.2.3.4", 5)).await.is_ok());
assert!(take_address("999.0.0.0:0").await.is_err());
assert!(take_address("1.2.3.4:70000").await.is_err());
assert!(take_address("").await.is_err());
assert!(take_address("this is not an address").await.is_err());
assert!(take_address("no.dns.resolution.com").await.is_err());
}
#[async_std::test]
async fn test_connection_loss_data() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let ref socket = client.socket;
let mut buf = [0u8; BUF_SIZE];
iotry!(socket.recv_from(&mut buf));
for _ in 0..attempts {
match socket.recv_from(&mut buf).await {
Ok((len, _src)) => assert_eq!(
Packet::try_from(&buf[..len]).unwrap().get_type(),
PacketType::Data
),
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
iotry!(server.send_to(&[0]));
let mut buf = [0u8; BUF_SIZE];
match server.recv(&mut buf).await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
child.await;
}
#[async_std::test]
async fn test_connection_loss_fin() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let child = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let ref socket = client.socket;
let mut buf = [0u8; BUF_SIZE];
iotry!(socket.recv_from(&mut buf));
for _ in 0..attempts {
match socket.recv_from(&mut buf).await {
Ok((len, _src)) => assert_eq!(
Packet::try_from(&buf[..len]).unwrap().get_type(),
PacketType::Fin
),
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
match server.close().await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
child.await;
}
#[async_std::test]
async fn test_connection_loss_waiting() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let ref socket = client.socket;
let seq_nr = client.seq_nr;
let mut buf = [0u8; BUF_SIZE];
for _ in 0..(3 * attempts) {
match socket.recv_from(&mut buf).await {
Ok((len, _src)) => {
let packet = Packet::try_from(&buf[..len]).unwrap();
assert_eq!(packet.get_type(), PacketType::State);
assert_eq!(packet.ack_nr(), seq_nr - 1);
}
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf).await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
}
}