utp/
socket.rs

1use std::cmp::{max, min};
2use std::collections::VecDeque;
3use std::future::Future;
4use std::io::{ErrorKind, Result};
5use std::iter::Iterator;
6use std::mem;
7use std::net::SocketAddr;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13use crate::error::SocketError;
14use crate::packet::*;
15use crate::time::*;
16use crate::util::*;
17
18use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
19use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket};
20use tokio::sync::mpsc::{
21    unbounded_channel, UnboundedReceiver, UnboundedSender,
22};
23use tokio::sync::Mutex;
24use tokio::time::{sleep, timeout, Instant as TokioInstant, Sleep};
25
26use tracing::debug;
27
28// For simplicity's sake, let us assume no packet will ever exceed the
29// Ethernet maximum transfer unit of 1500 bytes.
30const BUF_SIZE: usize = 1500;
31const GAIN: f64 = 1.0;
32const ALLOWED_INCREASE: u32 = 1;
33const TARGET: f64 = 100_000.0; // 100 milliseconds
34const MSS: u32 = 1400;
35const MIN_CWND: u32 = 2;
36const INIT_CWND: u32 = 2;
37const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; // one second
38const MIN_CONGESTION_TIMEOUT: u64 = 500; // 500 ms
39const MAX_CONGESTION_TIMEOUT: u64 = 60_000; // one minute
40const BASE_HISTORY: usize = 10; // base delays history size
41const MAX_SYN_RETRIES: u32 = 5; // maximum connection retries
42const MAX_RETRANSMISSION_RETRIES: u32 = 5; // maximum retransmission retries
43const WINDOW_SIZE: u32 = 1024 * 1024; // local receive window size
44
45/// Maximum age of base delay sample (60 seconds)
46const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
47
48#[derive(PartialEq, Eq, Debug, Copy, Clone)]
49enum SocketState {
50    New,
51    Connected,
52    SynSent,
53    FinSent,
54    ResetReceived,
55    Closed,
56}
57
58struct DelayDifferenceSample {
59    received_at: Timestamp,
60    difference: Delay,
61}
62
63/// A structure that represents a uTP (Micro Transport Protocol) connection
64/// between a local socket and a remote socket.
65///
66/// The socket will be closed when the value is dropped (either explicitly or
67/// when it goes out of scope).
68///
69/// The default maximum retransmission retries is 5, which translates to about
70/// 16 seconds. It can be changed by assigning the desired maximum
71/// retransmission retries to a socket's `max_retransmission_retries` field.
72/// Notice that the initial congestion timeout is 500 ms and doubles with each
73/// timeout.
74struct UtpSocket {
75    /// The wrapped UDP socket
76    socket: UdpSocket,
77
78    /// Remote peer
79    connected_to: SocketAddr,
80
81    /// Sender connection identifier
82    sender_connection_id: u16,
83
84    /// Receiver connection identifier
85    receiver_connection_id: u16,
86
87    /// Sequence number for the next packet
88    seq_nr: u16,
89
90    /// Sequence number of the latest acknowledged packet sent by the remote peer
91    ack_nr: u16,
92
93    /// Socket state
94    state: SocketState,
95
96    /// Received but not acknowledged packets
97    incoming_buffer: Vec<Packet>,
98
99    /// Sent but not yet acknowledged packets
100    send_window: Vec<Packet>,
101
102    /// Packets not yet sent
103    unsent_queue: VecDeque<Packet>,
104
105    /// How many ACKs did the socket receive for packet with sequence number
106    /// equal to `ack_nr`
107    duplicate_ack_count: u32,
108
109    /// Sequence number of the latest packet the remote peer acknowledged
110    last_acked: u16,
111
112    /// Timestamp of the latest packet the remote peer acknowledged
113    last_acked_timestamp: Timestamp,
114
115    /// Sequence number of the last packet removed from the incoming buffer
116    last_dropped: u16,
117
118    /// Round-trip time to remote peer
119    rtt: i32,
120
121    /// Variance of the round-trip time to the remote peer
122    rtt_variance: i32,
123
124    /// Data from the latest packet not yet returned in `recv_from`
125    pending_data: Vec<u8>,
126
127    /// Bytes in flight
128    curr_window: u32,
129
130    /// Window size of the remote peer
131    remote_wnd_size: u32,
132
133    /// Rolling window of packet delay to remote peer
134    base_delays: VecDeque<Delay>,
135
136    /// Rolling window of the difference between sending a packet and receiving
137    /// its acknowledgement
138    current_delays: Vec<DelayDifferenceSample>,
139
140    /// Difference between timestamp of the latest packet received and time of
141    /// reception
142    their_delay: Delay,
143
144    /// Start of the current minute for sampling purposes
145    last_rollover: Timestamp,
146
147    /// Current congestion timeout in milliseconds
148    congestion_timeout: u64,
149
150    /// Congestion window in bytes
151    cwnd: u32,
152
153    /// Maximum retransmission retries
154    pub max_retransmission_retries: u32,
155}
156
157impl UtpSocket {
158    /// Creates a new UTP socket from the given UDP socket and the remote peer's
159    /// address.
160    ///
161    /// The connection identifier of the resulting socket is randomly generated.
162    fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
163        let (receiver_id, sender_id) = generate_sequential_identifiers();
164
165        UtpSocket {
166            socket: s,
167            connected_to: src,
168            receiver_connection_id: receiver_id,
169            sender_connection_id: sender_id,
170            seq_nr: 1,
171            ack_nr: 0,
172            state: SocketState::New,
173            incoming_buffer: Vec::new(),
174            send_window: Vec::new(),
175            unsent_queue: VecDeque::new(),
176            duplicate_ack_count: 0,
177            last_acked: 0,
178            last_acked_timestamp: Timestamp::default(),
179            last_dropped: 0,
180            rtt: 0,
181            rtt_variance: 0,
182            pending_data: Vec::new(),
183            curr_window: 0,
184            remote_wnd_size: 0,
185            current_delays: Vec::new(),
186            base_delays: VecDeque::with_capacity(BASE_HISTORY),
187            their_delay: Delay::default(),
188            last_rollover: Timestamp::default(),
189            congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
190            cwnd: INIT_CWND * MSS,
191            max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
192        }
193    }
194
195    /// Creates a new UTP socket from the given address.
196    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
197        let src = lookup_host(&addr)
198            .await?
199            .last()
200            .ok_or(ErrorKind::AddrNotAvailable)?;
201        let socket = UdpSocket::bind(addr).await?;
202
203        Ok(UtpSocket::from_raw_parts(socket, src))
204    }
205
206    /// Returns the socket address that this socket was created from.
207    pub fn local_addr(&self) -> Result<SocketAddr> {
208        self.socket.local_addr()
209    }
210
211    /// Returns the socket address of the remote peer of this UTP connection.
212    pub fn peer_addr(&self) -> Result<SocketAddr> {
213        if self.state == SocketState::Connected
214            || self.state == SocketState::FinSent
215        {
216            Ok(self.connected_to)
217        } else {
218            Err(SocketError::NotConnected.into())
219        }
220    }
221
222    /// Opens a connection to a remote host by hostname or IP address.
223    pub async fn connect(addr: SocketAddr) -> Result<UtpSocket> {
224        let mut socket = UtpSocket::bind(addr).await?;
225        socket.connected_to = addr;
226
227        let mut packet = Packet::new();
228        packet.set_type(PacketType::Syn);
229        packet.set_connection_id(socket.receiver_connection_id);
230        packet.set_seq_nr(socket.seq_nr);
231
232        let mut len = 0;
233        let mut buf = [0; BUF_SIZE];
234
235        let mut syn_timeout = socket.congestion_timeout;
236        for _ in 0..MAX_SYN_RETRIES {
237            packet.set_timestamp(now_microseconds());
238
239            // Send packet
240            debug!("connecting to {}", socket.connected_to);
241            socket
242                .socket
243                .send_to(packet.as_ref(), socket.connected_to)
244                .await?;
245            socket.state = SocketState::SynSent;
246            debug!("sent {:?}", packet);
247
248            // Validate response
249            let to = Duration::from_millis(syn_timeout);
250
251            match timeout(to, socket.socket.recv_from(&mut buf)).await {
252                Ok(Ok((read, src))) => {
253                    socket.connected_to = src;
254                    len = read;
255                    break;
256                }
257                Ok(Err(e)) => return Err(e),
258                Err(_) => {
259                    debug!("timed out, retrying");
260                    syn_timeout *= 2;
261                    continue;
262                }
263            };
264        }
265
266        let addr = socket.connected_to;
267        let packet = Packet::try_from(&buf[..len])?;
268        debug!("received {:?}", packet);
269        socket.handle_packet(&packet, addr)?;
270
271        debug!("connected to: {}", socket.connected_to);
272
273        Ok(socket)
274    }
275
276    /// Gracefully closes connection to peer.
277    ///
278    /// This method allows both peers to receive all packets still in
279    /// flight.
280    pub async fn close(&mut self) -> Result<()> {
281        // Nothing to do if the socket's already closed or not connected
282        if self.state == SocketState::Closed
283            || self.state == SocketState::New
284            || self.state == SocketState::SynSent
285        {
286            return Ok(());
287        }
288
289        let local = self.socket.local_addr()?;
290
291        debug!("closing {} -> {}", local, self.connected_to);
292
293        // Flush unsent and unacknowledged packets
294        self.flush().await?;
295
296        debug!("close flush completed");
297
298        let mut packet = Packet::new();
299        packet.set_connection_id(self.sender_connection_id);
300        packet.set_seq_nr(self.seq_nr);
301        packet.set_ack_nr(self.ack_nr);
302        packet.set_timestamp(now_microseconds());
303        packet.set_type(PacketType::Fin);
304
305        // Send FIN
306        self.socket
307            .send_to(packet.as_ref(), self.connected_to)
308            .await?;
309        debug!("sent {:?}", packet);
310        self.state = SocketState::FinSent;
311
312        // Receive JAKE
313        let mut jbuf = [0; BUF_SIZE];
314        let mut buf: ReadBuf<'_> = ReadBuf::new(&mut jbuf);
315
316        while self.state != SocketState::Closed {
317            self.recv(&mut buf).await?;
318        }
319
320        debug!("closed {} -> {}", local, self.connected_to);
321
322        Ok(())
323    }
324
325    /// Receives data from socket.
326    ///
327    /// On success, returns the number of bytes read and the sender's address.
328    /// Returns 0 bytes read after receiving a FIN packet when the remaining
329    /// in-flight packets are consumed.
330    pub async fn recv_from(
331        &mut self,
332        buf: &mut [u8],
333    ) -> Result<(usize, SocketAddr)> {
334        let mut buf = ReadBuf::new(buf);
335        let read = self.flush_incoming_buffer(&mut buf);
336
337        if read > 0 {
338            Ok((read, self.connected_to))
339        } else {
340            // If the socket received a reset packet and all data has been
341            // flushed, then it can't receive anything else
342            if self.state == SocketState::ResetReceived {
343                return Err(SocketError::ConnectionReset.into());
344            }
345
346            loop {
347                // A closed socket with no pending data can only "read" 0 new
348                // bytes.
349                if self.state == SocketState::Closed {
350                    return Ok((0, self.connected_to));
351                }
352
353                match self.recv(&mut buf).await {
354                    Ok((0, _src)) => continue,
355                    Ok(x) => return Ok(x),
356                    Err(e) => return Err(e),
357                }
358            }
359        }
360    }
361
362    async fn recv(
363        &mut self,
364        buf: &mut ReadBuf<'_>,
365    ) -> Result<(usize, SocketAddr)> {
366        let mut b = [0; BUF_SIZE + HEADER_SIZE];
367        let start = Instant::now();
368        let read;
369        let src;
370        let mut retries = 0;
371
372        // Try to receive a packet and handle timeouts
373        loop {
374            // Abort loop if the current try exceeds the maximum number of
375            // retransmission retries.
376            if retries >= self.max_retransmission_retries {
377                self.state = SocketState::Closed;
378                return Err(SocketError::ConnectionTimedOut.into());
379            }
380
381            if self.state != SocketState::New {
382                let to = Duration::from_millis(self.congestion_timeout);
383                debug!(
384                    "setting read timeout of {} ms",
385                    self.congestion_timeout
386                );
387
388                match timeout(to, self.socket.recv_from(&mut b)).await {
389                    Ok(Ok((r, s))) => {
390                        read = r;
391                        src = s;
392                        break;
393                    }
394                    Ok(Err(e)) => return Err(e),
395                    Err(_) => {
396                        debug!("recv_from timed out");
397                        self.handle_receive_timeout().await?;
398                    }
399                };
400            } else {
401                match self.socket.recv_from(&mut b).await {
402                    Ok((r, s)) => {
403                        read = r;
404                        src = s;
405                        break;
406                    }
407                    Err(e) => return Err(e),
408                }
409            };
410
411            let elapsed = start.elapsed();
412            let elapsed_ms = elapsed.as_secs() * 1000
413                + (elapsed.subsec_millis() / 1_000_000) as u64;
414            debug!("{} ms elapsed", elapsed_ms);
415            retries += 1;
416        }
417
418        // Decode received data into a packet
419        let packet = match Packet::try_from(&b[..read]) {
420            Ok(packet) => packet,
421            Err(e) => {
422                debug!("{}", e);
423                debug!("Ignoring invalid packet");
424                return Ok((0, self.connected_to));
425            }
426        };
427        debug!("received {:?}", packet);
428
429        // Process packet, including sending a reply if necessary
430        if let Some(mut pkt) = self.handle_packet(&packet, src)? {
431            pkt.set_wnd_size(WINDOW_SIZE);
432            self.socket.send_to(pkt.as_ref(), src).await?;
433            debug!("sent {:?}", pkt);
434        }
435
436        // Insert data packet into the incoming buffer if it isn't a duplicate
437        // of a previously discarded packet
438        if packet.get_type() == PacketType::Data
439            && packet.seq_nr().wrapping_sub(self.last_dropped) > 0
440        {
441            self.insert_into_buffer(packet);
442        }
443
444        // Flush incoming buffer if possible
445        let read = self.flush_incoming_buffer(buf);
446
447        Ok((read, src))
448    }
449
450    async fn handle_receive_timeout(&mut self) -> Result<()> {
451        self.congestion_timeout *= 2;
452        self.cwnd = MSS;
453
454        // There are three possible cases here:
455        //
456        // - If the socket is sending and waiting for acknowledgements (the send
457        //   window is not empty), resend the first unacknowledged packet;
458        //
459        // - If the socket is not sending and it hasn't sent a FIN yet, then
460        //   it's waiting for incoming packets: send a fast resend request;
461        //
462        // - If the socket sent a FIN previously, resend it.
463        debug!(
464            "self.send_window: {:?}",
465            self.send_window
466                .iter()
467                .map(Packet::seq_nr)
468                .collect::<Vec<u16>>()
469        );
470
471        if self.send_window.is_empty() {
472            // The socket is trying to close, all sent packets were acknowledged,
473            // and it has already sent a FIN: resend it.
474            if self.state == SocketState::FinSent {
475                let mut packet = Packet::new();
476                packet.set_connection_id(self.sender_connection_id);
477                packet.set_seq_nr(self.seq_nr);
478                packet.set_ack_nr(self.ack_nr);
479                packet.set_timestamp(now_microseconds());
480                packet.set_type(PacketType::Fin);
481
482                // Send FIN
483                self.socket
484                    .send_to(packet.as_ref(), self.connected_to)
485                    .await?;
486                debug!("resent FIN: {:?}", packet);
487            } else if self.state != SocketState::New {
488                // The socket is waiting for incoming packets but the remote
489                // peer is silent: send a fast resend request.
490                debug!("sending fast resend request");
491                self.send_fast_resend_request();
492            }
493        } else {
494            // The socket is sending data packets but there is no reply from the
495            // remote peer: resend the first unacknowledged packet with the
496            // current timestamp.
497            let packet = &mut self.send_window[0];
498            packet.set_timestamp(now_microseconds());
499            self.socket
500                .send_to(packet.as_ref(), self.connected_to)
501                .await?;
502            debug!("resent {:?}", packet);
503        }
504
505        Ok(())
506    }
507
508    fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
509        let mut resp = Packet::new();
510        resp.set_type(t);
511        let self_t_micro = now_microseconds();
512        let other_t_micro = original.timestamp();
513        let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
514        resp.set_timestamp(self_t_micro);
515        resp.set_timestamp_difference(time_difference);
516        resp.set_connection_id(self.sender_connection_id);
517        resp.set_seq_nr(self.seq_nr);
518        resp.set_ack_nr(self.ack_nr);
519
520        resp
521    }
522
523    /// Removes a packet in the incoming buffer and updates the current
524    /// acknowledgement number.
525    fn advance_incoming_buffer(&mut self) -> Option<Packet> {
526        if !self.incoming_buffer.is_empty() {
527            let packet = self.incoming_buffer.remove(0);
528            debug!("Removed packet from incoming buffer: {:?}", packet);
529            self.ack_nr = packet.seq_nr();
530            self.last_dropped = self.ack_nr;
531            Some(packet)
532        } else {
533            None
534        }
535    }
536
537    /// Discards sequential, ordered packets in incoming buffer, starting from
538    /// the most recently acknowledged to the most recent, as long as there are
539    /// no missing packets. The discarded packets' payload is written to the
540    /// slice `buf`, starting in position `start`.
541    /// Returns the last written index.
542    fn flush_incoming_buffer(&mut self, buf: &mut ReadBuf) -> usize {
543        fn copy(src: &[u8], dst: &mut ReadBuf) -> usize {
544            let to_copy = min(src.len(), dst.capacity());
545
546            dst.put_slice(&src[..to_copy]);
547
548            to_copy
549        }
550        // Return pending data from a partially read packet
551        if !self.pending_data.is_empty() {
552            let flushed = copy(&self.pending_data[..], buf);
553
554            if flushed == self.pending_data.len() {
555                self.pending_data.clear();
556                self.advance_incoming_buffer();
557            } else {
558                self.pending_data = self.pending_data[flushed..].to_vec();
559            }
560
561            return flushed;
562        }
563
564        // only flush data that we acked (e.g. the packets are in order in the buffer)
565        if !self.incoming_buffer.is_empty()
566            && (self.ack_nr == self.incoming_buffer[0].seq_nr()
567                || self.ack_nr.wrapping_sub(self.incoming_buffer[0].seq_nr())
568                    >= 1)
569        {
570            let flushed = copy(&self.incoming_buffer[0].payload(), buf);
571
572            if flushed == self.incoming_buffer[0].payload().len() {
573                self.advance_incoming_buffer();
574            } else {
575                self.pending_data =
576                    self.incoming_buffer[0].payload()[flushed..].to_vec();
577            }
578
579            return flushed;
580        } else if !self.incoming_buffer.is_empty() {
581            debug!(
582                "not flushing out of order data, acked={} != cached={}",
583                self.ack_nr,
584                self.incoming_buffer[0].seq_nr()
585            );
586        }
587
588        0
589    }
590
591    /// Checks if any pending data can be read without any syscalls
592    pub(crate) fn should_read(&self) -> bool {
593        self.incoming_buffer.is_empty() && self.pending_data.is_empty()
594    }
595
596    /// Sends data on the socket to the remote peer. On success, returns the
597    /// number of bytes written.
598    //
599    // # Implementation details
600    //
601    // This method inserts packets into the send buffer and keeps trying to
602    // advance the send window until an ACK corresponding to the last packet is
603    // received.
604    //
605    // Note that the buffer passed to `send_to` might exceed the maximum packet
606    // size, which will result in the data being split over several packets.
607    pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
608        if self.state == SocketState::Closed {
609            return Err(SocketError::ConnectionClosed.into());
610        }
611
612        let total_length = buf.len();
613
614        for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
615            let mut packet = Packet::with_payload(chunk);
616            packet.set_seq_nr(self.seq_nr);
617            packet.set_ack_nr(self.ack_nr);
618            packet.set_connection_id(self.sender_connection_id);
619
620            self.unsent_queue.push_back(packet);
621
622            // Intentionally wrap around sequence number
623            self.seq_nr = self.seq_nr.wrapping_add(1);
624        }
625
626        // Send every packet in the queue
627        self.send().await?;
628
629        Ok(total_length)
630    }
631
632    /// Consumes acknowledgements for every pending packet.
633    pub async fn flush(&mut self) -> Result<()> {
634        let mut buf = [0u8; BUF_SIZE];
635        let mut buf = ReadBuf::new(&mut buf);
636
637        while !self.send_window.is_empty() {
638            debug!("packets in send window: {}", self.send_window.len());
639            self.recv(&mut buf).await?;
640        }
641
642        Ok(())
643    }
644
645    /// Sends every packet in the unsent packet queue.
646    async fn send(&mut self) -> Result<()> {
647        while let Some(mut packet) = self.unsent_queue.pop_front() {
648            self.send_packet(&mut packet).await?;
649            self.curr_window += packet.len() as u32;
650            self.send_window.push(packet);
651        }
652        Ok(())
653    }
654
655    fn max_inflight(&self) -> u32 {
656        let max_inflight = min(self.cwnd, self.remote_wnd_size);
657        max(MIN_CWND * MSS, max_inflight)
658    }
659
660    /// Send one packet.
661    #[inline]
662    async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
663        debug!("current window: {}", self.send_window.len());
664
665        packet.set_timestamp(now_microseconds());
666        packet.set_timestamp_difference(self.their_delay);
667
668        self.socket
669            .send_to(packet.as_ref(), self.connected_to)
670            .await?;
671
672        debug!("sent {:?}", packet);
673
674        Ok(())
675    }
676
677    // Insert a new sample in the base delay list.
678    //
679    // The base delay list contains at most `BASE_HISTORY` samples, each sample
680    // is the minimum measured over a period of a minute (MAX_BASE_DELAY_AGE).
681    fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
682        if self.base_delays.is_empty()
683            || now - self.last_rollover > MAX_BASE_DELAY_AGE
684        {
685            // Update last rollover
686            self.last_rollover = now;
687
688            // Drop the oldest sample, if need be
689            if self.base_delays.len() == BASE_HISTORY {
690                self.base_delays.pop_front();
691            }
692
693            // Insert new sample
694            self.base_delays.push_back(base_delay);
695        } else {
696            // Replace sample for the current minute if the delay is lower
697            let last_idx = self.base_delays.len() - 1;
698            if base_delay < self.base_delays[last_idx] {
699                self.base_delays[last_idx] = base_delay;
700            }
701        }
702    }
703
704    /// Inserts a new sample in the current delay list after removing samples
705    /// older than one RTT, as specified in RFC6817.
706    fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
707        // Remove samples more than one RTT old
708        let rtt = (self.rtt as i64 * 100).into();
709        while !self.current_delays.is_empty()
710            && now - self.current_delays[0].received_at > rtt
711        {
712            self.current_delays.remove(0);
713        }
714
715        // Insert new measurement
716        self.current_delays.push(DelayDifferenceSample {
717            received_at: now,
718            difference: v,
719        });
720    }
721
722    fn update_congestion_timeout(&mut self, current_delay: i32) {
723        let delta = self.rtt - current_delay;
724        self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
725        self.rtt += (current_delay - self.rtt) / 8;
726        self.congestion_timeout = max(
727            (self.rtt + self.rtt_variance * 4) as u64,
728            MIN_CONGESTION_TIMEOUT,
729        );
730        self.congestion_timeout =
731            min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
732
733        debug!("current_delay: {}", current_delay);
734        debug!("delta: {}", delta);
735        debug!("self.rtt_variance: {}", self.rtt_variance);
736        debug!("self.rtt: {}", self.rtt);
737        debug!("self.congestion_timeout: {}", self.congestion_timeout);
738    }
739
740    /// Calculates the filtered current delay in the current window.
741    ///
742    /// The current delay is calculated through application of the exponential
743    /// weighted moving average filter with smoothing factor 0.333 over the
744    /// current delays in the current window.
745    fn filtered_current_delay(&self) -> Delay {
746        let input = self.current_delays.iter().map(|delay| &delay.difference);
747        (ewma(input, 0.333) as i64).into()
748    }
749
750    /// Calculates the lowest base delay in the current window.
751    fn min_base_delay(&self) -> Delay {
752        self.base_delays.iter().min().cloned().unwrap_or_default()
753    }
754
755    /// Builds the selective acknowledgement extension data for usage in packets.
756    fn build_selective_ack(&self) -> Vec<u8> {
757        let stashed = self
758            .incoming_buffer
759            .iter()
760            .filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
761            .map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
762            .map(|diff| (diff / 8, diff % 8));
763
764        let mut sack = Vec::new();
765        for (byte, bit) in stashed {
766            // Make sure the amount of elements in the SACK vector is a
767            // multiple of 4 and enough to represent the lost packets
768            while byte >= sack.len() || sack.len() % 4 != 0 {
769                sack.push(0u8);
770            }
771
772            sack[byte] |= 1 << bit;
773        }
774
775        sack
776    }
777
778    /// Sends a fast resend request to the remote peer.
779    ///
780    /// A fast resend request consists of sending three State packets
781    /// (acknowledging the last received packet) in quick succession.
782    fn send_fast_resend_request(&mut self) {
783        for _ in 0..3usize {
784            let mut packet = Packet::new();
785            packet.set_type(PacketType::State);
786            let self_t_micro = now_microseconds();
787            packet.set_timestamp(self_t_micro);
788            packet.set_timestamp_difference(self.their_delay);
789            packet.set_connection_id(self.sender_connection_id);
790            packet.set_seq_nr(self.seq_nr);
791            packet.set_ack_nr(self.ack_nr);
792            self.unsent_queue.push_back(packet);
793        }
794    }
795
796    fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
797        debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
798        match self
799            .send_window
800            .iter()
801            .position(|pkt| pkt.seq_nr() == lost_packet_nr)
802        {
803            None => debug!("Packet {} not found", lost_packet_nr),
804            Some(position) => {
805                debug!("self.send_window.len(): {}", self.send_window.len());
806                debug!("position: {}", position);
807                let packet = self.send_window[position].clone();
808
809                self.unsent_queue.push_back(packet);
810
811                // We intentionally don't increase `curr_window` because
812                // otherwise a packet's length would be counted more than once
813            }
814        }
815        debug!("---> END resend_lost_packet <---");
816    }
817
818    /// Forgets sent packets that were acknowledged by the remote peer.
819    fn advance_send_window(&mut self) {
820        // The reason I'm not removing the first element in a loop while its
821        // sequence number is smaller than `last_acked` is because of wrapping
822        // sequence numbers, which would create the sequence [..., 65534, 65535,
823        // 0, 1, ...]. If `last_acked` is smaller than the first packet's
824        // sequence number because of wraparound (for instance, 1), no packets
825        // would be removed, as the condition `seq_nr < last_acked` would fail
826        // immediately.
827        //
828        // On the other hand, I can't keep removing the first packet in a loop
829        // until its sequence number matches `last_acked` because it might never
830        // match, and in that case no packets should be removed.
831        if let Some(position) = self
832            .send_window
833            .iter()
834            .position(|packet| packet.seq_nr() == self.last_acked)
835        {
836            for _ in 0..=position {
837                let packet = self.send_window.remove(0);
838                debug!("removing {} bytes from send window", packet.len());
839                debug!(
840                    "{} packets left in send window",
841                    self.send_window.len()
842                );
843                self.curr_window -= packet.len() as u32;
844            }
845        }
846        debug!("self.curr_window: {}", self.curr_window);
847    }
848
849    fn handle_fin_packet(
850        &mut self,
851        packet: &Packet,
852        src: SocketAddr,
853    ) -> Packet {
854        if packet.ack_nr() < self.seq_nr {
855            debug!("FIN received but there are missing acknowledgements for sent packets");
856        }
857        let mut reply = self.prepare_reply(packet, PacketType::State);
858        if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
859            debug!(
860                "current ack_nr({}) is behind received packet seq_nr ({})",
861                self.ack_nr,
862                packet.seq_nr()
863            );
864
865            // Set SACK extension payload if the packet is not in order
866            let sack = self.build_selective_ack();
867
868            if !sack.is_empty() {
869                debug!("sending SACK to peer");
870                reply.set_sack(sack);
871            }
872        }
873
874        debug!("received FIN from {}, connection is closed", src);
875
876        // Give up, the remote peer might not care about our missing packets
877        self.state = SocketState::Closed;
878        reply
879    }
880
881    /// Handles an incoming packet, updating socket state accordingly.
882    ///
883    /// Returns the appropriate reply packet, if needed.
884    fn handle_packet(
885        &mut self,
886        packet: &Packet,
887        src: SocketAddr,
888    ) -> Result<Option<Packet>> {
889        debug!("({:?}, {:?})", self.state, packet.get_type());
890
891        // Acknowledge only if the packet strictly follows the previous one
892        if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
893            self.ack_nr = packet.seq_nr();
894        }
895
896        // Reset connection if connection id doesn't match and this isn't a SYN
897        if packet.get_type() != PacketType::Syn
898            && self.state != SocketState::SynSent
899            && !(packet.connection_id() == self.sender_connection_id
900                || packet.connection_id() == self.receiver_connection_id)
901        {
902            return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
903        }
904
905        // Update remote window size
906        self.remote_wnd_size = packet.wnd_size();
907        debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
908
909        // Update remote peer's delay between them sending the packet and us
910        // receiving it
911        let now = now_microseconds();
912        self.their_delay = abs_diff(now, packet.timestamp());
913        debug!("self.their_delay: {}", self.their_delay);
914
915        match (self.state, packet.get_type()) {
916            (SocketState::New, PacketType::Syn) => {
917                self.connected_to = src;
918                self.ack_nr = packet.seq_nr();
919                self.seq_nr = rand::random();
920                self.receiver_connection_id = packet.connection_id() + 1;
921                self.sender_connection_id = packet.connection_id();
922                self.state = SocketState::Connected;
923                self.last_dropped = self.ack_nr;
924
925                Ok(Some(self.prepare_reply(packet, PacketType::State)))
926            }
927            (_, PacketType::Syn) => {
928                Ok(Some(self.prepare_reply(packet, PacketType::Reset)))
929            }
930            (SocketState::SynSent, PacketType::State) => {
931                self.connected_to = src;
932                self.ack_nr = packet.seq_nr();
933                self.seq_nr += 1;
934                self.state = SocketState::Connected;
935                self.last_acked = packet.ack_nr();
936                self.last_acked_timestamp = now_microseconds();
937                Ok(None)
938            }
939            (SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
940            (SocketState::Connected, PacketType::Data)
941            | (SocketState::FinSent, PacketType::Data) => {
942                Ok(self.handle_data_packet(packet))
943            }
944            (SocketState::Connected, PacketType::State) => {
945                self.handle_state_packet(packet);
946                Ok(None)
947            }
948            (SocketState::Connected, PacketType::Fin)
949            | (SocketState::FinSent, PacketType::Fin) => {
950                Ok(Some(self.handle_fin_packet(packet, src)))
951            }
952            (SocketState::Closed, PacketType::Fin) => {
953                Ok(Some(self.prepare_reply(packet, PacketType::State)))
954            }
955            (SocketState::FinSent, PacketType::State) => {
956                if packet.ack_nr() == self.seq_nr {
957                    debug!("connection closed succesfully");
958                    self.state = SocketState::Closed;
959                } else {
960                    self.handle_state_packet(packet);
961                }
962                Ok(None)
963            }
964            (_, PacketType::Reset) => {
965                self.state = SocketState::ResetReceived;
966                Err(SocketError::ConnectionReset.into())
967            }
968            (state, ty) => {
969                let message = format!(
970                    "Unimplemented handling for ({:?},{:?})",
971                    state, ty
972                );
973                debug!("{}", message);
974                Err(SocketError::Other(message).into())
975            }
976        }
977    }
978
979    fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
980        // If a FIN was previously sent, reply with a FIN packet acknowledging
981        // the received packet.
982        let packet_type = if self.state == SocketState::FinSent {
983            PacketType::Fin
984        } else {
985            PacketType::State
986        };
987        let mut reply = self.prepare_reply(packet, packet_type);
988
989        if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
990            debug!(
991                "current ack_nr ({}) is behind received packet seq_nr ({})",
992                self.ack_nr,
993                packet.seq_nr()
994            );
995
996            // Set SACK extension payload if the packet is not in order
997            let sack = self.build_selective_ack();
998
999            if !sack.is_empty() {
1000                debug!("sending SACK packet");
1001                reply.set_sack(sack);
1002            }
1003        }
1004
1005        Some(reply)
1006    }
1007
1008    fn queuing_delay(&self) -> Delay {
1009        let filtered_current_delay = self.filtered_current_delay();
1010        let min_base_delay = self.min_base_delay();
1011        let queuing_delay = filtered_current_delay - min_base_delay;
1012
1013        debug!("filtered_current_delay: {}", filtered_current_delay);
1014        debug!("min_base_delay: {}", min_base_delay);
1015        debug!("queuing_delay: {}", queuing_delay);
1016
1017        queuing_delay
1018    }
1019
1020    /// Calculates the new congestion window size, increasing it or decreasing it.
1021    ///
1022    /// This is the core of uTP, the [LEDBAT][ledbat_rfc] congestion algorithm.
1023    /// It depends on estimating the queuing delay between the two peers, and
1024    /// adjusting the congestion window accordingly.
1025    ///
1026    /// `off_target` is a normalized value representing the difference between
1027    /// the current queuing delay and a fixed target delay (`TARGET`).
1028    /// `off_target` ranges between -1.0 and 1.0. A positive value makes the
1029    /// congestion window increase, while a negative value makes the congestion
1030    /// window decrease.
1031    ///
1032    /// `bytes_newly_acked` is the number of bytes acknowledged by an inbound
1033    /// `State` packet. It may be the size of the packet explicitly acknowledged
1034    /// by the inbound packet (i.e., with sequence number equal to the inbound
1035    /// packet's acknowledgement number), or every packet implicitly
1036    /// acknowledged (every packet with sequence number between the previous
1037    /// inbound `State` packet's acknowledgement number and the current inbound
1038    /// `State` packet's acknowledgement number).
1039    ///
1040    ///[ledbat_rfc]: https://tools.ietf.org/html/rfc6817
1041    fn update_congestion_window(
1042        &mut self,
1043        off_target: f64,
1044        bytes_newly_acked: u32,
1045    ) {
1046        let flightsize = self.curr_window;
1047
1048        let cwnd_increase =
1049            GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
1050        let cwnd_increase = cwnd_increase / self.cwnd as f64;
1051        debug!("cwnd_increase: {}", cwnd_increase);
1052
1053        self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
1054        let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
1055        self.cwnd = min(self.cwnd, max_allowed_cwnd);
1056        self.cwnd = max(self.cwnd, MIN_CWND * MSS);
1057
1058        debug!("cwnd: {}", self.cwnd);
1059        debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
1060    }
1061
1062    fn handle_packet_extension(
1063        &mut self,
1064        packet: &Packet,
1065        packet_loss_detected: &mut bool,
1066    ) {
1067        // Process extensions, if any
1068        for extension in packet.extensions() {
1069            if extension.get_type() == ExtensionType::SelectiveAck {
1070                // If three or more packets are acknowledged past the implicit missing one,
1071                // assume it was lost.
1072                if extension.iter().count_ones() >= 3 {
1073                    self.resend_lost_packet(packet.ack_nr() + 1);
1074                    *packet_loss_detected = true;
1075                }
1076
1077                if let Some(last_seq_nr) =
1078                    self.send_window.last().map(Packet::seq_nr)
1079                {
1080                    let lost_packets = extension
1081                        .iter()
1082                        .enumerate()
1083                        .filter(|&(_, received)| !received)
1084                        .map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
1085                        .take_while(|&seq_nr| seq_nr < last_seq_nr);
1086
1087                    for seq_nr in lost_packets {
1088                        debug!("SACK: packet {} lost", seq_nr);
1089                        self.resend_lost_packet(seq_nr);
1090                        *packet_loss_detected = true;
1091                    }
1092                }
1093            } else {
1094                debug!(
1095                    "Unknown extension {:?}, ignoring",
1096                    extension.get_type()
1097                );
1098            }
1099        }
1100    }
1101
1102    fn handle_state_packet(&mut self, packet: &Packet) {
1103        if packet.ack_nr() == self.last_acked {
1104            self.duplicate_ack_count += 1;
1105        } else {
1106            self.last_acked = packet.ack_nr();
1107            self.last_acked_timestamp = now_microseconds();
1108            self.duplicate_ack_count = 1;
1109        }
1110
1111        // Update congestion window size
1112        if let Some(index) = self
1113            .send_window
1114            .iter()
1115            .position(|p| packet.ack_nr() == p.seq_nr())
1116        {
1117            // Calculate the sum of the size of every packet implicitly and
1118            // explicitly acknowledged by the inbound packet (i.e., every packet
1119            // whose sequence number precedes the inbound packet's
1120            // acknowledgement number, plus the packet whose sequence number
1121            // matches)
1122            let bytes_newly_acked = self
1123                .send_window
1124                .iter()
1125                .take(index + 1)
1126                .fold(0, |acc, p| acc + p.len());
1127
1128            // Update base and current delay
1129            let now = now_microseconds();
1130            let our_delay = now - self.send_window[index].timestamp();
1131            debug!("our_delay: {}", our_delay);
1132            self.update_base_delay(our_delay, now);
1133            self.update_current_delay(our_delay, now);
1134
1135            let off_target: f64 =
1136                (TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
1137            debug!("off_target: {}", off_target);
1138
1139            self.update_congestion_window(off_target, bytes_newly_acked as u32);
1140
1141            // Update congestion timeout in milliseconds
1142            let rtt = u32::from(our_delay - self.queuing_delay()) / 1000;
1143            self.update_congestion_timeout(rtt as i32);
1144        }
1145
1146        let mut packet_loss_detected: bool =
1147            !self.send_window.is_empty() && self.duplicate_ack_count == 3;
1148
1149        self.handle_packet_extension(packet, &mut packet_loss_detected);
1150
1151        // Three duplicate ACKs mean a fast resend request. Resend the first
1152        // unacknowledged packet if the incoming packet doesn't have a SACK
1153        // extension. If it does, the lost packets were already resent.
1154        if !self.send_window.is_empty()
1155            && self.duplicate_ack_count == 3
1156            && !packet
1157                .extensions()
1158                .any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
1159        {
1160            self.resend_lost_packet(packet.ack_nr() + 1);
1161        }
1162
1163        // Packet lost, halve the congestion window
1164        if packet_loss_detected {
1165            debug!("packet loss detected, halving congestion window");
1166            self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
1167            debug!("cwnd: {}", self.cwnd);
1168        }
1169
1170        // Success, advance send window
1171        self.advance_send_window();
1172    }
1173
1174    /// Inserts a packet into the socket's buffer.
1175    ///
1176    /// The packet is inserted in such a way that the packets in the buffer are
1177    /// sorted according to their sequence number in ascending order. This
1178    /// allows storing packets that were received out of order.
1179    ///
1180    /// Trying to insert a duplicate of a packet will silently fail.
1181    /// it's more recent (larger timestamp).
1182    fn insert_into_buffer(&mut self, packet: Packet) {
1183        // Immediately push to the end if the packet's sequence number comes
1184        // after the last packet's.
1185        if self
1186            .incoming_buffer
1187            .last()
1188            .map_or(false, |p| packet.seq_nr() > p.seq_nr())
1189        {
1190            self.incoming_buffer.push(packet);
1191        } else {
1192            // Find index following the most recent packet before the one we
1193            // wish to insert
1194            let i = self
1195                .incoming_buffer
1196                .iter()
1197                .filter(|p| p.seq_nr() < packet.seq_nr())
1198                .count();
1199
1200            if self
1201                .incoming_buffer
1202                .get(i)
1203                .map_or(true, |p| p.seq_nr() != packet.seq_nr())
1204            {
1205                self.incoming_buffer.insert(i, packet);
1206            }
1207        }
1208    }
1209}
1210
1211/// Polls a `Future` and returns from current function unless the future is
1212/// `Ready`
1213macro_rules! ready_unpin {
1214    ($data:expr, $cx:expr) => {
1215        match unsafe { Pin::new_unchecked(&mut $data) }.poll($cx) {
1216            Poll::Ready(v) => v,
1217            Poll::Pending => return Poll::Pending,
1218        }
1219    };
1220}
1221
1222/// Polls a `Future` that returns a `Result` and returns from the current
1223/// function unless the feature is `Ready` and the `Result` is `Ok`
1224macro_rules! ready_try_unpin {
1225    ($data:expr, $cx:expr) => {
1226        match ready_unpin!($data, $cx) {
1227            Ok(v) => v,
1228            Err(e) => return Poll::Ready(Err(e)),
1229        }
1230    };
1231}
1232
1233/// Polls a `Future` while ensuring pinning
1234macro_rules! poll_unpin {
1235    ($data:expr, $cx:expr) => {{
1236        #[allow()]
1237        let x = unsafe { Pin::new_unchecked(&mut $data) }.poll($cx);
1238        x
1239    }};
1240}
1241
1242macro_rules! ready_try {
1243    ($data:expr) => {{
1244        match ($data) {
1245            Poll::Pending => return Poll::Pending,
1246            Poll::Ready(Ok(v)) => v,
1247            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1248        }
1249    }};
1250}
1251
1252/// A reference to an existing `UtpSocket` that can be shared amongst multiple
1253/// tasks. This can't function unless the corresponding `UtpSocketDriver` is
1254/// scheduled to run on the same runtime.
1255pub struct UtpSocketRef(Arc<Mutex<UtpSocket>>, SocketAddr);
1256
1257impl UtpSocketRef {
1258    fn new(socket: Arc<Mutex<UtpSocket>>, local: SocketAddr) -> Self {
1259        Self(socket, local)
1260    }
1261
1262    /// Bind an unconnected `UtpSocket` on the given address.
1263    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
1264        let udp = UdpSocket::bind(addr).await?;
1265        let resolved = udp.local_addr()?;
1266        let socket = UtpSocket::from_raw_parts(udp, resolved);
1267        let lock = Arc::new(Mutex::new(socket));
1268
1269        debug!("bound utp socket on {}", resolved);
1270
1271        Ok(Self::new(lock, resolved))
1272    }
1273
1274    /// Connect to a remote host using this `UtpSocket`
1275    pub async fn connect(
1276        self,
1277        dst: SocketAddr,
1278    ) -> Result<(UtpStream, UtpStreamDriver)> {
1279        let mut socket = self.0.lock().await;
1280
1281        socket.connected_to = dst;
1282
1283        let mut packet = Packet::new();
1284        packet.set_type(PacketType::Syn);
1285        packet.set_connection_id(socket.receiver_connection_id);
1286        packet.set_seq_nr(socket.seq_nr);
1287
1288        let mut len = 0;
1289        let mut buf = [0; BUF_SIZE];
1290
1291        let mut syn_timeout = socket.congestion_timeout;
1292        for _ in 0..MAX_SYN_RETRIES {
1293            packet.set_timestamp(now_microseconds());
1294
1295            debug!("connecting to {}", socket.connected_to);
1296            let dst = socket.connected_to;
1297
1298            socket.socket.send_to(packet.as_ref(), dst).await?;
1299            socket.state = SocketState::SynSent;
1300            debug!("sent {:?}", packet);
1301
1302            let to = Duration::from_millis(syn_timeout);
1303
1304            match timeout(to, socket.socket.recv_from(&mut buf)).await {
1305                Ok(Ok((read, src))) => {
1306                    socket.connected_to = src;
1307                    len = read;
1308                    break;
1309                }
1310                Ok(Err(e)) => return Err(e),
1311                Err(_) => {
1312                    debug!("timed out, retrying");
1313                    syn_timeout *= 2;
1314                    continue;
1315                }
1316            };
1317        }
1318
1319        let remote = socket.connected_to;
1320        let packet = Packet::try_from(&buf[..len])?;
1321        debug!("received {:?}", packet);
1322        socket.handle_packet(&packet, remote)?;
1323
1324        debug!("connected to: {}", socket.connected_to);
1325
1326        let (tx, rx) = unbounded_channel();
1327
1328        let local = socket.local_addr()?;
1329
1330        mem::drop(socket);
1331
1332        let driver = UtpStreamDriver::new(self.0.clone(), tx);
1333        let stream = UtpStream::new(self.0, rx, local, remote);
1334
1335        Ok((stream, driver))
1336    }
1337
1338    /// Accept an incoming connection using this `UtpSocket`. This also
1339    /// returns a `UtpStreamDriver` that must be scheduled on a runtime
1340    /// in order for the associated `UtpStream` to work properly.
1341    /// Accepting a new connection will consume this listener.
1342    pub async fn accept(self) -> Result<(UtpStream, UtpStreamDriver)> {
1343        let (src, dst);
1344
1345        loop {
1346            let mut socket = self.0.lock().await;
1347            let mut buf = [0u8; BUF_SIZE];
1348
1349            let (read, remote) = socket.socket.recv_from(&mut buf).await?;
1350
1351            let packet = Packet::try_from(&buf[..read])?;
1352
1353            debug!("accept receive {:?}", packet);
1354
1355            if let Ok(Some(reply)) = socket.handle_packet(&packet, remote) {
1356                src = socket.socket.local_addr()?;
1357                dst = socket.connected_to;
1358
1359                socket.socket.send_to(reply.as_ref(), dst).await?;
1360
1361                debug!("sent {:?} to {}", reply, dst);
1362                debug!("accepted connection {} -> {}", dst, src);
1363                break;
1364            }
1365        }
1366
1367        let (tx, rx) = unbounded_channel();
1368        let socket = self.0;
1369        let stream = UtpStream::new(socket.clone(), rx, src, dst);
1370        let driver = UtpStreamDriver::new(socket, tx);
1371
1372        Ok((stream, driver))
1373    }
1374
1375    /// Get the local address for this `UtpSocket`
1376    pub fn local_addr(&self) -> SocketAddr {
1377        self.1
1378    }
1379}
1380
1381/// A `UtpStream` that can be used to read and write in a more convenient
1382/// fashion with the `AsyncRead` and `AsyncWrite` traits.
1383pub struct UtpStream {
1384    socket: Arc<Mutex<UtpSocket>>,
1385    receiver: UnboundedReceiver<Result<()>>,
1386    local: SocketAddr,
1387    remote: SocketAddr,
1388}
1389
1390impl UtpStream {
1391    fn new(
1392        socket: Arc<Mutex<UtpSocket>>,
1393        receiver: UnboundedReceiver<Result<()>>,
1394        local: SocketAddr,
1395        remote: SocketAddr,
1396    ) -> Self {
1397        Self {
1398            socket,
1399            receiver,
1400            local,
1401            remote,
1402        }
1403    }
1404
1405    /// Get the local address used by this `UtpStream`
1406    pub fn local_addr(&self) -> SocketAddr {
1407        self.local
1408    }
1409
1410    /// Get the address of the remote end of this `UtpStream`
1411    pub fn peer_addr(&self) -> SocketAddr {
1412        self.remote
1413    }
1414
1415    fn handle_driver_notification(
1416        mut self: Pin<&mut Self>,
1417        cx: &mut Context<'_>,
1418        buf: &mut ReadBuf,
1419    ) -> Poll<Result<()>> {
1420        match poll_unpin!(self.receiver.recv(), cx) {
1421            // either driver sender was dropped or disconnection notice
1422            Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
1423                debug!("connection driver has died");
1424                Poll::Ready(Ok(()))
1425            }
1426            Poll::Ready(Some(Ok(()))) => {
1427                debug!("notification from driver");
1428                self.poll_read(cx, buf)
1429            }
1430            Poll::Pending => {
1431                debug!("waiting for notification from driver");
1432                Poll::Pending
1433            }
1434        }
1435    }
1436
1437    fn prepare_packet(socket: &mut UtpSocket, chunk: &[u8]) -> Packet {
1438        let mut packet = Packet::with_payload(chunk);
1439
1440        packet.set_seq_nr(socket.seq_nr);
1441        packet.set_ack_nr(socket.ack_nr);
1442        packet.set_connection_id(socket.sender_connection_id);
1443
1444        packet
1445    }
1446
1447    fn handle_driver_message(
1448        msg: Poll<Option<Result<()>>>,
1449    ) -> Poll<Result<()>> {
1450        match msg {
1451            Poll::Ready(None) => {
1452                debug!("driver is dead, closing success");
1453                Poll::Ready(Ok(()))
1454            }
1455            Poll::Ready(Some(Ok(()))) => {
1456                debug!("driver sent closing notice");
1457                Poll::Ready(Ok(()))
1458            }
1459            Poll::Ready(Some(Err(e)))
1460                if e.kind() == ErrorKind::NotConnected =>
1461            {
1462                debug!("connection closed by err");
1463                Poll::Ready(Ok(()))
1464            }
1465            Poll::Ready(Some(Err(e))) => {
1466                debug!("failed to close correctly");
1467                Poll::Ready(Err(e))
1468            }
1469            Poll::Pending => {
1470                debug!("waiting for driver to complete closing");
1471                Poll::Pending
1472            }
1473        }
1474    }
1475
1476    fn wait_acks(
1477        socket: &mut UtpSocket,
1478        cx: &mut Context<'_>,
1479    ) -> Poll<Result<()>> {
1480        let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
1481
1482        debug!("waiting for ACKs for {} packets", socket.send_window.len());
1483
1484        while !socket.send_window.is_empty()
1485            && socket.state != SocketState::Closed
1486        {
1487            let (read, src) = {
1488                match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
1489                    Poll::Ready(Ok((read, src))) => (read, src),
1490                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1491                    Poll::Pending => return Poll::Pending,
1492                }
1493            };
1494
1495            let packet = Packet::try_from(&buf[..read])?;
1496
1497            if let Some(reply) = socket.handle_packet(&packet, src)? {
1498                if poll_unpin!(socket.socket.send_to(reply.as_ref(), src), cx)
1499                    .is_pending()
1500                {
1501                    socket.unsent_queue.push_back(reply);
1502                    return Poll::Pending;
1503                }
1504            }
1505        }
1506
1507        Poll::Ready(Ok(()))
1508    }
1509
1510    fn flush_unsent(
1511        socket: &mut UtpSocket,
1512        cx: &mut Context<'_>,
1513    ) -> Poll<Result<()>> {
1514        while let Some(mut packet) = socket.unsent_queue.pop_front() {
1515            if poll_unpin!(socket.send_packet(&mut packet), cx).is_pending() {
1516                debug!("too many in flight packets, waiting for ack");
1517                return Poll::Pending;
1518            }
1519
1520            let result = {
1521                let dst = socket.connected_to;
1522                poll_unpin!(socket.socket.send_to(packet.as_ref(), dst), cx)
1523            };
1524
1525            match result {
1526                Poll::Pending => {
1527                    socket.unsent_queue.push_front(packet);
1528                    return Poll::Pending;
1529                }
1530                Poll::Ready(Ok(_)) => socket.send_window.push(packet),
1531                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1532            }
1533        }
1534
1535        Poll::Ready(Ok(()))
1536    }
1537}
1538
1539impl AsyncRead for UtpStream
1540where
1541    Self: Unpin,
1542{
1543    fn poll_read(
1544        self: Pin<&mut Self>,
1545        cx: &mut Context,
1546        buf: &mut ReadBuf,
1547    ) -> Poll<Result<()>> {
1548        debug!("read poll for {} bytes", buf.capacity());
1549
1550        let (read, state) = {
1551            let mut socket = ready_unpin!(self.socket.lock(), cx);
1552
1553            (socket.flush_incoming_buffer(buf), socket.state)
1554        };
1555
1556        if read > 0 {
1557            debug!("flushed {} bytes of received data", read);
1558            Poll::Ready(Ok(()))
1559        } else if state == SocketState::Closed {
1560            debug!("read on closed connection");
1561            Poll::Ready(Ok(()))
1562        } else if state == SocketState::ResetReceived {
1563            debug!("read on reset connection");
1564            Poll::Ready(Err(SocketError::ConnectionReset.into()))
1565        } else {
1566            self.handle_driver_notification(cx, buf)
1567        }
1568    }
1569}
1570
1571impl AsyncWrite for UtpStream
1572where
1573    Self: Unpin,
1574{
1575    fn poll_write(
1576        mut self: Pin<&mut Self>,
1577        cx: &mut Context,
1578        buf: &[u8],
1579    ) -> Poll<Result<usize>> {
1580        let mut socket = ready_unpin!(self.socket.lock(), cx);
1581
1582        if socket.state == SocketState::Closed {
1583            debug!("tried to write on closed connection");
1584            return Poll::Ready(Err(SocketError::ConnectionClosed.into()));
1585        }
1586
1587        let mut sent: usize = 0;
1588
1589        debug!("trying to send {} bytes", buf.len());
1590
1591        for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
1592            if socket.curr_window >= socket.max_inflight() {
1593                debug!("send window is full, waiting for ACKs");
1594                mem::drop(socket);
1595
1596                while poll_unpin!(self.receiver.recv(), cx).is_ready() {}
1597
1598                return Poll::Pending;
1599            }
1600
1601            debug!("attempting to send chunk of {} byte", chunk.len());
1602
1603            let mut packet = Self::prepare_packet(&mut socket, chunk);
1604
1605            match poll_unpin!(socket.send_packet(&mut packet), cx) {
1606                Poll::Pending if sent == 0 => {
1607                    debug!("socket send buffer is full, waiting..");
1608                    return Poll::Pending;
1609                }
1610                Poll::Ready(Err(e)) if sent == 0 => {
1611                    debug!("os error reading data: {}", e);
1612                    return Poll::Ready(Err(e));
1613                }
1614                Poll::Pending | Poll::Ready(Err(_)) => {
1615                    debug!("successfully sent {} bytes, sleeping...", sent);
1616                    return Poll::Ready(Ok(sent));
1617                }
1618
1619                Poll::Ready(Ok(())) => {
1620                    let written = packet.len();
1621
1622                    socket.curr_window += written as u32;
1623                    socket.send_window.push(packet);
1624
1625                    sent += written;
1626                    socket.seq_nr = socket.seq_nr.wrapping_add(1);
1627
1628                    debug!(
1629                        "poll_write sent seq {}, curr_window: {}",
1630                        socket.seq_nr - 1,
1631                        socket.curr_window
1632                    );
1633                }
1634            }
1635        }
1636
1637        Poll::Ready(Ok(buf.len()))
1638    }
1639
1640    fn poll_flush(
1641        mut self: Pin<&mut Self>,
1642        cx: &mut Context,
1643    ) -> Poll<Result<()>> {
1644        debug!("attempting flush");
1645
1646        match poll_unpin!(self.receiver.recv(), cx) {
1647            Poll::Ready(Some(Err(e))) => {
1648                debug!("driver signaled error over channel");
1649                return Poll::Ready(Err(e));
1650            }
1651            Poll::Ready(None) => {
1652                debug!("connection driver disconnected");
1653                return Poll::Ready(Ok(()));
1654            }
1655            _ => debug!("no message from driver"),
1656        }
1657
1658        let mut socket = ready_unpin!(self.socket.lock(), cx);
1659
1660        if socket.state == SocketState::Closed {
1661            return Poll::Ready(Err(SocketError::NotConnected.into()));
1662        }
1663
1664        ready_try!(Self::flush_unsent(&mut socket, cx));
1665
1666        ready_try!(Self::wait_acks(&mut socket, cx));
1667
1668        debug!("sucessfully flushed");
1669
1670        Poll::Ready(Ok(()))
1671    }
1672
1673    fn poll_shutdown(
1674        mut self: Pin<&mut Self>,
1675        cx: &mut Context,
1676    ) -> Poll<Result<()>> {
1677        debug!("poll_shutdown connection...");
1678
1679        {
1680            let socket = ready_unpin!(self.socket.lock(), cx);
1681
1682            if socket.state == SocketState::Closed {
1683                debug!("socket closed by driver");
1684                return Poll::Ready(Ok(()));
1685            }
1686        }
1687
1688        match self.as_mut().poll_flush(cx) {
1689            Poll::Pending => Poll::Pending,
1690            Poll::Ready(Ok(())) => {
1691                {
1692                    let mut socket = ready_unpin!(self.socket.lock(), cx);
1693
1694                    if socket.state != SocketState::FinSent {
1695                        if let Poll::Ready(Ok(())) =
1696                            poll_unpin!(socket.close(), cx)
1697                        {
1698                            return Poll::Ready(Ok(()));
1699                        } else {
1700                            mem::drop(socket);
1701                            ready_unpin!(self.receiver.recv(), cx);
1702                        }
1703                    }
1704                }
1705
1706                let msg = poll_unpin!(self.receiver.recv(), cx);
1707
1708                Self::handle_driver_message(msg)
1709            }
1710            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1711        }
1712    }
1713}
1714
1715#[must_use = "stream drivers must be spawned for the stream to work"]
1716/// This is a `Future` that takes care of handling all events related to
1717/// a `UtpStream`. `UtpStream` won't receive neither send any data until this
1718/// driver is spawned as a tokio task.
1719pub struct UtpStreamDriver {
1720    socket: Arc<Mutex<UtpSocket>>,
1721    sender: UnboundedSender<Result<()>>,
1722    timer: Pin<Box<Sleep>>,
1723    timeout_nr: u32,
1724}
1725
1726impl UtpStreamDriver {
1727    fn new(
1728        socket: Arc<Mutex<UtpSocket>>,
1729        sender: UnboundedSender<Result<()>>,
1730    ) -> Self {
1731        Self {
1732            socket,
1733            sender,
1734            timer: Box::pin(sleep(Duration::from_millis(
1735                INITIAL_CONGESTION_TIMEOUT,
1736            ))),
1737            timeout_nr: 0,
1738        }
1739    }
1740
1741    async fn handle_timeout(&mut self, next_timeout: u64) -> Result<()> {
1742        self.timeout_nr += 1;
1743        debug!(
1744            "timed out {} times out of {} max, retrying in {} ms",
1745            self.timeout_nr, MAX_RETRANSMISSION_RETRIES, next_timeout
1746        );
1747
1748        if self.timeout_nr > MAX_RETRANSMISSION_RETRIES {
1749            let mut socket = self.socket.lock().await;
1750            socket.state = SocketState::Closed;
1751
1752            return Err(SocketError::ConnectionTimedOut.into());
1753        }
1754
1755        let ret = {
1756            let mut socket = self.socket.lock().await;
1757            socket.handle_receive_timeout().await
1758        };
1759
1760        self.reset_timer(Duration::from_millis(next_timeout));
1761
1762        ret
1763    }
1764
1765    fn notify_close(&mut self) {
1766        if self
1767            .sender
1768            .send(Err(SocketError::NotConnected.into()))
1769            .is_err()
1770        {
1771            error!("failed to notify socket of termination");
1772        } else {
1773            debug!("notified socket of closing");
1774        }
1775    }
1776
1777    fn send_reply(
1778        socket: &mut UtpSocket,
1779        mut reply: Packet,
1780        cx: &mut Context<'_>,
1781    ) -> Poll<Result<()>> {
1782        match poll_unpin!(socket.send_packet(&mut reply), cx) {
1783            Poll::Pending => {
1784                socket.unsent_queue.push_back(reply);
1785                Poll::Pending
1786            }
1787            Poll::Ready(Err(e)) => {
1788                error!("driver failed to send packet: {}", e);
1789                Poll::Ready(Err(e))
1790            }
1791            _ => Poll::Ready(Ok(())),
1792        }
1793    }
1794
1795    fn reset_timer(&mut self, next_timeout: Duration) {
1796        let now = TokioInstant::from_std(Instant::now());
1797
1798        self.timer.as_mut().reset(now + next_timeout);
1799    }
1800
1801    fn check_timeout(
1802        &mut self,
1803        cx: &mut Context<'_>,
1804        next_timeout: u64,
1805    ) -> Poll<Result<()>> {
1806        if self.timer.is_elapsed() {
1807            debug!("receive timeout detected");
1808
1809            match poll_unpin!(self.handle_timeout(next_timeout), cx) {
1810                Poll::Pending => todo!("socket buffer full"),
1811                Poll::Ready(Ok(())) => {
1812                    self.reset_timer(Duration::from_millis(next_timeout));
1813
1814                    ready_unpin!(self.timer, cx);
1815
1816                    Poll::Pending
1817                }
1818                Poll::Ready(Err(e)) => {
1819                    debug!("remote peer timed out too many times");
1820                    self.sender
1821                        .send(Err(e.kind().into()))
1822                        .expect("failed to propagate");
1823
1824                    Poll::Ready(Err(e))
1825                }
1826            }
1827        } else {
1828            ready_unpin!(self.timer, cx);
1829            Poll::Pending
1830        }
1831    }
1832}
1833
1834impl Future for UtpStreamDriver {
1835    type Output = Result<()>;
1836
1837    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
1838        let sender = self.sender.clone();
1839        let mut socket = ready_unpin!(self.socket.lock(), cx);
1840        let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
1841
1842        loop {
1843            debug!("stream driver poll attempt");
1844
1845            if socket.state == SocketState::Closed {
1846                debug!("socket is closed when attempting poll, killing driver");
1847
1848                mem::drop(socket);
1849
1850                self.notify_close();
1851
1852                return Poll::Ready(Ok(()));
1853            }
1854
1855            match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
1856                Poll::Ready(Ok((read, src))) => {
1857                    if let Ok(packet) = Packet::try_from(&buf[..read]) {
1858                        debug!("received packet {:?}", packet);
1859
1860                        match socket.handle_packet(&packet, src) {
1861                            Ok(Some(reply)) => {
1862                                if let PacketType::Data = packet.get_type() {
1863                                    socket.insert_into_buffer(packet);
1864
1865                                    // notify socket that data is available
1866                                    if sender.send(Ok(())).is_err() {
1867                                        debug!(
1868                                            "dropped socket, killing driver"
1869                                        );
1870                                        return Poll::Ready(Ok(()));
1871                                    }
1872                                }
1873
1874                                if Self::send_reply(&mut socket, reply, cx)
1875                                    .is_pending()
1876                                {
1877                                    return Poll::Pending;
1878                                }
1879                            }
1880                            Ok(None) => ready_try_unpin!(socket.send(), cx),
1881                            Err(e) => return Poll::Ready(Err(e)),
1882                        }
1883                    }
1884                }
1885                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1886                Poll::Pending => {
1887                    let next_timeout = socket.congestion_timeout * 2;
1888
1889                    mem::drop(socket);
1890
1891                    return self.check_timeout(cx, next_timeout);
1892                }
1893            }
1894        }
1895    }
1896}
1897
1898impl Drop for UtpSocket {
1899    fn drop(&mut self) {
1900        let _ = self.close();
1901    }
1902}
1903
1904const MTU: usize = 1500;
1905
1906/// A buffered `UtpStream` to avoid making too many system calls when sending
1907/// small amounts of data through utp
1908pub struct BufferedUtpStream {
1909    stream: BufReader<UtpStream>,
1910}
1911
1912impl BufferedUtpStream {
1913    /// Create a new `BufferedUtpStream` from an open `UtpStream`
1914    pub fn new(stream: UtpStream) -> Self {
1915        Self {
1916            stream: BufReader::with_capacity(MTU, stream),
1917        }
1918    }
1919
1920    fn get_stream(self: Pin<&mut Self>) -> Pin<&mut BufReader<UtpStream>> {
1921        unsafe { self.map_unchecked_mut(|s| &mut s.stream) }
1922    }
1923
1924    /// Get the local address for this `BufferedUtpStream`
1925    pub fn local_addr(&self) -> Result<SocketAddr> {
1926        Ok(self.stream.get_ref().local_addr())
1927    }
1928
1929    /// Get the peer address for this `BufferedUtpStream`
1930    pub fn peer_addr(&self) -> Result<SocketAddr> {
1931        Ok(self.stream.get_ref().peer_addr())
1932    }
1933}
1934
1935impl AsyncRead for BufferedUtpStream {
1936    fn poll_read(
1937        self: Pin<&mut Self>,
1938        cx: &mut Context,
1939        buf: &mut ReadBuf,
1940    ) -> Poll<Result<()>> {
1941        // bypass buffering when reading since UtpStream already buffers
1942        self.get_stream().get_pin_mut().poll_read(cx, buf)
1943    }
1944}
1945
1946impl AsyncWrite for BufferedUtpStream {
1947    fn poll_write(
1948        self: Pin<&mut Self>,
1949        cx: &mut Context,
1950        buf: &[u8],
1951    ) -> Poll<Result<usize>> {
1952        self.get_stream().poll_write(cx, buf)
1953    }
1954
1955    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
1956        self.get_stream().poll_flush(cx)
1957    }
1958
1959    fn poll_shutdown(
1960        self: Pin<&mut Self>,
1961        cx: &mut Context,
1962    ) -> Poll<Result<()>> {
1963        self.get_stream().poll_shutdown(cx)
1964    }
1965}
1966
1967#[cfg(test)]
1968mod test {
1969    use std::env;
1970    use std::io::ErrorKind;
1971    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
1972
1973    use std::sync::atomic::Ordering;
1974
1975    use super::*;
1976    use crate::socket::{SocketState, UtpSocket, BUF_SIZE};
1977    use crate::time::now_microseconds;
1978
1979    use tokio::io::{AsyncReadExt, AsyncWriteExt};
1980    use tokio::task;
1981    use tokio::time::interval;
1982
1983    use tracing::debug_span;
1984    use tracing_futures::Instrument;
1985    use tracing_subscriber::FmtSubscriber;
1986
1987    macro_rules! iotry {
1988        ($e:expr) => {
1989            match $e.await {
1990                Ok(e) => e,
1991                Err(e) => panic!("{:?}", e),
1992            }
1993        };
1994    }
1995
1996    fn init_logger() {
1997        if let Some(level) = env::var("RUST_LOG").ok().map(|x| x.parse().ok()) {
1998            let subscriber =
1999                FmtSubscriber::builder().with_max_level(level).finish();
2000
2001            let _ = tracing::subscriber::set_global_default(subscriber);
2002        }
2003    }
2004
2005    fn next_test_port() -> u16 {
2006        use std::sync::atomic::AtomicUsize;
2007        static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
2008        const BASE_PORT: u16 = 9600;
2009        BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
2010    }
2011
2012    fn next_test_ip4() -> SocketAddr {
2013        ("127.0.0.1".parse::<Ipv4Addr>().unwrap(), next_test_port()).into()
2014    }
2015
2016    fn next_test_ip6() -> SocketAddr {
2017        ("::1".parse::<Ipv6Addr>().unwrap(), next_test_port()).into()
2018    }
2019
2020    async fn stream_accept(server_addr: SocketAddr) -> UtpStream {
2021        let (stream, driver) = UtpSocketRef::bind(server_addr)
2022            .await
2023            .expect("failed to bind")
2024            .accept()
2025            .await
2026            .expect("failed to accept");
2027
2028        task::spawn(driver.instrument(debug_span!("stream_driver")));
2029
2030        stream
2031    }
2032
2033    async fn stream_connect(local: SocketAddr, peer: SocketAddr) -> UtpStream {
2034        let socket = UtpSocketRef::bind(local).await.expect("failed to bind");
2035        let (stream, driver) =
2036            socket.connect(peer).await.expect("failed to connect");
2037
2038        task::spawn(driver.instrument(debug_span!("stream_driver")));
2039
2040        stream
2041    }
2042
2043    #[tokio::test]
2044    async fn stream_fast_resend_active() {
2045        init_logger();
2046        let server_addr = next_test_ip4();
2047        let client_addr = next_test_ip4();
2048        const DATA: u8 = 2;
2049        const LEN: usize = 345;
2050
2051        let socket =
2052            UtpSocketRef::bind(server_addr).await.expect("bind failed");
2053
2054        let handle = task::spawn(async {
2055            let buf = [DATA; LEN];
2056            let (mut stream, driver) =
2057                socket.accept().await.expect("accept failed");
2058
2059            task::spawn(driver);
2060
2061            stream.write_all(&buf).await.expect("write failed");
2062            stream.shutdown().await.expect("shutdown failed");
2063        });
2064
2065        let (mut stream, driver) = UtpSocketRef::bind(client_addr)
2066            .await
2067            .expect("bind failed")
2068            .connect(server_addr)
2069            .await
2070            .expect("connect failed");
2071
2072        {
2073            let mut lock = stream.socket.lock().await;
2074            let mut buf = [0u8; LEN];
2075
2076            // intentionaly drop the received packet to trigger fast_resend
2077            lock.recv_from(&mut buf).await.expect("read failed");
2078        }
2079
2080        task::spawn(driver);
2081
2082        stream.shutdown().await.expect("close failed");
2083
2084        handle.await.expect("task failure");
2085    }
2086
2087    #[tokio::test]
2088    async fn stream_connect_disconnect() {
2089        init_logger();
2090        let server_addr = next_test_ip4();
2091        let client_addr = next_test_ip4();
2092
2093        let handle = task::spawn(async move {
2094            let mut stream = stream_accept(server_addr).await;
2095
2096            stream.shutdown().await.expect("failed to close");
2097        });
2098
2099        let mut stream = stream_connect(client_addr, server_addr).await;
2100
2101        stream.shutdown().await.expect("failed to close connection");
2102
2103        handle.await.expect("task failure");
2104    }
2105
2106    #[tokio::test]
2107    #[ignore]
2108    async fn stream_packet_split() {
2109        init_logger();
2110
2111        let server_addr = next_test_ip4();
2112        let client_addr = next_test_ip4();
2113        const LEN: usize = 2000;
2114        const DATA: u8 = 1;
2115
2116        let handle = task::spawn(async move {
2117            let mut stream = stream_accept(server_addr)
2118                .instrument(debug_span!("server"))
2119                .await;
2120
2121            let mut buf = [0u8; LEN];
2122
2123            stream
2124                .read_exact(&mut buf)
2125                .instrument(debug_span!("server_read_exact"))
2126                .await
2127                .expect("read failed");
2128
2129            for b in &buf[..] {
2130                assert_eq!(*b, DATA, "data was altered");
2131            }
2132
2133            stream
2134                .shutdown()
2135                .instrument(debug_span!("server_shutdown"))
2136                .await
2137                .expect("flush failed");
2138        });
2139
2140        let mut stream = stream_connect(client_addr, server_addr)
2141            .instrument(debug_span!("client"))
2142            .await;
2143        let buf = [DATA; LEN];
2144
2145        stream
2146            .write_all(&buf)
2147            .instrument(debug_span!("client_write_all"))
2148            .await
2149            .expect("write failed");
2150
2151        stream
2152            .shutdown()
2153            .instrument(debug_span!("client_shutdown"))
2154            .await
2155            .expect("close failed");
2156
2157        handle.await.expect("task failure")
2158    }
2159
2160    #[tokio::test]
2161    #[ignore]
2162    async fn stream_closed_write() {
2163        init_logger();
2164        let server_addr = next_test_ip4();
2165        let client_addr = next_test_ip4();
2166
2167        const LEN: usize = 1240;
2168        const DATA: u8 = 12;
2169
2170        let handle = task::spawn(async move {
2171            let mut stream = stream_accept(server_addr)
2172                .instrument(debug_span!("server"))
2173                .await;
2174            let mut buf = [0u8; LEN];
2175
2176            stream
2177                .read_exact(&mut buf)
2178                .instrument(debug_span!("server_read_exact"))
2179                .await
2180                .expect("read failed");
2181
2182            stream
2183                .shutdown()
2184                .instrument(debug_span!("server_shutdown"))
2185                .await
2186                .expect("shutdown failed");
2187
2188            stream
2189                .read_exact(&mut buf)
2190                .instrument(debug_span!("server_closed_read"))
2191                .await
2192                .expect_err("read on closed stream");
2193        });
2194
2195        let mut stream = stream_connect(client_addr, server_addr)
2196            .instrument(debug_span!("client"))
2197            .await;
2198        let buf = [DATA; LEN];
2199
2200        stream
2201            .write_all(&buf)
2202            .instrument(debug_span!("client_write_all"))
2203            .await
2204            .expect("write failed");
2205        stream
2206            .shutdown()
2207            .instrument(debug_span!("client_shutdown"))
2208            .await
2209            .expect("shutdown failed");
2210
2211        stream
2212            .write_all(&buf)
2213            .instrument(debug_span!("client_closed_write"))
2214            .await
2215            .expect_err("wrote on closed stream");
2216
2217        handle.await.expect("execution failure");
2218    }
2219
2220    #[tokio::test]
2221    async fn stream_fast_resend_idle() {
2222        init_logger();
2223        let server = next_test_ip4();
2224        let client = next_test_ip4();
2225
2226        let handle = task::spawn(async move {
2227            let mut stream = stream_accept(server).await;
2228
2229            let mut timer = interval(Duration::from_secs(3));
2230
2231            timer.tick().await;
2232
2233            stream.shutdown().await.expect("close failed");
2234        });
2235
2236        let mut stream = stream_connect(client, server).await;
2237        let mut timer = interval(Duration::from_secs(4));
2238
2239        timer.tick().await;
2240
2241        stream.shutdown().await.expect("close failed");
2242
2243        handle.await.expect("task failed");
2244    }
2245
2246    #[tokio::test]
2247    async fn stream_clean_close() {
2248        init_logger();
2249        let server_addr = next_test_ip4();
2250        let client_addr = next_test_ip4();
2251
2252        const DATA: u8 = 1;
2253        const LEN: usize = 1024;
2254
2255        let handle = task::spawn(async move {
2256            let mut stream = stream_accept(server_addr)
2257                .instrument(debug_span!("stream_accept"))
2258                .await;
2259            let buf = [DATA; LEN];
2260
2261            stream
2262                .write_all(&buf)
2263                .instrument(debug_span!("server_write"))
2264                .await
2265                .expect("write failed");
2266
2267            stream
2268                .shutdown()
2269                .instrument(debug_span!("server_shutdown"))
2270                .await
2271                .expect("shutdown failed");
2272        });
2273
2274        let mut socket = stream_connect(client_addr, server_addr)
2275            .instrument(debug_span!("stream_connect"))
2276            .await;
2277        let mut buf = [0u8; LEN];
2278
2279        socket
2280            .read_exact(&mut buf)
2281            .instrument(debug_span!("client_read"))
2282            .await
2283            .expect("read failed");
2284
2285        socket
2286            .shutdown()
2287            .instrument(debug_span!("client_shutdown"))
2288            .await
2289            .expect("shutdown failed");
2290
2291        handle.await.expect("task panic");
2292    }
2293
2294    #[tokio::test]
2295    async fn stream_connect_timeout() {
2296        init_logger();
2297        let server_addr = next_test_ip4();
2298        let client_addr = next_test_ip4();
2299
2300        let socket =
2301            UtpSocketRef::bind(client_addr).await.expect("bind failed");
2302
2303        socket.0.lock().await.congestion_timeout = 100;
2304
2305        assert!(
2306            socket.connect(server_addr).await.is_err(),
2307            "connected to void"
2308        );
2309    }
2310
2311    #[tokio::test]
2312    async fn stream_read_timeout() {
2313        init_logger();
2314        let server_addr = next_test_ip4();
2315        let client_addr = next_test_ip4();
2316
2317        let handle = task::spawn(async move {
2318            let sock =
2319                UtpSocketRef::bind(server_addr).await.expect("bind failed");
2320
2321            // ignore driver so that this stream doesn't answer to packets
2322            let _ = sock.accept().await.expect("accept failed");
2323        });
2324
2325        let mut socket = stream_connect(client_addr, server_addr).await;
2326        let mut buf = [0u8; 1024];
2327
2328        socket.socket.lock().await.congestion_timeout = 100;
2329
2330        socket
2331            .read_exact(&mut buf)
2332            .await
2333            .expect_err("read from non responding peer");
2334
2335        handle.await.expect("task panic");
2336    }
2337
2338    #[tokio::test]
2339    async fn stream_write_timeout() {
2340        init_logger();
2341
2342        let (server, client) = (next_test_ip4(), next_test_ip4());
2343        const DATA: u8 = 45;
2344        const LEN: usize = 123;
2345
2346        let handle = task::spawn(async move {
2347            let mut stream = stream_accept(server).await;
2348            let buf = [DATA; LEN];
2349
2350            stream.socket.lock().await.congestion_timeout = 100;
2351
2352            stream
2353                .write_all(&buf)
2354                .await
2355                .expect("packets weren't buffered");
2356            stream
2357                .flush()
2358                .await
2359                .expect_err("flush succeeded without ack");
2360        });
2361
2362        let sock = UtpSocketRef::bind(client).await.expect("bind failed");
2363        let _ = sock.connect(server).await.expect("connect failed");
2364
2365        handle.await.expect("execution failure");
2366    }
2367
2368    #[tokio::test]
2369    #[ignore]
2370    async fn stream_flush_then_send() {
2371        init_logger();
2372        let server_addr = next_test_ip4();
2373        let client_addr = next_test_ip4();
2374
2375        const LEN: usize = 1240;
2376        const DATA: u8 = 25;
2377
2378        let handle = task::spawn(async move {
2379            let mut stream = stream_accept(server_addr).await;
2380            let mut buf = [0u8; 2 * LEN];
2381
2382            stream.read_exact(&mut buf).await.expect("failed to read");
2383
2384            for b in buf.iter() {
2385                assert_eq!(*b, DATA, "data corrupted");
2386            }
2387
2388            stream.flush().await.expect("flush failed");
2389            stream.shutdown().await.expect("shutdown failed");
2390        });
2391
2392        let mut stream = stream_connect(client_addr, server_addr).await;
2393        let buf = [DATA; LEN];
2394
2395        stream.write_all(&buf).await.expect("write failed");
2396
2397        stream.flush().await.expect("flush failed");
2398
2399        stream.write_all(&buf).await.expect("write failed");
2400        stream.shutdown().await.expect("shutdown failed");
2401
2402        handle.await.expect("task failure");
2403    }
2404
2405    #[tokio::test]
2406    async fn test_socket_ipv4() {
2407        let server_addr = next_test_ip4();
2408
2409        let handle = task::spawn(async move {
2410            let mut server = iotry!(UtpSocket::bind(server_addr));
2411            assert_eq!(server.state, SocketState::New);
2412
2413            let mut buf = [0u8; BUF_SIZE];
2414            match server.recv_from(&mut buf).await {
2415                e => println!("{:?}", e),
2416            }
2417            // After establishing a new connection, the server's ids are a
2418            // mirror of the client's.
2419            assert_eq!(
2420                server.receiver_connection_id,
2421                server.sender_connection_id + 1
2422            );
2423
2424            assert_eq!(server.state, SocketState::Closed);
2425            drop(server);
2426        });
2427
2428        let mut client = iotry!(UtpSocket::connect(server_addr));
2429        assert_eq!(client.state, SocketState::Connected);
2430        // Check proper difference in client's send connection id and receive
2431        // connection id
2432        assert_eq!(
2433            client.sender_connection_id,
2434            client.receiver_connection_id + 1
2435        );
2436        assert_eq!(
2437            client.connected_to,
2438            server_addr.to_socket_addrs().unwrap().next().unwrap()
2439        );
2440        iotry!(client.close());
2441
2442        handle.await.expect("task failure");
2443    }
2444
2445    #[ignore]
2446    #[tokio::test]
2447    async fn test_socket_ipv6() {
2448        let server_addr = next_test_ip6();
2449
2450        let mut server = iotry!(UtpSocket::bind(server_addr));
2451        assert_eq!(server.state, SocketState::New);
2452
2453        task::spawn(async move {
2454            let mut client = iotry!(UtpSocket::connect(server_addr));
2455            assert_eq!(client.state, SocketState::Connected);
2456            // Check proper difference in client's send connection id and
2457            // receive connection id
2458            assert_eq!(
2459                client.sender_connection_id,
2460                client.receiver_connection_id + 1
2461            );
2462            assert_eq!(
2463                client.connected_to,
2464                server_addr.to_socket_addrs().unwrap().next().unwrap()
2465            );
2466            iotry!(client.close());
2467            drop(client);
2468        });
2469
2470        let mut buf = [0u8; BUF_SIZE];
2471        match server.recv_from(&mut buf).await {
2472            e => println!("{:?}", e),
2473        }
2474        // After establishing a new connection, the server's ids are a mirror of
2475        // the client's.
2476        assert_eq!(
2477            server.receiver_connection_id,
2478            server.sender_connection_id + 1
2479        );
2480
2481        assert_eq!(server.state, SocketState::Closed);
2482        drop(server);
2483    }
2484
2485    #[tokio::test]
2486    async fn test_recvfrom_on_closed_socket() {
2487        let server_addr = next_test_ip4();
2488
2489        let mut server = iotry!(UtpSocket::bind(server_addr));
2490        assert_eq!(server.state, SocketState::New);
2491
2492        let handle = task::spawn(async move {
2493            let mut client = iotry!(UtpSocket::connect(server_addr));
2494            assert_eq!(client.state, SocketState::Connected);
2495            assert!(client.close().await.is_ok());
2496        });
2497
2498        // Make the server listen for incoming connections until the end of the
2499        // input
2500        let mut buf = [0u8; BUF_SIZE];
2501        let _resp = server.recv_from(&mut buf).await;
2502        assert_eq!(server.state, SocketState::Closed);
2503
2504        // Trying to receive again returns `Ok(0)` (equivalent to the old
2505        // `EndOfFile`)
2506        match server.recv_from(&mut buf).await {
2507            Ok((0, _src)) => {}
2508            e => panic!("Expected Ok(0), got {:?}", e),
2509        }
2510        assert_eq!(server.state, SocketState::Closed);
2511
2512        handle.await.expect("task failure");
2513    }
2514
2515    #[tokio::test]
2516    async fn test_sendto_on_closed_socket() {
2517        init_logger();
2518        let server_addr = next_test_ip4();
2519
2520        let mut server = iotry!(UtpSocket::bind(server_addr));
2521        assert_eq!(server.state, SocketState::New);
2522
2523        let handle = task::spawn(async move {
2524            let mut client = iotry!(UtpSocket::connect(server_addr));
2525            assert_eq!(client.state, SocketState::Connected);
2526            iotry!(client.close());
2527        });
2528
2529        // Make the server listen for incoming connections
2530        let mut buf = [0u8; BUF_SIZE];
2531        let (_read, _src) = iotry!(server.recv_from(&mut buf));
2532        assert_eq!(server.state, SocketState::Closed);
2533
2534        // Trying to send to the socket after closing it raises an error
2535        match server.send_to(&buf).await {
2536            Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
2537            v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
2538        }
2539
2540        handle.await.expect("task failure");
2541    }
2542
2543    #[tokio::test]
2544    async fn test_acks_on_socket() {
2545        use tokio::sync::mpsc::channel;
2546
2547        init_logger();
2548        let server_addr = next_test_ip4();
2549        let (tx, mut rx) = channel(1);
2550
2551        let mut server = iotry!(UtpSocket::bind(server_addr));
2552
2553        let handle = task::spawn(async move {
2554            // Make the server listen for incoming connections
2555            let mut buf = [0u8; BUF_SIZE];
2556            let mut buf = ReadBuf::new(&mut buf);
2557            let _resp = server.recv(&mut buf).await.unwrap();
2558            tx.send(server.seq_nr).await.expect("channel closed");
2559
2560            // Close the connection
2561            let mut buf = [0; 1500];
2562            iotry!(server.recv_from(&mut buf));
2563        });
2564
2565        let mut client = iotry!(UtpSocket::connect(server_addr));
2566        assert_eq!(client.state, SocketState::Connected);
2567        let sender_seq_nr = rx.recv().await.expect("channel closed");
2568        let ack_nr = client.ack_nr;
2569        assert_eq!(ack_nr, sender_seq_nr);
2570        assert!(client.close().await.is_ok());
2571
2572        // The reply to both connect (SYN) and close (FIN) should be
2573        // STATE packets, which don't increase the sequence number
2574        // and, hence, the receiver's acknowledgement number.
2575        assert_eq!(client.ack_nr, ack_nr);
2576        drop(client);
2577
2578        handle.await.expect("task failure");
2579    }
2580
2581    #[tokio::test]
2582    async fn test_handle_packet() {
2583        //fn test_connection_setup() {
2584        let initial_connection_id: u16 = rand::random();
2585        let sender_connection_id = initial_connection_id + 1;
2586        let (server_addr, client_addr) = (
2587            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2588            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2589        );
2590        let mut socket = iotry!(UtpSocket::bind(server_addr));
2591
2592        let mut packet = Packet::new();
2593        packet.set_wnd_size(BUF_SIZE as u32);
2594        packet.set_type(PacketType::Syn);
2595        packet.set_connection_id(initial_connection_id);
2596
2597        // Do we have a response?
2598        let response = socket.handle_packet(&packet, client_addr);
2599        assert!(response.is_ok());
2600        let response = response.unwrap();
2601        assert!(response.is_some());
2602
2603        // Is is of the correct type?
2604        let response = response.unwrap();
2605        assert_eq!(response.get_type(), PacketType::State);
2606
2607        // Same connection id on both ends during connection establishment
2608        assert_eq!(response.connection_id(), packet.connection_id());
2609
2610        // Response acknowledges SYN
2611        assert_eq!(response.ack_nr(), packet.seq_nr());
2612
2613        // No payload?
2614        assert!(response.payload().is_empty());
2615        //}
2616
2617        // ---------------------------------
2618
2619        // fn test_connection_usage() {
2620        let old_packet = packet;
2621        let old_response = response;
2622
2623        let mut packet = Packet::new();
2624        packet.set_type(PacketType::Data);
2625        packet.set_connection_id(sender_connection_id);
2626        packet.set_seq_nr(old_packet.seq_nr() + 1);
2627        packet.set_ack_nr(old_response.seq_nr());
2628
2629        let response = socket.handle_packet(&packet, client_addr);
2630        assert!(response.is_ok());
2631        let response = response.unwrap();
2632        assert!(response.is_some());
2633
2634        let response = response.unwrap();
2635        assert_eq!(response.get_type(), PacketType::State);
2636
2637        // Sender (i.e., who the initiated connection and sent a SYN) has
2638        // connection id equal to initial connection id + 1
2639        // Receiver (i.e., who accepted connection) has connection id equal to
2640        // initial connection id
2641        assert_eq!(response.connection_id(), initial_connection_id);
2642        assert_eq!(response.connection_id(), packet.connection_id() - 1);
2643
2644        // Previous packets should be ack'ed
2645        assert_eq!(response.ack_nr(), packet.seq_nr());
2646
2647        // Responses with no payload should not increase the sequence number
2648        assert!(response.payload().is_empty());
2649        assert_eq!(response.seq_nr(), old_response.seq_nr());
2650        // }
2651
2652        //fn test_connection_teardown() {
2653        let old_packet = packet;
2654        let old_response = response;
2655
2656        let mut packet = Packet::new();
2657        packet.set_type(PacketType::Fin);
2658        packet.set_connection_id(sender_connection_id);
2659        packet.set_seq_nr(old_packet.seq_nr() + 1);
2660        packet.set_ack_nr(old_response.seq_nr());
2661
2662        let response = socket.handle_packet(&packet, client_addr);
2663        assert!(response.is_ok());
2664        let response = response.unwrap();
2665        assert!(response.is_some());
2666
2667        let response = response.unwrap();
2668
2669        assert_eq!(response.get_type(), PacketType::State);
2670
2671        // FIN packets have no payload but the sequence number shouldn't increase
2672        assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
2673
2674        // Nor should the ACK packet's sequence number
2675        assert_eq!(response.seq_nr(), old_response.seq_nr());
2676
2677        // FIN should be acknowledged
2678        assert_eq!(response.ack_nr(), packet.seq_nr());
2679    }
2680
2681    #[tokio::test]
2682    async fn test_response_to_keepalive_ack() {
2683        // Boilerplate test setup
2684        let initial_connection_id: u16 = rand::random();
2685        let (server_addr, client_addr) = (
2686            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2687            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2688        );
2689        let mut socket = iotry!(UtpSocket::bind(server_addr));
2690
2691        // Establish connection
2692        let mut packet = Packet::new();
2693        packet.set_wnd_size(BUF_SIZE as u32);
2694        packet.set_type(PacketType::Syn);
2695        packet.set_connection_id(initial_connection_id);
2696
2697        let response = socket.handle_packet(&packet, client_addr);
2698        assert!(response.is_ok());
2699        let response = response.unwrap();
2700        assert!(response.is_some());
2701        let response = response.unwrap();
2702        assert_eq!(response.get_type(), PacketType::State);
2703
2704        let old_packet = packet;
2705        let old_response = response;
2706
2707        // Now, send a keepalive packet
2708        let mut packet = Packet::new();
2709        packet.set_wnd_size(BUF_SIZE as u32);
2710        packet.set_type(PacketType::State);
2711        packet.set_connection_id(initial_connection_id);
2712        packet.set_seq_nr(old_packet.seq_nr() + 1);
2713        packet.set_ack_nr(old_response.seq_nr());
2714
2715        let response = socket.handle_packet(&packet, client_addr);
2716        assert!(response.is_ok());
2717        let response = response.unwrap();
2718        assert!(response.is_none());
2719
2720        // Send a second keepalive packet, identical to the previous one
2721        let response = socket.handle_packet(&packet, client_addr);
2722        assert!(response.is_ok());
2723        let response = response.unwrap();
2724        assert!(response.is_none());
2725
2726        // Mark socket as closed
2727        socket.state = SocketState::Closed;
2728    }
2729
2730    #[tokio::test]
2731    async fn test_response_to_wrong_connection_id() {
2732        // Boilerplate test setup
2733        let initial_connection_id: u16 = rand::random();
2734        let (server_addr, client_addr) = (
2735            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2736            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2737        );
2738        let mut socket = iotry!(UtpSocket::bind(server_addr));
2739
2740        // Establish connection
2741        let mut packet = Packet::new();
2742        packet.set_wnd_size(BUF_SIZE as u32);
2743        packet.set_type(PacketType::Syn);
2744        packet.set_connection_id(initial_connection_id);
2745
2746        let response = socket.handle_packet(&packet, client_addr);
2747        assert!(response.is_ok());
2748        let response = response.unwrap();
2749        assert!(response.is_some());
2750        assert_eq!(response.unwrap().get_type(), PacketType::State);
2751
2752        // Now, disrupt connection with a packet with an incorrect connection id
2753        let new_connection_id = initial_connection_id.wrapping_mul(2);
2754
2755        let mut packet = Packet::new();
2756        packet.set_wnd_size(BUF_SIZE as u32);
2757        packet.set_type(PacketType::State);
2758        packet.set_connection_id(new_connection_id);
2759
2760        let response = socket.handle_packet(&packet, client_addr);
2761        assert!(response.is_ok());
2762        let response = response.unwrap();
2763        assert!(response.is_some());
2764
2765        let response = response.unwrap();
2766        assert_eq!(response.get_type(), PacketType::Reset);
2767        assert_eq!(response.ack_nr(), packet.seq_nr());
2768
2769        // Mark socket as closed
2770        socket.state = SocketState::Closed;
2771    }
2772
2773    #[tokio::test]
2774    async fn test_unordered_packets() {
2775        // Boilerplate test setup
2776        let initial_connection_id: u16 = rand::random();
2777        let (server_addr, client_addr) = (
2778            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2779            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2780        );
2781        let mut socket = iotry!(UtpSocket::bind(server_addr));
2782
2783        // Establish connection
2784        let mut packet = Packet::new();
2785        packet.set_wnd_size(BUF_SIZE as u32);
2786        packet.set_type(PacketType::Syn);
2787        packet.set_connection_id(initial_connection_id);
2788
2789        let response = socket.handle_packet(&packet, client_addr);
2790        assert!(response.is_ok());
2791        let response = response.unwrap();
2792        assert!(response.is_some());
2793        let response = response.unwrap();
2794        assert_eq!(response.get_type(), PacketType::State);
2795
2796        let old_packet = packet;
2797        let old_response = response;
2798
2799        let mut window: Vec<Packet> = Vec::new();
2800
2801        // Now, send a keepalive packet
2802        let mut packet = Packet::with_payload(&[1, 2, 3]);
2803        packet.set_wnd_size(BUF_SIZE as u32);
2804        packet.set_connection_id(initial_connection_id);
2805        packet.set_seq_nr(old_packet.seq_nr() + 1);
2806        packet.set_ack_nr(old_response.seq_nr());
2807        window.push(packet);
2808
2809        let mut packet = Packet::with_payload(&[4, 5, 6]);
2810        packet.set_wnd_size(BUF_SIZE as u32);
2811        packet.set_connection_id(initial_connection_id);
2812        packet.set_seq_nr(old_packet.seq_nr() + 2);
2813        packet.set_ack_nr(old_response.seq_nr());
2814        window.push(packet);
2815
2816        // Send packets in reverse order
2817        let response = socket.handle_packet(&window[1], client_addr);
2818        assert!(response.is_ok());
2819        let response = response.unwrap();
2820        assert!(response.is_some());
2821        let response = response.unwrap();
2822        assert!(response.ack_nr() != window[1].seq_nr());
2823
2824        let response = socket.handle_packet(&window[0], client_addr);
2825        assert!(response.is_ok());
2826        let response = response.unwrap();
2827        assert!(response.is_some());
2828
2829        // Mark socket as closed
2830        socket.state = SocketState::Closed;
2831    }
2832
2833    #[tokio::test]
2834    async fn test_response_to_triple_ack() {
2835        let server_addr = next_test_ip4();
2836        let mut server = iotry!(UtpSocket::bind(server_addr));
2837
2838        // Fits in a packet
2839        const LEN: usize = 1024;
2840        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2841        let d = data.clone();
2842        assert_eq!(LEN, data.len());
2843
2844        let handle = task::spawn(async move {
2845            let mut client = iotry!(UtpSocket::connect(server_addr));
2846            iotry!(client.send_to(&d[..]));
2847            iotry!(client.close());
2848        });
2849
2850        let mut buf = [0; BUF_SIZE];
2851        let mut buf = ReadBuf::new(&mut buf);
2852        // Expect SYN
2853        iotry!(server.recv(&mut buf));
2854
2855        // Receive data
2856        let data_packet =
2857            match server.socket.recv_from(buf.initialized_mut()).await {
2858                Ok((_, _src)) => Packet::try_from(buf.filled()).unwrap(),
2859                Err(e) => panic!("{}", e),
2860            };
2861        assert_eq!(data_packet.get_type(), PacketType::Data);
2862        assert_eq!(&data_packet.payload(), &data.as_slice());
2863        assert_eq!(data_packet.payload().len(), data.len());
2864
2865        // Send triple ACK
2866        let mut packet = Packet::new();
2867        packet.set_wnd_size(BUF_SIZE as u32);
2868        packet.set_type(PacketType::State);
2869        packet.set_seq_nr(server.seq_nr);
2870        packet.set_ack_nr(data_packet.seq_nr() - 1);
2871        packet.set_connection_id(server.sender_connection_id);
2872
2873        for _ in 0..3usize {
2874            iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
2875        }
2876
2877        // Receive data again and check that it's the same we reported as missing
2878        let client_addr = server.connected_to;
2879
2880        let mut buf = [0; BUF_SIZE];
2881
2882        match server.socket.recv_from(&mut buf).await {
2883            Ok((0, _)) => panic!("Received 0 bytes from socket"),
2884            Ok((read, _src)) => {
2885                let packet = Packet::try_from(&buf[..read]).unwrap();
2886                assert_eq!(packet.get_type(), PacketType::Data);
2887                assert_eq!(packet.seq_nr(), data_packet.seq_nr());
2888                assert_eq!(packet.payload(), data_packet.payload());
2889                let response = server.handle_packet(&packet, client_addr);
2890                assert!(response.is_ok());
2891                let response = response.unwrap();
2892                assert!(response.is_some());
2893                let response = response.unwrap();
2894                iotry!(server
2895                    .socket
2896                    .send_to(response.as_ref(), server.connected_to));
2897            }
2898            Err(e) => panic!("{}", e),
2899        }
2900
2901        // Receive close
2902        let mut buf = [0; 1500];
2903        iotry!(server.recv_from(&mut buf));
2904
2905        handle.await.expect("task failure");
2906    }
2907
2908    #[ignore]
2909    #[tokio::test]
2910    async fn test_socket_timeout_request() {
2911        let (server_addr, client_addr) = (
2912            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2913            next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2914        );
2915
2916        let client = iotry!(UtpSocket::bind(client_addr));
2917        let mut server = iotry!(UtpSocket::bind(server_addr));
2918        const LEN: usize = 512;
2919        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2920        let d = data.clone();
2921
2922        assert_eq!(server.state, SocketState::New);
2923        assert_eq!(client.state, SocketState::New);
2924
2925        // Check proper difference in client's send connection id and receive
2926        // connection id
2927        assert_eq!(
2928            client.sender_connection_id,
2929            client.receiver_connection_id + 1
2930        );
2931
2932        let handle = task::spawn(async move {
2933            let mut client = iotry!(UtpSocket::connect(server_addr));
2934            assert_eq!(client.state, SocketState::Connected);
2935            assert_eq!(client.connected_to, server_addr);
2936            iotry!(client.send_to(&d[..]));
2937            drop(client);
2938        });
2939
2940        let mut buf = [0u8; BUF_SIZE];
2941        let mut buf = ReadBuf::new(&mut buf);
2942        server.recv(&mut buf).await.unwrap();
2943        // After establishing a new connection, the server's ids are a mirror of
2944        // the client's.
2945        assert_eq!(
2946            server.receiver_connection_id,
2947            server.sender_connection_id + 1
2948        );
2949
2950        assert_eq!(server.state, SocketState::Connected);
2951
2952        // Purposefully read from UDP socket directly and discard it, in order
2953        // to behave as if the packet was lost and thus trigger the timeout
2954        // handling in the *next* call to `UtpSocket.recv_from`.
2955        let mut buf = [0; 1500];
2956        iotry!(server.socket.recv_from(&mut buf));
2957
2958        // Set a much smaller than usual timeout, for quicker test completion
2959        server.congestion_timeout = 50;
2960
2961        // Now wait for the previously discarded packet
2962        loop {
2963            match server.recv_from(&mut buf).await {
2964                Ok((0, _)) => continue,
2965                Ok(_) => break,
2966                Err(e) => panic!("{}", e),
2967            }
2968        }
2969
2970        drop(server);
2971
2972        handle.await.expect("task failure");
2973    }
2974
2975    #[tokio::test]
2976    async fn test_sorted_buffer_insertion() {
2977        let server_addr = next_test_ip4();
2978        let mut socket = iotry!(UtpSocket::bind(server_addr));
2979
2980        let mut packet = Packet::new();
2981        packet.set_seq_nr(1);
2982
2983        assert!(socket.incoming_buffer.is_empty());
2984
2985        socket.insert_into_buffer(packet.clone());
2986        assert_eq!(socket.incoming_buffer.len(), 1);
2987
2988        packet.set_seq_nr(2);
2989        packet.set_timestamp(128.into());
2990
2991        socket.insert_into_buffer(packet.clone());
2992        assert_eq!(socket.incoming_buffer.len(), 2);
2993        assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
2994        assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
2995
2996        packet.set_seq_nr(3);
2997        packet.set_timestamp(256.into());
2998
2999        socket.insert_into_buffer(packet.clone());
3000        assert_eq!(socket.incoming_buffer.len(), 3);
3001        assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
3002        assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
3003
3004        // Replacing a packet with a more recent version doesn't work
3005        packet.set_seq_nr(2);
3006        packet.set_timestamp(456.into());
3007
3008        socket.insert_into_buffer(packet);
3009        assert_eq!(socket.incoming_buffer.len(), 3);
3010        assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
3011        assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
3012    }
3013
3014    #[tokio::test]
3015    async fn test_duplicate_packet_handling() {
3016        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3017
3018        let client = iotry!(UtpSocket::bind(client_addr));
3019        let mut server = iotry!(UtpSocket::bind(server_addr));
3020
3021        assert_eq!(server.state, SocketState::New);
3022        assert_eq!(client.state, SocketState::New);
3023
3024        // Check proper difference in client's send connection id and receive
3025        // connection id
3026        assert_eq!(
3027            client.sender_connection_id,
3028            client.receiver_connection_id + 1
3029        );
3030
3031        let handle = task::spawn(async move {
3032            let mut client = iotry!(UtpSocket::connect(server_addr));
3033            assert_eq!(client.state, SocketState::Connected);
3034
3035            let mut packet = Packet::with_payload(&[1, 2, 3]);
3036            packet.set_wnd_size(BUF_SIZE as u32);
3037            packet.set_connection_id(client.sender_connection_id);
3038            packet.set_seq_nr(client.seq_nr);
3039            packet.set_ack_nr(client.ack_nr);
3040
3041            // Send two copies of the packet, with different timestamps
3042            for _ in 0..2usize {
3043                packet.set_timestamp(now_microseconds());
3044                iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3045            }
3046            client.seq_nr += 1;
3047
3048            // Receive one ACK
3049            for _ in 0..1usize {
3050                let mut buf = [0; BUF_SIZE];
3051                iotry!(client.socket.recv_from(&mut buf));
3052            }
3053
3054            iotry!(client.close());
3055        });
3056        let mut buf = [0u8; BUF_SIZE];
3057        let mut buf = ReadBuf::new(&mut buf);
3058        iotry!(server.recv(&mut buf));
3059        // After establishing a new connection, the server's ids are a mirror of
3060        // the client's.
3061        assert_eq!(
3062            server.receiver_connection_id,
3063            server.sender_connection_id + 1
3064        );
3065
3066        assert_eq!(server.state, SocketState::Connected);
3067
3068        let expected: Vec<u8> = vec![1, 2, 3];
3069        let mut received: Vec<u8> = vec![];
3070        loop {
3071            match server.recv_from(buf.initialized_mut()).await {
3072                Ok((0, _src)) => break,
3073                Ok((_, _src)) => received.extend(buf.filled().to_vec()),
3074                Err(e) => panic!("{:?}", e),
3075            }
3076        }
3077        assert_eq!(received.len(), expected.len());
3078        assert_eq!(received, expected);
3079
3080        handle.await.expect("task failure");
3081    }
3082
3083    #[tokio::test]
3084    async fn test_correct_packet_loss() {
3085        init_logger();
3086        let server_addr = next_test_ip4();
3087
3088        let mut server = iotry!(UtpSocket::bind(server_addr));
3089        const LEN: usize = 1024 * 10;
3090        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3091        let to_send = data.clone();
3092
3093        let handle = task::spawn(
3094            async move {
3095                let mut client = iotry!(UtpSocket::connect(server_addr));
3096
3097                // Send everything except the odd chunks
3098                let chunks = to_send[..].chunks(BUF_SIZE);
3099                let dst = client.connected_to;
3100                for (index, chunk) in chunks.enumerate() {
3101                    let mut packet = Packet::with_payload(chunk);
3102                    packet.set_seq_nr(client.seq_nr);
3103                    packet.set_ack_nr(client.ack_nr);
3104                    packet.set_connection_id(client.sender_connection_id);
3105                    packet.set_timestamp(now_microseconds());
3106
3107                    if index % 2 == 0 {
3108                        iotry!(client.socket.send_to(packet.as_ref(), dst));
3109                    }
3110
3111                    client.curr_window += packet.len() as u32;
3112                    client.send_window.push(packet);
3113                    client.seq_nr += 1;
3114                }
3115
3116                iotry!(client.close());
3117            }
3118            .instrument(debug_span!("sender")),
3119        );
3120
3121        let mut buf = [0; BUF_SIZE];
3122        let mut received: Vec<u8> = vec![];
3123        loop {
3124            match server.recv_from(&mut buf).await {
3125                Ok((0, _src)) => break,
3126                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3127                Err(e) => panic!("{}", e),
3128            }
3129        }
3130        assert_eq!(
3131            received.len(),
3132            data.len(),
3133            "wrong number of bytes received"
3134        );
3135        assert_eq!(received, data, "incorrect data received");
3136        handle.await.expect("task failure");
3137    }
3138
3139    #[tokio::test]
3140    async fn test_tolerance_to_small_buffers() {
3141        let server_addr = next_test_ip4();
3142        let mut server = iotry!(UtpSocket::bind(server_addr));
3143        const LEN: usize = 1024;
3144        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3145        let to_send = data.clone();
3146
3147        let handle = task::spawn(async move {
3148            let mut client = iotry!(UtpSocket::connect(server_addr));
3149            iotry!(client.send_to(&to_send[..]));
3150            iotry!(client.close());
3151        });
3152
3153        let mut read = Vec::new();
3154        while server.state != SocketState::Closed {
3155            let mut small_buffer = [0; 512];
3156            match server.recv_from(&mut small_buffer).await {
3157                Ok((0, _src)) => break,
3158                Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
3159                Err(e) => panic!("{}", e),
3160            }
3161        }
3162
3163        assert_eq!(read.len(), data.len());
3164        assert_eq!(read, data);
3165        handle.await.expect("task failure");
3166    }
3167
3168    #[tokio::test]
3169    async fn test_sequence_number_rollover() {
3170        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3171
3172        let mut server = iotry!(UtpSocket::bind(server_addr));
3173
3174        const LEN: usize = BUF_SIZE * 4;
3175        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3176        let to_send = data.clone();
3177
3178        let mut client = iotry!(UtpSocket::bind(client_addr));
3179
3180        // Advance socket's sequence number
3181        client.seq_nr =
3182            ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
3183
3184        let handle = task::spawn(async move {
3185            let mut client = iotry!(UtpSocket::connect(server_addr));
3186            // Send enough data to rollover
3187            iotry!(client.send_to(&to_send[..]));
3188            // Check that the sequence number did rollover
3189            assert!(client.seq_nr < 50);
3190            // Close connection
3191            iotry!(client.close());
3192        });
3193
3194        let mut buf = [0; BUF_SIZE];
3195        let mut received: Vec<u8> = vec![];
3196        loop {
3197            match server.recv_from(&mut buf).await {
3198                Ok((0, _src)) => break,
3199                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3200                Err(e) => panic!("{}", e),
3201            }
3202        }
3203        assert_eq!(received.len(), data.len());
3204        assert_eq!(received, data);
3205        handle.await.expect("task failure");
3206    }
3207
3208    #[tokio::test]
3209    async fn test_drop_unused_socket() {
3210        let server_addr = next_test_ip4();
3211        let server = iotry!(UtpSocket::bind(server_addr));
3212
3213        // Explicitly dropping socket. This test should not hang.
3214        drop(server);
3215    }
3216
3217    #[tokio::test]
3218    async fn test_invalid_packet_on_connect() {
3219        use tokio::net::UdpSocket;
3220        let server_addr = next_test_ip4();
3221        let server = iotry!(UdpSocket::bind(server_addr));
3222
3223        let handle = task::spawn(async move {
3224            match UtpSocket::connect(server_addr).await {
3225                Err(ref e) if e.kind() == ErrorKind::Other => (), // OK
3226                Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
3227                Ok(_) => panic!("Expected Err, got Ok"),
3228            }
3229        });
3230
3231        let mut buf = [0; BUF_SIZE];
3232        match server.recv_from(&mut buf).await {
3233            Ok((_len, client_addr)) => {
3234                iotry!(server.send_to(&[], client_addr));
3235            }
3236            _ => panic!(),
3237        }
3238
3239        handle.await.expect("task failure");
3240    }
3241
3242    #[tokio::test]
3243    async fn test_receive_unexpected_reply_type_on_connect() {
3244        use tokio::net::UdpSocket;
3245        let server_addr = next_test_ip4();
3246        let server = iotry!(UdpSocket::bind(server_addr));
3247
3248        let mut buf = [0; BUF_SIZE];
3249        let mut packet = Packet::new();
3250        packet.set_type(PacketType::Data);
3251
3252        let handle = task::spawn(async move {
3253            match server.recv_from(&mut buf).await {
3254                Ok((_len, client_addr)) => {
3255                    iotry!(server.send_to(packet.as_ref(), client_addr));
3256                }
3257                _ => panic!(),
3258            }
3259        });
3260
3261        match UtpSocket::connect(server_addr).await {
3262            Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), // OK
3263            Err(e) => {
3264                panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e)
3265            }
3266            Ok(_) => panic!("Expected Err, got Ok"),
3267        }
3268
3269        handle.await.expect("task failure");
3270    }
3271
3272    #[tokio::test]
3273    async fn test_receiving_syn_on_established_connection() {
3274        // Establish connection
3275        let server_addr = next_test_ip4();
3276        let mut server = iotry!(UtpSocket::bind(server_addr));
3277
3278        let handle = task::spawn(async move {
3279            let mut buf = [0; BUF_SIZE];
3280            loop {
3281                match server.recv_from(&mut buf).await {
3282                    Ok((0, _src)) => break,
3283                    Ok(_) => (),
3284                    Err(e) => panic!("{:?}", e),
3285                }
3286            }
3287        });
3288
3289        let mut client = iotry!(UtpSocket::connect(server_addr));
3290        let mut packet = Packet::new();
3291        packet.set_wnd_size(BUF_SIZE as u32);
3292        packet.set_type(PacketType::Syn);
3293        packet.set_connection_id(client.sender_connection_id);
3294        packet.set_seq_nr(client.seq_nr);
3295        packet.set_ack_nr(client.ack_nr);
3296        iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3297        let mut buf = [0; BUF_SIZE];
3298
3299        let (len, _) = client
3300            .socket
3301            .recv_from(&mut buf)
3302            .await
3303            .expect("recv failed");
3304        let reply = Packet::try_from(&buf[..len]).ok().unwrap();
3305        assert_eq!(reply.get_type(), PacketType::Reset);
3306        iotry!(client.close());
3307        handle.await.expect("task failure");
3308    }
3309
3310    #[tokio::test]
3311    #[ignore]
3312    async fn test_receiving_reset_on_established_connection() {
3313        // Establish connection
3314        let server_addr = next_test_ip4();
3315        let mut server = iotry!(UtpSocket::bind(server_addr));
3316
3317        let handle = task::spawn(async move {
3318            let client = iotry!(UtpSocket::connect(server_addr));
3319            let mut packet = Packet::new();
3320            packet.set_wnd_size(BUF_SIZE as u32);
3321            packet.set_type(PacketType::Reset);
3322            packet.set_connection_id(client.sender_connection_id);
3323            packet.set_seq_nr(client.seq_nr);
3324            packet.set_ack_nr(client.ack_nr);
3325            iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3326
3327            let mut buf = [0; BUF_SIZE];
3328
3329            client
3330                .socket
3331                .recv_from(&mut buf)
3332                .await
3333                .expect("recv failed");
3334        });
3335
3336        let mut buf = [0; BUF_SIZE];
3337
3338        loop {
3339            match server.recv_from(&mut buf).await {
3340                Ok((0, _src)) => break,
3341                Ok(_) => (),
3342                Err(ref e) if e.kind() == ErrorKind::ConnectionReset => {
3343                    handle.await.expect("task failure");
3344                    return;
3345                }
3346                Err(e) => panic!("{:?}", e),
3347            }
3348        }
3349        panic!("Should have received Reset");
3350    }
3351
3352    #[cfg(not(windows))]
3353    #[tokio::test]
3354    async fn test_premature_fin() {
3355        let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3356        let mut server = iotry!(UtpSocket::bind(server_addr));
3357
3358        const LEN: usize = BUF_SIZE * 4;
3359        let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3360        let to_send = data.clone();
3361
3362        task::spawn(async move {
3363            let mut client = iotry!(UtpSocket::connect(server_addr));
3364            iotry!(client.send_to(&to_send[..]));
3365            iotry!(client.close());
3366        });
3367
3368        let mut buf = [0; BUF_SIZE];
3369        let mut buf = ReadBuf::new(&mut buf);
3370
3371        // Accept connection
3372        iotry!(server.recv(&mut buf));
3373
3374        // Send FIN without acknowledging packets received
3375        let mut packet = Packet::new();
3376        packet.set_connection_id(server.sender_connection_id);
3377        packet.set_seq_nr(server.seq_nr);
3378        packet.set_ack_nr(server.ack_nr);
3379        packet.set_timestamp(now_microseconds());
3380        packet.set_type(PacketType::Fin);
3381        iotry!(server.socket.send_to(packet.as_ref(), client_addr));
3382
3383        // Receive until end
3384        let mut received: Vec<u8> = vec![];
3385        loop {
3386            let mut buf = [0; BUF_SIZE];
3387
3388            match server.recv_from(&mut buf).await {
3389                Ok((0, _src)) => break,
3390                Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3391                Err(e) => panic!("{}", e),
3392            }
3393        }
3394        assert_eq!(received.len(), data.len());
3395        assert_eq!(received, data);
3396    }
3397
3398    #[tokio::test]
3399    async fn test_base_delay_calculation() {
3400        let minute_in_microseconds = 60 * 10i64.pow(6);
3401        let samples = vec![
3402            (0, 10),
3403            (1, 8),
3404            (2, 12),
3405            (3, 7),
3406            (minute_in_microseconds + 1, 11),
3407            (minute_in_microseconds + 2, 19),
3408            (minute_in_microseconds + 3, 9),
3409        ];
3410        let addr = next_test_ip4();
3411        let mut socket = UtpSocket::bind(addr).await.unwrap();
3412
3413        for (timestamp, delay) in samples {
3414            socket.update_base_delay(
3415                delay.into(),
3416                ((timestamp + delay) as u32).into(),
3417            );
3418        }
3419
3420        let expected = vec![7i64, 9i64]
3421            .into_iter()
3422            .map(Into::into)
3423            .collect::<Vec<_>>();
3424        let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
3425        assert_eq!(expected, actual);
3426        assert_eq!(
3427            socket.min_base_delay(),
3428            expected.iter().min().cloned().unwrap_or_default()
3429        );
3430    }
3431
3432    #[tokio::test]
3433    async fn test_local_addr() {
3434        let addr = next_test_ip4();
3435        let addr = addr.to_socket_addrs().unwrap().next().unwrap();
3436        let socket = UtpSocket::bind(addr).await.unwrap();
3437
3438        assert!(socket.local_addr().is_ok());
3439        assert_eq!(socket.local_addr().unwrap(), addr);
3440    }
3441
3442    #[tokio::test]
3443    async fn test_peer_addr() {
3444        use std::sync::mpsc::channel;
3445        let addr = next_test_ip4();
3446        let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
3447        let mut server = UtpSocket::bind(server_addr).await.unwrap();
3448        let (tx, rx) = channel();
3449
3450        // `peer_addr` should return an error because the socket isn't connected
3451        // yet
3452        assert!(server.peer_addr().is_err());
3453
3454        task::spawn(async move {
3455            let mut client = iotry!(UtpSocket::connect(server_addr));
3456            let mut buf = [0; 1024];
3457            tx.send(client.local_addr())
3458                .expect("failed to send on channel");
3459            iotry!(client.recv_from(&mut buf));
3460
3461            // Wait for a connection to be established
3462            let mut buf = [0; 1024];
3463            let mut buf = ReadBuf::new(&mut buf);
3464            iotry!(server.recv(&mut buf));
3465
3466            // `peer_addr` should succeed and be equal to the client's address
3467            assert!(server.peer_addr().is_ok());
3468            // The client is expected to be bound to "0.0.0.0", so we can only check if the port is
3469            // correct
3470            let client_addr = rx.recv().unwrap().unwrap();
3471            assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
3472
3473            // Close the connection
3474            iotry!(server.close());
3475
3476            // `peer_addr` should now return an error because the socket is closed
3477            assert!(server.peer_addr().is_err());
3478        });
3479    }
3480
3481    // Test reaction to connection loss when sending data packets
3482    #[ignore]
3483    #[tokio::test]
3484    async fn test_connection_loss_data() {
3485        let server_addr = next_test_ip4();
3486        let mut server = iotry!(UtpSocket::bind(server_addr));
3487        // Decrease timeouts for faster tests
3488        server.congestion_timeout = 1;
3489        let attempts = server.max_retransmission_retries;
3490
3491        let mut client = iotry!(UtpSocket::connect(server_addr));
3492        iotry!(client.send_to(&[0]));
3493        // Simulate connection loss by killing the socket.
3494        client.state = SocketState::Closed;
3495
3496        let mut buf = [0; BUF_SIZE];
3497        iotry!(client.socket.recv_from(&mut buf));
3498
3499        for _ in 0..attempts {
3500            match client.socket.recv_from(&mut buf).await {
3501                Ok((len, _src)) => assert_eq!(
3502                    Packet::try_from(&buf[..len]).unwrap().get_type(),
3503                    PacketType::Data
3504                ),
3505                Err(e) => panic!("{}", e),
3506            }
3507        }
3508
3509        // Drain incoming packets
3510        let mut buf = [0; BUF_SIZE];
3511        iotry!(server.recv_from(&mut buf));
3512
3513        iotry!(server.send_to(&[0]));
3514
3515        // Try to receive ACKs, time out too many times on flush, and fail with
3516        // `TimedOut`
3517        let mut buf = [0; BUF_SIZE];
3518        let mut buf = ReadBuf::new(&mut buf);
3519        match server.recv(&mut buf).await {
3520            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3521            x => panic!("Expected Err(TimedOut), got {:?}", x),
3522        }
3523    }
3524
3525    // Test reaction to connection loss when sending FIN
3526    #[ignore]
3527    #[tokio::test]
3528    async fn test_connection_loss_fin() {
3529        let server_addr = next_test_ip4();
3530        let mut server = iotry!(UtpSocket::bind(server_addr));
3531        // Decrease timeouts for faster tests
3532        server.congestion_timeout = 1;
3533        let attempts = server.max_retransmission_retries;
3534
3535        let mut client = iotry!(UtpSocket::connect(server_addr));
3536        iotry!(client.send_to(&[0]));
3537        // Simulate connection loss by killing the socket.
3538        client.state = SocketState::Closed;
3539        let mut buf = [0; BUF_SIZE];
3540        iotry!(client.socket.recv_from(&mut buf));
3541        for _ in 0..attempts {
3542            match client.socket.recv_from(&mut buf).await {
3543                Ok((len, _src)) => assert_eq!(
3544                    Packet::try_from(&buf[..len]).unwrap().get_type(),
3545                    PacketType::Fin
3546                ),
3547                Err(e) => panic!("{}", e),
3548            }
3549        }
3550
3551        // Drain incoming packets
3552        let mut buf = [0; BUF_SIZE];
3553        iotry!(server.recv_from(&mut buf));
3554
3555        // Send FIN, time out too many times, and fail with `TimedOut`
3556        match server.close().await {
3557            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3558            x => panic!("Expected Err(TimedOut), got {:?}", x),
3559        }
3560    }
3561
3562    // Test reaction to connection loss when waiting for data packets
3563    #[ignore]
3564    #[tokio::test]
3565    async fn test_connection_loss_waiting() {
3566        let server_addr = next_test_ip4();
3567        let mut server = iotry!(UtpSocket::bind(server_addr));
3568        // Decrease timeouts for faster tests
3569        server.congestion_timeout = 1;
3570        let attempts = server.max_retransmission_retries;
3571
3572        let mut client = iotry!(UtpSocket::connect(server_addr));
3573        iotry!(client.send_to(&[0]));
3574        // Simulate connection loss by killing the socket.
3575        client.state = SocketState::Closed;
3576        let seq_nr = client.seq_nr;
3577        let mut buf = [0; BUF_SIZE];
3578        for _ in 0..(3 * attempts) {
3579            match client.socket.recv_from(&mut buf).await {
3580                Ok((len, _src)) => {
3581                    let packet = Packet::try_from(&buf[..len]).unwrap();
3582                    assert_eq!(packet.get_type(), PacketType::State);
3583                    assert_eq!(packet.ack_nr(), seq_nr - 1);
3584                }
3585                Err(e) => panic!("{}", e),
3586            }
3587        }
3588
3589        // Drain incoming packets
3590        let mut buf = [0; BUF_SIZE];
3591        iotry!(server.recv_from(&mut buf));
3592
3593        // Try to receive data, time out too many times, and fail with `TimedOut`
3594        let mut buf = [0; BUF_SIZE];
3595        match server.recv_from(&mut buf).await {
3596            Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3597            x => panic!("Expected Err(TimedOut), got {:?}", x),
3598        }
3599    }
3600}