use std::cmp::{min, max};
use std::collections::VecDeque;
use std::net::{ToSocketAddrs, SocketAddr, UdpSocket};
use std::io::{Result, Error, ErrorKind};
use util::{now_microseconds, ewma, Sequence};
use packet::{Packet, PacketType, Encodable, Decodable, ExtensionType, HEADER_SIZE};
use rand::{self, Rng};
use time::SteadyTime;
use time;
use std::time::Duration;
const BUF_SIZE: usize = 1500;
const GAIN: f64 = 1.0;
const ALLOWED_INCREASE: u32 = 1;
const TARGET: i64 = 100_000; const MSS: u32 = 1400;
const MIN_CWND: u32 = 2;
const INIT_CWND: u32 = 2;
const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; const MIN_CONGESTION_TIMEOUT: u64 = 500; const MAX_CONGESTION_TIMEOUT: u64 = 60_000; const BASE_HISTORY: usize = 10; const MAX_SYN_RETRIES: u32 = 5; const MAX_RETRANSMISSION_RETRIES: u32 = 5;
const PRE_SEND_TIMEOUT: u32 = 500_000;
const MAX_BASE_DELAY_AGE: i64 = 60_000_000;
#[derive(Debug)]
pub enum SocketError {
ConnectionClosed,
ConnectionReset,
ConnectionTimedOut,
UserTimedOut,
InvalidAddress,
InvalidPacket,
InvalidReply,
NotConnected,
}
impl From<SocketError> for Error {
fn from(error: SocketError) -> Error {
use self::SocketError::*;
let (kind, message) = match error {
ConnectionClosed => (ErrorKind::NotConnected, "The socket is closed"),
ConnectionReset => {
(ErrorKind::ConnectionReset,
"Connection reset by remote peer")
}
ConnectionTimedOut | UserTimedOut => (ErrorKind::TimedOut, "Connection timed out"),
InvalidAddress => (ErrorKind::InvalidInput, "Invalid address"),
InvalidPacket => (ErrorKind::Other, "Error parsing packet"),
InvalidReply => {
(ErrorKind::ConnectionRefused,
"The remote peer sent an invalid reply")
}
NotConnected => (ErrorKind::NotConnected, "The socket is not connected"),
};
Error::new(kind, message)
}
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
enum SocketState {
New,
Connected,
SynSent,
FinSent,
ResetReceived,
Closed,
}
struct DelayDifferenceSample {
received_at: i64,
difference: i64,
}
fn take_address<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
addr.to_socket_addrs()
.and_then(|mut it| it.next().ok_or(From::from(SocketError::InvalidAddress)))
}
fn unsafe_copy(src: &[u8], dst: &mut [u8]) -> usize {
let max_len = min(src.len(), dst.len());
unsafe {
use std::ptr::copy;
copy(src.as_ptr(), dst.as_mut_ptr(), max_len);
}
max_len
}
pub struct UtpSocket {
socket: UdpSocket,
connected_to: SocketAddr,
sender_connection_id: u16,
receiver_connection_id: u16,
seq_nr: u16,
ack_nr: u16,
state: SocketState,
incoming_buffer: Vec<Packet>,
send_window: Vec<Packet>,
unsent_queue: VecDeque<Packet>,
duplicate_ack_count: u32,
last_acked: u16,
last_acked_timestamp: u32,
last_dropped: u16,
rtt: i32,
rtt_variance: i32,
pending_data: VecDeque<u8>,
read_ready_data: VecDeque<u8>,
curr_window: u32,
remote_wnd_size: u32,
base_delays: VecDeque<i64>,
current_delays: Vec<DelayDifferenceSample>,
their_delay: u32,
last_rollover: i64,
congestion_timeout: u64,
cwnd: u32,
pub max_retransmission_retries: u32,
user_read_timeout: u64,
last_congestion_update: SteadyTime,
retries: u32,
state_packet: Option<Packet>,
last_msg_sent_timestamp: SteadyTime,
}
impl UtpSocket {
fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
let (receiver_id, sender_id) =
|| -> (u16, u16) {
let mut rng = rand::thread_rng();
loop {
let id = rng.gen::<u16>();
if id.checked_add(1).is_some() {
return (id, id + 1);
}
}
}();
UtpSocket {
socket: s,
connected_to: src,
receiver_connection_id: receiver_id,
sender_connection_id: sender_id,
seq_nr: 1,
ack_nr: 0,
state: SocketState::New,
incoming_buffer: Vec::new(),
send_window: Vec::new(),
unsent_queue: VecDeque::new(),
duplicate_ack_count: 0,
last_acked: 0,
last_acked_timestamp: 0,
last_dropped: 0,
rtt: 0,
rtt_variance: 0,
read_ready_data: VecDeque::new(),
pending_data: VecDeque::new(),
curr_window: 0,
remote_wnd_size: 0,
current_delays: Vec::new(),
base_delays: VecDeque::with_capacity(BASE_HISTORY),
their_delay: 0,
last_rollover: 0,
congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
cwnd: INIT_CWND * MSS,
max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
user_read_timeout: 0,
last_congestion_update: SteadyTime::now(),
retries: 0,
state_packet: None,
last_msg_sent_timestamp: SteadyTime::now(),
}
}
pub fn bind_with_udp_socket(socket: UdpSocket) -> Result<UtpSocket> {
socket.local_addr().map(|a| UtpSocket::from_raw_parts(socket, a))
}
pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
take_address(addr).and_then(|a| UdpSocket::bind(a).map(|s| UtpSocket::from_raw_parts(s, a)))
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.local_addr()
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
if self.state == SocketState::Connected || self.state == SocketState::FinSent {
Ok(self.connected_to)
} else {
Err(Error::from(SocketError::NotConnected))
}
}
pub fn connect<A: ToSocketAddrs>(other: A) -> Result<UtpSocket> {
let addr = try!(take_address(other));
let my_addr = match addr {
SocketAddr::V4(_) => "0.0.0.0:0",
SocketAddr::V6(_) => ":::0",
};
let mut socket = try!(UtpSocket::bind(my_addr));
socket.connected_to = addr;
let mut packet = Packet::new();
packet.set_type(PacketType::Syn);
packet.set_connection_id(socket.receiver_connection_id);
packet.set_seq_nr(socket.seq_nr);
let mut buf = [0; BUF_SIZE];
let mut syn_timeout = socket.congestion_timeout;
let mut syn_retries = 0;
while syn_retries < MAX_SYN_RETRIES {
packet.set_timestamp_microseconds(now_microseconds());
debug!("Connecting to {}", socket.connected_to);
try!(socket.socket.send_to(&packet.to_bytes()[..], socket.connected_to));
socket.state = SocketState::SynSent;
debug!("sent {:?}", packet);
socket.socket
.set_read_timeout(Some(Duration::from_millis(syn_timeout)))
.expect("Error setting read timeout");
match socket.socket.recv_from(&mut buf) {
Ok((read, addr)) => {
let packet = try!(Packet::from_bytes(&buf[..read]).or(Err(SocketError::InvalidPacket)));
socket.connected_to = addr;
if packet.get_type() != PacketType::State {
syn_retries += 1;
continue;
}
try!(socket.handle_packet(&packet, addr));
return Ok(socket);
},
Err(ref e) if (e.kind() == ErrorKind::WouldBlock ||
e.kind() == ErrorKind::TimedOut) => {
debug!("Timed out, retrying");
syn_timeout *= 2;
syn_retries += 1;
continue;
}
Err(e) => return Err(e),
};
}
Err(Error::from(SocketError::ConnectionTimedOut))
}
pub fn rendezvous_connect<A: ToSocketAddrs>(udp_socket: UdpSocket,
other: A)
-> Result<UtpSocket> {
let addr = try!(take_address(other));
let mut socket = try!(UtpSocket::bind_with_udp_socket(udp_socket));
socket.rendezvous_connect_to(addr).map(|_| socket)
}
fn rendezvous_connect_to(&mut self, addr: SocketAddr) -> Result<()> {
self.connected_to = addr;
let mut packet = Packet::new();
packet.set_type(PacketType::Syn);
packet.set_connection_id(self.receiver_connection_id);
packet.set_seq_nr(self.seq_nr);
let mut buf = [0; BUF_SIZE];
let mut syn_timeout = self.congestion_timeout;
let mut retry_count = 0;
let mut rx_syn: Option<Packet> = None;
let mut rx_state: Option<Packet> = None;
while retry_count < MAX_SYN_RETRIES {
packet.set_timestamp_microseconds(now_microseconds());
debug!("Connecting to {}", self.connected_to);
try!(self.socket.send_to(&packet.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
self.state = SocketState::SynSent;
debug!("sent {:?}", packet);
try!(self.socket.set_read_timeout(Some(Duration::from_millis(syn_timeout))));
match self.socket.recv_from(&mut buf) {
Ok((read, src)) => {
let mut packet = match Packet::from_bytes(&buf[..read]) {
Ok(packet) => packet,
Err(_) => {
continue;
}
};
let cid = min(self.receiver_connection_id, packet.connection_id());
packet.set_connection_id(cid);
match packet.get_type() {
PacketType::Syn => {
self.receiver_connection_id = cid;
self.sender_connection_id = cid + 1;
let reply = self.prepare_reply(&packet, PacketType::State);
try!(self.socket.send_to(&reply.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
rx_syn = Some(packet);
}
PacketType::State => {
self.receiver_connection_id = cid;
self.sender_connection_id = cid + 1;
rx_state = Some(packet);
}
_ => continue,
}
match (&rx_syn, &rx_state) {
(&Some(ref _syn), &Some(ref state)) => {
try!(self.handle_packet(state, src));
return Ok(());
}
_ => continue,
}
}
Err(ref e) if (e.kind() == ErrorKind::WouldBlock ||
e.kind() == ErrorKind::TimedOut) => {
debug!("Timed out, retrying");
syn_timeout *= 2;
retry_count += 1;
continue;
}
Err(e) => return Err(e),
};
}
Err(Error::from(SocketError::ConnectionTimedOut))
}
pub fn close(&mut self) -> Result<()> {
if self.state == SocketState::Closed || self.state == SocketState::New ||
self.state == SocketState::SynSent {
return Ok(());
}
try!(self.flush());
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp_microseconds(now_microseconds());
packet.set_type(PacketType::Fin);
try!(self.socket.send_to(&packet.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
debug!("sent {:?}", packet);
self.state = SocketState::FinSent;
let mut buf = [0; BUF_SIZE];
while self.state != SocketState::Closed {
try!(self.recv(&mut buf, false));
}
Ok(())
}
pub fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
let read = {
let (read_ready_0, read_ready_1) = self.read_ready_data.as_slices();
let mut read = 0;
if read_ready_0.len() > 0 {
read = unsafe_copy(read_ready_0, buf)
}
if read_ready_1.len() > 0 {
read += unsafe_copy(read_ready_1, &mut buf[read..])
}
read
};
if read > 0 {
self.read_ready_data.drain(..read);
return Ok((read, self.connected_to));
}
let read = self.flush_incoming_buffer(buf);
if read > 0 {
Ok((read, self.connected_to))
} else {
if self.state == SocketState::ResetReceived {
return Err(Error::from(SocketError::ConnectionReset));
}
loop {
if self.state == SocketState::Closed {
return Ok((0, self.connected_to));
}
match self.recv(buf, true) {
Ok((0, _src)) => continue,
Ok(x) => return Ok(x),
Err(e) => return Err(e),
}
}
}
}
pub fn set_read_timeout(&mut self, user_timeout: Option<u64>) {
self.user_read_timeout = match user_timeout {
Some(t) => {
if t > 0 {
t
} else {
0
}
}
None => 0,
}
}
#[cfg(windows)]
fn ignore_udp_error(e: &Error) -> bool {
const WSAECONNRESET: i32 = 10054;
const WSAEMSGSIZE: i32 = 10040;
match e.raw_os_error() {
Some(e) => match e {
WSAECONNRESET | WSAEMSGSIZE => true,
_ => false,
},
None => false,
}
}
#[cfg(not(windows))]
fn ignore_udp_error(_: &Error) -> bool {
false
}
fn recv(&mut self, buf: &mut [u8], use_user_timeout: bool) -> Result<(usize, SocketAddr)> {
let mut b = [0; BUF_SIZE + HEADER_SIZE];
let now = SteadyTime::now();
let (read, src);
let user_timeout = if use_user_timeout {
self.user_read_timeout
} else {
0
};
let use_user_timeout = user_timeout != 0;
loop {
if self.retries >= self.max_retransmission_retries {
debug!("exceeds max_retransmission_retries : {} ; current connect state is : {:?}",
self.max_retransmission_retries,
self.state);
self.state = SocketState::Closed;
debug!("socket marked as closed from {:?} to {:?}",
self.local_addr(),
self.connected_to);
return Err(Error::from(SocketError::ConnectionTimedOut));
}
let timeout;
let congestion_timeout = if self.state != SocketState::New {
debug!("setting read timeout of {} ms", self.congestion_timeout);
Some(Duration::from_millis(self.congestion_timeout))
} else {
None
};
{
let user_timeout = Duration::from_millis(user_timeout);
timeout = if use_user_timeout {
match congestion_timeout {
Some(congestion_timeout) => {
use std::cmp::min;
Some(min(congestion_timeout, user_timeout))
}
None => Some(user_timeout),
}
} else {
congestion_timeout
};
}
if use_user_timeout {
let user_timeout = time::Duration::milliseconds(user_timeout as i64);
if (SteadyTime::now() - now) >= user_timeout {
return Err(Error::from(SocketError::UserTimedOut));
}
}
self.socket.set_read_timeout(timeout).expect("Error setting read timeout");
match self.socket.recv_from(&mut b) {
Ok((r, s)) => {
read = r;
src = s;
break;
}
Err(ref e) if (e.kind() == ErrorKind::WouldBlock ||
e.kind() == ErrorKind::TimedOut) => {
debug!("recv_from timed out");
let now = SteadyTime::now();
let congestion_timeout = {
time::Duration::milliseconds(self.congestion_timeout as i64)
};
if !use_user_timeout ||
((now - self.last_congestion_update) >= congestion_timeout) {
self.last_congestion_update = now;
try!(self.handle_receive_timeout());
self.retries += 1;
}
}
Err(ref e) if Self::ignore_udp_error(e) => (),
Err(e) => return Err(e),
};
let elapsed = (SteadyTime::now() - now).num_milliseconds();
debug!("{} ms elapsed", elapsed);
}
self.last_congestion_update = SteadyTime::now();
self.retries = 0;
let packet = match Packet::from_bytes(&b[..read]) {
Ok(packet) => packet,
Err(e) => {
debug!("{}", e);
debug!("Ignoring invalid packet");
return Ok((0, self.connected_to));
}
};
debug!("received {:?}", packet);
if let Some(mut pkt) = try!(self.handle_packet(&packet, src)) {
pkt.set_wnd_size(BUF_SIZE as u32);
try!(self.socket.send_to(&pkt.to_bytes()[..], src));
self.last_msg_sent_timestamp = SteadyTime::now();
debug!("sent {:?}", pkt);
}
if packet.get_type() == PacketType::Data {
if Sequence::less(self.last_dropped, packet.seq_nr()) {
self.insert_into_buffer(packet);
}
}
let read = self.flush_incoming_buffer(buf);
Ok((read, src))
}
fn handle_receive_timeout(&mut self) -> Result<()> {
self.congestion_timeout *= 2;
self.cwnd = MSS;
debug!("self.send_window: {:?}",
self.send_window
.iter()
.map(Packet::seq_nr)
.collect::<Vec<u16>>());
if self.send_window.is_empty() {
if self.state == SocketState::FinSent {
let mut packet = Packet::new();
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_timestamp_microseconds(now_microseconds());
packet.set_type(PacketType::Fin);
try!(self.socket.send_to(&packet.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
debug!("resent FIN: {:?}", packet);
} else if self.state != SocketState::New {
debug!("sending fast resend request");
self.send_fast_resend_request();
}
} else {
let mut packet = &mut self.send_window[0];
packet.set_timestamp_microseconds(now_microseconds());
try!(self.socket.send_to(&packet.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
debug!("resent {:?}", packet);
}
Ok(())
}
fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
let mut resp = Packet::new();
resp.set_type(t);
let self_t_micro: u32 = now_microseconds();
let other_t_micro: u32 = original.timestamp_microseconds();
resp.set_timestamp_microseconds(self_t_micro);
resp.set_timestamp_difference_microseconds(self_t_micro.wrapping_sub(other_t_micro));
resp.set_connection_id(self.sender_connection_id);
resp.set_seq_nr(self.seq_nr);
resp.set_ack_nr(self.ack_nr);
resp
}
fn advance_incoming_buffer(&mut self) -> Option<Packet> {
if !self.incoming_buffer.is_empty() {
let packet = self.incoming_buffer.remove(0);
debug!("Removed packet from incoming buffer: {:?}", packet);
self.ack_nr = packet.seq_nr();
self.last_dropped = self.ack_nr;
Some(packet)
} else {
None
}
}
fn flush_incoming_buffer(&mut self, buf: &mut [u8]) -> usize {
if !self.pending_data.is_empty() {
let flushed = {
let (pending_0, pending_1) = self.pending_data.as_slices();
let mut flushed = 0;
if pending_0.len() > 0 {
flushed += unsafe_copy(pending_0, buf);
}
if pending_1.len() > 0 {
flushed += unsafe_copy(pending_1, &mut buf[flushed..]);
}
flushed
};
if flushed == self.pending_data.len() {
self.pending_data.clear();
self.advance_incoming_buffer();
} else {
self.pending_data.drain(..flushed);
}
return flushed;
}
if !self.incoming_buffer.is_empty() &&
(self.ack_nr == self.incoming_buffer[0].seq_nr() ||
self.ack_nr.wrapping_add(1) == self.incoming_buffer[0].seq_nr())
{
let flushed = unsafe_copy(&self.incoming_buffer[0].payload[..], buf);
if flushed == self.incoming_buffer[0].payload.len() {
self.advance_incoming_buffer();
} else {
self.pending_data.extend(self.incoming_buffer[0].payload.drain(flushed..));
}
return flushed;
}
0
}
pub fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
if self.state == SocketState::Closed {
return Err(Error::from(SocketError::ConnectionClosed));
}
let total_length = buf.len();
for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
let mut packet = Packet::with_payload(chunk);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
packet.set_connection_id(self.sender_connection_id);
self.unsent_queue.push_back(packet);
if self.seq_nr == ::std::u16::MAX {
self.seq_nr = 0;
} else {
self.seq_nr += 1;
}
}
try!(self.send());
Ok(total_length)
}
pub fn flush(&mut self) -> Result<()> {
let mut buf = [0u8; BUF_SIZE];
while !self.send_window.is_empty() {
debug!("packets in send window: {}", self.send_window.len());
try!(self.recv(&mut buf, false));
}
Ok(())
}
fn send_state(&mut self) {
let mut packet = Packet::new();
packet.set_type(PacketType::State);
let self_t_micro: u32 = now_microseconds();
packet.set_timestamp_microseconds(self_t_micro);
packet.set_timestamp_difference_microseconds(self.their_delay);
packet.set_connection_id(self.sender_connection_id);
packet.set_seq_nr(self.seq_nr);
packet.set_ack_nr(self.ack_nr);
let _ = self.socket.send_to(&packet.to_bytes()[..], self.connected_to);
self.last_msg_sent_timestamp = SteadyTime::now();
}
pub fn send_keepalive(&mut self) {
if (SteadyTime::now() - self.last_msg_sent_timestamp).num_milliseconds()
>= 14_000 {
self.send_state();
}
}
fn send(&mut self) -> Result<()> {
while let Some(mut packet) = self.unsent_queue.pop_front() {
try!(self.send_packet(&mut packet));
self.curr_window += packet.len() as u32;
self.send_window.push(packet);
}
Ok(())
}
#[inline]
fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
debug!("current window: {}", self.send_window.len());
let max_inflight = min(self.cwnd, self.remote_wnd_size);
let max_inflight = max(MIN_CWND * MSS, max_inflight);
let now = now_microseconds();
while self.curr_window >= max_inflight && now_microseconds() - now < PRE_SEND_TIMEOUT {
debug!("self.curr_window: {}", self.curr_window);
debug!("max_inflight: {}", max_inflight);
debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count);
debug!("now_microseconds() - now = {}", now_microseconds() - now);
let mut buf = [0; BUF_SIZE];
let (read, _) = try!(self.recv(&mut buf, false));
self.read_ready_data.extend(&buf[..read]);
}
debug!("out: now_microseconds() - now = {}",
now_microseconds() - now);
let distance_a = packet.seq_nr().wrapping_sub(self.last_acked);
let distance_b = self.last_acked.wrapping_sub(packet.seq_nr());
if distance_a > distance_b {
debug!("Packet already acknowledged, skipping...");
return Ok(());
}
packet.set_timestamp_microseconds(now_microseconds());
packet.set_timestamp_difference_microseconds(self.their_delay);
try!(self.socket.send_to(&packet.to_bytes()[..], self.connected_to));
self.last_msg_sent_timestamp = SteadyTime::now();
debug!("sent {:?}", packet);
Ok(())
}
fn update_base_delay(&mut self, base_delay: i64, now: i64) {
if self.base_delays.is_empty() || now - self.last_rollover > MAX_BASE_DELAY_AGE {
self.last_rollover = now;
if self.base_delays.len() == BASE_HISTORY {
self.base_delays.pop_front();
}
self.base_delays.push_back(base_delay);
} else {
let last_idx = self.base_delays.len() - 1;
if base_delay < self.base_delays[last_idx] {
self.base_delays[last_idx] = base_delay;
}
}
}
fn update_current_delay(&mut self, v: i64, now: i64) {
let rtt = self.rtt as i64 * 100;
while !self.current_delays.is_empty() && now - self.current_delays[0].received_at > rtt {
self.current_delays.remove(0);
}
self.current_delays.push(DelayDifferenceSample {
received_at: now,
difference: v,
});
}
fn update_congestion_timeout(&mut self, current_delay: i32) {
let delta = self.rtt - current_delay;
self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
self.rtt += (current_delay - self.rtt) / 8;
self.congestion_timeout = max((self.rtt + self.rtt_variance * 4) as u64,
MIN_CONGESTION_TIMEOUT);
self.congestion_timeout = min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
debug!("current_delay: {}", current_delay);
debug!("delta: {}", delta);
debug!("self.rtt_variance: {}", self.rtt_variance);
debug!("self.rtt: {}", self.rtt);
debug!("self.congestion_timeout: {}", self.congestion_timeout);
}
fn filtered_current_delay(&self) -> i64 {
let input = self.current_delays.iter().map(|x| x.difference);
ewma(input, 0.333) as i64
}
fn min_base_delay(&self) -> i64 {
self.base_delays.iter().min().cloned().unwrap_or(0)
}
fn build_selective_ack(&self) -> Vec<u8> {
let stashed = self.incoming_buffer
.iter()
.filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
.map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
.map(|diff| (diff / 8, diff % 8));
let mut sack = Vec::new();
for (byte, bit) in stashed {
while byte >= sack.len() || sack.len() % 4 != 0 {
sack.push(0u8);
}
sack[byte] |= 1 << bit;
}
sack
}
fn send_fast_resend_request(&mut self) {
for _ in 0..3 {
self.send_state();
}
}
fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
match self.send_window.iter().position(|pkt| pkt.seq_nr() == lost_packet_nr) {
None => debug!("Packet {} not found", lost_packet_nr),
Some(position) => {
debug!("self.send_window.len(): {}", self.send_window.len());
debug!("position: {}", position);
let mut packet = self.send_window[position].clone();
let _ = self.send_packet(&mut packet);
}
}
debug!("---> END resend_lost_packet <---");
}
fn advance_send_window(&mut self) {
if let Some(position) = self.send_window
.iter()
.position(|pkt| pkt.seq_nr() == self.last_acked) {
for _ in 0..position + 1 {
let packet = self.send_window.remove(0);
self.curr_window -= packet.len() as u32;
}
}
debug!("self.curr_window: {}", self.curr_window);
}
fn handle_packet(&mut self, packet: &Packet, src: SocketAddr) -> Result<Option<Packet>> {
debug!("({:?}, {:?})", self.state, packet.get_type());
let is_data_or_fin = packet.get_type() == PacketType::Data
|| packet.get_type() == PacketType::Fin;
if is_data_or_fin && packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
self.ack_nr = packet.seq_nr();
}
if packet.get_type() != PacketType::Syn && self.state != SocketState::SynSent &&
!(packet.connection_id() == self.sender_connection_id ||
packet.connection_id() == self.receiver_connection_id) {
return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
}
self.remote_wnd_size = packet.wnd_size();
debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
let now = now_microseconds();
self.their_delay = now.wrapping_sub(packet.timestamp_microseconds());
debug!("self.their_delay: {}", self.their_delay);
match (self.state, packet.get_type()) {
(SocketState::New, PacketType::Syn) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr = rand::random();
self.last_acked = self.seq_nr.wrapping_sub(1);
self.receiver_connection_id = packet.connection_id() + 1;
self.sender_connection_id = packet.connection_id();
self.state = SocketState::Connected;
self.last_dropped = self.ack_nr;
self.state_packet = Some(self.prepare_reply(packet, PacketType::State));
self.seq_nr = self.seq_nr.wrapping_add(1);
Ok(self.state_packet.clone())
}
(SocketState::Connected, PacketType::Syn) if self.connected_to == src => {
Ok(self.state_packet.clone())
}
(_, PacketType::Syn) => {
Ok(Some(self.prepare_reply(packet, PacketType::Reset)))
}
(SocketState::SynSent, PacketType::State) => {
self.connected_to = src;
self.ack_nr = packet.seq_nr();
self.seq_nr += 1;
self.state = SocketState::Connected;
self.last_acked = packet.ack_nr();
self.last_dropped = packet.seq_nr();
self.last_acked_timestamp = now_microseconds();
Ok(None)
}
(SocketState::SynSent, _) => Err(Error::from(SocketError::InvalidReply)),
(SocketState::Connected, PacketType::Data) |
(SocketState::FinSent, PacketType::Data) => Ok(self.handle_data_packet(packet)),
(SocketState::Connected, PacketType::State) => {
self.handle_state_packet(packet);
Ok(None)
}
(SocketState::Connected, PacketType::Fin) |
(SocketState::FinSent, PacketType::Fin) => {
if packet.ack_nr() < self.seq_nr {
debug!("FIN received but there are missing acknowledgements for sent packets");
}
let mut reply = self.prepare_reply(packet, PacketType::State);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!("current ack_nr ({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr());
let sack = self.build_selective_ack();
if sack.len() > 0 {
reply.set_sack(sack);
}
}
self.state = SocketState::Closed;
Ok(Some(reply))
}
(SocketState::FinSent, PacketType::State) => {
if packet.ack_nr() == self.seq_nr {
self.state = SocketState::Closed;
} else {
self.handle_state_packet(packet);
}
Ok(None)
}
(_, PacketType::Reset) => {
self.state = SocketState::ResetReceived;
Err(Error::from(SocketError::ConnectionReset))
}
(state, ty) => {
let message = format!("Unimplemented handling for ({:?},{:?})", state, ty);
debug!("{}", message);
Err(Error::new(ErrorKind::Other, message))
}
}
}
fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
let packet_type = if self.state == SocketState::FinSent {
PacketType::Fin
} else {
PacketType::State
};
let mut reply = self.prepare_reply(packet, packet_type);
if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
debug!("current ack_nr ({}) is behind received packet seq_nr ({})",
self.ack_nr,
packet.seq_nr());
let sack = self.build_selective_ack();
if sack.len() > 0 {
reply.set_sack(sack);
}
}
Some(reply)
}
fn queuing_delay(&self) -> i64 {
let filtered_current_delay = self.filtered_current_delay();
let min_base_delay = self.min_base_delay();
let queuing_delay = filtered_current_delay - min_base_delay;
debug!("filtered_current_delay: {}", filtered_current_delay);
debug!("min_base_delay: {}", min_base_delay);
debug!("queuing_delay: {}", queuing_delay);
queuing_delay
}
fn update_congestion_window(&mut self, off_target: f64, bytes_newly_acked: u32) {
let flightsize = self.curr_window;
let cwnd_increase = GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
let cwnd_increase = cwnd_increase / self.cwnd as f64;
debug!("cwnd_increase: {}", cwnd_increase);
self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
self.cwnd = min(self.cwnd, max_allowed_cwnd);
self.cwnd = max(self.cwnd, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
}
fn handle_state_packet(&mut self, packet: &Packet) {
if packet.ack_nr() == self.last_acked {
self.duplicate_ack_count += 1;
} else {
self.last_acked = packet.ack_nr();
self.last_acked_timestamp = now_microseconds();
self.duplicate_ack_count = 1;
}
if let Some(index) = self.send_window.iter().position(|p| packet.ack_nr() == p.seq_nr()) {
let bytes_newly_acked = self.send_window
.iter()
.take(index + 1)
.fold(0, |acc, p| acc + p.len());
let now = now_microseconds() as i64;
let our_delay = now - self.send_window[index].timestamp_microseconds() as i64;
debug!("our_delay: {}", our_delay);
self.update_base_delay(our_delay, now);
self.update_current_delay(our_delay, now);
let off_target: f64 = (TARGET as f64 - self.queuing_delay() as f64) / TARGET as f64;
debug!("off_target: {}", off_target);
self.update_congestion_window(off_target, bytes_newly_acked as u32);
let rtt = (TARGET - off_target as i64) / 1000; self.update_congestion_timeout(rtt as i32);
}
let mut packet_loss_detected: bool = !self.send_window.is_empty() &&
self.duplicate_ack_count == 3;
for extension in packet.extensions.iter() {
if extension.get_type() == ExtensionType::SelectiveAck {
if extension.iter().count_ones() >= 3 {
self.resend_lost_packet(packet.ack_nr() + 1);
packet_loss_detected = true;
}
if let Some(last_seq_nr) = self.send_window.last().map(Packet::seq_nr) {
for seq_nr in extension.iter()
.enumerate()
.filter(|&(_idx, received)| !received)
.map(|(idx, _received)| {
packet.ack_nr() + 2 + idx as u16
})
.take_while(|&seq_nr| seq_nr < last_seq_nr) {
debug!("SACK: packet {} lost", seq_nr);
self.resend_lost_packet(seq_nr);
packet_loss_detected = true;
}
}
} else {
debug!("Unknown extension {:?}, ignoring", extension.get_type());
}
}
if !self.send_window.is_empty() && self.duplicate_ack_count == 3 &&
!packet.extensions.iter().any(|ext| ext.get_type() == ExtensionType::SelectiveAck) {
self.resend_lost_packet(packet.ack_nr().wrapping_add(1));
}
if packet_loss_detected {
debug!("packet loss detected, halving congestion window");
self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
debug!("cwnd: {}", self.cwnd);
}
self.advance_send_window();
}
fn insert_into_buffer(&mut self, packet: Packet) {
if self.incoming_buffer.last().map(|p| packet.seq_nr() > p.seq_nr()).unwrap_or(false) {
self.incoming_buffer.push(packet);
} else {
let i = self.incoming_buffer.iter().filter(|p| p.seq_nr() < packet.seq_nr()).count();
if self.incoming_buffer.get(i).map(|p| p.seq_nr() != packet.seq_nr()).unwrap_or(true) {
self.incoming_buffer.insert(i, packet);
}
}
}
}
impl Drop for UtpSocket {
fn drop(&mut self) {
let _ = self.close();
}
}
pub struct UtpListener {
socket: UdpSocket,
}
impl UtpListener {
pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpListener> {
UdpSocket::bind(addr).and_then(|s| Ok(UtpListener { socket: s }))
}
pub fn accept(&self) -> Result<(UtpSocket, SocketAddr)> {
let mut buf = [0; BUF_SIZE];
match self.socket.recv_from(&mut buf) {
Ok((nread, src)) => {
let packet = try!(Packet::from_bytes(&buf[..nread])
.or(Err(SocketError::InvalidPacket)));
if packet.get_type() != PacketType::Syn {
return Err(Error::from(SocketError::InvalidPacket));
}
let inner_socket = self.socket.local_addr().and_then(|addr| {
match addr {
SocketAddr::V4(_) => UdpSocket::bind("0.0.0.0:0"),
SocketAddr::V6(_) => UdpSocket::bind(":::0"),
}
});
let mut socket = try!(inner_socket.map(|s| UtpSocket::from_raw_parts(s, src)));
match socket.handle_packet(&packet, src) {
Ok(Some(reply)) => try!(socket.socket.send_to(&reply.to_bytes()[..], src)),
Ok(None) => return Err(Error::from(SocketError::InvalidPacket)),
Err(e) => return Err(e),
};
Ok((socket, src))
}
Err(e) => Err(e),
}
}
pub fn incoming(&self) -> Incoming {
Incoming { listener: self }
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.local_addr()
}
}
pub struct Incoming<'a> {
listener: &'a UtpListener,
}
impl<'a> Iterator for Incoming<'a> {
type Item = Result<(UtpSocket, SocketAddr)>;
fn next(&mut self) -> Option<Result<(UtpSocket, SocketAddr)>> {
Some(self.listener.accept())
}
}
#[cfg(test)]
mod test {
use std::thread;
use std::net::ToSocketAddrs;
use std::io::ErrorKind;
use super::{UtpSocket, UtpListener, SocketState, BUF_SIZE, take_address};
use packet::{Packet, PacketType, Encodable, Decodable};
use util::now_microseconds;
use rand;
macro_rules! iotry {
($e:expr) => (match $e { Ok(e) => e, Err(e) => panic!("{:?}", e) })
}
fn next_test_port() -> u16 {
use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering};
static NEXT_OFFSET: AtomicUsize = ATOMIC_USIZE_INIT;
const BASE_PORT: u16 = 9600;
BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
}
fn next_test_ip4<'a>() -> (&'a str, u16) {
("127.0.0.1", next_test_port())
}
fn next_test_ip6<'a>() -> (&'a str, u16) {
("::1", next_test_port())
}
#[test]
fn test_socket_ipv4() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
assert_eq!(client.sender_connection_id,
client.receiver_connection_id + 1);
assert_eq!(client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap());
iotry!(client.close());
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf) {
e => println!("{:?}", e),
}
assert_eq!(server.receiver_connection_id,
server.sender_connection_id + 1);
assert!(server.state == SocketState::Closed);
drop(server);
assert!(child.join().is_ok());
}
#[test]
fn test_socket_ipv6() {
let server_addr = next_test_ip6();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
assert_eq!(client.sender_connection_id,
client.receiver_connection_id + 1);
assert_eq!(client.connected_to,
server_addr.to_socket_addrs().unwrap().next().unwrap());
iotry!(client.close());
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
match server.recv_from(&mut buf) {
e => println!("{:?}", e),
}
assert_eq!(server.receiver_connection_id,
server.sender_connection_id + 1);
assert!(server.state == SocketState::Closed);
drop(server);
assert!(child.join().is_ok());
}
#[test]
fn test_rendezvous_connect() {
use std::net::{UdpSocket, Ipv4Addr, SocketAddrV4};
let peer1_udp_socket = iotry!(UdpSocket::bind("0.0.0.0:0"));
let peer2_udp_socket = iotry!(UdpSocket::bind("0.0.0.0:0"));
let peer1_port = iotry!(peer1_udp_socket.local_addr()).port();
let peer1_addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), peer1_port);
let peer2_port = iotry!(peer2_udp_socket.local_addr()).port();
let peer2_addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), peer2_port);
const BUF_LEN: u32 = 16777216;
let tx_buffer: Vec<u8> = (0..BUF_LEN).map(|_| rand::random::<u8>()).collect();
let t = thread::spawn(move || {
let mut peer1 = iotry!(UtpSocket::rendezvous_connect(peer1_udp_socket, peer2_addr));
let mut sent_total = 0;
while sent_total < tx_buffer.len() {
let chunk_size = rand::random::<u16>() as usize + 1;
let slice_end = ::std::cmp::min(tx_buffer.len(), sent_total + chunk_size);
let sent = peer1.send_to(&tx_buffer[sent_total..slice_end]).unwrap();
sent_total += sent;
}
let r = peer1.flush();
r.unwrap();
let _ = peer1.close();
tx_buffer
});
let mut peer2 = iotry!(UtpSocket::rendezvous_connect(peer2_udp_socket, peer1_addr));
let mut rx_buffer: Vec<u8> = (0..BUF_LEN).into_iter().map(|_| 0u8).collect();
let mut received_total = 0;
while received_total < rx_buffer.len() {
let chunk_size = rand::random::<u16>() as usize + 1;
::std::thread::sleep(::std::time::Duration::from_millis(1));
let slice_end = ::std::cmp::min(rx_buffer.len(), received_total + chunk_size);
let (received, _) = peer2.recv_from(&mut rx_buffer[received_total..slice_end]).unwrap();
received_total += received;
}
let tx_buffer = t.join().unwrap();
assert_eq!(tx_buffer, rx_buffer);
let _ = peer2.close();
}
#[test]
fn test_recvfrom_on_closed_socket() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
assert!(client.close().is_ok());
});
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv_from(&mut buf);
assert!(server.state == SocketState::Closed);
match server.recv_from(&mut buf) {
Ok((0, _src)) => {}
e => panic!("Expected Ok(0), got {:?}", e),
}
assert_eq!(server.state, SocketState::Closed);
assert!(child.join().is_ok());
}
#[test]
fn test_sendto_on_closed_socket() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
let (_read, _src) = iotry!(server.recv_from(&mut buf));
assert_eq!(server.state, SocketState::Closed);
match server.send_to(&buf) {
Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
}
assert!(child.join().is_ok());
}
#[test]
fn test_acks_on_socket() {
use std::sync::mpsc::channel;
let server_addr = next_test_ip4();
let (tx, rx) = channel();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = thread::spawn(move || {
let mut buf = [0u8; BUF_SIZE];
let _resp = server.recv(&mut buf, false);
tx.send(server.seq_nr).unwrap();
iotry!(server.recv_from(&mut buf));
drop(server);
});
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
let sender_seq_nr = rx.recv().unwrap();
let ack_nr = client.ack_nr;
assert!(ack_nr != 0);
assert!(ack_nr.wrapping_add(1) == sender_seq_nr);
assert!(client.close().is_ok());
assert!(client.ack_nr == ack_nr);
drop(client);
assert!(child.join().is_ok());
}
#[test]
fn test_handle_packet() {
let initial_connection_id: u16 = rand::random();
let sender_connection_id = initial_connection_id + 1;
let (server_addr, client_addr) = (next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap());
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::State);
assert!(response.connection_id() == packet.connection_id());
assert!(response.ack_nr() == packet.seq_nr());
assert!(response.payload.is_empty());
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::State);
assert!(response.connection_id() == initial_connection_id);
assert!(response.connection_id() == packet.connection_id() - 1);
assert!(response.ack_nr() == packet.seq_nr());
assert!(response.payload.is_empty());
assert!(response.seq_nr() == old_response.seq_nr().wrapping_add(1));
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_type(PacketType::Fin);
packet.set_connection_id(sender_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::State);
assert!(packet.seq_nr() == old_packet.seq_nr() + 1);
assert!(response.seq_nr() == old_response.seq_nr());
assert!(response.ack_nr() == packet.seq_nr());
}
#[test]
fn test_response_to_keepalive_ack() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap());
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::State);
let old_packet = packet;
let old_response = response;
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_none());
socket.state = SocketState::Closed;
}
#[test]
fn test_response_to_wrong_connection_id() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap());
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
assert!(response.unwrap().get_type() == PacketType::State);
let new_connection_id = initial_connection_id.wrapping_mul(2);
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_connection_id(new_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::Reset);
assert!(response.ack_nr() == packet.seq_nr());
socket.state = SocketState::Closed;
}
#[test]
fn test_unordered_packets() {
let initial_connection_id: u16 = rand::random();
let (server_addr, client_addr) = (next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap());
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(initial_connection_id);
let response = socket.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.get_type() == PacketType::State);
let old_packet = packet;
let old_response = response;
let mut window: Vec<Packet> = Vec::new();
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Data);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 1);
packet.set_ack_nr(old_response.seq_nr());
packet.payload = vec![1, 2, 3];
window.push(packet);
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Data);
packet.set_connection_id(initial_connection_id);
packet.set_seq_nr(old_packet.seq_nr() + 2);
packet.set_ack_nr(old_response.seq_nr());
packet.payload = vec![4, 5, 6];
window.push(packet);
let response = socket.handle_packet(&window[1], client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
assert!(response.ack_nr() != window[1].seq_nr());
let response = socket.handle_packet(&window[0], client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
socket.state = SocketState::Closed;
}
#[test]
fn test_socket_unordered_packets() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
assert_eq!(client.sender_connection_id,
client.receiver_connection_id + 1);
let s = client.socket.try_clone().ok().expect("Error cloning internal UDP socket");
let mut window: Vec<Packet> = Vec::new();
for data in (1..13u8).collect::<Vec<u8>>()[..].chunks(3) {
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Data);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
packet.payload = data.to_vec();
window.push(packet.clone());
client.send_window.push(packet.clone());
client.seq_nr += 1;
client.curr_window += packet.len() as u32;
}
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Fin);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
window.push(packet);
client.seq_nr += 1;
iotry!(s.send_to(&window[3].to_bytes()[..], server_addr));
iotry!(s.send_to(&window[2].to_bytes()[..], server_addr));
iotry!(s.send_to(&window[1].to_bytes()[..], server_addr));
iotry!(s.send_to(&window[0].to_bytes()[..], server_addr));
iotry!(s.send_to(&window[4].to_bytes()[..], server_addr));
for _ in 0u8..2 {
let mut buf = [0; BUF_SIZE];
iotry!(s.recv_from(&mut buf));
}
});
let mut buf = [0; BUF_SIZE];
let expected: Vec<u8> = (1..13u8).collect();
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{:?}", e),
}
}
assert_eq!(server.receiver_connection_id,
server.sender_connection_id + 1);
assert_eq!(server.state, SocketState::Closed);
assert_eq!(received.len(), expected.len());
assert_eq!(received, expected);
assert!(child.join().is_ok());
}
#[test]
fn test_response_to_triple_ack() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert_eq!(LEN, data.len());
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&d[..]));
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv(&mut buf, false));
let data_packet = match server.socket.recv_from(&mut buf) {
Ok((read, _src)) => iotry!(Packet::from_bytes(&buf[..read])),
Err(e) => panic!("{}", e),
};
assert_eq!(data_packet.get_type(), PacketType::Data);
assert_eq!(data_packet.payload, data);
assert_eq!(data_packet.payload.len(), data.len());
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::State);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(data_packet.seq_nr() - 1);
packet.set_connection_id(server.sender_connection_id);
for _ in 0u8..3 {
iotry!(server.socket.send_to(&packet.to_bytes()[..], server.connected_to));
}
let client_addr = server.connected_to;
match server.socket.recv_from(&mut buf) {
Ok((0, _)) => panic!("Received 0 bytes from socket"),
Ok((read, _src)) => {
let packet = iotry!(Packet::from_bytes(&buf[..read]));
assert_eq!(packet.get_type(), PacketType::Data);
assert_eq!(packet.seq_nr(), data_packet.seq_nr());
assert!(packet.payload == data_packet.payload);
let response = server.handle_packet(&packet, client_addr);
assert!(response.is_ok());
let response = response.unwrap();
assert!(response.is_some());
let response = response.unwrap();
iotry!(server.socket.send_to(&response.to_bytes()[..], server.connected_to));
}
Err(e) => panic!("{}", e),
}
iotry!(server.recv_from(&mut buf));
assert!(child.join().is_ok());
}
#[test]
fn test_socket_timeout_request() {
let (server_addr, client_addr) = (next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap(),
next_test_ip4()
.to_socket_addrs()
.unwrap()
.next()
.unwrap());
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 512;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let d = data.clone();
assert!(server.state == SocketState::New);
assert!(client.state == SocketState::New);
assert_eq!(client.sender_connection_id,
client.receiver_connection_id + 1);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
assert_eq!(client.connected_to, server_addr);
iotry!(client.send_to(&d[..]));
drop(client);
});
let mut buf = [0u8; BUF_SIZE];
server.recv(&mut buf, false).unwrap();
assert_eq!(server.receiver_connection_id,
server.sender_connection_id + 1);
assert!(server.state == SocketState::Connected);
iotry!(server.socket.recv_from(&mut buf));
server.congestion_timeout = 50;
loop {
match server.recv_from(&mut buf) {
Ok((0, _)) => continue,
Ok(_) => break,
Err(e) => panic!("{}", e),
}
}
drop(server);
assert!(child.join().is_ok());
}
#[test]
fn test_sorted_buffer_insertion() {
let server_addr = next_test_ip4();
let mut socket = iotry!(UtpSocket::bind(server_addr));
let mut packet = Packet::new();
packet.set_seq_nr(1);
assert!(socket.incoming_buffer.is_empty());
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 1);
packet.set_seq_nr(2);
packet.set_timestamp_microseconds(128);
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 2);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp_microseconds(), 128);
packet.set_seq_nr(3);
packet.set_timestamp_microseconds(256);
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
assert_eq!(socket.incoming_buffer[2].timestamp_microseconds(), 256);
packet.set_seq_nr(2);
packet.set_timestamp_microseconds(456);
socket.insert_into_buffer(packet.clone());
assert_eq!(socket.incoming_buffer.len(), 3);
assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
assert_eq!(socket.incoming_buffer[1].timestamp_microseconds(), 128);
}
#[test]
fn test_duplicate_packet_handling() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let client = iotry!(UtpSocket::bind(client_addr));
let mut server = iotry!(UtpSocket::bind(server_addr));
assert!(server.state == SocketState::New);
assert!(client.state == SocketState::New);
assert_eq!(client.sender_connection_id,
client.receiver_connection_id + 1);
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert!(client.state == SocketState::Connected);
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Data);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
packet.payload = vec![1, 2, 3];
for _ in 0u8..2 {
packet.set_timestamp_microseconds(now_microseconds());
iotry!(client.socket.send_to(&packet.to_bytes()[..], server_addr));
}
client.seq_nr += 1;
for _ in 0u8..1 {
let mut buf = [0; BUF_SIZE];
iotry!(client.socket.recv_from(&mut buf));
}
iotry!(client.close());
});
let mut buf = [0u8; BUF_SIZE];
iotry!(server.recv(&mut buf, false));
assert_eq!(server.receiver_connection_id,
server.sender_connection_id + 1);
assert!(server.state == SocketState::Connected);
let expected: Vec<u8> = vec![1, 2, 3];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{:?}", e),
}
}
assert_eq!(received.len(), expected.len());
assert_eq!(received, expected);
assert!(child.join().is_ok());
}
#[test]
fn test_correct_packet_loss() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024 * 10;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
let chunks = to_send[..].chunks(BUF_SIZE);
let dst = client.connected_to;
for (index, chunk) in chunks.enumerate() {
let mut packet = Packet::new();
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
packet.set_connection_id(client.sender_connection_id);
packet.set_timestamp_microseconds(now_microseconds());
packet.payload = chunk.to_vec();
packet.set_type(PacketType::Data);
if index % 2 == 0 {
iotry!(client.socket.send_to(&packet.to_bytes()[..], dst));
}
client.curr_window += packet.len() as u32;
client.send_window.push(packet);
client.seq_nr += 1;
}
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
assert!(child.join().is_ok());
}
#[test]
fn test_tolerance_to_small_buffers() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = 1024;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut read = Vec::new();
while server.state != SocketState::Closed {
let mut small_buffer = [0; 512];
match server.recv_from(&mut small_buffer) {
Ok((0, _src)) => break,
Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(read.len(), data.len());
assert_eq!(read, data);
assert!(child.join().is_ok());
}
#[test]
fn test_sequence_number_rollover() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::bind(client_addr));
client.seq_nr = ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
assert!(client.seq_nr < 50);
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
assert!(child.join().is_ok());
}
#[test]
fn test_drop_unused_socket() {
let server_addr = next_test_ip4();
let server = iotry!(UtpSocket::bind(server_addr));
drop(server);
}
#[test]
fn test_invalid_packet_on_connect() {
use std::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let child = thread::spawn(move || {
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf) {
Ok((_len, client_addr)) => {
iotry!(server.send_to(&[], client_addr));
}
_ => panic!(),
}
});
match UtpSocket::connect(server_addr) {
Err(ref e) if e.kind() == ErrorKind::Other => (), Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
Ok(_) => panic!("Expected Err, got Ok"),
}
assert!(child.join().is_ok());
}
#[test]
fn test_receive_unexpected_reply_type_on_connect() {
use std::net::UdpSocket;
let server_addr = next_test_ip4();
let server = iotry!(UdpSocket::bind(server_addr));
let child = thread::spawn(move || {
let mut buf = [0; BUF_SIZE];
let mut packet = Packet::new();
packet.set_type(PacketType::Data);
match server.recv_from(&mut buf) {
Ok((_len, client_addr)) => {
iotry!(server.send_to(&packet.to_bytes()[..], client_addr));
}
_ => panic!(),
}
});
match UtpSocket::connect(server_addr) {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (), Err(e) => panic!("Expected ErrorKind::TimedOut, got {:?}", e),
Ok(_) => panic!("Expected Err, got Ok"),
}
assert!(child.join().is_ok());
}
#[test]
fn test_receiving_syn_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = thread::spawn(move || {
let mut buf = [0; BUF_SIZE];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok(_) => (),
Err(e) => panic!("{:?}", e),
}
}
});
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Syn);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
let other_socket = iotry!(::std::net::UdpSocket::bind("0.0.0.0:0"));
iotry!(other_socket.send_to(&packet.to_bytes()[..], server_addr));
let mut buf = [0; BUF_SIZE];
match other_socket.recv_from(&mut buf) {
Ok((len, _src)) => {
let reply = Packet::from_bytes(&buf[..len]).ok().unwrap();
assert_eq!(reply.get_type(), PacketType::Reset);
}
Err(e) => panic!("{:?}", e),
}
iotry!(client.close());
assert!(child.join().is_ok());
}
#[test]
fn test_receiving_reset_on_established_connection() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
let child = thread::spawn(move || {
let client = iotry!(UtpSocket::connect(server_addr));
let mut packet = Packet::new();
packet.set_wnd_size(BUF_SIZE as u32);
packet.set_type(PacketType::Reset);
packet.set_connection_id(client.sender_connection_id);
packet.set_seq_nr(client.seq_nr);
packet.set_ack_nr(client.ack_nr);
iotry!(client.socket.send_to(&packet.to_bytes()[..], server_addr));
let mut buf = [0; BUF_SIZE];
match client.socket.recv_from(&mut buf) {
Ok((_len, _src)) => (),
Err(e) => panic!("{:?}", e),
}
});
let mut buf = [0; BUF_SIZE];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok(_) => (),
Err(ref e) if e.kind() == ErrorKind::ConnectionReset => return,
Err(e) => panic!("{:?}", e),
}
}
assert!(child.join().is_ok());
panic!("Should have received Reset");
}
#[test]
fn test_premature_fin() {
let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
let mut server = iotry!(UtpSocket::bind(server_addr));
const LEN: usize = BUF_SIZE * 4;
let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
let to_send = data.clone();
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&to_send[..]));
iotry!(client.close());
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv(&mut buf, false));
let mut packet = Packet::new();
packet.set_connection_id(server.sender_connection_id);
packet.set_seq_nr(server.seq_nr);
packet.set_ack_nr(server.ack_nr);
packet.set_timestamp_microseconds(now_microseconds());
packet.set_type(PacketType::Fin);
iotry!(server.socket.send_to(&packet.to_bytes()[..], client_addr));
let mut received: Vec<u8> = vec![];
loop {
match server.recv_from(&mut buf) {
Ok((0, _src)) => break,
Ok((len, _src)) => received.extend(buf[..len].to_vec()),
Err(e) => panic!("{}", e),
}
}
assert_eq!(received.len(), data.len());
assert_eq!(received, data);
assert!(child.join().is_ok());
}
#[test]
fn test_base_delay_calculation() {
let minute_in_microseconds = 60 * 10i64.pow(6);
let samples = vec![(0, 10),
(1, 8),
(2, 12),
(3, 7),
(minute_in_microseconds + 1, 11),
(minute_in_microseconds + 2, 19),
(minute_in_microseconds + 3, 9)];
let addr = next_test_ip4();
let mut socket = UtpSocket::bind(addr).unwrap();
for (timestamp, delay) in samples {
socket.update_base_delay(delay, timestamp + delay);
}
let expected = vec![7, 9];
let actual = socket.base_delays.iter().map(|&x| x).collect::<Vec<_>>();
assert_eq!(expected, actual);
assert_eq!(socket.min_base_delay(), 7);
}
#[test]
fn test_local_addr() {
let addr = next_test_ip4();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let socket = UtpSocket::bind(addr).unwrap();
assert!(socket.local_addr().is_ok());
assert_eq!(socket.local_addr().unwrap(), addr);
}
#[test]
fn test_listener_local_addr() {
let addr = next_test_ip4();
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let listener = UtpListener::bind(addr).unwrap();
assert!(listener.local_addr().is_ok());
assert_eq!(listener.local_addr().unwrap(), addr);
}
#[test]
fn test_peer_addr() {
use std::sync::mpsc::channel;
let addr = next_test_ip4();
let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
let mut server = UtpSocket::bind(server_addr).unwrap();
let (tx, rx) = channel();
assert!(server.peer_addr().is_err());
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut buf = [0; 1024];
iotry!(tx.send(client.local_addr()));
iotry!(client.recv_from(&mut buf));
});
let mut buf = [0; 1024];
iotry!(server.recv(&mut buf, false));
assert!(server.peer_addr().is_ok());
let client_addr = rx.recv().unwrap().unwrap();
assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
iotry!(server.close());
assert!(server.peer_addr().is_err());
assert!(child.join().is_ok());
}
#[test]
fn test_take_address() {
assert!(take_address(("0.0.0.0:0")).is_ok());
assert!(take_address((":::0")).is_ok());
assert!(take_address(("0.0.0.0", 0)).is_ok());
assert!(take_address(("::", 0)).is_ok());
assert!(take_address(("1.2.3.4", 5)).is_ok());
assert!(take_address("999.0.0.0:0").is_err());
assert!(take_address(("1.2.3.4:70000")).is_err());
assert!(take_address("").is_err());
assert!(take_address("this is not an address").is_err());
assert!(take_address("no.dns.resolution.com").is_err());
}
#[test]
fn test_connection_loss_data() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let socket = client.socket.try_clone().unwrap();
let mut buf = [0; BUF_SIZE];
iotry!(socket.recv_from(&mut buf));
for _ in 0..attempts {
match socket.recv_from(&mut buf) {
Ok((len, _src)) => {
assert_eq!(Packet::from_bytes(&buf[..len]).unwrap().get_type(),
PacketType::Data)
}
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
iotry!(server.send_to(&[0]));
let mut buf = [0; BUF_SIZE];
match server.recv(&mut buf, false) {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
assert!(child.join().is_ok());
}
#[test]
fn test_connection_loss_fin() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let socket = client.socket.try_clone().unwrap();
let mut buf = [0; BUF_SIZE];
iotry!(socket.recv_from(&mut buf));
for _ in 0..attempts {
match socket.recv_from(&mut buf) {
Ok((len, _src)) => {
assert_eq!(Packet::from_bytes(&buf[..len]).unwrap().get_type(),
PacketType::Fin)
}
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
match server.close() {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
assert!(child.join().is_ok());
}
#[test]
fn test_connection_loss_waiting() {
let server_addr = next_test_ip4();
let mut server = iotry!(UtpSocket::bind(server_addr));
server.congestion_timeout = 1;
let attempts = server.max_retransmission_retries;
let child = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
iotry!(client.send_to(&[0]));
client.state = SocketState::Closed;
let socket = client.socket.try_clone().unwrap();
let seq_nr = client.seq_nr;
let mut buf = [0; BUF_SIZE];
for _ in 0..(3 * attempts) {
match socket.recv_from(&mut buf) {
Ok((len, _src)) => {
let packet = iotry!(Packet::from_bytes(&buf[..len]));
assert_eq!(packet.get_type(), PacketType::State);
assert_eq!(packet.ack_nr(), seq_nr - 1);
}
Err(e) => panic!("{}", e),
}
}
});
let mut buf = [0; BUF_SIZE];
iotry!(server.recv_from(&mut buf));
let mut buf = [0; BUF_SIZE];
match server.recv_from(&mut buf) {
Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
x => panic!("Expected Err(TimedOut), got {:?}", x),
}
assert!(child.join().is_ok());
}
const NETWORK_NODE_COUNT: usize = 20;
const NETWORK_MSG_COUNT: usize = 5;
fn test_network(exchange: fn(&mut UtpSocket) -> ()) {
use std::net::SocketAddr;
use std::thread::{JoinHandle, spawn};
const NODE_COUNT: usize = NETWORK_NODE_COUNT;
struct Node {
listener: UtpListener,
}
impl Node {
fn new() -> Node {
Node { listener: iotry!(UtpListener::bind("127.0.0.1:0")) }
}
fn run(&mut self, exchange: fn(&mut UtpSocket) -> (), peer_addrs: Vec<SocketAddr>) {
let connect_cnt = peer_addrs.len();
let connect_join_handle = spawn(move || {
let mut send_jhs = Vec::<JoinHandle<()>>::new();
for peer_addr in peer_addrs {
send_jhs.push(spawn(move || {
let mut socket = iotry!(UtpSocket::connect(peer_addr));
exchange(&mut socket);
}));
}
for jh in send_jhs {
iotry!(jh.join());
}
});
let mut recv_jhs = Vec::<JoinHandle<()>>::new();
for _ in 0..NODE_COUNT-1-connect_cnt {
let mut socket = iotry!(self.listener.accept()).0;
recv_jhs.push(spawn(move || {
exchange(&mut socket);
}));
}
for jh in recv_jhs {
iotry!(jh.join());
}
iotry!(connect_join_handle.join());
}
}
let mut nodes = Vec::<Node>::new();
for _ in 0..NODE_COUNT {
nodes.push(Node::new());
}
let listening_addrs = nodes.iter()
.map(|n| iotry!(n.listener.local_addr()))
.collect::<Vec<_>>();
let mut join_handles = Vec::<JoinHandle<()>>::new();
let mut ni: usize = 0;
for mut node in nodes {
let mut addrs = Vec::<SocketAddr>::new();
for ai in 0..listening_addrs.len() {
if ai <= ni { continue }
addrs.push(listening_addrs[ai].clone());
}
join_handles.push(spawn(move || {
node.run(exchange, addrs);
}));
ni += 1;
}
for handle in join_handles {
iotry!(handle.join());
}
}
#[test]
fn test_network_no_timeout() {
static MSG_COUNT: usize = NETWORK_MSG_COUNT;
fn make_buf(i: usize) -> [u8; 10] {
let mut buf = [0; 10];
for j in 0..10 {
buf[j] = (i + j) as u8;
}
buf
}
fn sequential_exchange(socket: &mut UtpSocket) {
let mut i = 0;
let from = socket.socket.local_addr().map(|addr| addr.port()).unwrap_or(0);
let to = socket.connected_to.port();
while i < MSG_COUNT {
let tx_buf = make_buf(i);
assert_eq!(iotry!(socket.send_to(&tx_buf)), tx_buf.len());
let mut buf = [0; 10];
match socket.recv_from(&mut buf) {
Ok((cnt, _)) => {
if cnt == 0 {
if socket.state != SocketState::Connected {
panic!("socket is in an invalid state \"{:?}\" from {:?} to {:?}",
socket.state, from, to);
}
}
assert_eq!(cnt, 10);
if buf != make_buf(i) {
panic!("expected {:?} but received {:?} in recv step {}",
make_buf(i),
buf,
i);
}
},
Err(err) => {
panic!("Recv error {:?}; from {:?} to {:?}", err, from, to);
}
}
i += 1;
}
}
for i in 0..100 {
println!("------ Testing Network iteration {}", i);
test_network(sequential_exchange);
}
}
#[test]
fn test_network_with_timeout() {
static MSG_COUNT: usize = NETWORK_MSG_COUNT;
fn make_buf(i: usize) -> [u8; 10] {
let mut buf = [0; 10];
for j in 0..10 {
buf[j] = (i + j) as u8;
}
buf
}
fn timeout_exchange(socket: &mut UtpSocket) {
socket.set_read_timeout(Some(50));
let mut recv_cnt = 0;
let mut send_cnt = 0;
let from = socket.socket.local_addr().map(|addr| addr.port()).unwrap_or(0);
let to = socket.connected_to.port();
loop {
if send_cnt < MSG_COUNT {
let tx_buf = make_buf(send_cnt);
match socket.send_to(&tx_buf) {
Ok(cnt) => {
assert_eq!(cnt, tx_buf.len());
send_cnt += 1;
}
Err(ref e) if e.kind() == ErrorKind::TimedOut => {}
Err(e) => {
panic!("{:?}", e);
}
}
}
if recv_cnt < MSG_COUNT {
let exp_buf = make_buf(recv_cnt);
let mut buf = [0; 10];
match socket.recv_from(&mut buf) {
Ok((cnt, _)) => {
if cnt == 0 {
if socket.state != SocketState::Connected {
panic!("socket is in an invalid state \"{:?}\" \
from {:?} to {:?} in receive #{}",
socket.state, from, to, recv_cnt);
}
} else {
assert_eq!(cnt, exp_buf.len());
assert_eq!(buf, exp_buf);
recv_cnt += 1;
}
},
Err(ref e) if e.kind() == ErrorKind::TimedOut => {
},
Err(e) => {
panic!("{:?} recv_cnt={} send_cnt={}", e, recv_cnt, send_cnt);
}
}
}
if send_cnt == MSG_COUNT && recv_cnt == MSG_COUNT {
break;
}
}
}
for i in 0..100 {
println!("------ Testing Network iteration {}", i);
test_network(timeout_exchange);
}
}
#[test]
fn test_send_client_to_server() {
let listener = iotry!(UtpListener::bind("127.0.0.1:0"));
let server_addr = iotry!(listener.local_addr());
static TX_BUF: [u8; 10] = [0,1,2,3,4,5,6,7,8,9];
let client_t = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(iotry!(client.send_to(&TX_BUF)), TX_BUF.len());
});
let mut server = iotry!(listener.accept()).0;
let mut buf = [0; 10];
iotry!(server.recv_from(&mut buf));
assert_eq!(buf, TX_BUF);
assert!(client_t.join().is_ok());
}
#[test]
fn test_send_server_to_client() {
let listener = iotry!(UtpListener::bind("127.0.0.1:0"));
let server_addr = iotry!(listener.local_addr());
static TX_BUF: [u8; 10] = [0,1,2,3,4,5,6,7,8,9];
let client_t = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
let mut buf = [0; 10];
iotry!(client.recv_from(&mut buf));
assert_eq!(buf, TX_BUF);
});
let mut server = iotry!(listener.accept()).0;
assert_eq!(iotry!(server.send_to(&TX_BUF)), TX_BUF.len());
let fr = server.flush();
assert!(fr.is_ok());
assert!(client_t.join().is_ok());
}
#[test]
fn test_data_exchange_utp() {
let listener = iotry!(UtpListener::bind("127.0.0.1:0"));
let server_addr = iotry!(listener.local_addr());
static TX_BUF: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let client_t = thread::spawn(move || {
let mut client = iotry!(UtpSocket::connect(server_addr));
assert_eq!(iotry!(client.send_to(&TX_BUF)), TX_BUF.len());
let mut buf = [0; 10];
iotry!(client.recv_from(&mut buf));
assert_eq!(buf, TX_BUF);
});
let mut server = iotry!(listener.accept()).0;
assert_eq!(iotry!(server.send_to(&TX_BUF)), TX_BUF.len());
let mut buf = [0; 10];
iotry!(server.recv_from(&mut buf));
assert_eq!(buf, TX_BUF);
let _ = server.flush();
assert!(client_t.join().is_ok());
}
#[test]
fn test_data_exchange_tcp() {
use std::net::{TcpListener, TcpStream};
use std::io::{Read, Write};
static TX_BUF: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let listener = iotry!(TcpListener::bind("127.0.0.1:0"));
let server_addr = iotry!(listener.local_addr());
let client_t = thread::spawn(move || {
let mut client = iotry!(TcpStream::connect(server_addr));
assert_eq!(iotry!(client.write(&TX_BUF)), TX_BUF.len());
let mut buf = [0; 10];
iotry!(client.read(&mut buf));
assert_eq!(buf, TX_BUF);
});
let mut server = iotry!(listener.accept()).0;
assert_eq!(iotry!(server.write(&TX_BUF)), TX_BUF.len());
let mut buf = [0; 10];
iotry!(server.read(&mut buf));
assert_eq!(buf, TX_BUF);
assert!(client_t.join().is_ok());
}
}