use std::cmp::{max, min};
use std::collections::VecDeque;
use std::future::Future;
use std::io::{ErrorKind, Result};
use std::iter::Iterator;
use std::mem;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use crate::error::SocketError;
use crate::packet::*;
use crate::time::*;
use crate::util::*;
use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket};
use tokio::sync::mpsc::{
unbounded_channel, UnboundedReceiver, UnboundedSender,
};
use tokio::sync::Mutex;
use tokio::time::{sleep, timeout, Instant as TokioInstant, Sleep};
use tracing::debug;
const BUF_SIZE: usize = 1500;
const GAIN: f64 = 1.0;
const ALLOWED_INCREASE: u32 = 1;
const TARGET: f64 = 100_000.0; const MSS: u32 = 1400;
const MIN_CWND: u32 = 2;
const INIT_CWND: u32 = 2;
const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; const MIN_CONGESTION_TIMEOUT: u64 = 500; const MAX_CONGESTION_TIMEOUT: u64 = 60_000; const BASE_HISTORY: usize = 10; const MAX_SYN_RETRIES: u32 = 5; const MAX_RETRANSMISSION_RETRIES: u32 = 5; const WINDOW_SIZE: u32 = 1024 * 1024;
const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
enum SocketState {
New,
Connected,
SynSent,
FinSent,
ResetReceived,
Closed,
}
struct DelayDifferenceSample {
received_at: Timestamp,
difference: Delay,
}
struct UtpSocket {
socket: UdpSocket,
connected_to: SocketAddr,
sender_connection_id: u16,
receiver_connection_id: u16,
seq_nr: u16,
ack_nr: u16,
state: SocketState,
incoming_buffer: Vec<Packet>,
send_window: Vec<Packet>,
unsent_queue: VecDeque<Packet>,
duplicate_ack_count: u32,
last_acked: u16,
last_acked_timestamp: Timestamp,
last_dropped: u16,
rtt: i32,
rtt_variance: i32,
pending_data: Vec<u8>,
curr_window: u32,
remote_wnd_size: u32,
base_delays: VecDeque<Delay>,
current_delays: Vec<DelayDifferenceSample>,
their_delay: Delay,
last_rollover: Timestamp,
congestion_timeout: u64,
cwnd: u32,
pub max_retransmission_retries: u32,
}
impl UtpSocket {
fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
let (receiver_id, sender_id) = generate_sequential_identifiers();
UtpSocket {
socket: s,
connected_to: src,
receiver_connection_id: receiver_id,
sender_connection_id: sender_id,
seq_nr: 1,
ack_nr: 0,
state: SocketState::New,
incoming_buffer: Vec::new(),
send_window: Vec::new(),
unsent_queue: VecDeque::new(),
duplicate_ack_count: 0,
last_acked: 0,
last_acked_timestamp: Timestamp::default(),
last_dropped: 0,
rtt: 0,
rtt_variance: 0,
pending_data: Vec::new(),
curr_window: 0,
remote_wnd_size: 0,
current_delays: Vec::new(),
base_delays: VecDeque::with_capacity(BASE_HISTORY),
their_delay: Delay::default(),
last_rollover: Timestamp::default(),
congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
cwnd: INIT_CWND * MSS,
max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
}
}
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
let src = lookup_host(&addr)
.await?
.last()
.ok_or(ErrorKind::AddrNotAvailable)?;
let socket = UdpSocket::bind(addr).await?;
Ok(UtpSocket::from_raw_parts(socket, src))
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.local_addr()
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
if self.state == SocketState::Connected
|| self.state == SocketState::FinSent
{
Ok(self.connected_to)
} else {
Err(SocketError::NotConnected.into())
}
}
pub async fn connect(addr: SocketAddr) -> Result<UtpSocket> {
let mut socket = UtpSocket::bind(addr).await?;
socket.connected_to = addr;
let mut packet = Packet::new();
packet.set_type(PacketType::Syn);
packet.set_connection_id(socket.receiver_connection_id);
packet.set_seq_nr(socket.seq_nr);
let mut len = 0;
let mut buf = [0; BUF_SIZE];
let mut syn_timeout = socket.congestion_timeout;
for _ in 0..MAX_SYN_RETRIES {
packet.set_timestamp(now_microseconds());
debug!("connecting to {}", socket.connected_to);
socket
.socket
.send_to(packet.as_ref(), socket.connected_to)
.await?;
socket.state = SocketState::SynSent;
debug!("sent {:?}", packet);
let to = Duration::from_millis(syn_timeout);
match timeout(to, socket.socket.recv_from(&mut buf)).await {
Ok(Ok((read, src))) => {
socket.connected_to = src;
len = read;
break;
}
Ok(Err(e)) => return Err(e),
Err(_) => {
debug!("timed out, retrying");
syn_timeout *= 2;
continue;
}
};
}
let addr = socket.connected_to;
let packet = Packet::try_from(&buf[..len])?;
debug!("received {:?}", packet);
socket.handle_packet(&packet, addr)?;
debug!("connected to: {}", socket.connected_to);
Ok(socket)
}
pub async fn close(&mut self) -> Result<()> {
if self.state == SocketState::Closed
|| self.state == SocketState::New
|| self.state == SocketState::SynSent
{
return Ok(());
}
let local = self.socket.local_addr()?;
debug!("closing {} -> {}", local, self.connected_to);
self.flush().await?;
debug!("close flush completed");
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("sent {:?}", packet);
self.state = SocketState::FinSent;
let mut jbuf = [0; BUF_SIZE];
let mut buf: ReadBuf<'_> = ReadBuf::new(&mut jbuf);
while self.state != SocketState::Closed {
self.recv(&mut buf).await?;
}
debug!("closed {} -> {}", local, self.connected_to);
Ok(())
}
pub async fn recv_from(
&mut self,
buf: &mut [u8],
) -> Result<(usize, SocketAddr)> {
let mut buf = ReadBuf::new(buf);
let read = self.flush_incoming_buffer(&mut buf);
if read > 0 {
Ok((read, self.connected_to))
} else {
if self.state == SocketState::ResetReceived {
return Err(SocketError::ConnectionReset.into());
}
loop {
if self.state == SocketState::Closed {
return Ok((0, self.connected_to));
}
match self.recv(&mut buf).await {
Ok((0, _src)) => continue,
Ok(x) => return Ok(x),
Err(e) => return Err(e),
}
}
}
}
async fn recv(
&mut self,
buf: &mut ReadBuf<'_>,
) -> Result<(usize, SocketAddr)> {
let mut b = [0; BUF_SIZE + HEADER_SIZE];
let start = Instant::now();
let read;
let src;
let mut retries = 0;
loop {
if retries >= self.max_retransmission_retries {
self.state = SocketState::Closed;
return Err(SocketError::ConnectionTimedOut.into());
}
if self.state != SocketState::New {
let to = Duration::from_millis(self.congestion_timeout);
debug!(
"setting read timeout of {} ms",
self.congestion_timeout
);
match timeout(to, self.socket.recv_from(&mut b)).await {
Ok(Ok((r, s))) => {
read = r;
src = s;
break;
}
Ok(Err(e)) => return Err(e),
Err(_) => {
debug!("recv_from timed out");
self.handle_receive_timeout().await?;
}
};
} else {
match self.socket.recv_from(&mut b).await {
Ok((r, s)) => {
read = r;
src = s;
break;
}
Err(e) => return Err(e),
}
};
let elapsed = start.elapsed();
let elapsed_ms = elapsed.as_secs() * 1000
+ (elapsed.subsec_millis() / 1_000_000) as u64;
debug!("{} ms elapsed", elapsed_ms);
retries += 1;
}
let packet = match Packet::try_from(&b[..read]) {
Ok(packet) => packet,
Err(e) => {
debug!("{}", e);
debug!("Ignoring invalid packet");
return Ok((0, self.connected_to));
}
};
debug!("received {:?}", packet);
if let Some(mut pkt) = self.handle_packet(&packet, src)? {
pkt.set_wnd_size(WINDOW_SIZE);
self.socket.send_to(pkt.as_ref(), src).await?;
debug!("sent {:?}", pkt);
}
if packet.get_type() == PacketType::Data
&& packet.seq_nr().wrapping_sub(self.last_dropped) > 0
{
self.insert_into_buffer(packet);
}
let read = self.flush_incoming_buffer(buf);
Ok((read, src))
}
async fn handle_receive_timeout(&mut self) -> Result<()> {
self.congestion_timeout *= 2;
self.cwnd = MSS;
debug!(
"self.send_window: {:?}",
self.send_window
.iter()
.map(Packet::seq_nr)
.collect::<Vec<u16>>()
);
if self.send_window.is_empty() {
if self.state == SocketState::FinSent {
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("resent FIN: {:?}", packet);
} else if self.state != SocketState::New {
debug!("sending fast resend request");
self.send_fast_resend_request();
}
} else {
let packet = &mut self.send_window[0];
packet.set_timestamp(now_microseconds());
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("resent {:?}", packet);
}
Ok(())
}
fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
let mut resp = Packet::new();
resp.set_type(t);
let self_t_micro = now_microseconds();
let other_t_micro = original.timestamp();
let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
resp.set_timestamp(self_t_micro);
resp.set_timestamp_difference(time_difference);
resp.set_connection_id(self.sender_connection_id);
resp.set_seq_nr(self.seq_nr);
resp.set_ack_nr(self.ack_nr);
resp
}
fn advance_incoming_buffer(&mut self) -> Option<Packet> {
if !self.incoming_buffer.is_empty() {
let packet = self.incoming_buffer.remove(0);
debug!("Removed packet from incoming buffer: {:?}", packet);
self.ack_nr = packet.seq_nr();
self.last_dropped = self.ack_nr;
Some(packet)
} else {
None
}
}
fn flush_incoming_buffer(&mut self, buf: &mut ReadBuf) -> usize {
fn copy(src: &[u8], dst: &mut ReadBuf) -> usize {
let to_copy = min(src.len(), dst.capacity());
dst.put_slice(&src[..to_copy]);
to_copy
}
if !self.pending_data.is_empty() {
let flushed = copy(&self.pending_data[..], buf);
if flushed == self.pending_data.len() {
self.pending_data.clear();
self.advance_incoming_buffer();
} else {
self.pending_data = self.pending_data[flushed..].to_vec();
}
return flushed;
}
if !self.incoming_buffer.is_empty()
&& (self.ack_nr == self.incoming_buffer[0].seq_nr()
|| self.ack_nr.wrapping_sub(self.incoming_buffer[0].seq_nr())
>= 1)
{
let flushed = copy(&self.incoming_buffer[0].payload(), buf);
if flushed == self.incoming_buffer[0].payload().len() {
self.advance_incoming_buffer();
} else {
self.pending_data =
self.incoming_buffer[0].payload()[flushed..].to_vec();
}
return flushed;
} else if !self.incoming_buffer.is_empty() {
debug!(
"not flushing out of order data, acked={} != cached={}",
self.ack_nr,
self.incoming_buffer[0].seq_nr()
);
}
0
}
pub(crate) fn should_read(&self) -> bool {
self.incoming_buffer.is_empty() && self.pending_data.is_empty()
}
pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
if self.state == SocketState::Closed {
return Err(SocketError::ConnectionClosed.into());
}
let total_length = buf.len();
for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_connection_id(self.sender_connection_id);
self.unsent_queue.push_back(packet);
self.seq_nr = self.seq_nr.wrapping_add(1);
}
self.send().await?;
Ok(total_length)
}
pub async fn flush(&mut self) -> Result<()> {
let mut buf = [0u8; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
while !self.send_window.is_empty() {
debug!("packets in send window: {}", self.send_window.len());
self.recv(&mut buf).await?;
}
Ok(())
}
async fn send(&mut self) -> Result<()> {
while let Some(mut packet) = self.unsent_queue.pop_front() {
self.send_packet(&mut packet).await?;
self.curr_window += packet.len() as u32;
self.send_window.push(packet);
}
Ok(())
}
fn max_inflight(&self) -> u32 {
let max_inflight = min(self.cwnd, self.remote_wnd_size);
max(MIN_CWND * MSS, max_inflight)
}
#[inline]
async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
debug!("current window: {}", self.send_window.len());
packet.set_timestamp(now_microseconds());
packet.set_timestamp_difference(self.their_delay);
self.socket
.send_to(packet.as_ref(), self.connected_to)
.await?;
debug!("sent {:?}", packet);
Ok(())
}
fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
if self.base_delays.is_empty()
|| now - self.last_rollover > MAX_BASE_DELAY_AGE
{
self.last_rollover = now;
if self.base_delays.len() == BASE_HISTORY {
self.base_delays.pop_front();
}
self.base_delays.push_back(base_delay);
} else {
let last_idx = self.base_delays.len() - 1;
if base_delay < self.base_delays[last_idx] {
self.base_delays[last_idx] = base_delay;
}
}
}
fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
let rtt = (self.rtt as i64 * 100).into();
while !self.current_delays.is_empty()
&& now - self.current_delays[0].received_at > rtt
{
self.current_delays.remove(0);
}
self.current_delays.push(DelayDifferenceSample {
received_at: now,
difference: v,
});
}
fn update_congestion_timeout(&mut self, current_delay: i32) {
let delta = self.rtt - current_delay;
self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
self.rtt += (current_delay - self.rtt) / 8;
self.congestion_timeout = max(
(self.rtt + self.rtt_variance * 4) as u64,
MIN_CONGESTION_TIMEOUT,
);
self.congestion_timeout =
min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
debug!("current_delay: {}", current_delay);
debug!("delta: {}", delta);
debug!("self.rtt_variance: {}", self.rtt_variance);
debug!("self.rtt: {}", self.rtt);
debug!("self.congestion_timeout: {}", self.congestion_timeout);
}
fn filtered_current_delay(&self) -> Delay {
let input = self.current_delays.iter().map(|delay| &delay.difference);
(ewma(input, 0.333) as i64).into()
}
fn min_base_delay(&self) -> Delay {
self.base_delays.iter().min().cloned().unwrap_or_default()
}
fn build_selective_ack(&self) -> Vec<u8> {
let stashed = self
.incoming_buffer
.iter()
.filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
.map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
.map(|diff| (diff / 8, diff % 8));
let mut sack = Vec::new();
for (byte, bit) in stashed {
while byte >= sack.len() || sack.len() % 4 != 0 {
sack.push(0u8);
}
sack[byte] |= 1 << bit;
}
sack
}
fn send_fast_resend_request(&mut self) {
for _ in 0..3usize {
let mut packet = Packet::new();
packet.set_type(PacketType::State);
let self_t_micro = now_microseconds();
packet.set_timestamp(self_t_micro);
packet.set_timestamp_difference(self.their_delay);
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
self.unsent_queue.push_back(packet);
}
}
fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
match self
.send_window
.iter()
.position(|pkt| pkt.seq_nr() == lost_packet_nr)
{
None => debug!("Packet {} not found", lost_packet_nr),
Some(position) => {
debug!("self.send_window.len(): {}", self.send_window.len());
debug!("position: {}", position);
let packet = self.send_window[position].clone();
self.unsent_queue.push_back(packet);
}
}
debug!("---> END resend_lost_packet <---");
}
fn advance_send_window(&mut self) {
if let Some(position) = self
.send_window
.iter()
.position(|packet| packet.seq_nr() == self.last_acked)
{
for _ in 0..=position {
let packet = self.send_window.remove(0);
debug!("removing {} bytes from send window", packet.len());
debug!(
"{} packets left in send window",
self.send_window.len()
);
self.curr_window -= packet.len() as u32;
}
}
debug!("self.curr_window: {}", self.curr_window);
}
fn handle_fin_packet(
&mut self,
packet: &Packet,
src: SocketAddr,
) -> Packet {
if packet.ack_nr() < self.seq_nr {
debug!("FIN received but there are missing acknowledgements for sent packets");
}
let mut reply = self.prepare_reply(packet, PacketType::State);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!(
"current ack_nr({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr()
);
let sack = self.build_selective_ack();
if !sack.is_empty() {
debug!("sending SACK to peer");
reply.set_sack(sack);
}
}
debug!("received FIN from {}, connection is closed", src);
self.state = SocketState::Closed;
reply
}
fn handle_packet(
&mut self,
packet: &Packet,
src: SocketAddr,
) -> Result<Option<Packet>> {
debug!("({:?}, {:?})", self.state, packet.get_type());
if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
self.ack_nr = packet.seq_nr();
}
if packet.get_type() != PacketType::Syn
&& self.state != SocketState::SynSent
&& !(packet.connection_id() == self.sender_connection_id
|| packet.connection_id() == self.receiver_connection_id)
{
return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
}
self.remote_wnd_size = packet.wnd_size();
debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
let now = now_microseconds();
self.their_delay = abs_diff(now, packet.timestamp());
debug!("self.their_delay: {}", self.their_delay);
match (self.state, packet.get_type()) {
(SocketState::New, PacketType::Syn) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr = rand::random();
self.receiver_connection_id = packet.connection_id() + 1;
self.sender_connection_id = packet.connection_id();
self.state = SocketState::Connected;
self.last_dropped = self.ack_nr;
Ok(Some(self.prepare_reply(packet, PacketType::State)))
}
(_, PacketType::Syn) => {
Ok(Some(self.prepare_reply(packet, PacketType::Reset)))
}
(SocketState::SynSent, PacketType::State) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr += 1;
self.state = SocketState::Connected;
self.last_acked = packet.ack_nr();
self.last_acked_timestamp = now_microseconds();
Ok(None)
}
(SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
(SocketState::Connected, PacketType::Data)
| (SocketState::FinSent, PacketType::Data) => {
Ok(self.handle_data_packet(packet))
}
(SocketState::Connected, PacketType::State) => {
self.handle_state_packet(packet);
Ok(None)
}
(SocketState::Connected, PacketType::Fin)
| (SocketState::FinSent, PacketType::Fin) => {
Ok(Some(self.handle_fin_packet(packet, src)))
}
(SocketState::Closed, PacketType::Fin) => {
Ok(Some(self.prepare_reply(packet, PacketType::State)))
}
(SocketState::FinSent, PacketType::State) => {
if packet.ack_nr() == self.seq_nr {
debug!("connection closed succesfully");
self.state = SocketState::Closed;
} else {
self.handle_state_packet(packet);
}
Ok(None)
}
(_, PacketType::Reset) => {
self.state = SocketState::ResetReceived;
Err(SocketError::ConnectionReset.into())
}
(state, ty) => {
let message = format!(
"Unimplemented handling for ({:?},{:?})",
state, ty
);
debug!("{}", message);
Err(SocketError::Other(message).into())
}
}
}
fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
let packet_type = if self.state == SocketState::FinSent {
PacketType::Fin
} else {
PacketType::State
};
let mut reply = self.prepare_reply(packet, packet_type);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!(
"current ack_nr ({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr()
);
let sack = self.build_selective_ack();
if !sack.is_empty() {
debug!("sending SACK packet");
reply.set_sack(sack);
}
}
Some(reply)
}
fn queuing_delay(&self) -> Delay {
let filtered_current_delay = self.filtered_current_delay();
let min_base_delay = self.min_base_delay();
let queuing_delay = filtered_current_delay - min_base_delay;
debug!("filtered_current_delay: {}", filtered_current_delay);
debug!("min_base_delay: {}", min_base_delay);
debug!("queuing_delay: {}", queuing_delay);
queuing_delay
}
fn update_congestion_window(
&mut self,
off_target: f64,
bytes_newly_acked: u32,
) {
let flightsize = self.curr_window;
let cwnd_increase =
GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
let cwnd_increase = cwnd_increase / self.cwnd as f64;
debug!("cwnd_increase: {}", cwnd_increase);
self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
self.cwnd = min(self.cwnd, max_allowed_cwnd);
self.cwnd = max(self.cwnd, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
}
fn handle_packet_extension(
&mut self,
packet: &Packet,
packet_loss_detected: &mut bool,
) {
for extension in packet.extensions() {
if extension.get_type() == ExtensionType::SelectiveAck {
if extension.iter().count_ones() >= 3 {
self.resend_lost_packet(packet.ack_nr() + 1);
*packet_loss_detected = true;
}
if let Some(last_seq_nr) =
self.send_window.last().map(Packet::seq_nr)
{
let lost_packets = extension
.iter()
.enumerate()
.filter(|&(_, received)| !received)
.map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
.take_while(|&seq_nr| seq_nr < last_seq_nr);
for seq_nr in lost_packets {
debug!("SACK: packet {} lost", seq_nr);
self.resend_lost_packet(seq_nr);
*packet_loss_detected = true;
}
}
} else {
debug!(
"Unknown extension {:?}, ignoring",
extension.get_type()
);
}
}
}
fn handle_state_packet(&mut self, packet: &Packet) {
if packet.ack_nr() == self.last_acked {
self.duplicate_ack_count += 1;
} else {
self.last_acked = packet.ack_nr();
self.last_acked_timestamp = now_microseconds();
self.duplicate_ack_count = 1;
}
if let Some(index) = self
.send_window
.iter()
.position(|p| packet.ack_nr() == p.seq_nr())
{
let bytes_newly_acked = self
.send_window
.iter()
.take(index + 1)
.fold(0, |acc, p| acc + p.len());
let now = now_microseconds();
let our_delay = now - self.send_window[index].timestamp();
debug!("our_delay: {}", our_delay);
self.update_base_delay(our_delay, now);
self.update_current_delay(our_delay, now);
let off_target: f64 =
(TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
debug!("off_target: {}", off_target);
self.update_congestion_window(off_target, bytes_newly_acked as u32);
let rtt = u32::from(our_delay - self.queuing_delay()) / 1000;
self.update_congestion_timeout(rtt as i32);
}
let mut packet_loss_detected: bool =
!self.send_window.is_empty() && self.duplicate_ack_count == 3;
self.handle_packet_extension(packet, &mut packet_loss_detected);
if !self.send_window.is_empty()
&& self.duplicate_ack_count == 3
&& !packet
.extensions()
.any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
{
self.resend_lost_packet(packet.ack_nr() + 1);
}
if packet_loss_detected {
debug!("packet loss detected, halving congestion window");
self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
}
self.advance_send_window();
}
fn insert_into_buffer(&mut self, packet: Packet) {
if self
.incoming_buffer
.last()
.map_or(false, |p| packet.seq_nr() > p.seq_nr())
{
self.incoming_buffer.push(packet);
} else {
let i = self
.incoming_buffer
.iter()
.filter(|p| p.seq_nr() < packet.seq_nr())
.count();
if self
.incoming_buffer
.get(i)
.map_or(true, |p| p.seq_nr() != packet.seq_nr())
{
self.incoming_buffer.insert(i, packet);
}
}
}
}
macro_rules! ready_unpin {
($data:expr, $cx:expr) => {
match unsafe { Pin::new_unchecked(&mut $data) }.poll($cx) {
Poll::Ready(v) => v,
Poll::Pending => return Poll::Pending,
}
};
}
macro_rules! ready_try_unpin {
($data:expr, $cx:expr) => {
match ready_unpin!($data, $cx) {
Ok(v) => v,
Err(e) => return Poll::Ready(Err(e)),
}
};
}
macro_rules! poll_unpin {
($data:expr, $cx:expr) => {{
#[allow()]
let x = unsafe { Pin::new_unchecked(&mut $data) }.poll($cx);
x
}};
}
macro_rules! ready_try {
($data:expr) => {{
match ($data) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(v)) => v,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
}};
}
pub struct UtpSocketRef(Arc<Mutex<UtpSocket>>, SocketAddr);
impl UtpSocketRef {
fn new(socket: Arc<Mutex<UtpSocket>>, local: SocketAddr) -> Self {
Self(socket, local)
}
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
let udp = UdpSocket::bind(addr).await?;
let resolved = udp.local_addr()?;
let socket = UtpSocket::from_raw_parts(udp, resolved);
let lock = Arc::new(Mutex::new(socket));
debug!("bound utp socket on {}", resolved);
Ok(Self::new(lock, resolved))
}
pub async fn connect(
self,
dst: SocketAddr,
) -> Result<(UtpStream, UtpStreamDriver)> {
let mut socket = self.0.lock().await;
socket.connected_to = dst;
let mut packet = Packet::new();
packet.set_type(PacketType::Syn);
packet.set_connection_id(socket.receiver_connection_id);
packet.set_seq_nr(socket.seq_nr);
let mut len = 0;
let mut buf = [0; BUF_SIZE];
let mut syn_timeout = socket.congestion_timeout;
for _ in 0..MAX_SYN_RETRIES {
packet.set_timestamp(now_microseconds());
debug!("connecting to {}", socket.connected_to);
let dst = socket.connected_to;
socket.socket.send_to(packet.as_ref(), dst).await?;
socket.state = SocketState::SynSent;
debug!("sent {:?}", packet);
let to = Duration::from_millis(syn_timeout);
match timeout(to, socket.socket.recv_from(&mut buf)).await {
Ok(Ok((read, src))) => {
socket.connected_to = src;
len = read;
break;
}
Ok(Err(e)) => return Err(e),
Err(_) => {
debug!("timed out, retrying");
syn_timeout *= 2;
continue;
}
};
}
let remote = socket.connected_to;
let packet = Packet::try_from(&buf[..len])?;
debug!("received {:?}", packet);
socket.handle_packet(&packet, remote)?;
debug!("connected to: {}", socket.connected_to);
let (tx, rx) = unbounded_channel();
let local = socket.local_addr()?;
mem::drop(socket);
let driver = UtpStreamDriver::new(self.0.clone(), tx);
let stream = UtpStream::new(self.0, rx, local, remote);
Ok((stream, driver))
}
pub async fn accept(self) -> Result<(UtpStream, UtpStreamDriver)> {
let (src, dst);
loop {
let mut socket = self.0.lock().await;
let mut buf = [0u8; BUF_SIZE];
let (read, remote) = socket.socket.recv_from(&mut buf).await?;
let packet = Packet::try_from(&buf[..read])?;
debug!("accept receive {:?}", packet);
if let Ok(Some(reply)) = socket.handle_packet(&packet, remote) {
src = socket.socket.local_addr()?;
dst = socket.connected_to;
socket.socket.send_to(reply.as_ref(), dst).await?;
debug!("sent {:?} to {}", reply, dst);
debug!("accepted connection {} -> {}", dst, src);
break;
}
}
let (tx, rx) = unbounded_channel();
let socket = self.0;
let stream = UtpStream::new(socket.clone(), rx, src, dst);
let driver = UtpStreamDriver::new(socket, tx);
Ok((stream, driver))
}
pub fn local_addr(&self) -> SocketAddr {
self.1
}
}
pub struct UtpStream {
socket: Arc<Mutex<UtpSocket>>,
receiver: UnboundedReceiver<Result<()>>,
local: SocketAddr,
remote: SocketAddr,
}
impl UtpStream {
fn new(
socket: Arc<Mutex<UtpSocket>>,
receiver: UnboundedReceiver<Result<()>>,
local: SocketAddr,
remote: SocketAddr,
) -> Self {
Self {
socket,
receiver,
local,
remote,
}
}
pub fn local_addr(&self) -> SocketAddr {
self.local
}
pub fn peer_addr(&self) -> SocketAddr {
self.remote
}
fn handle_driver_notification(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<()>> {
match poll_unpin!(self.receiver.recv(), cx) {
Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
debug!("connection driver has died");
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Ok(()))) => {
debug!("notification from driver");
self.poll_read(cx, buf)
}
Poll::Pending => {
debug!("waiting for notification from driver");
Poll::Pending
}
}
}
fn prepare_packet(socket: &mut UtpSocket, chunk: &[u8]) -> Packet {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(socket.seq_nr);
packet.set_ack_nr(socket.ack_nr);
packet.set_connection_id(socket.sender_connection_id);
packet
}
fn handle_driver_message(
msg: Poll<Option<Result<()>>>,
) -> Poll<Result<()>> {
match msg {
Poll::Ready(None) => {
debug!("driver is dead, closing success");
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Ok(()))) => {
debug!("driver sent closing notice");
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Err(e)))
if e.kind() == ErrorKind::NotConnected =>
{
debug!("connection closed by err");
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Err(e))) => {
debug!("failed to close correctly");
Poll::Ready(Err(e))
}
Poll::Pending => {
debug!("waiting for driver to complete closing");
Poll::Pending
}
}
}
fn wait_acks(
socket: &mut UtpSocket,
cx: &mut Context<'_>,
) -> Poll<Result<()>> {
let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
debug!("waiting for ACKs for {} packets", socket.send_window.len());
while !socket.send_window.is_empty()
&& socket.state != SocketState::Closed
{
let (read, src) = {
match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
Poll::Ready(Ok((read, src))) => (read, src),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
};
let packet = Packet::try_from(&buf[..read])?;
if let Some(reply) = socket.handle_packet(&packet, src)? {
if poll_unpin!(socket.socket.send_to(reply.as_ref(), src), cx)
.is_pending()
{
socket.unsent_queue.push_back(reply);
return Poll::Pending;
}
}
}
Poll::Ready(Ok(()))
}
fn flush_unsent(
socket: &mut UtpSocket,
cx: &mut Context<'_>,
) -> Poll<Result<()>> {
while let Some(mut packet) = socket.unsent_queue.pop_front() {
if poll_unpin!(socket.send_packet(&mut packet), cx).is_pending() {
debug!("too many in flight packets, waiting for ack");
return Poll::Pending;
}
let result = {
let dst = socket.connected_to;
poll_unpin!(socket.socket.send_to(packet.as_ref(), dst), cx)
};
match result {
Poll::Pending => {
socket.unsent_queue.push_front(packet);
return Poll::Pending;
}
Poll::Ready(Ok(_)) => socket.send_window.push(packet),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(()))
}
}
impl AsyncRead for UtpStream
where
Self: Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<()>> {
debug!("read poll for {} bytes", buf.capacity());
let (read, state) = {
let mut socket = ready_unpin!(self.socket.lock(), cx);
(socket.flush_incoming_buffer(buf), socket.state)
};
if read > 0 {
debug!("flushed {} bytes of received data", read);
Poll::Ready(Ok(()))
} else if state == SocketState::Closed {
debug!("read on closed connection");
Poll::Ready(Ok(()))
} else if state == SocketState::ResetReceived {
debug!("read on reset connection");
Poll::Ready(Err(SocketError::ConnectionReset.into()))
} else {
self.handle_driver_notification(cx, buf)
}
}
}
impl AsyncWrite for UtpStream
where
Self: Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize>> {
let mut socket = ready_unpin!(self.socket.lock(), cx);
if socket.state == SocketState::Closed {
debug!("tried to write on closed connection");
return Poll::Ready(Err(SocketError::ConnectionClosed.into()));
}
let mut sent: usize = 0;
debug!("trying to send {} bytes", buf.len());
for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
if socket.curr_window >= socket.max_inflight() {
debug!("send window is full, waiting for ACKs");
mem::drop(socket);
while poll_unpin!(self.receiver.recv(), cx).is_ready() {}
return Poll::Pending;
}
debug!("attempting to send chunk of {} byte", chunk.len());
let mut packet = Self::prepare_packet(&mut socket, chunk);
match poll_unpin!(socket.send_packet(&mut packet), cx) {
Poll::Pending if sent == 0 => {
debug!("socket send buffer is full, waiting..");
return Poll::Pending;
}
Poll::Ready(Err(e)) if sent == 0 => {
debug!("os error reading data: {}", e);
return Poll::Ready(Err(e));
}
Poll::Pending | Poll::Ready(Err(_)) => {
debug!("successfully sent {} bytes, sleeping...", sent);
return Poll::Ready(Ok(sent));
}
Poll::Ready(Ok(())) => {
let written = packet.len();
socket.curr_window += written as u32;
socket.send_window.push(packet);
sent += written;
socket.seq_nr = socket.seq_nr.wrapping_add(1);
debug!(
"poll_write sent seq {}, curr_window: {}",
socket.seq_nr - 1,
socket.curr_window
);
}
}
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<()>> {
debug!("attempting flush");
match poll_unpin!(self.receiver.recv(), cx) {
Poll::Ready(Some(Err(e))) => {
debug!("driver signaled error over channel");
return Poll::Ready(Err(e));
}
Poll::Ready(None) => {
debug!("connection driver disconnected");
return Poll::Ready(Ok(()));
}
_ => debug!("no message from driver"),
}
let mut socket = ready_unpin!(self.socket.lock(), cx);
if socket.state == SocketState::Closed {
return Poll::Ready(Err(SocketError::NotConnected.into()));
}
ready_try!(Self::flush_unsent(&mut socket, cx));
ready_try!(Self::wait_acks(&mut socket, cx));
debug!("sucessfully flushed");
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<()>> {
debug!("poll_shutdown connection...");
{
let socket = ready_unpin!(self.socket.lock(), cx);
if socket.state == SocketState::Closed {
debug!("socket closed by driver");
return Poll::Ready(Ok(()));
}
}
match self.as_mut().poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
{
let mut socket = ready_unpin!(self.socket.lock(), cx);
if socket.state != SocketState::FinSent {
if let Poll::Ready(Ok(())) =
poll_unpin!(socket.close(), cx)
{
return Poll::Ready(Ok(()));
} else {
mem::drop(socket);
ready_unpin!(self.receiver.recv(), cx);
}
}
}
let msg = poll_unpin!(self.receiver.recv(), cx);
Self::handle_driver_message(msg)
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
}
}
#[must_use = "stream drivers must be spawned for the stream to work"]
pub struct UtpStreamDriver {
socket: Arc<Mutex<UtpSocket>>,
sender: UnboundedSender<Result<()>>,
timer: Pin<Box<Sleep>>,
timeout_nr: u32,
}
impl UtpStreamDriver {
fn new(
socket: Arc<Mutex<UtpSocket>>,
sender: UnboundedSender<Result<()>>,
) -> Self {
Self {
socket,
sender,
timer: Box::pin(sleep(Duration::from_millis(
INITIAL_CONGESTION_TIMEOUT,
))),
timeout_nr: 0,
}
}
async fn handle_timeout(&mut self, next_timeout: u64) -> Result<()> {
self.timeout_nr += 1;
debug!(
"timed out {} times out of {} max, retrying in {} ms",
self.timeout_nr, MAX_RETRANSMISSION_RETRIES, next_timeout
);
if self.timeout_nr > MAX_RETRANSMISSION_RETRIES {
let mut socket = self.socket.lock().await;
socket.state = SocketState::Closed;
return Err(SocketError::ConnectionTimedOut.into());
}
let ret = {
let mut socket = self.socket.lock().await;
socket.handle_receive_timeout().await
};
self.reset_timer(Duration::from_millis(next_timeout));
ret
}
fn notify_close(&mut self) {
if self
.sender
.send(Err(SocketError::NotConnected.into()))
.is_err()
{
error!("failed to notify socket of termination");
} else {
debug!("notified socket of closing");
}
}
fn send_reply(
socket: &mut UtpSocket,
mut reply: Packet,
cx: &mut Context<'_>,
) -> Poll<Result<()>> {
match poll_unpin!(socket.send_packet(&mut reply), cx) {
Poll::Pending => {
socket.unsent_queue.push_back(reply);
Poll::Pending
}
Poll::Ready(Err(e)) => {
error!("driver failed to send packet: {}", e);
Poll::Ready(Err(e))
}
_ => Poll::Ready(Ok(())),
}
}
fn reset_timer(&mut self, next_timeout: Duration) {
let now = TokioInstant::from_std(Instant::now());
self.timer.as_mut().reset(now + next_timeout);
}
fn check_timeout(
&mut self,
cx: &mut Context<'_>,
next_timeout: u64,
) -> Poll<Result<()>> {
if self.timer.is_elapsed() {
debug!("receive timeout detected");
match poll_unpin!(self.handle_timeout(next_timeout), cx) {
Poll::Pending => todo!("socket buffer full"),
Poll::Ready(Ok(())) => {
self.reset_timer(Duration::from_millis(next_timeout));
ready_unpin!(self.timer, cx);
Poll::Pending
}
Poll::Ready(Err(e)) => {
debug!("remote peer timed out too many times");
self.sender
.send(Err(e.kind().into()))
.expect("failed to propagate");
Poll::Ready(Err(e))
}
}
} else {
ready_unpin!(self.timer, cx);
Poll::Pending
}
}
}
impl Future for UtpStreamDriver {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let sender = self.sender.clone();
let mut socket = ready_unpin!(self.socket.lock(), cx);
let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
loop {
debug!("stream driver poll attempt");
if socket.state == SocketState::Closed {
debug!("socket is closed when attempting poll, killing driver");
mem::drop(socket);
self.notify_close();
return Poll::Ready(Ok(()));
}
match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
Poll::Ready(Ok((read, src))) => {
if let Ok(packet) = Packet::try_from(&buf[..read]) {
debug!("received packet {:?}", packet);
match socket.handle_packet(&packet, src) {
Ok(Some(reply)) => {
if let PacketType::Data = packet.get_type() {
socket.insert_into_buffer(packet);
if sender.send(Ok(())).is_err() {
debug!(
"dropped socket, killing driver"
);
return Poll::Ready(Ok(()));
}
}
if Self::send_reply(&mut socket, reply, cx)
.is_pending()
{
return Poll::Pending;
}
}
Ok(None) => ready_try_unpin!(socket.send(), cx),
Err(e) => return Poll::Ready(Err(e)),
}
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
let next_timeout = socket.congestion_timeout * 2;
mem::drop(socket);
return self.check_timeout(cx, next_timeout);
}
}
}
}
}
impl Drop for UtpSocket {
fn drop(&mut self) {
let _ = self.close();
}
}
const MTU: usize = 1500;
pub struct BufferedUtpStream {
stream: BufReader<UtpStream>,
}
impl BufferedUtpStream {
pub fn new(stream: UtpStream) -> Self {
Self {
stream: BufReader::with_capacity(MTU, stream),
}
}
fn get_stream(self: Pin<&mut Self>) -> Pin<&mut BufReader<UtpStream>> {
unsafe { self.map_unchecked_mut(|s| &mut s.stream) }
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.stream.get_ref().local_addr())
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.stream.get_ref().peer_addr())
}
}
impl AsyncRead for BufferedUtpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<()>> {
self.get_stream().get_pin_mut().poll_read(cx, buf)
}
}
impl AsyncWrite for BufferedUtpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize>> {
self.get_stream().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
self.get_stream().poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<()>> {
self.get_stream().poll_shutdown(cx)
}
}
#[cfg(test)]
mod test {
use std::env;
use std::io::ErrorKind;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::sync::atomic::Ordering;
use super::*;
use crate::socket::{SocketState, UtpSocket, BUF_SIZE};
use crate::time::now_microseconds;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task;
use tokio::time::interval;
use tracing::debug_span;
use tracing_futures::Instrument;
use tracing_subscriber::FmtSubscriber;
macro_rules! iotry {
($e:expr) => {
match $e.await {
Ok(e) => e,
Err(e) => panic!("{:?}", e),
}
};
}
fn init_logger() {
if let Some(level) = env::var("RUST_LOG").ok().map(|x| x.parse().ok()) {
let subscriber =
FmtSubscriber::builder().with_max_level(level).finish();
let _ = tracing::subscriber::set_global_default(subscriber);
}
}
fn next_test_port() -> u16 {
use std::sync::atomic::AtomicUsize;
static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
const BASE_PORT: u16 = 9600;
BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
}
fn next_test_ip4() -> SocketAddr {
("127.0.0.1".parse::<Ipv4Addr>().unwrap(), next_test_port()).into()
}
fn next_test_ip6() -> SocketAddr {
("::1".parse::<Ipv6Addr>().unwrap(), next_test_port()).into()
}
async fn stream_accept(server_addr: SocketAddr) -> UtpStream {
let (stream, driver) = UtpSocketRef::bind(server_addr)
.await
.expect("failed to bind")
.accept()
.await
.expect("failed to accept");
task::spawn(driver.instrument(debug_span!("stream_driver")));
stream
}
async fn stream_connect(local: SocketAddr, peer: SocketAddr) -> UtpStream {
let socket = UtpSocketRef::bind(local).await.expect("failed to bind");
let (stream, driver) =
socket.connect(peer).await.expect("failed to connect");
task::spawn(driver.instrument(debug_span!("stream_driver")));
stream
}
#[tokio::test]
async fn stream_fast_resend_active() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
const DATA: u8 = 2;
const LEN: usize = 345;
let socket =
UtpSocketRef::bind(server_addr).await.expect("bind failed");
let handle = task::spawn(async {
let buf = [DATA; LEN];
let (mut stream, driver) =
socket.accept().await.expect("accept failed");
task::spawn(driver);
stream.write_all(&buf).await.expect("write failed");
stream.shutdown().await.expect("shutdown failed");
});
let (mut stream, driver) = UtpSocketRef::bind(client_addr)
.await
.expect("bind failed")
.connect(server_addr)
.await
.expect("connect failed");
{
let mut lock = stream.socket.lock().await;
let mut buf = [0u8; LEN];
lock.recv_from(&mut buf).await.expect("read failed");
}
task::spawn(driver);
stream.shutdown().await.expect("close failed");
handle.await.expect("task failure");
}
#[tokio::test]
async fn stream_connect_disconnect() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
let handle = task::spawn(async move {
let mut stream = stream_accept(server_addr).await;
stream.shutdown().await.expect("failed to close");
});
let mut stream = stream_connect(client_addr, server_addr).await;
stream.shutdown().await.expect("failed to close connection");
handle.await.expect("task failure");
}
#[tokio::test]
#[ignore]
async fn stream_packet_split() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
const LEN: usize = 2000;
const DATA: u8 = 1;
let handle = task::spawn(async move {
let mut stream = stream_accept(server_addr)
.instrument(debug_span!("server"))
.await;
let mut buf = [0u8; LEN];
stream
.read_exact(&mut buf)
.instrument(debug_span!("server_read_exact"))
.await
.expect("read failed");
for b in &buf[..] {
assert_eq!(*b, DATA, "data was altered");
}
stream
.shutdown()
.instrument(debug_span!("server_shutdown"))
.await
.expect("flush failed");
});
let mut stream = stream_connect(client_addr, server_addr)
.instrument(debug_span!("client"))
.await;
let buf = [DATA; LEN];
stream
.write_all(&buf)
.instrument(debug_span!("client_write_all"))
.await
.expect("write failed");
stream
.shutdown()
.instrument(debug_span!("client_shutdown"))
.await
.expect("close failed");
handle.await.expect("task failure")
}
#[tokio::test]
#[ignore]
async fn stream_closed_write() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
const LEN: usize = 1240;
const DATA: u8 = 12;
let handle = task::spawn(async move {
let mut stream = stream_accept(server_addr)
.instrument(debug_span!("server"))
.await;
let mut buf = [0u8; LEN];
stream
.read_exact(&mut buf)
.instrument(debug_span!("server_read_exact"))
.await
.expect("read failed");
stream
.shutdown()
.instrument(debug_span!("server_shutdown"))
.await
.expect("shutdown failed");
stream
.read_exact(&mut buf)
.instrument(debug_span!("server_closed_read"))
.await
.expect_err("read on closed stream");
});
let mut stream = stream_connect(client_addr, server_addr)
.instrument(debug_span!("client"))
.await;
let buf = [DATA; LEN];
stream
.write_all(&buf)
.instrument(debug_span!("client_write_all"))
.await
.expect("write failed");
stream
.shutdown()
.instrument(debug_span!("client_shutdown"))
.await
.expect("shutdown failed");
stream
.write_all(&buf)
.instrument(debug_span!("client_closed_write"))
.await
.expect_err("wrote on closed stream");
handle.await.expect("execution failure");
}
#[tokio::test]
async fn stream_fast_resend_idle() {
init_logger();
let server = next_test_ip4();
let client = next_test_ip4();
let handle = task::spawn(async move {
let mut stream = stream_accept(server).await;
let mut timer = interval(Duration::from_secs(3));
timer.tick().await;
stream.shutdown().await.expect("close failed");
});
let mut stream = stream_connect(client, server).await;
let mut timer = interval(Duration::from_secs(4));
timer.tick().await;
stream.shutdown().await.expect("close failed");
handle.await.expect("task failed");
}
#[tokio::test]
async fn stream_clean_close() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
const DATA: u8 = 1;
const LEN: usize = 1024;
let handle = task::spawn(async move {
let mut stream = stream_accept(server_addr)
.instrument(debug_span!("stream_accept"))
.await;
let buf = [DATA; LEN];
stream
.write_all(&buf)
.instrument(debug_span!("server_write"))
.await
.expect("write failed");
stream
.shutdown()
.instrument(debug_span!("server_shutdown"))
.await
.expect("shutdown failed");
});
let mut socket = stream_connect(client_addr, server_addr)
.instrument(debug_span!("stream_connect"))
.await;
let mut buf = [0u8; LEN];
socket
.read_exact(&mut buf)
.instrument(debug_span!("client_read"))
.await
.expect("read failed");
socket
.shutdown()
.instrument(debug_span!("client_shutdown"))
.await
.expect("shutdown failed");
handle.await.expect("task panic");
}
#[tokio::test]
async fn stream_connect_timeout() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
let socket =
UtpSocketRef::bind(client_addr).await.expect("bind failed");
socket.0.lock().await.congestion_timeout = 100;
assert!(
socket.connect(server_addr).await.is_err(),
"connected to void"
);
}
#[tokio::test]
async fn stream_read_timeout() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
let handle = task::spawn(async move {
let sock =
UtpSocketRef::bind(server_addr).await.expect("bind failed");
let _ = sock.accept().await.expect("accept failed");
});
let mut socket = stream_connect(client_addr, server_addr).await;
let mut buf = [0u8; 1024];
socket.socket.lock().await.congestion_timeout = 100;
socket
.read_exact(&mut buf)
.await
.expect_err("read from non responding peer");
handle.await.expect("task panic");
}
#[tokio::test]
async fn stream_write_timeout() {
init_logger();
let (server, client) = (next_test_ip4(), next_test_ip4());
const DATA: u8 = 45;
const LEN: usize = 123;
let handle = task::spawn(async move {
let mut stream = stream_accept(server).await;
let buf = [DATA; LEN];
stream.socket.lock().await.congestion_timeout = 100;
stream
.write_all(&buf)
.await
.expect("packets weren't buffered");
stream
.flush()
.await
.expect_err("flush succeeded without ack");
});
let sock = UtpSocketRef::bind(client).await.expect("bind failed");
let _ = sock.connect(server).await.expect("connect failed");
handle.await.expect("execution failure");
}
#[tokio::test]
#[ignore]
async fn stream_flush_then_send() {
init_logger();
let server_addr = next_test_ip4();
let client_addr = next_test_ip4();
const LEN: usize = 1240;
const DATA: u8 = 25;
let handle = task::spawn(async move {
let mut stream = stream_accept(server_addr).await;
let mut buf = [0u8; 2 * LEN];
stream.read_exact(&mut buf).await.expect("failed to read");
for b in buf.iter() {
assert_eq!(*b, DATA, "data corrupted");
}
stream.flush().await.expect("flush failed");
stream.shutdown().await.expect("shutdown failed");
});
let mut stream = stream_connect(client_addr, server_addr).await;
let buf = [DATA; LEN];
stream.write_all(&buf).await.expect("write failed");
stream.flush().await.expect("flush failed");
stream.write_all(&buf).await.expect("write failed");
stream.shutdown().await.expect("shutdown failed");
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_socket_ipv4() {
let server_addr = next_test_ip4();
let handle = task::spawn(async move {
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf).await {
e => println!("{:?}", e),
}
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Closed);
drop(server);
});
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
assert_eq!(
client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap()
);
iotry!(client.close());
handle.await.expect("task failure");
}
#[ignore]
#[tokio::test]
async fn test_socket_ipv6() {
let server_addr = next_test_ip6();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
assert_eq!(
client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap()
);
iotry!(client.close());
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf).await {
e => println!("{:?}", e),
}
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Closed);
drop(server);
}
#[tokio::test]
async fn test_recvfrom_on_closed_socket() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert!(client.close().await.is_ok());
});
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv_from(&mut buf).await;
assert_eq!(server.state, SocketState::Closed);
match server.recv_from(&mut buf).await {
Ok((0, _src)) => {}
e => panic!("Expected Ok(0), got {:?}", e),
}
assert_eq!(server.state, SocketState::Closed);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_sendto_on_closed_socket() {
init_logger();
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let (_read, _src) = iotry!(server.recv_from(&mut buf));
assert_eq!(server.state, SocketState::Closed);
match server.send_to(&buf).await {
Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
}
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_acks_on_socket() {
use tokio::sync::mpsc::channel;
init_logger();
let server_addr = next_test_ip4();
let (tx, mut rx) = channel(1);
let mut server = iotry!(UtpSocket::bind(server_addr));
let handle = task::spawn(async move {
let mut buf = [0u8; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
let _resp = server.recv(&mut buf).await.unwrap();
tx.send(server.seq_nr).await.expect("channel closed");
let mut buf = [0; 1500];
iotry!(server.recv_from(&mut buf));
});
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
let sender_seq_nr = rx.recv().await.expect("channel closed");
let ack_nr = client.ack_nr;
assert_eq!(ack_nr, sender_seq_nr);
assert!(client.close().await.is_ok());
assert_eq!(client.ack_nr, ack_nr);
drop(client);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_handle_packet() {
let initial_connection_id: u16 = rand::random();
let sender_connection_id = initial_connection_id + 1;
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(response.connection_id(), packet.connection_id());
assert_eq!(response.ack_nr(), packet.seq_nr());
assert!(response.payload().is_empty());
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(response.connection_id(), initial_connection_id);
assert_eq!(response.connection_id(), packet.connection_id() - 1);
assert_eq!(response.ack_nr(), packet.seq_nr());
assert!(response.payload().is_empty());
assert_eq!(response.seq_nr(), old_response.seq_nr());
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Fin);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
assert_eq!(response.seq_nr(), old_response.seq_nr());
assert_eq!(response.ack_nr(), packet.seq_nr());
}
#[tokio::test]
async fn test_response_to_keepalive_ack() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
socket.state = SocketState::Closed;
}
#[tokio::test]
async fn test_response_to_wrong_connection_id() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
assert_eq!(response.unwrap().get_type(), PacketType::State);
let new_connection_id = initial_connection_id.wrapping_mul(2);
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(new_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::Reset);
assert_eq!(response.ack_nr(), packet.seq_nr());
socket.state = SocketState::Closed;
}
#[tokio::test]
async fn test_unordered_packets() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.get_type(), PacketType::State);
let old_packet = packet;
let old_response = response;
let mut window: Vec<Packet> = Vec::new();
let mut packet = Packet::with_payload(&[1, 2, 3]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
window.push(packet);
let mut packet = Packet::with_payload(&[4, 5, 6]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 2);
packet.set_ack_nr(old_response.seq_nr());
window.push(packet);
let response = socket.handle_packet(&window[1], client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.ack_nr() != window[1].seq_nr());
let response = socket.handle_packet(&window[0], client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
socket.state = SocketState::Closed;
}
#[tokio::test]
async fn test_response_to_triple_ack() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert_eq!(LEN, data.len());
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&d[..]));
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
iotry!(server.recv(&mut buf));
let data_packet =
match server.socket.recv_from(buf.initialized_mut()).await {
Ok((_, _src)) => Packet::try_from(buf.filled()).unwrap(),
Err(e) => panic!("{}", e),
};
assert_eq!(data_packet.get_type(), PacketType::Data);
assert_eq!(&data_packet.payload(), &data.as_slice());
assert_eq!(data_packet.payload().len(), data.len());
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(data_packet.seq_nr() - 1);
packet.set_connection_id(server.sender_connection_id);
for _ in 0..3usize {
iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
}
let client_addr = server.connected_to;
let mut buf = [0; BUF_SIZE];
match server.socket.recv_from(&mut buf).await {
Ok((0, _)) => panic!("Received 0 bytes from socket"),
Ok((read, _src)) => {
let packet = Packet::try_from(&buf[..read]).unwrap();
assert_eq!(packet.get_type(), PacketType::Data);
assert_eq!(packet.seq_nr(), data_packet.seq_nr());
assert_eq!(packet.payload(), data_packet.payload());
let response = server.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
iotry!(server
.socket
.send_to(response.as_ref(), server.connected_to));
}
Err(e) => panic!("{}", e),
}
let mut buf = [0; 1500];
iotry!(server.recv_from(&mut buf));
handle.await.expect("task failure");
}
#[ignore]
#[tokio::test]
async fn test_socket_timeout_request() {
let (server_addr, client_addr) = (
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
);
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 512;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert_eq!(server.state, SocketState::New);
assert_eq!(client.state, SocketState::New);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
assert_eq!(client.connected_to, server_addr);
iotry!(client.send_to(&d[..]));
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
server.recv(&mut buf).await.unwrap();
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Connected);
let mut buf = [0; 1500];
iotry!(server.socket.recv_from(&mut buf));
server.congestion_timeout = 50;
loop {
match server.recv_from(&mut buf).await {
Ok((0, _)) => continue,
Ok(_) => break,
Err(e) => panic!("{}", e),
}
}
drop(server);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_sorted_buffer_insertion() {
let server_addr = next_test_ip4();
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_seq_nr(1);
assert!(socket.incoming_buffer.is_empty());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 1);
packet.set_seq_nr(2);
packet.set_timestamp(128.into());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 2);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
packet.set_seq_nr(3);
packet.set_timestamp(256.into());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
packet.set_seq_nr(2);
packet.set_timestamp(456.into());
socket.insert_into_buffer(packet);
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
}
#[tokio::test]
async fn test_duplicate_packet_handling() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
assert_eq!(server.state, SocketState::New);
assert_eq!(client.state, SocketState::New);
assert_eq!(
client.sender_connection_id,
client.receiver_connection_id + 1
);
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(client.state, SocketState::Connected);
let mut packet = Packet::with_payload(&[1, 2, 3]);
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
for _ in 0..2usize {
packet.set_timestamp(now_microseconds());
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
}
client.seq_nr += 1;
for _ in 0..1usize {
let mut buf = [0; BUF_SIZE];
iotry!(client.socket.recv_from(&mut buf));
}
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
iotry!(server.recv(&mut buf));
assert_eq!(
server.receiver_connection_id,
server.sender_connection_id + 1
);
assert_eq!(server.state, SocketState::Connected);
let expected: Vec<u8> = vec![1, 2, 3];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(buf.initialized_mut()).await {
Ok((0, _src)) => break,
Ok((_, _src)) => received.extend(buf.filled().to_vec()),
Err(e) => panic!("{:?}", e),
}
}
assert_eq!(received.len(), expected.len());
assert_eq!(received, expected);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_correct_packet_loss() {
init_logger();
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024 * 10;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let handle = task::spawn(
async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
let chunks = to_send[..].chunks(BUF_SIZE);
let dst = client.connected_to;
for (index, chunk) in chunks.enumerate() {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
packet.set_connection_id(client.sender_connection_id);
packet.set_timestamp(now_microseconds());
if index % 2 == 0 {
iotry!(client.socket.send_to(packet.as_ref(), dst));
}
client.curr_window += packet.len() as u32;
client.send_window.push(packet);
client.seq_nr += 1;
}
iotry!(client.close());
}
.instrument(debug_span!("sender")),
);
let mut buf = [0; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(
received.len(),
data.len(),
"wrong number of bytes received"
);
assert_eq!(received, data, "incorrect data received");
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_tolerance_to_small_buffers() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut read = Vec::new();
while server.state != SocketState::Closed {
let mut small_buffer = [0; 512];
match server.recv_from(&mut small_buffer).await {
Ok((0, _src)) => break,
Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(read.len(), data.len());
assert_eq!(read, data);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_sequence_number_rollover() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let mut client = iotry!(UtpSocket::bind(client_addr));
client.seq_nr =
::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
let handle = task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
assert!(client.seq_nr < 50);
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_drop_unused_socket() {
let server_addr = next_test_ip4();
let server = iotry!(UtpSocket::bind(server_addr));
drop(server);
}
#[tokio::test]
async fn test_invalid_packet_on_connect() {
use tokio::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let handle = task::spawn(async move {
match UtpSocket::connect(server_addr).await {
Err(ref e) if e.kind() == ErrorKind::Other => (), Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
Ok(_) => panic!("Expected Err, got Ok"),
}
});
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf).await {
Ok((_len, client_addr)) => {
iotry!(server.send_to(&[], client_addr));
}
_ => panic!(),
}
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_receive_unexpected_reply_type_on_connect() {
use tokio::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let mut buf = [0; BUF_SIZE];
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
let handle = task::spawn(async move {
match server.recv_from(&mut buf).await {
Ok((_len, client_addr)) => {
iotry!(server.send_to(packet.as_ref(), client_addr));
}
_ => panic!(),
}
});
match UtpSocket::connect(server_addr).await {
Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), Err(e) => {
panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e)
}
Ok(_) => panic!("Expected Err, got Ok"),
}
handle.await.expect("task failure");
}
#[tokio::test]
async fn test_receiving_syn_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let handle = task::spawn(async move {
let mut buf = [0; BUF_SIZE];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok(_) => (),
Err(e) => panic!("{:?}", e),
}
}
});
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
let mut buf = [0; BUF_SIZE];
let (len, _) = client
.socket
.recv_from(&mut buf)
.await
.expect("recv failed");
let reply = Packet::try_from(&buf[..len]).ok().unwrap();
assert_eq!(reply.get_type(), PacketType::Reset);
iotry!(client.close());
handle.await.expect("task failure");
}
#[tokio::test]
#[ignore]
async fn test_receiving_reset_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let handle = task::spawn(async move {
let client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Reset);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
iotry!(client.socket.send_to(packet.as_ref(), server_addr));
let mut buf = [0; BUF_SIZE];
client
.socket
.recv_from(&mut buf)
.await
.expect("recv failed");
});
let mut buf = [0; BUF_SIZE];
loop {
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok(_) => (),
Err(ref e) if e.kind() == ErrorKind::ConnectionReset => {
handle.await.expect("task failure");
return;
}
Err(e) => panic!("{:?}", e),
}
}
panic!("Should have received Reset");
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_premature_fin() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
iotry!(server.recv(&mut buf));
let mut packet = Packet::new();
packet.set_connection_id(server.sender_connection_id);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(server.ack_nr);
packet.set_timestamp(now_microseconds());
packet.set_type(PacketType::Fin);
iotry!(server.socket.send_to(packet.as_ref(), client_addr));
let mut received: Vec<u8> = vec![];
loop {
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf).await {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
}
#[tokio::test]
async fn test_base_delay_calculation() {
let minute_in_microseconds = 60 * 10i64.pow(6);
let samples = vec![
(0, 10),
(1, 8),
(2, 12),
(3, 7),
(minute_in_microseconds + 1, 11),
(minute_in_microseconds + 2, 19),
(minute_in_microseconds + 3, 9),
];
let addr = next_test_ip4();
let mut socket = UtpSocket::bind(addr).await.unwrap();
for (timestamp, delay) in samples {
socket.update_base_delay(
delay.into(),
((timestamp + delay) as u32).into(),
);
}
let expected = vec![7i64, 9i64]
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
assert_eq!(expected, actual);
assert_eq!(
socket.min_base_delay(),
expected.iter().min().cloned().unwrap_or_default()
);
}
#[tokio::test]
async fn test_local_addr() {
let addr = next_test_ip4();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let socket = UtpSocket::bind(addr).await.unwrap();
assert!(socket.local_addr().is_ok());
assert_eq!(socket.local_addr().unwrap(), addr);
}
#[tokio::test]
async fn test_peer_addr() {
use std::sync::mpsc::channel;
let addr = next_test_ip4();
let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
let mut server = UtpSocket::bind(server_addr).await.unwrap();
let (tx, rx) = channel();
assert!(server.peer_addr().is_err());
task::spawn(async move {
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut buf = [0; 1024];
tx.send(client.local_addr())
.expect("failed to send on channel");
iotry!(client.recv_from(&mut buf));
let mut buf = [0; 1024];
let mut buf = ReadBuf::new(&mut buf);
iotry!(server.recv(&mut buf));
assert!(server.peer_addr().is_ok());
let client_addr = rx.recv().unwrap().unwrap();
assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
iotry!(server.close());
assert!(server.peer_addr().is_err());
});
}
#[ignore]
#[tokio::test]
async fn test_connection_loss_data() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let mut buf = [0; BUF_SIZE];
iotry!(client.socket.recv_from(&mut buf));
for _ in 0..attempts {
match client.socket.recv_from(&mut buf).await {
Ok((len, _src)) => assert_eq!(
Packet::try_from(&buf[..len]).unwrap().get_type(),
PacketType::Data
),
Err(e) => panic!("{}", e),
}
}
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
iotry!(server.send_to(&[0]));
let mut buf = [0; BUF_SIZE];
let mut buf = ReadBuf::new(&mut buf);
match server.recv(&mut buf).await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
}
#[ignore]
#[tokio::test]
async fn test_connection_loss_fin() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let mut buf = [0; BUF_SIZE];
iotry!(client.socket.recv_from(&mut buf));
for _ in 0..attempts {
match client.socket.recv_from(&mut buf).await {
Ok((len, _src)) => assert_eq!(
Packet::try_from(&buf[..len]).unwrap().get_type(),
PacketType::Fin
),
Err(e) => panic!("{}", e),
}
}
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
match server.close().await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
}
#[ignore]
#[tokio::test]
async fn test_connection_loss_waiting() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let seq_nr = client.seq_nr;
let mut buf = [0; BUF_SIZE];
for _ in 0..(3 * attempts) {
match client.socket.recv_from(&mut buf).await {
Ok((len, _src)) => {
let packet = Packet::try_from(&buf[..len]).unwrap();
assert_eq!(packet.get_type(), PacketType::State);
assert_eq!(packet.ack_nr(), seq_nr - 1);
}
Err(e) => panic!("{}", e),
}
}
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf).await {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
}
}