async_std_utp/
socket.rs

1use async_std::{
2    io,
3    net::{SocketAddr, ToSocketAddrs, UdpSocket},
4    task,
5};
6use futures::FutureExt;
7use futures::{future::BoxFuture, ready};
8use log::debug;
9use std::collections::VecDeque;
10use std::io::{ErrorKind, Result};
11use std::task::Poll;
12use std::time::{Duration, Instant};
13use std::{
14    cmp::{max, min},
15    sync::Arc,
16};
17
18use crate::error::SocketError;
19use crate::packet::*;
20use crate::time::*;
21use crate::util::*;
22
23// For simplicity's sake, let us assume no packet will ever exceed the
24// Ethernet maximum transfer unit of 1500 bytes.
25pub(crate) const BUF_SIZE: usize = 1500;
26const GAIN: f64 = 1.0;
27const ALLOWED_INCREASE: u32 = 1;
28const TARGET: f64 = 100_000.0; // 100 milliseconds
29const MSS: u32 = 1400;
30const MIN_CWND: u32 = 2;
31const INIT_CWND: u32 = 2;
32const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; // one second
33const MIN_CONGESTION_TIMEOUT: u64 = 500; // 500 ms
34const MAX_CONGESTION_TIMEOUT: u64 = 60_000; // one minute
35const BASE_HISTORY: usize = 10; // base delays history size
36const MAX_SYN_RETRIES: u32 = 5; // maximum connection retries
37const MAX_RETRANSMISSION_RETRIES: u32 = 5; // maximum retransmission retries
38const WINDOW_SIZE: u32 = 1024 * 1024; // local receive window size
39
40// Maximum time (in microseconds) to wait for incoming packets when the send window is full
41const PRE_SEND_TIMEOUT: u32 = 500_000;
42
43// Maximum age of base delay sample (60 seconds)
44const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
45
46#[derive(PartialEq, Eq, Debug, Copy, Clone)]
47enum SocketState {
48    New,
49    Connected,
50    SynSent,
51    FinSent,
52    ResetReceived,
53    Closed,
54}
55
56#[derive(Debug, Clone)]
57struct DelayDifferenceSample {
58    received_at: Timestamp,
59    difference: Delay,
60}
61
62/// Returns the first valid address in a `ToSocketAddrs` iterator.
63async fn take_address<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
64    addr.to_socket_addrs()
65        .await
66        .and_then(|mut it| it.next().ok_or_else(|| SocketError::InvalidAddress.into()))
67}
68
69/// A structure that represents a uTP (Micro Transport Protocol) connection between a local socket
70/// and a remote socket.
71///
72/// The socket will be closed when the value is dropped (either explicitly or when it goes out of
73/// scope).
74///
75/// The default maximum retransmission retries is 5, which translates to about 16 seconds. It can be
76/// changed by assigning the desired maximum retransmission retries to a socket's
77/// `max_retransmission_retries` field. Notice that the initial congestion timeout is 500 ms and
78/// doubles with each timeout.
79///
80/// # Examples
81///
82/// ```no_run
83/// # fn main() { async_std::task::block_on(async {
84/// use async_std_utp::UtpSocket;
85///
86/// let mut socket = UtpSocket::bind("127.0.0.1:1234").await.expect("Error binding socket");
87///
88/// let mut buf = vec![0; 1000];
89/// let (amt, _src) = socket.recv_from(&mut buf).await.expect("Error receiving");
90///
91/// let mut buf = &mut buf[..amt];
92/// buf.reverse();
93/// let _ = socket.send_to(buf).await.expect("Error sending");
94///
95/// // Close the socket. You can either call `close` on the socket,
96/// // explicitly drop it or just let it go out of scope.
97/// socket.close().await;
98/// }); }
99/// ```
100#[derive(Debug)]
101pub struct UtpSocket {
102    /// The wrapped UDP socket
103    socket: UdpSocket,
104
105    /// Remote peer
106    connected_to: SocketAddr,
107
108    /// Sender connection identifier
109    sender_connection_id: u16,
110
111    /// Receiver connection identifier
112    receiver_connection_id: u16,
113
114    /// Sequence number for the next packet
115    seq_nr: u16,
116
117    /// Sequence number of the latest acknowledged packet sent by the remote peer
118    ack_nr: u16,
119
120    /// Socket state
121    state: SocketState,
122
123    /// Received but not acknowledged packets
124    incoming_buffer: Vec<Packet>,
125
126    /// Sent but not yet acknowledged packets
127    send_window: Vec<Packet>,
128
129    /// Packets not yet sent
130    unsent_queue: VecDeque<Packet>,
131
132    /// How many ACKs did the socket receive for packet with sequence number equal to `ack_nr`
133    duplicate_ack_count: u32,
134
135    /// Sequence number of the latest packet the remote peer acknowledged
136    last_acked: u16,
137
138    /// Timestamp of the latest packet the remote peer acknowledged
139    last_acked_timestamp: Timestamp,
140
141    /// Sequence number of the last packet removed from the incoming buffer
142    last_dropped: u16,
143
144    /// Round-trip time to remote peer
145    rtt: i32,
146
147    /// Variance of the round-trip time to the remote peer
148    rtt_variance: i32,
149
150    /// Data from the latest packet not yet returned in `recv_from`
151    pending_data: Vec<u8>,
152
153    /// Bytes in flight
154    curr_window: u32,
155
156    /// Window size of the remote peer
157    remote_wnd_size: u32,
158
159    /// Rolling window of packet delay to remote peer
160    base_delays: VecDeque<Delay>,
161
162    /// Rolling window of the difference between sending a packet and receiving its acknowledgement
163    current_delays: Vec<DelayDifferenceSample>,
164
165    /// Difference between timestamp of the latest packet received and time of reception
166    their_delay: Delay,
167
168    /// Start of the current minute for sampling purposes
169    last_rollover: Timestamp,
170
171    /// Current congestion timeout in milliseconds
172    congestion_timeout: u64,
173
174    /// Congestion window in bytes
175    cwnd: u32,
176
177    /// Maximum retransmission retries
178    pub max_retransmission_retries: u32,
179}
180
181impl UtpSocket {
182    /// Creates a new UTP socket from the given UDP socket and the remote peer's address.
183    ///
184    /// The connection identifier of the resulting socket is randomly generated.
185    fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
186        let (receiver_id, sender_id) = generate_sequential_identifiers();
187
188        UtpSocket {
189            socket: s,
190            connected_to: src,
191            receiver_connection_id: receiver_id,
192            sender_connection_id: sender_id,
193            seq_nr: 1,
194            ack_nr: 0,
195            state: SocketState::New,
196            incoming_buffer: Vec::new(),
197            send_window: Vec::new(),
198            unsent_queue: VecDeque::new(),
199            duplicate_ack_count: 0,
200            last_acked: 0,
201            last_acked_timestamp: Timestamp::default(),
202            last_dropped: 0,
203            rtt: 0,
204            rtt_variance: 0,
205            pending_data: Vec::new(),
206            curr_window: 0,
207            remote_wnd_size: 0,
208            current_delays: Vec::new(),
209            base_delays: VecDeque::with_capacity(BASE_HISTORY),
210            their_delay: Delay::default(),
211            last_rollover: Timestamp::default(),
212            congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
213            cwnd: INIT_CWND * MSS,
214            max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
215        }
216    }
217
218    /// Creates a new UTP socket from the given address.
219    ///
220    /// The address type can be any implementer of the `ToSocketAddr` trait. See its documentation
221    /// for concrete examples.
222    ///
223    /// If more than one valid address is specified, only the first will be used.
224    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
225        let addr = take_address(addr).await?;
226        let socket = UdpSocket::bind(addr).await?;
227        Ok(UtpSocket::from_raw_parts(socket, addr))
228    }
229
230    /// Returns the socket address that this socket was created from.
231    pub fn local_addr(&self) -> Result<SocketAddr> {
232        self.socket.local_addr()
233    }
234
235    /// Returns the socket address of the remote peer of this UTP connection.
236    pub fn peer_addr(&self) -> Result<SocketAddr> {
237        if self.state == SocketState::Connected || self.state == SocketState::FinSent {
238            Ok(self.connected_to)
239        } else {
240            Err(SocketError::NotConnected.into())
241        }
242    }
243
244    /// Opens a connection to a remote host by hostname or IP address.
245    ///
246    /// The address type can be any implementer of the `ToSocketAddr` trait. See its documentation
247    /// for concrete examples.
248    ///
249    /// If more than one valid address is specified, only the first will be used.
250    pub async fn connect<A: ToSocketAddrs>(other: A) -> Result<UtpSocket> {
251        let addr = take_address(other).await?;
252        let my_addr = match addr {
253            SocketAddr::V4(_) => "0.0.0.0:0",
254            SocketAddr::V6(_) => "[::]:0",
255        };
256        let mut socket = UtpSocket::bind(my_addr).await?;
257        socket.connected_to = addr;
258
259        let mut packet = Packet::new();
260        packet.set_type(PacketType::Syn);
261        packet.set_connection_id(socket.receiver_connection_id);
262        packet.set_seq_nr(socket.seq_nr);
263
264        let mut len = 0;
265        let mut buf = vec![0u8; BUF_SIZE];
266
267        let mut syn_timeout = Duration::from_millis(socket.congestion_timeout);
268        for _ in 0..MAX_SYN_RETRIES {
269            packet.set_timestamp(now_microseconds());
270
271            // Send packet
272            debug!("Connecting to {}", socket.connected_to);
273            socket
274                .socket
275                .send_to(packet.as_ref(), socket.connected_to)
276                .await?;
277            socket.state = SocketState::SynSent;
278            debug!("sent {:?}", packet);
279
280            // Validate response
281            match io::timeout(syn_timeout, socket.socket.recv_from(&mut buf)).await {
282                Ok((read, src)) => {
283                    socket.connected_to = src;
284                    len = read;
285                    break;
286                }
287                Err(ref e)
288                    if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
289                {
290                    debug!("Timed out, retrying");
291                    syn_timeout *= 2;
292                    continue;
293                }
294                Err(e) => return Err(e),
295            };
296        }
297
298        let addr = socket.connected_to;
299        let packet = Packet::try_from(&buf[..len])?;
300        debug!("received {:?}", packet);
301        socket.handle_packet(&packet, addr).await?;
302
303        debug!("connected to: {}", socket.connected_to);
304
305        Ok(socket)
306    }
307
308    /// Gracefully closes connection to peer.
309    ///
310    /// This method allows both peers to receive all packets still in
311    /// flight.
312    pub async fn close(&mut self) -> Result<()> {
313        // Nothing to do if the socket's already closed or not connected
314        if self.state == SocketState::Closed
315            || self.state == SocketState::New
316            || self.state == SocketState::SynSent
317        {
318            return Ok(());
319        }
320
321        // Flush unsent and unacknowledged packets
322        self.flush().await?;
323
324        let mut packet = Packet::new();
325        packet.set_connection_id(self.sender_connection_id);
326        packet.set_seq_nr(self.seq_nr);
327        packet.set_ack_nr(self.ack_nr);
328        packet.set_timestamp(now_microseconds());
329        packet.set_type(PacketType::Fin);
330
331        // Send FIN
332        self.socket
333            .send_to(packet.as_ref(), self.connected_to)
334            .await?;
335        debug!("sent {:?}", packet);
336        self.state = SocketState::FinSent;
337
338        // Receive JAKE
339        let mut buf = vec![0u8; BUF_SIZE];
340        while self.state != SocketState::Closed {
341            self.recv(&mut buf).await?;
342        }
343
344        Ok(())
345    }
346
347    /// Receives data from socket.
348    ///
349    /// On success, returns the number of bytes read and the sender's address.
350    /// Returns 0 bytes read after receiving a FIN packet when the remaining
351    /// in-flight packets are consumed.
352    pub async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
353        let read = self.flush_incoming_buffer(buf);
354
355        if read > 0 {
356            return Ok((read, self.connected_to));
357        }
358
359        // If the socket received a reset packet and all data has been flushed, then it can't
360        // receive anything else
361        if self.state == SocketState::ResetReceived {
362            return Err(SocketError::ConnectionReset.into());
363        }
364
365        loop {
366            // A closed socket with no pending data can only "read" 0 new bytes.
367            if self.state == SocketState::Closed {
368                return Ok((0, self.connected_to));
369            }
370
371            match self.recv(buf).await {
372                Ok((0, _src)) => continue,
373                Ok(x) => return Ok(x),
374                Err(e) => return Err(e),
375            }
376        }
377    }
378
379    async fn recv(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
380        let mut b = vec![0; BUF_SIZE + HEADER_SIZE];
381        let start = Instant::now();
382        let (read, src);
383        let mut retries = 0;
384
385        // Try to receive a packet and handle timeouts
386        loop {
387            // Abort loop if the current try exceeds the maximum number of retransmission retries.
388            if retries >= self.max_retransmission_retries {
389                self.state = SocketState::Closed;
390                return Err(SocketError::ConnectionTimedOut.into());
391            }
392
393            let timeout = if self.state != SocketState::New {
394                debug!("setting read timeout of {} ms", self.congestion_timeout);
395                Some(Duration::from_millis(self.congestion_timeout))
396            } else {
397                None
398            };
399
400            let response = match timeout {
401                Some(timeout) => io::timeout(timeout, self.socket.recv_from(&mut b)).await,
402                None => self.socket.recv_from(&mut b).await,
403            };
404
405            match response {
406                Ok((r, s)) => {
407                    read = r;
408                    src = s;
409                    break;
410                }
411                Err(ref e)
412                    if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
413                {
414                    debug!("recv_from timed out");
415                    self.handle_receive_timeout().await?;
416                }
417                Err(e) => return Err(e),
418            };
419
420            let elapsed = start.elapsed();
421            let elapsed_ms = elapsed.as_secs() * 1000 + elapsed.subsec_millis() as u64;
422            debug!("{} ms elapsed", elapsed_ms);
423            retries += 1;
424        }
425
426        // Decode received data into a packet
427        let packet = match Packet::try_from(&b[..read]) {
428            Ok(packet) => packet,
429            Err(e) => {
430                debug!("{}", e);
431                debug!("Ignoring invalid packet");
432                return Ok((0, self.connected_to));
433            }
434        };
435        debug!("received {:?}", packet);
436
437        // Process packet, including sending a reply if necessary
438        if let Some(mut pkt) = self.handle_packet(&packet, src).await? {
439            pkt.set_wnd_size(WINDOW_SIZE);
440            self.socket.send_to(pkt.as_ref(), src).await?;
441            debug!("sent {:?}", pkt);
442        }
443
444        // Insert data packet into the incoming buffer if it isn't a duplicate of a previously
445        // discarded packet
446        if packet.get_type() == PacketType::Data
447            && packet.seq_nr().wrapping_sub(self.last_dropped) > 0
448        {
449            self.insert_into_buffer(packet);
450        }
451
452        // Flush incoming buffer if possible
453        let read = self.flush_incoming_buffer(buf);
454
455        Ok((read, src))
456    }
457
458    async fn handle_receive_timeout(&mut self) -> Result<()> {
459        self.congestion_timeout *= 2;
460        self.cwnd = MSS;
461
462        // There are three possible cases here:
463        //
464        // - If the socket is sending and waiting for acknowledgements (the send window is
465        //   not empty), resend the first unacknowledged packet;
466        //
467        // - If the socket is not sending and it hasn't sent a FIN yet, then it's waiting
468        //   for incoming packets: send a fast resend request;
469        //
470        // - If the socket sent a FIN previously, resend it.
471        debug!(
472            "self.send_window: {:?}",
473            self.send_window
474                .iter()
475                .map(Packet::seq_nr)
476                .collect::<Vec<u16>>()
477        );
478
479        if self.send_window.is_empty() {
480            // The socket is trying to close, all sent packets were acknowledged, and it has
481            // already sent a FIN: resend it.
482            if self.state == SocketState::FinSent {
483                let mut packet = Packet::new();
484                packet.set_connection_id(self.sender_connection_id);
485                packet.set_seq_nr(self.seq_nr);
486                packet.set_ack_nr(self.ack_nr);
487                packet.set_timestamp(now_microseconds());
488                packet.set_type(PacketType::Fin);
489
490                // Send FIN
491                self.socket
492                    .send_to(packet.as_ref(), self.connected_to)
493                    .await?;
494                debug!("resent FIN: {:?}", packet);
495            } else if self.state != SocketState::New {
496                // The socket is waiting for incoming packets but the remote peer is silent:
497                // send a fast resend request.
498                debug!("sending fast resend request");
499                self.send_fast_resend_request().await;
500            }
501        } else {
502            // The socket is sending data packets but there is no reply from the remote
503            // peer: resend the first unacknowledged packet with the current timestamp.
504            let packet = &mut self.send_window[0];
505            packet.set_timestamp(now_microseconds());
506            self.socket
507                .send_to(packet.as_ref(), self.connected_to)
508                .await?;
509            debug!("resent {:?}", packet);
510        }
511
512        Ok(())
513    }
514
515    fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
516        let mut resp = Packet::new();
517        resp.set_type(t);
518        let self_t_micro = now_microseconds();
519        let other_t_micro = original.timestamp();
520        let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
521        resp.set_timestamp(self_t_micro);
522        resp.set_timestamp_difference(time_difference);
523        resp.set_connection_id(self.sender_connection_id);
524        resp.set_seq_nr(self.seq_nr);
525        resp.set_ack_nr(self.ack_nr);
526
527        resp
528    }
529
530    /// Removes a packet in the incoming buffer and updates the current acknowledgement number.
531    fn advance_incoming_buffer(&mut self) -> Option<Packet> {
532        if !self.incoming_buffer.is_empty() {
533            let packet = self.incoming_buffer.remove(0);
534            debug!("Removed packet from incoming buffer: {:?}", packet);
535            self.ack_nr = packet.seq_nr();
536            self.last_dropped = self.ack_nr;
537            Some(packet)
538        } else {
539            None
540        }
541    }
542
543    /// Discards sequential, ordered packets in incoming buffer, starting from
544    /// the most recently acknowledged to the most recent, as long as there are
545    /// no missing packets. The discarded packets' payload is written to the
546    /// slice `buf`, starting in position `start`.
547    /// Returns the last written index.
548    fn flush_incoming_buffer(&mut self, buf: &mut [u8]) -> usize {
549        fn unsafe_copy(src: &[u8], dst: &mut [u8]) -> usize {
550            let max_len = min(src.len(), dst.len());
551            unsafe {
552                use std::ptr::copy;
553                copy(src.as_ptr(), dst.as_mut_ptr(), max_len);
554            }
555            max_len
556        }
557
558        // Return pending data from a partially read packet
559        if !self.pending_data.is_empty() {
560            let flushed = unsafe_copy(&self.pending_data[..], buf);
561
562            if flushed == self.pending_data.len() {
563                self.pending_data.clear();
564                self.advance_incoming_buffer();
565            } else {
566                self.pending_data = self.pending_data[flushed..].to_vec();
567            }
568
569            return flushed;
570        }
571
572        if !self.incoming_buffer.is_empty()
573            && (self.ack_nr == self.incoming_buffer[0].seq_nr()
574                || self.ack_nr + 1 == self.incoming_buffer[0].seq_nr())
575        {
576            let flushed = unsafe_copy(&self.incoming_buffer[0].payload(), buf);
577
578            if flushed == self.incoming_buffer[0].payload().len() {
579                self.advance_incoming_buffer();
580            } else {
581                self.pending_data = self.incoming_buffer[0].payload()[flushed..].to_vec();
582            }
583
584            return flushed;
585        }
586
587        0
588    }
589
590    /// Sends data on the socket to the remote peer. On success, returns the number of bytes
591    /// written.
592    //
593    // # Implementation details
594    //
595    // This method inserts packets into the send buffer and keeps trying to
596    // advance the send window until an ACK corresponding to the last packet is
597    // received.
598    //
599    // Note that the buffer passed to `send_to` might exceed the maximum packet
600    // size, which will result in the data being split over several packets.
601    pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
602        if self.state == SocketState::Closed {
603            return Err(SocketError::ConnectionClosed.into());
604        }
605
606        let total_length = buf.len();
607
608        for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
609            let mut packet = Packet::with_payload(chunk);
610            packet.set_seq_nr(self.seq_nr);
611            packet.set_ack_nr(self.ack_nr);
612            packet.set_connection_id(self.sender_connection_id);
613
614            self.unsent_queue.push_back(packet);
615
616            // Intentionally wrap around sequence number
617            self.seq_nr = self.seq_nr.wrapping_add(1);
618        }
619
620        // Send every packet in the queue
621        self.send().await?;
622
623        Ok(total_length)
624    }
625
626    /// Consumes acknowledgements for every pending packet.
627    pub async fn flush(&mut self) -> Result<()> {
628        let mut buf = vec![0u8; BUF_SIZE];
629        while !self.send_window.is_empty() {
630            debug!("packets in send window: {}", self.send_window.len());
631            self.recv(&mut buf).await?;
632        }
633
634        Ok(())
635    }
636
637    /// Sends every packet in the unsent packet queue.
638    async fn send(&mut self) -> Result<()> {
639        while let Some(mut packet) = self.unsent_queue.pop_front() {
640            self.send_packet(&mut packet).await?;
641            self.curr_window += packet.len() as u32;
642            self.send_window.push(packet);
643        }
644        Ok(())
645    }
646
647    /// Send one packet.
648    async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
649        debug!("current window: {}", self.send_window.len());
650        let max_inflight = min(self.cwnd, self.remote_wnd_size);
651        let max_inflight = max(MIN_CWND * MSS, max_inflight);
652        let now = now_microseconds();
653
654        // Wait until enough in-flight packets are acknowledged for rate control purposes, but don't
655        // wait more than 500 ms (PRE_SEND_TIMEOUT) before sending the packet.
656        while self.curr_window >= max_inflight && now_microseconds() - now < PRE_SEND_TIMEOUT.into()
657        {
658            debug!("self.curr_window: {}", self.curr_window);
659            debug!("max_inflight: {}", max_inflight);
660            debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count);
661            debug!("now_microseconds() - now = {}", now_microseconds() - now);
662            let mut buf = vec![0u8; BUF_SIZE];
663            self.recv(&mut buf).await?;
664        }
665        debug!(
666            "out: now_microseconds() - now = {}",
667            now_microseconds() - now
668        );
669
670        // Check if it still makes sense to send packet, as we might be trying to resend a lost
671        // packet acknowledged in the receive loop above.
672        // If there were no wrapping around of sequence numbers, we'd simply check if the packet's
673        // sequence number is greater than `last_acked`.
674        let distance_a = packet.seq_nr().wrapping_sub(self.last_acked);
675        let distance_b = self.last_acked.wrapping_sub(packet.seq_nr());
676        if distance_a > distance_b {
677            debug!("Packet already acknowledged, skipping...");
678            return Ok(());
679        }
680
681        packet.set_timestamp(now_microseconds());
682        packet.set_timestamp_difference(self.their_delay);
683        self.socket
684            .send_to(packet.as_ref(), self.connected_to)
685            .await?;
686        debug!("sent {:?}", packet);
687
688        Ok(())
689    }
690
691    // Insert a new sample in the base delay list.
692    //
693    // The base delay list contains at most `BASE_HISTORY` samples, each sample is the minimum
694    // measured over a period of a minute (MAX_BASE_DELAY_AGE).
695    fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
696        if self.base_delays.is_empty() || now - self.last_rollover > MAX_BASE_DELAY_AGE {
697            // Update last rollover
698            self.last_rollover = now;
699
700            // Drop the oldest sample, if need be
701            if self.base_delays.len() == BASE_HISTORY {
702                self.base_delays.pop_front();
703            }
704
705            // Insert new sample
706            self.base_delays.push_back(base_delay);
707        } else {
708            // Replace sample for the current minute if the delay is lower
709            let last_idx = self.base_delays.len() - 1;
710            if base_delay < self.base_delays[last_idx] {
711                self.base_delays[last_idx] = base_delay;
712            }
713        }
714    }
715
716    /// Inserts a new sample in the current delay list after removing samples older than one RTT, as
717    /// specified in RFC6817.
718    fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
719        // Remove samples more than one RTT old
720        let rtt = (self.rtt as i64 * 100).into();
721        while !self.current_delays.is_empty() && now - self.current_delays[0].received_at > rtt {
722            self.current_delays.remove(0);
723        }
724
725        // Insert new measurement
726        self.current_delays.push(DelayDifferenceSample {
727            received_at: now,
728            difference: v,
729        });
730    }
731
732    fn update_congestion_timeout(&mut self, current_delay: i32) {
733        let delta = self.rtt - current_delay;
734        self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
735        self.rtt += (current_delay - self.rtt) / 8;
736        self.congestion_timeout = max(
737            (self.rtt + self.rtt_variance * 4) as u64,
738            MIN_CONGESTION_TIMEOUT,
739        );
740        self.congestion_timeout = min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
741
742        debug!("current_delay: {}", current_delay);
743        debug!("delta: {}", delta);
744        debug!("self.rtt_variance: {}", self.rtt_variance);
745        debug!("self.rtt: {}", self.rtt);
746        debug!("self.congestion_timeout: {}", self.congestion_timeout);
747    }
748
749    /// Calculates the filtered current delay in the current window.
750    ///
751    /// The current delay is calculated through application of the exponential
752    /// weighted moving average filter with smoothing factor 0.333 over the
753    /// current delays in the current window.
754    fn filtered_current_delay(&self) -> Delay {
755        let input = self.current_delays.iter().map(|delay| &delay.difference);
756        (ewma(input, 0.333) as i64).into()
757    }
758
759    /// Calculates the lowest base delay in the current window.
760    fn min_base_delay(&self) -> Delay {
761        self.base_delays.iter().min().cloned().unwrap_or_default()
762    }
763
764    /// Builds the selective acknowledgement extension data for usage in packets.
765    fn build_selective_ack(&self) -> Vec<u8> {
766        let stashed = self
767            .incoming_buffer
768            .iter()
769            .filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
770            .map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
771            .map(|diff| (diff / 8, diff % 8));
772
773        let mut sack = Vec::new();
774        for (byte, bit) in stashed {
775            // Make sure the amount of elements in the SACK vector is a
776            // multiple of 4 and enough to represent the lost packets
777            while byte >= sack.len() || sack.len() % 4 != 0 {
778                sack.push(0u8);
779            }
780
781            sack[byte] |= 1 << bit;
782        }
783
784        sack
785    }
786
787    /// Sends a fast resend request to the remote peer.
788    ///
789    /// A fast resend request consists of sending three State packets (acknowledging the last
790    /// received packet) in quick succession.
791    async fn send_fast_resend_request(&self) {
792        for _ in 0..3 {
793            let mut packet = Packet::new();
794            packet.set_type(PacketType::State);
795            let self_t_micro = now_microseconds();
796            packet.set_timestamp(self_t_micro);
797            packet.set_timestamp_difference(self.their_delay);
798            packet.set_connection_id(self.sender_connection_id);
799            packet.set_seq_nr(self.seq_nr);
800            packet.set_ack_nr(self.ack_nr);
801            let _ = self
802                .socket
803                .send_to(packet.as_ref(), self.connected_to)
804                .await;
805        }
806    }
807
808    async fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
809        debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
810        match self
811            .send_window
812            .iter()
813            .position(|pkt| pkt.seq_nr() == lost_packet_nr)
814        {
815            None => debug!("Packet {} not found", lost_packet_nr),
816            Some(position) => {
817                debug!("self.send_window.len(): {}", self.send_window.len());
818                debug!("position: {}", position);
819                let mut packet = self.send_window[position].clone();
820                // FIXME: Unchecked result
821                let _ = self.send_packet(&mut packet).await;
822
823                // We intentionally don't increase `curr_window` because otherwise a packet's length
824                // would be counted more than once
825            }
826        }
827        debug!("---> END resend_lost_packet <---");
828    }
829
830    /// Forgets sent packets that were acknowledged by the remote peer.
831    fn advance_send_window(&mut self) {
832        // The reason I'm not removing the first element in a loop while its sequence number is
833        // smaller than `last_acked` is because of wrapping sequence numbers, which would create the
834        // sequence [..., 65534, 65535, 0, 1, ...]. If `last_acked` is smaller than the first
835        // packet's sequence number because of wraparound (for instance, 1), no packets would be
836        // removed, as the condition `seq_nr < last_acked` would fail immediately.
837        //
838        // On the other hand, I can't keep removing the first packet in a loop until its sequence
839        // number matches `last_acked` because it might never match, and in that case no packets
840        // should be removed.
841        if let Some(position) = self
842            .send_window
843            .iter()
844            .position(|packet| packet.seq_nr() == self.last_acked)
845        {
846            for _ in 0..position + 1 {
847                let packet = self.send_window.remove(0);
848                self.curr_window -= packet.len() as u32;
849            }
850        }
851        debug!("self.curr_window: {}", self.curr_window);
852    }
853
854    /// Handles an incoming packet, updating socket state accordingly.
855    ///
856    /// Returns the appropriate reply packet, if needed.
857    async fn handle_packet(&mut self, packet: &Packet, src: SocketAddr) -> Result<Option<Packet>> {
858        debug!("({:?}, {:?})", self.state, packet.get_type());
859
860        // Acknowledge only if the packet strictly follows the previous one
861        if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
862            self.ack_nr = packet.seq_nr();
863        }
864
865        // Reset connection if connection id doesn't match and this isn't a SYN
866        if packet.get_type() != PacketType::Syn
867            && self.state != SocketState::SynSent
868            && !(packet.connection_id() == self.sender_connection_id
869                || packet.connection_id() == self.receiver_connection_id)
870        {
871            return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
872        }
873
874        // Update remote window size
875        self.remote_wnd_size = packet.wnd_size();
876        debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
877
878        // Update remote peer's delay between them sending the packet and us receiving it
879        let now = now_microseconds();
880        self.their_delay = abs_diff(now, packet.timestamp());
881        debug!("self.their_delay: {}", self.their_delay);
882
883        match (self.state, packet.get_type()) {
884            (SocketState::New, PacketType::Syn) => {
885                self.connected_to = src;
886                self.ack_nr = packet.seq_nr();
887                self.seq_nr = rand::random();
888                self.receiver_connection_id = packet.connection_id() + 1;
889                self.sender_connection_id = packet.connection_id();
890                self.state = SocketState::Connected;
891                self.last_dropped = self.ack_nr;
892
893                Ok(Some(self.prepare_reply(packet, PacketType::State)))
894            }
895            (_, PacketType::Syn) => Ok(Some(self.prepare_reply(packet, PacketType::Reset))),
896            (SocketState::SynSent, PacketType::State) => {
897                self.connected_to = src;
898                self.ack_nr = packet.seq_nr();
899                self.seq_nr += 1;
900                self.state = SocketState::Connected;
901                self.last_acked = packet.ack_nr();
902                self.last_acked_timestamp = now_microseconds();
903                Ok(None)
904            }
905            (SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
906            (SocketState::Connected, PacketType::Data)
907            | (SocketState::FinSent, PacketType::Data) => Ok(self.handle_data_packet(packet)),
908            (SocketState::Connected, PacketType::State) => {
909                self.handle_state_packet(packet).await;
910                Ok(None)
911            }
912            (SocketState::Connected, PacketType::Fin) | (SocketState::FinSent, PacketType::Fin) => {
913                if packet.ack_nr() < self.seq_nr {
914                    debug!("FIN received but there are missing acknowledgements for sent packets");
915                }
916                let mut reply = self.prepare_reply(packet, PacketType::State);
917                if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
918                    debug!(
919                        "current ack_nr ({}) is behind received packet seq_nr ({})",
920                        self.ack_nr,
921                        packet.seq_nr()
922                    );
923
924                    // Set SACK extension payload if the packet is not in order
925                    let sack = self.build_selective_ack();
926
927                    if !sack.is_empty() {
928                        reply.set_sack(sack);
929                    }
930                }
931
932                // Give up, the remote peer might not care about our missing packets
933                self.state = SocketState::Closed;
934                Ok(Some(reply))
935            }
936            (SocketState::Closed, PacketType::Fin) => {
937                Ok(Some(self.prepare_reply(packet, PacketType::State)))
938            }
939            (SocketState::FinSent, PacketType::State) => {
940                if packet.ack_nr() == self.seq_nr {
941                    self.state = SocketState::Closed;
942                } else {
943                    self.handle_state_packet(packet);
944                }
945                Ok(None)
946            }
947            (_, PacketType::Reset) => {
948                self.state = SocketState::ResetReceived;
949                Err(SocketError::ConnectionReset.into())
950            }
951            (state, ty) => {
952                let message = format!("Unimplemented handling for ({:?},{:?})", state, ty);
953                debug!("{}", message);
954                Err(SocketError::Other(message).into())
955            }
956        }
957    }
958
959    fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
960        // If a FIN was previously sent, reply with a FIN packet acknowledging the received packet.
961        let packet_type = if self.state == SocketState::FinSent {
962            PacketType::Fin
963        } else {
964            PacketType::State
965        };
966        let mut reply = self.prepare_reply(packet, packet_type);
967
968        if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
969            debug!(
970                "current ack_nr ({}) is behind received packet seq_nr ({})",
971                self.ack_nr,
972                packet.seq_nr()
973            );
974
975            // Set SACK extension payload if the packet is not in order
976            let sack = self.build_selective_ack();
977
978            if !sack.is_empty() {
979                reply.set_sack(sack);
980            }
981        }
982
983        Some(reply)
984    }
985
986    fn queuing_delay(&self) -> Delay {
987        let filtered_current_delay = self.filtered_current_delay();
988        let min_base_delay = self.min_base_delay();
989        let queuing_delay = filtered_current_delay - min_base_delay;
990
991        debug!("filtered_current_delay: {}", filtered_current_delay);
992        debug!("min_base_delay: {}", min_base_delay);
993        debug!("queuing_delay: {}", queuing_delay);
994
995        queuing_delay
996    }
997
998    /// Calculates the new congestion window size, increasing it or decreasing it.
999    ///
1000    /// This is the core of uTP, the [LEDBAT][ledbat_rfc] congestion algorithm. It depends on
1001    /// estimating the queuing delay between the two peers, and adjusting the congestion window
1002    /// accordingly.
1003    ///
1004    /// `off_target` is a normalized value representing the difference between the current queuing
1005    /// delay and a fixed target delay (`TARGET`). `off_target` ranges between -1.0 and 1.0. A
1006    /// positive value makes the congestion window increase, while a negative value makes the
1007    /// congestion window decrease.
1008    ///
1009    /// `bytes_newly_acked` is the number of bytes acknowledged by an inbound `State` packet. It may
1010    /// be the size of the packet explicitly acknowledged by the inbound packet (i.e., with sequence
1011    /// number equal to the inbound packet's acknowledgement number), or every packet implicitly
1012    /// acknowledged (every packet with sequence number between the previous inbound `State`
1013    /// packet's acknowledgement number and the current inbound `State` packet's acknowledgement
1014    /// number).
1015    ///
1016    ///[ledbat_rfc]: https://tools.ietf.org/html/rfc6817
1017    fn update_congestion_window(&mut self, off_target: f64, bytes_newly_acked: u32) {
1018        let flightsize = self.curr_window;
1019
1020        let cwnd_increase = GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
1021        let cwnd_increase = cwnd_increase / self.cwnd as f64;
1022        debug!("cwnd_increase: {}", cwnd_increase);
1023
1024        self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
1025        let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
1026        self.cwnd = min(self.cwnd, max_allowed_cwnd);
1027        self.cwnd = max(self.cwnd, MIN_CWND * MSS);
1028
1029        debug!("cwnd: {}", self.cwnd);
1030        debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
1031    }
1032
1033    #[async_recursion::async_recursion]
1034    async fn handle_state_packet(&mut self, packet: &Packet) {
1035        if packet.ack_nr() == self.last_acked {
1036            self.duplicate_ack_count += 1;
1037        } else {
1038            self.last_acked = packet.ack_nr();
1039            self.last_acked_timestamp = now_microseconds();
1040            self.duplicate_ack_count = 1;
1041        }
1042
1043        // Update congestion window size
1044        if let Some(index) = self
1045            .send_window
1046            .iter()
1047            .position(|p| packet.ack_nr() == p.seq_nr())
1048        {
1049            // Calculate the sum of the size of every packet implicitly and explicitly acknowledged
1050            // by the inbound packet (i.e., every packet whose sequence number precedes the inbound
1051            // packet's acknowledgement number, plus the packet whose sequence number matches)
1052            let bytes_newly_acked = self
1053                .send_window
1054                .iter()
1055                .take(index + 1)
1056                .fold(0, |acc, p| acc + p.len());
1057
1058            // Update base and current delay
1059            let now = now_microseconds();
1060            let our_delay = now - self.send_window[index].timestamp();
1061            debug!("our_delay: {}", our_delay);
1062            self.update_base_delay(our_delay, now);
1063            self.update_current_delay(our_delay, now);
1064
1065            let off_target: f64 = (TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
1066            debug!("off_target: {}", off_target);
1067
1068            self.update_congestion_window(off_target, bytes_newly_acked as u32);
1069
1070            // Update congestion timeout
1071            let rtt = u32::from(our_delay - self.queuing_delay()) / 1000; // in milliseconds
1072            self.update_congestion_timeout(rtt as i32);
1073        }
1074
1075        let mut packet_loss_detected: bool =
1076            !self.send_window.is_empty() && self.duplicate_ack_count == 3;
1077
1078        // Process extensions, if any
1079        for extension in packet.extensions() {
1080            if extension.get_type() == ExtensionType::SelectiveAck {
1081                // If three or more packets are acknowledged past the implicit missing one,
1082                // assume it was lost.
1083                if extension.iter().count_ones() >= 3 {
1084                    self.resend_lost_packet(packet.ack_nr() + 1).await;
1085                    packet_loss_detected = true;
1086                }
1087
1088                if let Some(last_seq_nr) = self.send_window.last().map(Packet::seq_nr) {
1089                    let lost_packets = extension
1090                        .iter()
1091                        .enumerate()
1092                        .filter(|&(_, received)| !received)
1093                        .map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
1094                        .take_while(|&seq_nr| seq_nr < last_seq_nr);
1095
1096                    for seq_nr in lost_packets {
1097                        debug!("SACK: packet {} lost", seq_nr);
1098                        self.resend_lost_packet(seq_nr).await;
1099                        packet_loss_detected = true;
1100                    }
1101                }
1102            } else {
1103                debug!("Unknown extension {:?}, ignoring", extension.get_type());
1104            }
1105        }
1106
1107        // Three duplicate ACKs mean a fast resend request. Resend the first unacknowledged packet
1108        // if the incoming packet doesn't have a SACK extension. If it does, the lost packets were
1109        // already resent.
1110        if !self.send_window.is_empty()
1111            && self.duplicate_ack_count == 3
1112            && !packet
1113                .extensions()
1114                .any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
1115        {
1116            self.resend_lost_packet(packet.ack_nr() + 1).await;
1117        }
1118
1119        // Packet lost, halve the congestion window
1120        if packet_loss_detected {
1121            debug!("packet loss detected, halving congestion window");
1122            self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
1123            debug!("cwnd: {}", self.cwnd);
1124        }
1125
1126        // Success, advance send window
1127        self.advance_send_window();
1128    }
1129
1130    /// Inserts a packet into the socket's buffer.
1131    ///
1132    /// The packet is inserted in such a way that the packets in the buffer are sorted according to
1133    /// their sequence number in ascending order. This allows storing packets that were received out
1134    /// of order.
1135    ///
1136    /// Trying to insert a duplicate of a packet will silently fail.
1137    /// it's more recent (larger timestamp).
1138    fn insert_into_buffer(&mut self, packet: Packet) {
1139        // Immediately push to the end if the packet's sequence number comes after the last
1140        // packet's.
1141        if self
1142            .incoming_buffer
1143            .last()
1144            .map_or(false, |p| packet.seq_nr() > p.seq_nr())
1145        {
1146            self.incoming_buffer.push(packet);
1147        } else {
1148            // Find index following the most recent packet before the one we wish to insert
1149            let i = self
1150                .incoming_buffer
1151                .iter()
1152                .filter(|p| p.seq_nr() < packet.seq_nr())
1153                .count();
1154
1155            if self
1156                .incoming_buffer
1157                .get(i)
1158                .map_or(true, |p| p.seq_nr() != packet.seq_nr())
1159            {
1160                self.incoming_buffer.insert(i, packet);
1161            }
1162        }
1163    }
1164}
1165
1166impl Drop for UtpSocket {
1167    fn drop(&mut self) {
1168        task::block_on(async {
1169            drop(self.close().await);
1170        });
1171    }
1172}
1173
1174/// A structure representing a socket server.
1175///
1176/// # Examples
1177///
1178/// ```no_run
1179/// use async_std_utp::{UtpListener, UtpSocket};
1180/// use async_std::{prelude::*, task};
1181///
1182/// async fn handle_client(socket: UtpSocket) {
1183///     // ...
1184/// }
1185///
1186/// # fn main() { async_std::task::block_on(async {
1187///     // Create a listener
1188///     let addr = "127.0.0.1:8080";
1189///     let listener = UtpListener::bind(addr).await.expect("Error binding socket");
1190///     let mut incoming = listener.incoming();
1191///     while let Some(connection) = incoming.next().await {
1192///         // Spawn a new handler for each new connection
1193///         if let Ok((socket, _src)) = connection {
1194///             task::spawn(async move { handle_client(socket) });
1195///         }
1196///     }
1197/// # }); }
1198/// ```
1199#[derive(Clone)]
1200pub struct UtpListener {
1201    /// The public facing UDP socket
1202    socket: Arc<UdpSocket>,
1203}
1204
1205impl UtpListener {
1206    /// Creates a new `UtpListener` bound to a specific address.
1207    ///
1208    /// The resulting listener is ready for accepting connections.
1209    ///
1210    /// The address type can be any implementer of the `ToSocketAddr` trait. See its documentation
1211    /// for concrete examples.
1212    ///
1213    /// If more than one valid address is specified, only the first will be used.
1214    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpListener> {
1215        let socket = UdpSocket::bind(addr).await?;
1216        Ok(UtpListener {
1217            socket: socket.into(),
1218        })
1219    }
1220
1221    /// Accepts a new incoming connection from this listener.
1222    ///
1223    /// This function will block the caller until a new uTP connection is established. When
1224    /// established, the corresponding `UtpSocket` and the peer's remote address will be returned.
1225    ///
1226    /// Notice that the resulting `UtpSocket` is bound to a different local port than the public
1227    /// listening port (which `UtpListener` holds). This may confuse the remote peer!
1228    pub async fn accept(&self) -> Result<(UtpSocket, SocketAddr)> {
1229        let mut buf = vec![0; BUF_SIZE];
1230
1231        let (nread, src) = self.socket.recv_from(&mut buf).await?;
1232        let packet = Packet::try_from(&buf[..nread])?;
1233
1234        // Ignore non-SYN packets
1235        if packet.get_type() != PacketType::Syn {
1236            let message = format!("Expected SYN packet, got {:?} instead", packet.get_type());
1237            return Err(SocketError::Other(message).into());
1238        }
1239
1240        // The address of the new socket will depend on the type of the listener.
1241        let local_addr = self.socket.local_addr()?;
1242        let inner_socket = match local_addr {
1243            SocketAddr::V4(_) => UdpSocket::bind("0.0.0.0:0"),
1244            SocketAddr::V6(_) => UdpSocket::bind("[::]:0"),
1245        }
1246        .await?;
1247
1248        let mut socket = UtpSocket::from_raw_parts(inner_socket, src);
1249
1250        // Establish connection with remote peer
1251        if let Ok(Some(reply)) = socket.handle_packet(&packet, src).await {
1252            socket
1253                .socket
1254                .send_to(reply.as_ref(), src)
1255                .await
1256                .and(Ok((socket, src)))
1257        } else {
1258            Err(SocketError::Other("Reached unreachable statement".to_owned()).into())
1259        }
1260    }
1261
1262    /// Returns an iterator over the connections being received by this listener.
1263    ///
1264    /// The returned iterator will never return `None`.
1265    pub fn incoming(&self) -> Incoming<'_> {
1266        Incoming {
1267            listener: self,
1268            accept: None,
1269        }
1270    }
1271
1272    /// Returns the local socket address of this listener.
1273    pub fn local_addr(&self) -> Result<SocketAddr> {
1274        self.socket.local_addr()
1275    }
1276}
1277
1278type AcceptFuture<'a> = Option<BoxFuture<'a, io::Result<(UtpSocket, SocketAddr)>>>;
1279
1280pub struct Incoming<'a> {
1281    listener: &'a UtpListener,
1282    accept: AcceptFuture<'a>,
1283}
1284
1285impl<'a> futures::Stream for Incoming<'a> {
1286    type Item = Result<(UtpSocket, SocketAddr)>;
1287
1288    fn poll_next(
1289        mut self: std::pin::Pin<&mut Self>,
1290        cx: &mut std::task::Context<'_>,
1291    ) -> std::task::Poll<Option<Self::Item>> {
1292        loop {
1293            if self.accept.is_none() {
1294                self.accept = Some(self.listener.accept().boxed());
1295            }
1296
1297            if let Some(f) = &mut self.accept {
1298                let res = ready!(f.as_mut().poll(cx));
1299                self.accept = None;
1300                return Poll::Ready(Some(res));
1301            }
1302        }
1303    }
1304}
1305
1306#[cfg(test)]
1307mod test {
1308    use crate::packet::*;
1309    use crate::socket::{take_address, SocketState, UtpListener, UtpSocket, BUF_SIZE};
1310    use crate::time::now_microseconds;
1311    use async_std::task;
1312    use rand;
1313    use std::io::ErrorKind;
1314    use std::net::ToSocketAddrs;
1315
1316    macro_rules! iotry {
1317        ($e:expr) => {
1318            match $e.await {
1319                Ok(e) => e,
1320                Err(e) => panic!("{:?}", e),
1321            }
1322        };
1323    }
1324
1325    fn next_test_port() -> u16 {
1326        use std::sync::atomic::{AtomicUsize, Ordering};
1327        static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
1328        const BASE_PORT: u16 = 9600;
1329        BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
1330    }
1331
1332    fn next_test_ip4<'a>() -> (&'a str, u16) {
1333        ("127.0.0.1", next_test_port())
1334    }
1335
1336    fn next_test_ip6<'a>() -> (&'a str, u16) {
1337        ("::1", next_test_port())
1338    }
1339
1340    #[async_std::test]
1341    async fn test_socket_ipv4() {
1342        let server_addr = next_test_ip4();
1343
1344        let mut server = iotry!(UtpSocket::bind(server_addr));
1345        assert_eq!(server.state, SocketState::New);
1346
1347        let child = task::spawn(async move {
1348            let mut client = iotry!(UtpSocket::connect(server_addr));
1349            assert_eq!(client.state, SocketState::Connected);
1350            // Check proper difference in client's send connection id and receive connection id
1351            assert_eq!(
1352                client.sender_connection_id,
1353                client.receiver_connection_id + 1
1354            );
1355            assert_eq!(
1356                client.connected_to,
1357                server_addr.to_socket_addrs().unwrap().next().unwrap()
1358            );
1359            iotry!(client.close());
1360            drop(client);
1361        });
1362
1363        let mut buf = vec![0; BUF_SIZE];
1364        match server.recv_from(&mut buf).await {
1365            e => println!("{:?}", e),
1366        }
1367        // After establishing a new connection, the server's ids are a mirror of the client's.
1368        assert_eq!(
1369            server.receiver_connection_id,
1370            server.sender_connection_id + 1
1371        );
1372
1373        assert_eq!(server.state, SocketState::Closed);
1374        drop(server);
1375
1376        child.await;
1377    }
1378
1379    #[async_std::test]
1380    async fn test_socket_ipv6() {
1381        let server_addr = next_test_ip6();
1382
1383        let mut server = iotry!(UtpSocket::bind(server_addr));
1384        assert_eq!(server.state, SocketState::New);
1385
1386        let child = task::spawn(async move {
1387            let mut client = iotry!(UtpSocket::connect(server_addr));
1388            assert_eq!(client.state, SocketState::Connected);
1389            // Check proper difference in client's send connection id and receive connection id
1390            assert_eq!(
1391                client.sender_connection_id,
1392                client.receiver_connection_id + 1
1393            );
1394            assert_eq!(
1395                client.connected_to,
1396                server_addr.to_socket_addrs().unwrap().next().unwrap()
1397            );
1398            iotry!(client.close());
1399            drop(client);
1400        });
1401
1402        let mut buf = vec![0u8; BUF_SIZE];
1403        match server.recv_from(&mut buf).await {
1404            e => println!("{:?}", e),
1405        }
1406        // After establishing a new connection, the server's ids are a mirror of the client's.
1407        assert_eq!(
1408            server.receiver_connection_id,
1409            server.sender_connection_id + 1
1410        );
1411
1412        assert_eq!(server.state, SocketState::Closed);
1413        drop(server);
1414
1415        child.await;
1416    }
1417
1418    #[async_std::test]
1419    async fn test_recvfrom_on_closed_socket() {
1420        let server_addr = next_test_ip4();
1421
1422        let mut server = iotry!(UtpSocket::bind(server_addr));
1423        assert_eq!(server.state, SocketState::New);
1424
1425        let child = task::spawn(async move {
1426            let mut client = iotry!(UtpSocket::connect(server_addr));
1427            assert_eq!(client.state, SocketState::Connected);
1428            assert!(client.close().await.is_ok());
1429        });
1430
1431        // Make the server listen for incoming connections until the end of the input
1432        let mut buf = vec![0u8; BUF_SIZE];
1433        let _resp = server.recv_from(&mut buf).await;
1434        assert_eq!(server.state, SocketState::Closed);
1435
1436        // Trying to receive again returns `Ok(0)` (equivalent to the old `EndOfFile`)
1437        match server.recv_from(&mut buf).await {
1438            Ok((0, _src)) => {}
1439            e => panic!("Expected Ok(0), got {:?}", e),
1440        }
1441        assert_eq!(server.state, SocketState::Closed);
1442
1443        child.await;
1444    }
1445
1446    #[async_std::test]
1447    async fn test_sendto_on_closed_socket() {
1448        let server_addr = next_test_ip4();
1449
1450        let mut server = iotry!(UtpSocket::bind(server_addr));
1451        assert_eq!(server.state, SocketState::New);
1452
1453        let child = task::spawn(async move {
1454            let mut client = iotry!(UtpSocket::connect(server_addr));
1455            assert_eq!(client.state, SocketState::Connected);
1456            iotry!(client.close());
1457        });
1458
1459        // Make the server listen for incoming connections
1460        let mut buf = vec![0u8; BUF_SIZE];
1461        let (_read, _src) = iotry!(server.recv_from(&mut buf));
1462        assert_eq!(server.state, SocketState::Closed);
1463
1464        // Trying to send to the socket after closing it raises an error
1465        match server.send_to(&buf).await {
1466            Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
1467            v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
1468        }
1469
1470        child.await;
1471    }
1472
1473    #[async_std::test]
1474    async fn test_acks_on_socket() {
1475        use std::sync::mpsc::channel;
1476        let server_addr = next_test_ip4();
1477        let (tx, rx) = channel();
1478
1479        let mut server = iotry!(UtpSocket::bind(server_addr));
1480
1481        let child = task::spawn(async move {
1482            // Make the server listen for incoming connections
1483            let mut buf = vec![0u8; BUF_SIZE];
1484            let _resp = server.recv(&mut buf).await;
1485            tx.send(server.seq_nr).unwrap();
1486
1487            // Close the connection
1488            iotry!(server.recv_from(&mut buf));
1489
1490            drop(server);
1491        });
1492
1493        let mut client = iotry!(UtpSocket::connect(server_addr));
1494        assert_eq!(client.state, SocketState::Connected);
1495        let sender_seq_nr = rx.recv().unwrap();
1496        let ack_nr = client.ack_nr;
1497        assert_eq!(ack_nr, sender_seq_nr);
1498        assert!(client.close().await.is_ok());
1499
1500        // The reply to both connect (SYN) and close (FIN) should be
1501        // STATE packets, which don't increase the sequence number
1502        // and, hence, the receiver's acknowledgement number.
1503        assert_eq!(client.ack_nr, ack_nr);
1504        drop(client);
1505
1506        child.await;
1507    }
1508
1509    #[async_std::test]
1510    async fn test_handle_packet() {
1511        //fn test_connection_setup() {
1512        let initial_connection_id: u16 = rand::random();
1513        let sender_connection_id = initial_connection_id + 1;
1514        let (server_addr, client_addr) = (
1515            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1516            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1517        );
1518        let mut socket = iotry!(UtpSocket::bind(server_addr));
1519
1520        let mut packet = Packet::new();
1521        packet.set_wnd_size(BUF_SIZE as u32);
1522        packet.set_type(PacketType::Syn);
1523        packet.set_connection_id(initial_connection_id);
1524
1525        // Do we have a response?
1526        let response = socket.handle_packet(&packet, client_addr).await;
1527        assert!(response.is_ok());
1528        let response = response.unwrap();
1529        assert!(response.is_some());
1530
1531        // Is is of the correct type?
1532        let response = response.unwrap();
1533        assert_eq!(response.get_type(), PacketType::State);
1534
1535        // Same connection id on both ends during connection establishment
1536        assert_eq!(response.connection_id(), packet.connection_id());
1537
1538        // Response acknowledges SYN
1539        assert_eq!(response.ack_nr(), packet.seq_nr());
1540
1541        // No payload?
1542        assert!(response.payload().is_empty());
1543        //}
1544
1545        // ---------------------------------
1546
1547        // fn test_connection_usage() {
1548        let old_packet = packet;
1549        let old_response = response;
1550
1551        let mut packet = Packet::new();
1552        packet.set_type(PacketType::Data);
1553        packet.set_connection_id(sender_connection_id);
1554        packet.set_seq_nr(old_packet.seq_nr() + 1);
1555        packet.set_ack_nr(old_response.seq_nr());
1556
1557        let response = socket.handle_packet(&packet, client_addr).await;
1558        assert!(response.is_ok());
1559        let response = response.unwrap();
1560        assert!(response.is_some());
1561
1562        let response = response.unwrap();
1563        assert_eq!(response.get_type(), PacketType::State);
1564
1565        // Sender (i.e., who the initiated connection and sent a SYN) has connection id equal to
1566        // initial connection id + 1
1567        // Receiver (i.e., who accepted connection) has connection id equal to initial connection id
1568        assert_eq!(response.connection_id(), initial_connection_id);
1569        assert_eq!(response.connection_id(), packet.connection_id() - 1);
1570
1571        // Previous packets should be ack'ed
1572        assert_eq!(response.ack_nr(), packet.seq_nr());
1573
1574        // Responses with no payload should not increase the sequence number
1575        assert!(response.payload().is_empty());
1576        assert_eq!(response.seq_nr(), old_response.seq_nr());
1577        // }
1578
1579        //fn test_connection_teardown() {
1580        let old_packet = packet;
1581        let old_response = response;
1582
1583        let mut packet = Packet::new();
1584        packet.set_type(PacketType::Fin);
1585        packet.set_connection_id(sender_connection_id);
1586        packet.set_seq_nr(old_packet.seq_nr() + 1);
1587        packet.set_ack_nr(old_response.seq_nr());
1588
1589        let response = socket.handle_packet(&packet, client_addr).await;
1590        assert!(response.is_ok());
1591        let response = response.unwrap();
1592        assert!(response.is_some());
1593
1594        let response = response.unwrap();
1595
1596        assert_eq!(response.get_type(), PacketType::State);
1597
1598        // FIN packets have no payload but the sequence number shouldn't increase
1599        assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
1600
1601        // Nor should the ACK packet's sequence number
1602        assert_eq!(response.seq_nr(), old_response.seq_nr());
1603
1604        // FIN should be acknowledged
1605        assert_eq!(response.ack_nr(), packet.seq_nr());
1606
1607        //}
1608    }
1609
1610    #[async_std::test]
1611    async fn test_response_to_keepalive_ack() {
1612        // Boilerplate test setup
1613        let initial_connection_id: u16 = rand::random();
1614        let (server_addr, client_addr) = (
1615            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1616            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1617        );
1618        let mut socket = iotry!(UtpSocket::bind(server_addr));
1619
1620        // Establish connection
1621        let mut packet = Packet::new();
1622        packet.set_wnd_size(BUF_SIZE as u32);
1623        packet.set_type(PacketType::Syn);
1624        packet.set_connection_id(initial_connection_id);
1625
1626        let response = socket.handle_packet(&packet, client_addr).await;
1627        assert!(response.is_ok());
1628        let response = response.unwrap();
1629        assert!(response.is_some());
1630        let response = response.unwrap();
1631        assert_eq!(response.get_type(), PacketType::State);
1632
1633        let old_packet = packet;
1634        let old_response = response;
1635
1636        // Now, send a keepalive packet
1637        let mut packet = Packet::new();
1638        packet.set_wnd_size(BUF_SIZE as u32);
1639        packet.set_type(PacketType::State);
1640        packet.set_connection_id(initial_connection_id);
1641        packet.set_seq_nr(old_packet.seq_nr() + 1);
1642        packet.set_ack_nr(old_response.seq_nr());
1643
1644        let response = socket.handle_packet(&packet, client_addr).await;
1645        assert!(response.is_ok());
1646        let response = response.unwrap();
1647        assert!(response.is_none());
1648
1649        // Send a second keepalive packet, identical to the previous one
1650        let response = socket.handle_packet(&packet, client_addr).await;
1651        assert!(response.is_ok());
1652        let response = response.unwrap();
1653        assert!(response.is_none());
1654
1655        // Mark socket as closed
1656        socket.state = SocketState::Closed;
1657    }
1658
1659    #[async_std::test]
1660    async fn test_response_to_wrong_connection_id() {
1661        // Boilerplate test setup
1662        let initial_connection_id: u16 = rand::random();
1663        let (server_addr, client_addr) = (
1664            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1665            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1666        );
1667        let mut socket = iotry!(UtpSocket::bind(server_addr));
1668
1669        // Establish connection
1670        let mut packet = Packet::new();
1671        packet.set_wnd_size(BUF_SIZE as u32);
1672        packet.set_type(PacketType::Syn);
1673        packet.set_connection_id(initial_connection_id);
1674
1675        let response = socket.handle_packet(&packet, client_addr).await;
1676        assert!(response.is_ok());
1677        let response = response.unwrap();
1678        assert!(response.is_some());
1679        assert_eq!(response.unwrap().get_type(), PacketType::State);
1680
1681        // Now, disrupt connection with a packet with an incorrect connection id
1682        let new_connection_id = initial_connection_id.wrapping_mul(2);
1683
1684        let mut packet = Packet::new();
1685        packet.set_wnd_size(BUF_SIZE as u32);
1686        packet.set_type(PacketType::State);
1687        packet.set_connection_id(new_connection_id);
1688
1689        let response = socket.handle_packet(&packet, client_addr).await;
1690        assert!(response.is_ok());
1691        let response = response.unwrap();
1692        assert!(response.is_some());
1693
1694        let response = response.unwrap();
1695        assert_eq!(response.get_type(), PacketType::Reset);
1696        assert_eq!(response.ack_nr(), packet.seq_nr());
1697
1698        // Mark socket as closed
1699        socket.state = SocketState::Closed;
1700    }
1701
1702    #[async_std::test]
1703    async fn test_unordered_packets() {
1704        // Boilerplate test setup
1705        let initial_connection_id: u16 = rand::random();
1706        let (server_addr, client_addr) = (
1707            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1708            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1709        );
1710        let mut socket = iotry!(UtpSocket::bind(server_addr));
1711
1712        // Establish connection
1713        let mut packet = Packet::new();
1714        packet.set_wnd_size(BUF_SIZE as u32);
1715        packet.set_type(PacketType::Syn);
1716        packet.set_connection_id(initial_connection_id);
1717
1718        let response = socket.handle_packet(&packet, client_addr).await;
1719        assert!(response.is_ok());
1720        let response = response.unwrap();
1721        assert!(response.is_some());
1722        let response = response.unwrap();
1723        assert_eq!(response.get_type(), PacketType::State);
1724
1725        let old_packet = packet;
1726        let old_response = response;
1727
1728        let mut window: Vec<Packet> = Vec::new();
1729
1730        // Now, send a keepalive packet
1731        let mut packet = Packet::with_payload(&[1, 2, 3]);
1732        packet.set_wnd_size(BUF_SIZE as u32);
1733        packet.set_connection_id(initial_connection_id);
1734        packet.set_seq_nr(old_packet.seq_nr() + 1);
1735        packet.set_ack_nr(old_response.seq_nr());
1736        window.push(packet);
1737
1738        let mut packet = Packet::with_payload(&[4, 5, 6]);
1739        packet.set_wnd_size(BUF_SIZE as u32);
1740        packet.set_connection_id(initial_connection_id);
1741        packet.set_seq_nr(old_packet.seq_nr() + 2);
1742        packet.set_ack_nr(old_response.seq_nr());
1743        window.push(packet);
1744
1745        // Send packets in reverse order
1746        let response = socket.handle_packet(&window[1], client_addr).await;
1747        assert!(response.is_ok());
1748        let response = response.unwrap();
1749        assert!(response.is_some());
1750        let response = response.unwrap();
1751        assert!(response.ack_nr() != window[1].seq_nr());
1752
1753        let response = socket.handle_packet(&window[0], client_addr).await;
1754        assert!(response.is_ok());
1755        let response = response.unwrap();
1756        assert!(response.is_some());
1757
1758        // Mark socket as closed
1759        socket.state = SocketState::Closed;
1760    }
1761
1762    #[async_std::test]
1763    async fn test_response_to_triple_ack() {
1764        let server_addr = next_test_ip4();
1765        let mut server = iotry!(UtpSocket::bind(server_addr));
1766
1767        // Fits in a packet
1768        const LEN: usize = 1024;
1769        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
1770        let d = data.clone();
1771        assert_eq!(LEN, data.len());
1772
1773        let child = task::spawn(async move {
1774            let mut client = iotry!(UtpSocket::connect(server_addr));
1775            iotry!(client.send_to(&d[..]));
1776            iotry!(client.close());
1777        });
1778
1779        let mut buf = vec![0u8; BUF_SIZE];
1780        // Expect SYN
1781        iotry!(server.recv(&mut buf));
1782
1783        // Receive data
1784        let data_packet = match server.socket.recv_from(&mut buf).await {
1785            Ok((read, _src)) => Packet::try_from(&buf[..read]).unwrap(),
1786            Err(e) => panic!("{}", e),
1787        };
1788        assert_eq!(data_packet.get_type(), PacketType::Data);
1789        assert_eq!(&data_packet.payload(), &data.as_slice());
1790        assert_eq!(data_packet.payload().len(), data.len());
1791
1792        // Send triple ACK
1793        let mut packet = Packet::new();
1794        packet.set_wnd_size(BUF_SIZE as u32);
1795        packet.set_type(PacketType::State);
1796        packet.set_seq_nr(server.seq_nr);
1797        packet.set_ack_nr(data_packet.seq_nr() - 1);
1798        packet.set_connection_id(server.sender_connection_id);
1799
1800        for _ in 0..3u8 {
1801            iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
1802        }
1803
1804        // Receive data again and check that it's the same we reported as missing
1805        let client_addr = server.connected_to;
1806        match server.socket.recv_from(&mut buf).await {
1807            Ok((0, _)) => panic!("Received 0 bytes from socket"),
1808            Ok((read, _src)) => {
1809                let packet = Packet::try_from(&buf[..read]).unwrap();
1810                assert_eq!(packet.get_type(), PacketType::Data);
1811                assert_eq!(packet.seq_nr(), data_packet.seq_nr());
1812                assert_eq!(packet.payload(), data_packet.payload());
1813                let response = server.handle_packet(&packet, client_addr).await;
1814                assert!(response.is_ok());
1815                let response = response.unwrap();
1816                assert!(response.is_some());
1817                let response = response.unwrap();
1818                iotry!(server
1819                    .socket
1820                    .send_to(response.as_ref(), server.connected_to));
1821            }
1822            Err(e) => panic!("{}", e),
1823        }
1824
1825        // Receive close
1826        iotry!(server.recv_from(&mut buf));
1827        child.await;
1828    }
1829
1830    #[async_std::test]
1831    async fn test_socket_timeout_request() {
1832        let (server_addr, client_addr) = (
1833            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1834            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1835        );
1836
1837        let client = iotry!(UtpSocket::bind(client_addr));
1838        let mut server = iotry!(UtpSocket::bind(server_addr));
1839        const LEN: usize = 512;
1840        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
1841        let d = data.clone();
1842
1843        assert_eq!(server.state, SocketState::New);
1844        assert_eq!(client.state, SocketState::New);
1845
1846        // Check proper difference in client's send connection id and receive connection id
1847        assert_eq!(
1848            client.sender_connection_id,
1849            client.receiver_connection_id + 1
1850        );
1851
1852        let child = task::spawn(async move {
1853            let mut client = iotry!(UtpSocket::connect(server_addr));
1854            assert_eq!(client.state, SocketState::Connected);
1855            assert_eq!(client.connected_to, server_addr);
1856            iotry!(client.send_to(&d[..]));
1857            drop(client);
1858        });
1859
1860        let mut buf = vec![0u8; BUF_SIZE];
1861        server.recv(&mut buf).await.unwrap();
1862        // After establishing a new connection, the server's ids are a mirror of the client's.
1863        assert_eq!(
1864            server.receiver_connection_id,
1865            server.sender_connection_id + 1
1866        );
1867
1868        assert_eq!(server.state, SocketState::Connected);
1869
1870        // Purposefully read from UDP socket directly and discard it, in order
1871        // to behave as if the packet was lost and thus trigger the timeout
1872        // handling in the *next* call to `UtpSocket.recv_from`.
1873        iotry!(server.socket.recv_from(&mut buf));
1874
1875        // Set a much smaller than usual timeout, for quicker test completion
1876        server.congestion_timeout = 50;
1877
1878        // Now wait for the previously discarded packet
1879        loop {
1880            let response = server.recv_from(&mut buf).await;
1881            match response {
1882                Ok((0, _)) => continue,
1883                Ok(_) => break,
1884                Err(e) => panic!("{}", e),
1885            }
1886        }
1887
1888        drop(server);
1889        child.await;
1890    }
1891
1892    #[async_std::test]
1893    async fn test_sorted_buffer_insertion() {
1894        let server_addr = next_test_ip4();
1895        let mut socket = iotry!(UtpSocket::bind(server_addr));
1896
1897        let mut packet = Packet::new();
1898        packet.set_seq_nr(1);
1899
1900        assert!(socket.incoming_buffer.is_empty());
1901
1902        socket.insert_into_buffer(packet.clone());
1903        assert_eq!(socket.incoming_buffer.len(), 1);
1904
1905        packet.set_seq_nr(2);
1906        packet.set_timestamp(128.into());
1907
1908        socket.insert_into_buffer(packet.clone());
1909        assert_eq!(socket.incoming_buffer.len(), 2);
1910        assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
1911        assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
1912
1913        packet.set_seq_nr(3);
1914        packet.set_timestamp(256.into());
1915
1916        socket.insert_into_buffer(packet.clone());
1917        assert_eq!(socket.incoming_buffer.len(), 3);
1918        assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
1919        assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
1920
1921        // Replacing a packet with a more recent version doesn't work
1922        packet.set_seq_nr(2);
1923        packet.set_timestamp(456.into());
1924
1925        socket.insert_into_buffer(packet.clone());
1926        assert_eq!(socket.incoming_buffer.len(), 3);
1927        assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
1928        assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
1929    }
1930
1931    #[async_std::test]
1932    async fn test_duplicate_packet_handling() {
1933        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
1934
1935        let client = iotry!(UtpSocket::bind(client_addr));
1936        let mut server = iotry!(UtpSocket::bind(server_addr));
1937
1938        assert_eq!(server.state, SocketState::New);
1939        assert_eq!(client.state, SocketState::New);
1940
1941        // Check proper difference in client's send connection id and receive connection id
1942        assert_eq!(
1943            client.sender_connection_id,
1944            client.receiver_connection_id + 1
1945        );
1946
1947        let child = task::spawn(async move {
1948            let mut client = iotry!(UtpSocket::connect(server_addr));
1949            assert_eq!(client.state, SocketState::Connected);
1950
1951            let mut packet = Packet::with_payload(&[1, 2, 3]);
1952            packet.set_wnd_size(BUF_SIZE as u32);
1953            packet.set_connection_id(client.sender_connection_id);
1954            packet.set_seq_nr(client.seq_nr);
1955            packet.set_ack_nr(client.ack_nr);
1956
1957            // Send two copies of the packet, with different timestamps
1958            for _ in 0..2 {
1959                packet.set_timestamp(now_microseconds());
1960                iotry!(client.socket.send_to(packet.as_ref(), server_addr));
1961            }
1962            client.seq_nr += 1;
1963
1964            // Receive one ACK
1965            for _ in 0..1 {
1966                let mut buf = vec![0u8; BUF_SIZE];
1967                iotry!(client.socket.recv_from(&mut buf));
1968            }
1969
1970            iotry!(client.close());
1971        });
1972
1973        let mut buf = vec![0u8; BUF_SIZE];
1974        iotry!(server.recv(&mut buf));
1975        // After establishing a new connection, the server's ids are a mirror of the client's.
1976        assert_eq!(
1977            server.receiver_connection_id,
1978            server.sender_connection_id + 1
1979        );
1980
1981        assert_eq!(server.state, SocketState::Connected);
1982
1983        let expected: Vec<u8> = vec![1, 2, 3];
1984        let mut received: Vec<u8> = vec![];
1985        loop {
1986            match server.recv_from(&mut buf).await {
1987                Ok((0, _src)) => break,
1988                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
1989                Err(e) => panic!("{:?}", e),
1990            }
1991        }
1992        assert_eq!(received.len(), expected.len());
1993        assert_eq!(received, expected);
1994
1995        child.await;
1996    }
1997
1998    #[async_std::test]
1999    async fn test_correct_packet_loss() {
2000        let server_addr = next_test_ip4();
2001
2002        let mut server = iotry!(UtpSocket::bind(server_addr));
2003        const LEN: usize = 1024 * 10;
2004        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2005        let to_send = data.clone();
2006
2007        let child = task::spawn(async move {
2008            let mut client = iotry!(UtpSocket::connect(server_addr));
2009
2010            // Send everything except the odd chunks
2011            let chunks = to_send[..].chunks(BUF_SIZE);
2012            let dst = client.connected_to;
2013            for (index, chunk) in chunks.enumerate() {
2014                let mut packet = Packet::with_payload(chunk);
2015                packet.set_seq_nr(client.seq_nr);
2016                packet.set_ack_nr(client.ack_nr);
2017                packet.set_connection_id(client.sender_connection_id);
2018                packet.set_timestamp(now_microseconds());
2019
2020                if index % 2 == 0 {
2021                    iotry!(client.socket.send_to(packet.as_ref(), dst));
2022                }
2023
2024                client.curr_window += packet.len() as u32;
2025                client.send_window.push(packet);
2026                client.seq_nr += 1;
2027            }
2028
2029            iotry!(client.close());
2030        });
2031
2032        let mut buf = vec![0u8; BUF_SIZE];
2033        let mut received: Vec<u8> = vec![];
2034        loop {
2035            match server.recv_from(&mut buf).await {
2036                Ok((0, _src)) => break,
2037                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2038                Err(e) => panic!("{}", e),
2039            }
2040        }
2041        assert_eq!(received.len(), data.len());
2042        assert_eq!(received, data);
2043
2044        child.await;
2045    }
2046
2047    #[async_std::test]
2048    async fn test_tolerance_to_small_buffers() {
2049        let server_addr = next_test_ip4();
2050        let mut server = iotry!(UtpSocket::bind(server_addr));
2051        const LEN: usize = 1024;
2052        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2053        let to_send = data.clone();
2054
2055        let child = task::spawn(async move {
2056            let mut client = iotry!(UtpSocket::connect(server_addr));
2057            iotry!(client.send_to(&to_send[..]));
2058            iotry!(client.close());
2059        });
2060
2061        let mut read = Vec::new();
2062        while server.state != SocketState::Closed {
2063            let mut small_buffer = vec![0; 512];
2064            match server.recv_from(&mut small_buffer).await {
2065                Ok((0, _src)) => break,
2066                Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
2067                Err(e) => panic!("{}", e),
2068            }
2069        }
2070
2071        assert_eq!(read.len(), data.len());
2072        assert_eq!(read, data);
2073
2074        child.await;
2075    }
2076
2077    #[async_std::test]
2078    async fn test_sequence_number_rollover() {
2079        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
2080
2081        let mut server = iotry!(UtpSocket::bind(server_addr));
2082
2083        const LEN: usize = BUF_SIZE * 4;
2084        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2085        let to_send = data.clone();
2086
2087        let child = task::spawn(async move {
2088            let mut client = iotry!(UtpSocket::bind(client_addr));
2089
2090            // Advance socket's sequence number
2091            client.seq_nr = ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
2092
2093            let mut client = iotry!(UtpSocket::connect(server_addr));
2094            // Send enough data to rollover
2095            iotry!(client.send_to(&to_send[..]));
2096            // Check that the sequence number did rollover
2097            assert!(client.seq_nr < 50);
2098            // Close connection
2099            iotry!(client.close());
2100        });
2101
2102        let mut buf = vec![0u8; BUF_SIZE];
2103        let mut received: Vec<u8> = vec![];
2104        loop {
2105            match server.recv_from(&mut buf).await {
2106                Ok((0, _src)) => break,
2107                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2108                Err(e) => panic!("{}", e),
2109            }
2110        }
2111        assert_eq!(received.len(), data.len());
2112        assert_eq!(received, data);
2113
2114        child.await;
2115    }
2116
2117    #[async_std::test]
2118    async fn test_drop_unused_socket() {
2119        let server_addr = next_test_ip4();
2120        let server = iotry!(UtpSocket::bind(server_addr));
2121
2122        // Explicitly dropping socket. This test should not hang.
2123        drop(server);
2124    }
2125
2126    #[async_std::test]
2127    async fn test_invalid_packet_on_connect() {
2128        use async_std::net::UdpSocket;
2129        let server_addr = next_test_ip4();
2130        let server = iotry!(UdpSocket::bind(server_addr));
2131
2132        let child = task::spawn(async move {
2133            let mut buf = vec![0u8; BUF_SIZE];
2134            match server.recv_from(&mut buf).await {
2135                Ok((_len, client_addr)) => {
2136                    iotry!(server.send_to(&[], client_addr));
2137                }
2138                _ => panic!(),
2139            }
2140        });
2141
2142        match UtpSocket::connect(server_addr).await {
2143            Err(ref e) if e.kind() == ErrorKind::Other => (), // OK
2144            Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
2145            Ok(_) => panic!("Expected Err, got Ok"),
2146        }
2147
2148        child.await;
2149    }
2150
2151    #[async_std::test]
2152    async fn test_receive_unexpected_reply_type_on_connect() {
2153        use async_std::net::UdpSocket;
2154        let server_addr = next_test_ip4();
2155        let server = iotry!(UdpSocket::bind(server_addr));
2156
2157        let child = task::spawn(async move {
2158            let mut buf = vec![0u8; BUF_SIZE];
2159            let mut packet = Packet::new();
2160            packet.set_type(PacketType::Data);
2161
2162            match server.recv_from(&mut buf).await {
2163                Ok((_len, client_addr)) => {
2164                    iotry!(server.send_to(packet.as_ref(), client_addr));
2165                }
2166                _ => panic!(),
2167            }
2168        });
2169
2170        match UtpSocket::connect(server_addr).await {
2171            Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), // OK
2172            Err(e) => panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e),
2173            Ok(_) => panic!("Expected Err, got Ok"),
2174        }
2175
2176        child.await;
2177    }
2178
2179    #[async_std::test]
2180    async fn test_receiving_syn_on_established_connection() {
2181        // Establish connection
2182        let server_addr = next_test_ip4();
2183        let mut server = iotry!(UtpSocket::bind(server_addr));
2184
2185        let child = task::spawn(async move {
2186            let mut buf = vec![0; BUF_SIZE];
2187            loop {
2188                match server.recv_from(&mut buf).await {
2189                    Ok((0, _src)) => break,
2190                    Ok(_) => (),
2191                    Err(e) => panic!("{:?}", e),
2192                }
2193            }
2194        });
2195
2196        let mut client = iotry!(UtpSocket::connect(server_addr));
2197        let mut packet = Packet::new();
2198        packet.set_wnd_size(BUF_SIZE as u32);
2199        packet.set_type(PacketType::Syn);
2200        packet.set_connection_id(client.sender_connection_id);
2201        packet.set_seq_nr(client.seq_nr);
2202        packet.set_ack_nr(client.ack_nr);
2203        iotry!(client.socket.send_to(packet.as_ref(), server_addr));
2204        let mut buf = vec![0u8; BUF_SIZE];
2205        match client.socket.recv_from(&mut buf).await {
2206            Ok((len, _src)) => {
2207                let reply = Packet::try_from(&buf[..len]).ok().unwrap();
2208                assert_eq!(reply.get_type(), PacketType::Reset);
2209            }
2210            Err(e) => panic!("{:?}", e),
2211        }
2212        iotry!(client.close());
2213
2214        child.await;
2215    }
2216
2217    #[async_std::test]
2218    async fn test_receiving_reset_on_established_connection() {
2219        // Establish connection
2220        let server_addr = next_test_ip4();
2221        let mut server = iotry!(UtpSocket::bind(server_addr));
2222
2223        let child = task::spawn(async move {
2224            let client = iotry!(UtpSocket::connect(server_addr));
2225            let mut packet = Packet::new();
2226            packet.set_wnd_size(BUF_SIZE as u32);
2227            packet.set_type(PacketType::Reset);
2228            packet.set_connection_id(client.sender_connection_id);
2229            packet.set_seq_nr(client.seq_nr);
2230            packet.set_ack_nr(client.ack_nr);
2231            iotry!(client.socket.send_to(packet.as_ref(), server_addr));
2232            let mut buf = vec![0u8; BUF_SIZE];
2233            match client.socket.recv_from(&mut buf).await {
2234                Ok((_len, _src)) => (),
2235                Err(e) => panic!("{:?}", e),
2236            }
2237        });
2238
2239        let mut buf = vec![0u8; BUF_SIZE];
2240        loop {
2241            match server.recv_from(&mut buf).await {
2242                Ok((0, _src)) => break,
2243                Ok(_) => (),
2244                Err(ref e) if e.kind() == ErrorKind::ConnectionReset => return,
2245                Err(e) => panic!("{:?}", e),
2246            }
2247        }
2248        child.await;
2249        panic!("Should have received Reset");
2250    }
2251
2252    #[cfg(not(windows))]
2253    #[async_std::test]
2254    async fn test_premature_fin() {
2255        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
2256        let mut server = iotry!(UtpSocket::bind(server_addr));
2257
2258        const LEN: usize = BUF_SIZE * 4;
2259        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2260        let to_send = data.clone();
2261
2262        let child = task::spawn(async move {
2263            let mut client = iotry!(UtpSocket::connect(server_addr));
2264            iotry!(client.send_to(&to_send[..]));
2265            iotry!(client.close());
2266        });
2267
2268        let mut buf = vec![0u8; BUF_SIZE];
2269
2270        // Accept connection
2271        iotry!(server.recv(&mut buf));
2272
2273        // Send FIN without acknowledging packets received
2274        let mut packet = Packet::new();
2275        packet.set_connection_id(server.sender_connection_id);
2276        packet.set_seq_nr(server.seq_nr);
2277        packet.set_ack_nr(server.ack_nr);
2278        packet.set_timestamp(now_microseconds());
2279        packet.set_type(PacketType::Fin);
2280        iotry!(server.socket.send_to(packet.as_ref(), client_addr));
2281
2282        // Receive until end
2283        let mut received: Vec<u8> = vec![];
2284        loop {
2285            match server.recv_from(&mut buf).await {
2286                Ok((0, _src)) => break,
2287                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2288                Err(e) => panic!("{}", e),
2289            }
2290        }
2291        assert_eq!(received.len(), data.len());
2292        assert_eq!(received, data);
2293
2294        child.await;
2295    }
2296
2297    #[async_std::test]
2298    async fn test_base_delay_calculation() {
2299        let minute_in_microseconds = 60 * 10i64.pow(6);
2300        let samples = vec![
2301            (0, 10),
2302            (1, 8),
2303            (2, 12),
2304            (3, 7),
2305            (minute_in_microseconds + 1, 11),
2306            (minute_in_microseconds + 2, 19),
2307            (minute_in_microseconds + 3, 9),
2308        ];
2309        let addr = next_test_ip4();
2310        let mut socket = UtpSocket::bind(addr).await.unwrap();
2311
2312        for (timestamp, delay) in samples {
2313            socket.update_base_delay(delay.into(), ((timestamp + delay) as u32).into());
2314        }
2315
2316        let expected = vec![7i64, 9i64]
2317            .into_iter()
2318            .map(Into::into)
2319            .collect::<Vec<_>>();
2320        let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
2321        assert_eq!(expected, actual);
2322        assert_eq!(
2323            socket.min_base_delay(),
2324            expected.iter().min().cloned().unwrap_or_default()
2325        );
2326    }
2327
2328    #[async_std::test]
2329    async fn test_local_addr() {
2330        let addr = next_test_ip4();
2331        let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2332        let socket = UtpSocket::bind(addr).await.unwrap();
2333
2334        assert!(socket.local_addr().is_ok());
2335        assert_eq!(socket.local_addr().unwrap(), addr);
2336    }
2337
2338    #[async_std::test]
2339    async fn test_listener_local_addr() {
2340        let addr = next_test_ip4();
2341        let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2342        let listener = UtpListener::bind(addr).await.unwrap();
2343
2344        assert!(listener.local_addr().is_ok());
2345        assert_eq!(listener.local_addr().unwrap(), addr);
2346    }
2347
2348    #[async_std::test]
2349    async fn test_listener_listener_clone() {
2350        let addr = next_test_ip4();
2351        let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2352
2353        // setup listener and clone to be used on two tasks
2354        let listener1 = UtpListener::bind(addr).await.unwrap();
2355        let listener2 = listener1.clone();
2356
2357        task::spawn(async move { listener1.accept().await.unwrap() });
2358        task::spawn(async move { listener2.accept().await.unwrap() });
2359
2360        // Connect twice - to each listerner
2361        task::spawn(async move {
2362            UtpSocket::connect(addr).await.unwrap();
2363            UtpSocket::connect(addr).await.unwrap();
2364        })
2365        .await;
2366    }
2367
2368    #[async_std::test]
2369    async fn test_peer_addr() {
2370        use std::sync::mpsc::channel;
2371        let addr = next_test_ip4();
2372        let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
2373        let mut server = UtpSocket::bind(server_addr).await.unwrap();
2374        let (tx, rx) = channel();
2375
2376        // `peer_addr` should return an error because the socket isn't connected yet
2377        assert!(server.peer_addr().is_err());
2378
2379        let child = task::spawn(async move {
2380            let mut client = iotry!(UtpSocket::connect(server_addr));
2381            let mut buf = vec![0; 1024];
2382            tx.send(client.local_addr()).unwrap();
2383            iotry!(client.recv_from(&mut buf));
2384        });
2385
2386        // Wait for a connection to be established
2387        let mut buf = vec![0; 1024];
2388        iotry!(server.recv(&mut buf));
2389
2390        // `peer_addr` should succeed and be equal to the client's address
2391        assert!(server.peer_addr().is_ok());
2392        // The client is expected to be bound to "0.0.0.0", so we can only check if the port is
2393        // correct
2394        let client_addr = rx.recv().unwrap().unwrap();
2395        assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
2396
2397        // Close the connection
2398        iotry!(server.close());
2399
2400        // `peer_addr` should now return an error because the socket is closed
2401        assert!(server.peer_addr().is_err());
2402
2403        child.await;
2404    }
2405
2406    #[async_std::test]
2407    async fn test_take_address() {
2408        // Expected successes
2409        assert!(take_address("0.0.0.0:0").await.is_ok());
2410        assert!(take_address("[::]:0").await.is_ok());
2411        assert!(take_address(("0.0.0.0", 0)).await.is_ok());
2412        assert!(take_address(("::", 0)).await.is_ok());
2413        assert!(take_address(("1.2.3.4", 5)).await.is_ok());
2414
2415        // Expected failures
2416        assert!(take_address("999.0.0.0:0").await.is_err());
2417        assert!(take_address("1.2.3.4:70000").await.is_err());
2418        assert!(take_address("").await.is_err());
2419        assert!(take_address("this is not an address").await.is_err());
2420        assert!(take_address("no.dns.resolution.com").await.is_err());
2421    }
2422
2423    // Test reaction to connection loss when sending data packets
2424    #[async_std::test]
2425    async fn test_connection_loss_data() {
2426        let server_addr = next_test_ip4();
2427        let mut server = iotry!(UtpSocket::bind(server_addr));
2428        // Decrease timeouts for faster tests
2429        server.congestion_timeout = 1;
2430        let attempts = server.max_retransmission_retries;
2431
2432        let child = task::spawn(async move {
2433            let mut client = iotry!(UtpSocket::connect(server_addr));
2434            iotry!(client.send_to(&[0]));
2435            // Simulate connection loss by killing the socket.
2436            client.state = SocketState::Closed;
2437            let ref socket = client.socket;
2438            let mut buf = vec![0u8; BUF_SIZE];
2439            iotry!(socket.recv_from(&mut buf));
2440            for _ in 0..attempts {
2441                match socket.recv_from(&mut buf).await {
2442                    Ok((len, _src)) => assert_eq!(
2443                        Packet::try_from(&buf[..len]).unwrap().get_type(),
2444                        PacketType::Data
2445                    ),
2446                    Err(e) => panic!("{}", e),
2447                }
2448            }
2449        });
2450
2451        // Drain incoming packets
2452        let mut buf = vec![0u8; BUF_SIZE];
2453        iotry!(server.recv_from(&mut buf));
2454
2455        iotry!(server.send_to(&[0]));
2456
2457        // Try to receive ACKs, time out too many times on flush, and fail with `TimedOut`
2458        let mut buf = vec![0u8; BUF_SIZE];
2459        match server.recv(&mut buf).await {
2460            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2461            x => panic!("Expected Err(TimedOut), got {:?}", x),
2462        }
2463
2464        child.await;
2465    }
2466
2467    // Test reaction to connection loss when sending FIN
2468    #[async_std::test]
2469    async fn test_connection_loss_fin() {
2470        let server_addr = next_test_ip4();
2471        let mut server = iotry!(UtpSocket::bind(server_addr));
2472        // Decrease timeouts for faster tests
2473        server.congestion_timeout = 1;
2474        let attempts = server.max_retransmission_retries;
2475
2476        let child = task::spawn(async move {
2477            let mut client = iotry!(UtpSocket::connect(server_addr));
2478            iotry!(client.send_to(&[0]));
2479            // Simulate connection loss by killing the socket.
2480            client.state = SocketState::Closed;
2481            let ref socket = client.socket;
2482            let mut buf = vec![0u8; BUF_SIZE];
2483            iotry!(socket.recv_from(&mut buf));
2484            for _ in 0..attempts {
2485                match socket.recv_from(&mut buf).await {
2486                    Ok((len, _src)) => assert_eq!(
2487                        Packet::try_from(&buf[..len]).unwrap().get_type(),
2488                        PacketType::Fin
2489                    ),
2490                    Err(e) => panic!("{}", e),
2491                }
2492            }
2493        });
2494
2495        // Drain incoming packets
2496        let mut buf = vec![0u8; BUF_SIZE];
2497        iotry!(server.recv_from(&mut buf));
2498
2499        // Send FIN, time out too many times, and fail with `TimedOut`
2500        match server.close().await {
2501            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2502            x => panic!("Expected Err(TimedOut), got {:?}", x),
2503        }
2504        child.await;
2505    }
2506
2507    // Test reaction to connection loss when waiting for data packets
2508    #[async_std::test]
2509    async fn test_connection_loss_waiting() {
2510        let server_addr = next_test_ip4();
2511        let mut server = iotry!(UtpSocket::bind(server_addr));
2512        // Decrease timeouts for faster tests
2513        server.congestion_timeout = 1;
2514        let attempts = server.max_retransmission_retries;
2515
2516        task::spawn(async move {
2517            let mut client = iotry!(UtpSocket::connect(server_addr));
2518            iotry!(client.send_to(&[0]));
2519            // Simulate connection loss by killing the socket.
2520            client.state = SocketState::Closed;
2521            let ref socket = client.socket;
2522            let seq_nr = client.seq_nr;
2523            let mut buf = vec![0u8; BUF_SIZE];
2524            for _ in 0..(3 * attempts) {
2525                match socket.recv_from(&mut buf).await {
2526                    Ok((len, _src)) => {
2527                        let packet = Packet::try_from(&buf[..len]).unwrap();
2528                        assert_eq!(packet.get_type(), PacketType::State);
2529                        assert_eq!(packet.ack_nr(), seq_nr - 1);
2530                    }
2531                    Err(e) => panic!("{}", e),
2532                }
2533            }
2534        });
2535
2536        // Drain incoming packets
2537        let mut buf = vec![0; BUF_SIZE];
2538        iotry!(server.recv_from(&mut buf));
2539
2540        // Try to receive data, time out too many times, and fail with `TimedOut`
2541        let mut buf = vec![0; BUF_SIZE];
2542        match server.recv_from(&mut buf).await {
2543            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2544            x => panic!("Expected Err(TimedOut), got {:?}", x),
2545        }
2546    }
2547}