1use async_std::{
2 io,
3 net::{SocketAddr, ToSocketAddrs, UdpSocket},
4 task,
5};
6use futures::FutureExt;
7use futures::{future::BoxFuture, ready};
8use log::debug;
9use std::collections::VecDeque;
10use std::io::{ErrorKind, Result};
11use std::task::Poll;
12use std::time::{Duration, Instant};
13use std::{
14 cmp::{max, min},
15 sync::Arc,
16};
17
18use crate::error::SocketError;
19use crate::packet::*;
20use crate::time::*;
21use crate::util::*;
22
23pub(crate) const BUF_SIZE: usize = 1500;
26const GAIN: f64 = 1.0;
27const ALLOWED_INCREASE: u32 = 1;
28const TARGET: f64 = 100_000.0; const MSS: u32 = 1400;
30const MIN_CWND: u32 = 2;
31const INIT_CWND: u32 = 2;
32const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; const MIN_CONGESTION_TIMEOUT: u64 = 500; const MAX_CONGESTION_TIMEOUT: u64 = 60_000; const BASE_HISTORY: usize = 10; const MAX_SYN_RETRIES: u32 = 5; const MAX_RETRANSMISSION_RETRIES: u32 = 5; const WINDOW_SIZE: u32 = 1024 * 1024; const PRE_SEND_TIMEOUT: u32 = 500_000;
42
43const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
45
46#[derive(PartialEq, Eq, Debug, Copy, Clone)]
47enum SocketState {
48 New,
49 Connected,
50 SynSent,
51 FinSent,
52 ResetReceived,
53 Closed,
54}
55
56#[derive(Debug, Clone)]
57struct DelayDifferenceSample {
58 received_at: Timestamp,
59 difference: Delay,
60}
61
62async fn take_address<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
64 addr.to_socket_addrs()
65 .await
66 .and_then(|mut it| it.next().ok_or_else(|| SocketError::InvalidAddress.into()))
67}
68
69#[derive(Debug)]
101pub struct UtpSocket {
102 socket: UdpSocket,
104
105 connected_to: SocketAddr,
107
108 sender_connection_id: u16,
110
111 receiver_connection_id: u16,
113
114 seq_nr: u16,
116
117 ack_nr: u16,
119
120 state: SocketState,
122
123 incoming_buffer: Vec<Packet>,
125
126 send_window: Vec<Packet>,
128
129 unsent_queue: VecDeque<Packet>,
131
132 duplicate_ack_count: u32,
134
135 last_acked: u16,
137
138 last_acked_timestamp: Timestamp,
140
141 last_dropped: u16,
143
144 rtt: i32,
146
147 rtt_variance: i32,
149
150 pending_data: Vec<u8>,
152
153 curr_window: u32,
155
156 remote_wnd_size: u32,
158
159 base_delays: VecDeque<Delay>,
161
162 current_delays: Vec<DelayDifferenceSample>,
164
165 their_delay: Delay,
167
168 last_rollover: Timestamp,
170
171 congestion_timeout: u64,
173
174 cwnd: u32,
176
177 pub max_retransmission_retries: u32,
179}
180
181impl UtpSocket {
182 fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
186 let (receiver_id, sender_id) = generate_sequential_identifiers();
187
188 UtpSocket {
189 socket: s,
190 connected_to: src,
191 receiver_connection_id: receiver_id,
192 sender_connection_id: sender_id,
193 seq_nr: 1,
194 ack_nr: 0,
195 state: SocketState::New,
196 incoming_buffer: Vec::new(),
197 send_window: Vec::new(),
198 unsent_queue: VecDeque::new(),
199 duplicate_ack_count: 0,
200 last_acked: 0,
201 last_acked_timestamp: Timestamp::default(),
202 last_dropped: 0,
203 rtt: 0,
204 rtt_variance: 0,
205 pending_data: Vec::new(),
206 curr_window: 0,
207 remote_wnd_size: 0,
208 current_delays: Vec::new(),
209 base_delays: VecDeque::with_capacity(BASE_HISTORY),
210 their_delay: Delay::default(),
211 last_rollover: Timestamp::default(),
212 congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
213 cwnd: INIT_CWND * MSS,
214 max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
215 }
216 }
217
218 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
225 let addr = take_address(addr).await?;
226 let socket = UdpSocket::bind(addr).await?;
227 Ok(UtpSocket::from_raw_parts(socket, addr))
228 }
229
230 pub fn local_addr(&self) -> Result<SocketAddr> {
232 self.socket.local_addr()
233 }
234
235 pub fn peer_addr(&self) -> Result<SocketAddr> {
237 if self.state == SocketState::Connected || self.state == SocketState::FinSent {
238 Ok(self.connected_to)
239 } else {
240 Err(SocketError::NotConnected.into())
241 }
242 }
243
244 pub async fn connect<A: ToSocketAddrs>(other: A) -> Result<UtpSocket> {
251 let addr = take_address(other).await?;
252 let my_addr = match addr {
253 SocketAddr::V4(_) => "0.0.0.0:0",
254 SocketAddr::V6(_) => "[::]:0",
255 };
256 let mut socket = UtpSocket::bind(my_addr).await?;
257 socket.connected_to = addr;
258
259 let mut packet = Packet::new();
260 packet.set_type(PacketType::Syn);
261 packet.set_connection_id(socket.receiver_connection_id);
262 packet.set_seq_nr(socket.seq_nr);
263
264 let mut len = 0;
265 let mut buf = vec![0u8; BUF_SIZE];
266
267 let mut syn_timeout = Duration::from_millis(socket.congestion_timeout);
268 for _ in 0..MAX_SYN_RETRIES {
269 packet.set_timestamp(now_microseconds());
270
271 debug!("Connecting to {}", socket.connected_to);
273 socket
274 .socket
275 .send_to(packet.as_ref(), socket.connected_to)
276 .await?;
277 socket.state = SocketState::SynSent;
278 debug!("sent {:?}", packet);
279
280 match io::timeout(syn_timeout, socket.socket.recv_from(&mut buf)).await {
282 Ok((read, src)) => {
283 socket.connected_to = src;
284 len = read;
285 break;
286 }
287 Err(ref e)
288 if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
289 {
290 debug!("Timed out, retrying");
291 syn_timeout *= 2;
292 continue;
293 }
294 Err(e) => return Err(e),
295 };
296 }
297
298 let addr = socket.connected_to;
299 let packet = Packet::try_from(&buf[..len])?;
300 debug!("received {:?}", packet);
301 socket.handle_packet(&packet, addr).await?;
302
303 debug!("connected to: {}", socket.connected_to);
304
305 Ok(socket)
306 }
307
308 pub async fn close(&mut self) -> Result<()> {
313 if self.state == SocketState::Closed
315 || self.state == SocketState::New
316 || self.state == SocketState::SynSent
317 {
318 return Ok(());
319 }
320
321 self.flush().await?;
323
324 let mut packet = Packet::new();
325 packet.set_connection_id(self.sender_connection_id);
326 packet.set_seq_nr(self.seq_nr);
327 packet.set_ack_nr(self.ack_nr);
328 packet.set_timestamp(now_microseconds());
329 packet.set_type(PacketType::Fin);
330
331 self.socket
333 .send_to(packet.as_ref(), self.connected_to)
334 .await?;
335 debug!("sent {:?}", packet);
336 self.state = SocketState::FinSent;
337
338 let mut buf = vec![0u8; BUF_SIZE];
340 while self.state != SocketState::Closed {
341 self.recv(&mut buf).await?;
342 }
343
344 Ok(())
345 }
346
347 pub async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
353 let read = self.flush_incoming_buffer(buf);
354
355 if read > 0 {
356 return Ok((read, self.connected_to));
357 }
358
359 if self.state == SocketState::ResetReceived {
362 return Err(SocketError::ConnectionReset.into());
363 }
364
365 loop {
366 if self.state == SocketState::Closed {
368 return Ok((0, self.connected_to));
369 }
370
371 match self.recv(buf).await {
372 Ok((0, _src)) => continue,
373 Ok(x) => return Ok(x),
374 Err(e) => return Err(e),
375 }
376 }
377 }
378
379 async fn recv(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
380 let mut b = vec![0; BUF_SIZE + HEADER_SIZE];
381 let start = Instant::now();
382 let (read, src);
383 let mut retries = 0;
384
385 loop {
387 if retries >= self.max_retransmission_retries {
389 self.state = SocketState::Closed;
390 return Err(SocketError::ConnectionTimedOut.into());
391 }
392
393 let timeout = if self.state != SocketState::New {
394 debug!("setting read timeout of {} ms", self.congestion_timeout);
395 Some(Duration::from_millis(self.congestion_timeout))
396 } else {
397 None
398 };
399
400 let response = match timeout {
401 Some(timeout) => io::timeout(timeout, self.socket.recv_from(&mut b)).await,
402 None => self.socket.recv_from(&mut b).await,
403 };
404
405 match response {
406 Ok((r, s)) => {
407 read = r;
408 src = s;
409 break;
410 }
411 Err(ref e)
412 if (e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut) =>
413 {
414 debug!("recv_from timed out");
415 self.handle_receive_timeout().await?;
416 }
417 Err(e) => return Err(e),
418 };
419
420 let elapsed = start.elapsed();
421 let elapsed_ms = elapsed.as_secs() * 1000 + elapsed.subsec_millis() as u64;
422 debug!("{} ms elapsed", elapsed_ms);
423 retries += 1;
424 }
425
426 let packet = match Packet::try_from(&b[..read]) {
428 Ok(packet) => packet,
429 Err(e) => {
430 debug!("{}", e);
431 debug!("Ignoring invalid packet");
432 return Ok((0, self.connected_to));
433 }
434 };
435 debug!("received {:?}", packet);
436
437 if let Some(mut pkt) = self.handle_packet(&packet, src).await? {
439 pkt.set_wnd_size(WINDOW_SIZE);
440 self.socket.send_to(pkt.as_ref(), src).await?;
441 debug!("sent {:?}", pkt);
442 }
443
444 if packet.get_type() == PacketType::Data
447 && packet.seq_nr().wrapping_sub(self.last_dropped) > 0
448 {
449 self.insert_into_buffer(packet);
450 }
451
452 let read = self.flush_incoming_buffer(buf);
454
455 Ok((read, src))
456 }
457
458 async fn handle_receive_timeout(&mut self) -> Result<()> {
459 self.congestion_timeout *= 2;
460 self.cwnd = MSS;
461
462 debug!(
472 "self.send_window: {:?}",
473 self.send_window
474 .iter()
475 .map(Packet::seq_nr)
476 .collect::<Vec<u16>>()
477 );
478
479 if self.send_window.is_empty() {
480 if self.state == SocketState::FinSent {
483 let mut packet = Packet::new();
484 packet.set_connection_id(self.sender_connection_id);
485 packet.set_seq_nr(self.seq_nr);
486 packet.set_ack_nr(self.ack_nr);
487 packet.set_timestamp(now_microseconds());
488 packet.set_type(PacketType::Fin);
489
490 self.socket
492 .send_to(packet.as_ref(), self.connected_to)
493 .await?;
494 debug!("resent FIN: {:?}", packet);
495 } else if self.state != SocketState::New {
496 debug!("sending fast resend request");
499 self.send_fast_resend_request().await;
500 }
501 } else {
502 let packet = &mut self.send_window[0];
505 packet.set_timestamp(now_microseconds());
506 self.socket
507 .send_to(packet.as_ref(), self.connected_to)
508 .await?;
509 debug!("resent {:?}", packet);
510 }
511
512 Ok(())
513 }
514
515 fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
516 let mut resp = Packet::new();
517 resp.set_type(t);
518 let self_t_micro = now_microseconds();
519 let other_t_micro = original.timestamp();
520 let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
521 resp.set_timestamp(self_t_micro);
522 resp.set_timestamp_difference(time_difference);
523 resp.set_connection_id(self.sender_connection_id);
524 resp.set_seq_nr(self.seq_nr);
525 resp.set_ack_nr(self.ack_nr);
526
527 resp
528 }
529
530 fn advance_incoming_buffer(&mut self) -> Option<Packet> {
532 if !self.incoming_buffer.is_empty() {
533 let packet = self.incoming_buffer.remove(0);
534 debug!("Removed packet from incoming buffer: {:?}", packet);
535 self.ack_nr = packet.seq_nr();
536 self.last_dropped = self.ack_nr;
537 Some(packet)
538 } else {
539 None
540 }
541 }
542
543 fn flush_incoming_buffer(&mut self, buf: &mut [u8]) -> usize {
549 fn unsafe_copy(src: &[u8], dst: &mut [u8]) -> usize {
550 let max_len = min(src.len(), dst.len());
551 unsafe {
552 use std::ptr::copy;
553 copy(src.as_ptr(), dst.as_mut_ptr(), max_len);
554 }
555 max_len
556 }
557
558 if !self.pending_data.is_empty() {
560 let flushed = unsafe_copy(&self.pending_data[..], buf);
561
562 if flushed == self.pending_data.len() {
563 self.pending_data.clear();
564 self.advance_incoming_buffer();
565 } else {
566 self.pending_data = self.pending_data[flushed..].to_vec();
567 }
568
569 return flushed;
570 }
571
572 if !self.incoming_buffer.is_empty()
573 && (self.ack_nr == self.incoming_buffer[0].seq_nr()
574 || self.ack_nr + 1 == self.incoming_buffer[0].seq_nr())
575 {
576 let flushed = unsafe_copy(&self.incoming_buffer[0].payload(), buf);
577
578 if flushed == self.incoming_buffer[0].payload().len() {
579 self.advance_incoming_buffer();
580 } else {
581 self.pending_data = self.incoming_buffer[0].payload()[flushed..].to_vec();
582 }
583
584 return flushed;
585 }
586
587 0
588 }
589
590 pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
602 if self.state == SocketState::Closed {
603 return Err(SocketError::ConnectionClosed.into());
604 }
605
606 let total_length = buf.len();
607
608 for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
609 let mut packet = Packet::with_payload(chunk);
610 packet.set_seq_nr(self.seq_nr);
611 packet.set_ack_nr(self.ack_nr);
612 packet.set_connection_id(self.sender_connection_id);
613
614 self.unsent_queue.push_back(packet);
615
616 self.seq_nr = self.seq_nr.wrapping_add(1);
618 }
619
620 self.send().await?;
622
623 Ok(total_length)
624 }
625
626 pub async fn flush(&mut self) -> Result<()> {
628 let mut buf = vec![0u8; BUF_SIZE];
629 while !self.send_window.is_empty() {
630 debug!("packets in send window: {}", self.send_window.len());
631 self.recv(&mut buf).await?;
632 }
633
634 Ok(())
635 }
636
637 async fn send(&mut self) -> Result<()> {
639 while let Some(mut packet) = self.unsent_queue.pop_front() {
640 self.send_packet(&mut packet).await?;
641 self.curr_window += packet.len() as u32;
642 self.send_window.push(packet);
643 }
644 Ok(())
645 }
646
647 async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
649 debug!("current window: {}", self.send_window.len());
650 let max_inflight = min(self.cwnd, self.remote_wnd_size);
651 let max_inflight = max(MIN_CWND * MSS, max_inflight);
652 let now = now_microseconds();
653
654 while self.curr_window >= max_inflight && now_microseconds() - now < PRE_SEND_TIMEOUT.into()
657 {
658 debug!("self.curr_window: {}", self.curr_window);
659 debug!("max_inflight: {}", max_inflight);
660 debug!("self.duplicate_ack_count: {}", self.duplicate_ack_count);
661 debug!("now_microseconds() - now = {}", now_microseconds() - now);
662 let mut buf = vec![0u8; BUF_SIZE];
663 self.recv(&mut buf).await?;
664 }
665 debug!(
666 "out: now_microseconds() - now = {}",
667 now_microseconds() - now
668 );
669
670 let distance_a = packet.seq_nr().wrapping_sub(self.last_acked);
675 let distance_b = self.last_acked.wrapping_sub(packet.seq_nr());
676 if distance_a > distance_b {
677 debug!("Packet already acknowledged, skipping...");
678 return Ok(());
679 }
680
681 packet.set_timestamp(now_microseconds());
682 packet.set_timestamp_difference(self.their_delay);
683 self.socket
684 .send_to(packet.as_ref(), self.connected_to)
685 .await?;
686 debug!("sent {:?}", packet);
687
688 Ok(())
689 }
690
691 fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
696 if self.base_delays.is_empty() || now - self.last_rollover > MAX_BASE_DELAY_AGE {
697 self.last_rollover = now;
699
700 if self.base_delays.len() == BASE_HISTORY {
702 self.base_delays.pop_front();
703 }
704
705 self.base_delays.push_back(base_delay);
707 } else {
708 let last_idx = self.base_delays.len() - 1;
710 if base_delay < self.base_delays[last_idx] {
711 self.base_delays[last_idx] = base_delay;
712 }
713 }
714 }
715
716 fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
719 let rtt = (self.rtt as i64 * 100).into();
721 while !self.current_delays.is_empty() && now - self.current_delays[0].received_at > rtt {
722 self.current_delays.remove(0);
723 }
724
725 self.current_delays.push(DelayDifferenceSample {
727 received_at: now,
728 difference: v,
729 });
730 }
731
732 fn update_congestion_timeout(&mut self, current_delay: i32) {
733 let delta = self.rtt - current_delay;
734 self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
735 self.rtt += (current_delay - self.rtt) / 8;
736 self.congestion_timeout = max(
737 (self.rtt + self.rtt_variance * 4) as u64,
738 MIN_CONGESTION_TIMEOUT,
739 );
740 self.congestion_timeout = min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
741
742 debug!("current_delay: {}", current_delay);
743 debug!("delta: {}", delta);
744 debug!("self.rtt_variance: {}", self.rtt_variance);
745 debug!("self.rtt: {}", self.rtt);
746 debug!("self.congestion_timeout: {}", self.congestion_timeout);
747 }
748
749 fn filtered_current_delay(&self) -> Delay {
755 let input = self.current_delays.iter().map(|delay| &delay.difference);
756 (ewma(input, 0.333) as i64).into()
757 }
758
759 fn min_base_delay(&self) -> Delay {
761 self.base_delays.iter().min().cloned().unwrap_or_default()
762 }
763
764 fn build_selective_ack(&self) -> Vec<u8> {
766 let stashed = self
767 .incoming_buffer
768 .iter()
769 .filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
770 .map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
771 .map(|diff| (diff / 8, diff % 8));
772
773 let mut sack = Vec::new();
774 for (byte, bit) in stashed {
775 while byte >= sack.len() || sack.len() % 4 != 0 {
778 sack.push(0u8);
779 }
780
781 sack[byte] |= 1 << bit;
782 }
783
784 sack
785 }
786
787 async fn send_fast_resend_request(&self) {
792 for _ in 0..3 {
793 let mut packet = Packet::new();
794 packet.set_type(PacketType::State);
795 let self_t_micro = now_microseconds();
796 packet.set_timestamp(self_t_micro);
797 packet.set_timestamp_difference(self.their_delay);
798 packet.set_connection_id(self.sender_connection_id);
799 packet.set_seq_nr(self.seq_nr);
800 packet.set_ack_nr(self.ack_nr);
801 let _ = self
802 .socket
803 .send_to(packet.as_ref(), self.connected_to)
804 .await;
805 }
806 }
807
808 async fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
809 debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
810 match self
811 .send_window
812 .iter()
813 .position(|pkt| pkt.seq_nr() == lost_packet_nr)
814 {
815 None => debug!("Packet {} not found", lost_packet_nr),
816 Some(position) => {
817 debug!("self.send_window.len(): {}", self.send_window.len());
818 debug!("position: {}", position);
819 let mut packet = self.send_window[position].clone();
820 let _ = self.send_packet(&mut packet).await;
822
823 }
826 }
827 debug!("---> END resend_lost_packet <---");
828 }
829
830 fn advance_send_window(&mut self) {
832 if let Some(position) = self
842 .send_window
843 .iter()
844 .position(|packet| packet.seq_nr() == self.last_acked)
845 {
846 for _ in 0..position + 1 {
847 let packet = self.send_window.remove(0);
848 self.curr_window -= packet.len() as u32;
849 }
850 }
851 debug!("self.curr_window: {}", self.curr_window);
852 }
853
854 async fn handle_packet(&mut self, packet: &Packet, src: SocketAddr) -> Result<Option<Packet>> {
858 debug!("({:?}, {:?})", self.state, packet.get_type());
859
860 if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
862 self.ack_nr = packet.seq_nr();
863 }
864
865 if packet.get_type() != PacketType::Syn
867 && self.state != SocketState::SynSent
868 && !(packet.connection_id() == self.sender_connection_id
869 || packet.connection_id() == self.receiver_connection_id)
870 {
871 return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
872 }
873
874 self.remote_wnd_size = packet.wnd_size();
876 debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
877
878 let now = now_microseconds();
880 self.their_delay = abs_diff(now, packet.timestamp());
881 debug!("self.their_delay: {}", self.their_delay);
882
883 match (self.state, packet.get_type()) {
884 (SocketState::New, PacketType::Syn) => {
885 self.connected_to = src;
886 self.ack_nr = packet.seq_nr();
887 self.seq_nr = rand::random();
888 self.receiver_connection_id = packet.connection_id() + 1;
889 self.sender_connection_id = packet.connection_id();
890 self.state = SocketState::Connected;
891 self.last_dropped = self.ack_nr;
892
893 Ok(Some(self.prepare_reply(packet, PacketType::State)))
894 }
895 (_, PacketType::Syn) => Ok(Some(self.prepare_reply(packet, PacketType::Reset))),
896 (SocketState::SynSent, PacketType::State) => {
897 self.connected_to = src;
898 self.ack_nr = packet.seq_nr();
899 self.seq_nr += 1;
900 self.state = SocketState::Connected;
901 self.last_acked = packet.ack_nr();
902 self.last_acked_timestamp = now_microseconds();
903 Ok(None)
904 }
905 (SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
906 (SocketState::Connected, PacketType::Data)
907 | (SocketState::FinSent, PacketType::Data) => Ok(self.handle_data_packet(packet)),
908 (SocketState::Connected, PacketType::State) => {
909 self.handle_state_packet(packet).await;
910 Ok(None)
911 }
912 (SocketState::Connected, PacketType::Fin) | (SocketState::FinSent, PacketType::Fin) => {
913 if packet.ack_nr() < self.seq_nr {
914 debug!("FIN received but there are missing acknowledgements for sent packets");
915 }
916 let mut reply = self.prepare_reply(packet, PacketType::State);
917 if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
918 debug!(
919 "current ack_nr ({}) is behind received packet seq_nr ({})",
920 self.ack_nr,
921 packet.seq_nr()
922 );
923
924 let sack = self.build_selective_ack();
926
927 if !sack.is_empty() {
928 reply.set_sack(sack);
929 }
930 }
931
932 self.state = SocketState::Closed;
934 Ok(Some(reply))
935 }
936 (SocketState::Closed, PacketType::Fin) => {
937 Ok(Some(self.prepare_reply(packet, PacketType::State)))
938 }
939 (SocketState::FinSent, PacketType::State) => {
940 if packet.ack_nr() == self.seq_nr {
941 self.state = SocketState::Closed;
942 } else {
943 self.handle_state_packet(packet);
944 }
945 Ok(None)
946 }
947 (_, PacketType::Reset) => {
948 self.state = SocketState::ResetReceived;
949 Err(SocketError::ConnectionReset.into())
950 }
951 (state, ty) => {
952 let message = format!("Unimplemented handling for ({:?},{:?})", state, ty);
953 debug!("{}", message);
954 Err(SocketError::Other(message).into())
955 }
956 }
957 }
958
959 fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
960 let packet_type = if self.state == SocketState::FinSent {
962 PacketType::Fin
963 } else {
964 PacketType::State
965 };
966 let mut reply = self.prepare_reply(packet, packet_type);
967
968 if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
969 debug!(
970 "current ack_nr ({}) is behind received packet seq_nr ({})",
971 self.ack_nr,
972 packet.seq_nr()
973 );
974
975 let sack = self.build_selective_ack();
977
978 if !sack.is_empty() {
979 reply.set_sack(sack);
980 }
981 }
982
983 Some(reply)
984 }
985
986 fn queuing_delay(&self) -> Delay {
987 let filtered_current_delay = self.filtered_current_delay();
988 let min_base_delay = self.min_base_delay();
989 let queuing_delay = filtered_current_delay - min_base_delay;
990
991 debug!("filtered_current_delay: {}", filtered_current_delay);
992 debug!("min_base_delay: {}", min_base_delay);
993 debug!("queuing_delay: {}", queuing_delay);
994
995 queuing_delay
996 }
997
998 fn update_congestion_window(&mut self, off_target: f64, bytes_newly_acked: u32) {
1018 let flightsize = self.curr_window;
1019
1020 let cwnd_increase = GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
1021 let cwnd_increase = cwnd_increase / self.cwnd as f64;
1022 debug!("cwnd_increase: {}", cwnd_increase);
1023
1024 self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
1025 let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
1026 self.cwnd = min(self.cwnd, max_allowed_cwnd);
1027 self.cwnd = max(self.cwnd, MIN_CWND * MSS);
1028
1029 debug!("cwnd: {}", self.cwnd);
1030 debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
1031 }
1032
1033 #[async_recursion::async_recursion]
1034 async fn handle_state_packet(&mut self, packet: &Packet) {
1035 if packet.ack_nr() == self.last_acked {
1036 self.duplicate_ack_count += 1;
1037 } else {
1038 self.last_acked = packet.ack_nr();
1039 self.last_acked_timestamp = now_microseconds();
1040 self.duplicate_ack_count = 1;
1041 }
1042
1043 if let Some(index) = self
1045 .send_window
1046 .iter()
1047 .position(|p| packet.ack_nr() == p.seq_nr())
1048 {
1049 let bytes_newly_acked = self
1053 .send_window
1054 .iter()
1055 .take(index + 1)
1056 .fold(0, |acc, p| acc + p.len());
1057
1058 let now = now_microseconds();
1060 let our_delay = now - self.send_window[index].timestamp();
1061 debug!("our_delay: {}", our_delay);
1062 self.update_base_delay(our_delay, now);
1063 self.update_current_delay(our_delay, now);
1064
1065 let off_target: f64 = (TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
1066 debug!("off_target: {}", off_target);
1067
1068 self.update_congestion_window(off_target, bytes_newly_acked as u32);
1069
1070 let rtt = u32::from(our_delay - self.queuing_delay()) / 1000; self.update_congestion_timeout(rtt as i32);
1073 }
1074
1075 let mut packet_loss_detected: bool =
1076 !self.send_window.is_empty() && self.duplicate_ack_count == 3;
1077
1078 for extension in packet.extensions() {
1080 if extension.get_type() == ExtensionType::SelectiveAck {
1081 if extension.iter().count_ones() >= 3 {
1084 self.resend_lost_packet(packet.ack_nr() + 1).await;
1085 packet_loss_detected = true;
1086 }
1087
1088 if let Some(last_seq_nr) = self.send_window.last().map(Packet::seq_nr) {
1089 let lost_packets = extension
1090 .iter()
1091 .enumerate()
1092 .filter(|&(_, received)| !received)
1093 .map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
1094 .take_while(|&seq_nr| seq_nr < last_seq_nr);
1095
1096 for seq_nr in lost_packets {
1097 debug!("SACK: packet {} lost", seq_nr);
1098 self.resend_lost_packet(seq_nr).await;
1099 packet_loss_detected = true;
1100 }
1101 }
1102 } else {
1103 debug!("Unknown extension {:?}, ignoring", extension.get_type());
1104 }
1105 }
1106
1107 if !self.send_window.is_empty()
1111 && self.duplicate_ack_count == 3
1112 && !packet
1113 .extensions()
1114 .any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
1115 {
1116 self.resend_lost_packet(packet.ack_nr() + 1).await;
1117 }
1118
1119 if packet_loss_detected {
1121 debug!("packet loss detected, halving congestion window");
1122 self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
1123 debug!("cwnd: {}", self.cwnd);
1124 }
1125
1126 self.advance_send_window();
1128 }
1129
1130 fn insert_into_buffer(&mut self, packet: Packet) {
1139 if self
1142 .incoming_buffer
1143 .last()
1144 .map_or(false, |p| packet.seq_nr() > p.seq_nr())
1145 {
1146 self.incoming_buffer.push(packet);
1147 } else {
1148 let i = self
1150 .incoming_buffer
1151 .iter()
1152 .filter(|p| p.seq_nr() < packet.seq_nr())
1153 .count();
1154
1155 if self
1156 .incoming_buffer
1157 .get(i)
1158 .map_or(true, |p| p.seq_nr() != packet.seq_nr())
1159 {
1160 self.incoming_buffer.insert(i, packet);
1161 }
1162 }
1163 }
1164}
1165
1166impl Drop for UtpSocket {
1167 fn drop(&mut self) {
1168 task::block_on(async {
1169 drop(self.close().await);
1170 });
1171 }
1172}
1173
1174#[derive(Clone)]
1200pub struct UtpListener {
1201 socket: Arc<UdpSocket>,
1203}
1204
1205impl UtpListener {
1206 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpListener> {
1215 let socket = UdpSocket::bind(addr).await?;
1216 Ok(UtpListener {
1217 socket: socket.into(),
1218 })
1219 }
1220
1221 pub async fn accept(&self) -> Result<(UtpSocket, SocketAddr)> {
1229 let mut buf = vec![0; BUF_SIZE];
1230
1231 let (nread, src) = self.socket.recv_from(&mut buf).await?;
1232 let packet = Packet::try_from(&buf[..nread])?;
1233
1234 if packet.get_type() != PacketType::Syn {
1236 let message = format!("Expected SYN packet, got {:?} instead", packet.get_type());
1237 return Err(SocketError::Other(message).into());
1238 }
1239
1240 let local_addr = self.socket.local_addr()?;
1242 let inner_socket = match local_addr {
1243 SocketAddr::V4(_) => UdpSocket::bind("0.0.0.0:0"),
1244 SocketAddr::V6(_) => UdpSocket::bind("[::]:0"),
1245 }
1246 .await?;
1247
1248 let mut socket = UtpSocket::from_raw_parts(inner_socket, src);
1249
1250 if let Ok(Some(reply)) = socket.handle_packet(&packet, src).await {
1252 socket
1253 .socket
1254 .send_to(reply.as_ref(), src)
1255 .await
1256 .and(Ok((socket, src)))
1257 } else {
1258 Err(SocketError::Other("Reached unreachable statement".to_owned()).into())
1259 }
1260 }
1261
1262 pub fn incoming(&self) -> Incoming<'_> {
1266 Incoming {
1267 listener: self,
1268 accept: None,
1269 }
1270 }
1271
1272 pub fn local_addr(&self) -> Result<SocketAddr> {
1274 self.socket.local_addr()
1275 }
1276}
1277
1278type AcceptFuture<'a> = Option<BoxFuture<'a, io::Result<(UtpSocket, SocketAddr)>>>;
1279
1280pub struct Incoming<'a> {
1281 listener: &'a UtpListener,
1282 accept: AcceptFuture<'a>,
1283}
1284
1285impl<'a> futures::Stream for Incoming<'a> {
1286 type Item = Result<(UtpSocket, SocketAddr)>;
1287
1288 fn poll_next(
1289 mut self: std::pin::Pin<&mut Self>,
1290 cx: &mut std::task::Context<'_>,
1291 ) -> std::task::Poll<Option<Self::Item>> {
1292 loop {
1293 if self.accept.is_none() {
1294 self.accept = Some(self.listener.accept().boxed());
1295 }
1296
1297 if let Some(f) = &mut self.accept {
1298 let res = ready!(f.as_mut().poll(cx));
1299 self.accept = None;
1300 return Poll::Ready(Some(res));
1301 }
1302 }
1303 }
1304}
1305
1306#[cfg(test)]
1307mod test {
1308 use crate::packet::*;
1309 use crate::socket::{take_address, SocketState, UtpListener, UtpSocket, BUF_SIZE};
1310 use crate::time::now_microseconds;
1311 use async_std::task;
1312 use rand;
1313 use std::io::ErrorKind;
1314 use std::net::ToSocketAddrs;
1315
1316 macro_rules! iotry {
1317 ($e:expr) => {
1318 match $e.await {
1319 Ok(e) => e,
1320 Err(e) => panic!("{:?}", e),
1321 }
1322 };
1323 }
1324
1325 fn next_test_port() -> u16 {
1326 use std::sync::atomic::{AtomicUsize, Ordering};
1327 static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
1328 const BASE_PORT: u16 = 9600;
1329 BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
1330 }
1331
1332 fn next_test_ip4<'a>() -> (&'a str, u16) {
1333 ("127.0.0.1", next_test_port())
1334 }
1335
1336 fn next_test_ip6<'a>() -> (&'a str, u16) {
1337 ("::1", next_test_port())
1338 }
1339
1340 #[async_std::test]
1341 async fn test_socket_ipv4() {
1342 let server_addr = next_test_ip4();
1343
1344 let mut server = iotry!(UtpSocket::bind(server_addr));
1345 assert_eq!(server.state, SocketState::New);
1346
1347 let child = task::spawn(async move {
1348 let mut client = iotry!(UtpSocket::connect(server_addr));
1349 assert_eq!(client.state, SocketState::Connected);
1350 assert_eq!(
1352 client.sender_connection_id,
1353 client.receiver_connection_id + 1
1354 );
1355 assert_eq!(
1356 client.connected_to,
1357 server_addr.to_socket_addrs().unwrap().next().unwrap()
1358 );
1359 iotry!(client.close());
1360 drop(client);
1361 });
1362
1363 let mut buf = vec![0; BUF_SIZE];
1364 match server.recv_from(&mut buf).await {
1365 e => println!("{:?}", e),
1366 }
1367 assert_eq!(
1369 server.receiver_connection_id,
1370 server.sender_connection_id + 1
1371 );
1372
1373 assert_eq!(server.state, SocketState::Closed);
1374 drop(server);
1375
1376 child.await;
1377 }
1378
1379 #[async_std::test]
1380 async fn test_socket_ipv6() {
1381 let server_addr = next_test_ip6();
1382
1383 let mut server = iotry!(UtpSocket::bind(server_addr));
1384 assert_eq!(server.state, SocketState::New);
1385
1386 let child = task::spawn(async move {
1387 let mut client = iotry!(UtpSocket::connect(server_addr));
1388 assert_eq!(client.state, SocketState::Connected);
1389 assert_eq!(
1391 client.sender_connection_id,
1392 client.receiver_connection_id + 1
1393 );
1394 assert_eq!(
1395 client.connected_to,
1396 server_addr.to_socket_addrs().unwrap().next().unwrap()
1397 );
1398 iotry!(client.close());
1399 drop(client);
1400 });
1401
1402 let mut buf = vec![0u8; BUF_SIZE];
1403 match server.recv_from(&mut buf).await {
1404 e => println!("{:?}", e),
1405 }
1406 assert_eq!(
1408 server.receiver_connection_id,
1409 server.sender_connection_id + 1
1410 );
1411
1412 assert_eq!(server.state, SocketState::Closed);
1413 drop(server);
1414
1415 child.await;
1416 }
1417
1418 #[async_std::test]
1419 async fn test_recvfrom_on_closed_socket() {
1420 let server_addr = next_test_ip4();
1421
1422 let mut server = iotry!(UtpSocket::bind(server_addr));
1423 assert_eq!(server.state, SocketState::New);
1424
1425 let child = task::spawn(async move {
1426 let mut client = iotry!(UtpSocket::connect(server_addr));
1427 assert_eq!(client.state, SocketState::Connected);
1428 assert!(client.close().await.is_ok());
1429 });
1430
1431 let mut buf = vec![0u8; BUF_SIZE];
1433 let _resp = server.recv_from(&mut buf).await;
1434 assert_eq!(server.state, SocketState::Closed);
1435
1436 match server.recv_from(&mut buf).await {
1438 Ok((0, _src)) => {}
1439 e => panic!("Expected Ok(0), got {:?}", e),
1440 }
1441 assert_eq!(server.state, SocketState::Closed);
1442
1443 child.await;
1444 }
1445
1446 #[async_std::test]
1447 async fn test_sendto_on_closed_socket() {
1448 let server_addr = next_test_ip4();
1449
1450 let mut server = iotry!(UtpSocket::bind(server_addr));
1451 assert_eq!(server.state, SocketState::New);
1452
1453 let child = task::spawn(async move {
1454 let mut client = iotry!(UtpSocket::connect(server_addr));
1455 assert_eq!(client.state, SocketState::Connected);
1456 iotry!(client.close());
1457 });
1458
1459 let mut buf = vec![0u8; BUF_SIZE];
1461 let (_read, _src) = iotry!(server.recv_from(&mut buf));
1462 assert_eq!(server.state, SocketState::Closed);
1463
1464 match server.send_to(&buf).await {
1466 Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
1467 v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
1468 }
1469
1470 child.await;
1471 }
1472
1473 #[async_std::test]
1474 async fn test_acks_on_socket() {
1475 use std::sync::mpsc::channel;
1476 let server_addr = next_test_ip4();
1477 let (tx, rx) = channel();
1478
1479 let mut server = iotry!(UtpSocket::bind(server_addr));
1480
1481 let child = task::spawn(async move {
1482 let mut buf = vec![0u8; BUF_SIZE];
1484 let _resp = server.recv(&mut buf).await;
1485 tx.send(server.seq_nr).unwrap();
1486
1487 iotry!(server.recv_from(&mut buf));
1489
1490 drop(server);
1491 });
1492
1493 let mut client = iotry!(UtpSocket::connect(server_addr));
1494 assert_eq!(client.state, SocketState::Connected);
1495 let sender_seq_nr = rx.recv().unwrap();
1496 let ack_nr = client.ack_nr;
1497 assert_eq!(ack_nr, sender_seq_nr);
1498 assert!(client.close().await.is_ok());
1499
1500 assert_eq!(client.ack_nr, ack_nr);
1504 drop(client);
1505
1506 child.await;
1507 }
1508
1509 #[async_std::test]
1510 async fn test_handle_packet() {
1511 let initial_connection_id: u16 = rand::random();
1513 let sender_connection_id = initial_connection_id + 1;
1514 let (server_addr, client_addr) = (
1515 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1516 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1517 );
1518 let mut socket = iotry!(UtpSocket::bind(server_addr));
1519
1520 let mut packet = Packet::new();
1521 packet.set_wnd_size(BUF_SIZE as u32);
1522 packet.set_type(PacketType::Syn);
1523 packet.set_connection_id(initial_connection_id);
1524
1525 let response = socket.handle_packet(&packet, client_addr).await;
1527 assert!(response.is_ok());
1528 let response = response.unwrap();
1529 assert!(response.is_some());
1530
1531 let response = response.unwrap();
1533 assert_eq!(response.get_type(), PacketType::State);
1534
1535 assert_eq!(response.connection_id(), packet.connection_id());
1537
1538 assert_eq!(response.ack_nr(), packet.seq_nr());
1540
1541 assert!(response.payload().is_empty());
1543 let old_packet = packet;
1549 let old_response = response;
1550
1551 let mut packet = Packet::new();
1552 packet.set_type(PacketType::Data);
1553 packet.set_connection_id(sender_connection_id);
1554 packet.set_seq_nr(old_packet.seq_nr() + 1);
1555 packet.set_ack_nr(old_response.seq_nr());
1556
1557 let response = socket.handle_packet(&packet, client_addr).await;
1558 assert!(response.is_ok());
1559 let response = response.unwrap();
1560 assert!(response.is_some());
1561
1562 let response = response.unwrap();
1563 assert_eq!(response.get_type(), PacketType::State);
1564
1565 assert_eq!(response.connection_id(), initial_connection_id);
1569 assert_eq!(response.connection_id(), packet.connection_id() - 1);
1570
1571 assert_eq!(response.ack_nr(), packet.seq_nr());
1573
1574 assert!(response.payload().is_empty());
1576 assert_eq!(response.seq_nr(), old_response.seq_nr());
1577 let old_packet = packet;
1581 let old_response = response;
1582
1583 let mut packet = Packet::new();
1584 packet.set_type(PacketType::Fin);
1585 packet.set_connection_id(sender_connection_id);
1586 packet.set_seq_nr(old_packet.seq_nr() + 1);
1587 packet.set_ack_nr(old_response.seq_nr());
1588
1589 let response = socket.handle_packet(&packet, client_addr).await;
1590 assert!(response.is_ok());
1591 let response = response.unwrap();
1592 assert!(response.is_some());
1593
1594 let response = response.unwrap();
1595
1596 assert_eq!(response.get_type(), PacketType::State);
1597
1598 assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
1600
1601 assert_eq!(response.seq_nr(), old_response.seq_nr());
1603
1604 assert_eq!(response.ack_nr(), packet.seq_nr());
1606
1607 }
1609
1610 #[async_std::test]
1611 async fn test_response_to_keepalive_ack() {
1612 let initial_connection_id: u16 = rand::random();
1614 let (server_addr, client_addr) = (
1615 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1616 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1617 );
1618 let mut socket = iotry!(UtpSocket::bind(server_addr));
1619
1620 let mut packet = Packet::new();
1622 packet.set_wnd_size(BUF_SIZE as u32);
1623 packet.set_type(PacketType::Syn);
1624 packet.set_connection_id(initial_connection_id);
1625
1626 let response = socket.handle_packet(&packet, client_addr).await;
1627 assert!(response.is_ok());
1628 let response = response.unwrap();
1629 assert!(response.is_some());
1630 let response = response.unwrap();
1631 assert_eq!(response.get_type(), PacketType::State);
1632
1633 let old_packet = packet;
1634 let old_response = response;
1635
1636 let mut packet = Packet::new();
1638 packet.set_wnd_size(BUF_SIZE as u32);
1639 packet.set_type(PacketType::State);
1640 packet.set_connection_id(initial_connection_id);
1641 packet.set_seq_nr(old_packet.seq_nr() + 1);
1642 packet.set_ack_nr(old_response.seq_nr());
1643
1644 let response = socket.handle_packet(&packet, client_addr).await;
1645 assert!(response.is_ok());
1646 let response = response.unwrap();
1647 assert!(response.is_none());
1648
1649 let response = socket.handle_packet(&packet, client_addr).await;
1651 assert!(response.is_ok());
1652 let response = response.unwrap();
1653 assert!(response.is_none());
1654
1655 socket.state = SocketState::Closed;
1657 }
1658
1659 #[async_std::test]
1660 async fn test_response_to_wrong_connection_id() {
1661 let initial_connection_id: u16 = rand::random();
1663 let (server_addr, client_addr) = (
1664 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1665 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1666 );
1667 let mut socket = iotry!(UtpSocket::bind(server_addr));
1668
1669 let mut packet = Packet::new();
1671 packet.set_wnd_size(BUF_SIZE as u32);
1672 packet.set_type(PacketType::Syn);
1673 packet.set_connection_id(initial_connection_id);
1674
1675 let response = socket.handle_packet(&packet, client_addr).await;
1676 assert!(response.is_ok());
1677 let response = response.unwrap();
1678 assert!(response.is_some());
1679 assert_eq!(response.unwrap().get_type(), PacketType::State);
1680
1681 let new_connection_id = initial_connection_id.wrapping_mul(2);
1683
1684 let mut packet = Packet::new();
1685 packet.set_wnd_size(BUF_SIZE as u32);
1686 packet.set_type(PacketType::State);
1687 packet.set_connection_id(new_connection_id);
1688
1689 let response = socket.handle_packet(&packet, client_addr).await;
1690 assert!(response.is_ok());
1691 let response = response.unwrap();
1692 assert!(response.is_some());
1693
1694 let response = response.unwrap();
1695 assert_eq!(response.get_type(), PacketType::Reset);
1696 assert_eq!(response.ack_nr(), packet.seq_nr());
1697
1698 socket.state = SocketState::Closed;
1700 }
1701
1702 #[async_std::test]
1703 async fn test_unordered_packets() {
1704 let initial_connection_id: u16 = rand::random();
1706 let (server_addr, client_addr) = (
1707 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1708 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1709 );
1710 let mut socket = iotry!(UtpSocket::bind(server_addr));
1711
1712 let mut packet = Packet::new();
1714 packet.set_wnd_size(BUF_SIZE as u32);
1715 packet.set_type(PacketType::Syn);
1716 packet.set_connection_id(initial_connection_id);
1717
1718 let response = socket.handle_packet(&packet, client_addr).await;
1719 assert!(response.is_ok());
1720 let response = response.unwrap();
1721 assert!(response.is_some());
1722 let response = response.unwrap();
1723 assert_eq!(response.get_type(), PacketType::State);
1724
1725 let old_packet = packet;
1726 let old_response = response;
1727
1728 let mut window: Vec<Packet> = Vec::new();
1729
1730 let mut packet = Packet::with_payload(&[1, 2, 3]);
1732 packet.set_wnd_size(BUF_SIZE as u32);
1733 packet.set_connection_id(initial_connection_id);
1734 packet.set_seq_nr(old_packet.seq_nr() + 1);
1735 packet.set_ack_nr(old_response.seq_nr());
1736 window.push(packet);
1737
1738 let mut packet = Packet::with_payload(&[4, 5, 6]);
1739 packet.set_wnd_size(BUF_SIZE as u32);
1740 packet.set_connection_id(initial_connection_id);
1741 packet.set_seq_nr(old_packet.seq_nr() + 2);
1742 packet.set_ack_nr(old_response.seq_nr());
1743 window.push(packet);
1744
1745 let response = socket.handle_packet(&window[1], client_addr).await;
1747 assert!(response.is_ok());
1748 let response = response.unwrap();
1749 assert!(response.is_some());
1750 let response = response.unwrap();
1751 assert!(response.ack_nr() != window[1].seq_nr());
1752
1753 let response = socket.handle_packet(&window[0], client_addr).await;
1754 assert!(response.is_ok());
1755 let response = response.unwrap();
1756 assert!(response.is_some());
1757
1758 socket.state = SocketState::Closed;
1760 }
1761
1762 #[async_std::test]
1763 async fn test_response_to_triple_ack() {
1764 let server_addr = next_test_ip4();
1765 let mut server = iotry!(UtpSocket::bind(server_addr));
1766
1767 const LEN: usize = 1024;
1769 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
1770 let d = data.clone();
1771 assert_eq!(LEN, data.len());
1772
1773 let child = task::spawn(async move {
1774 let mut client = iotry!(UtpSocket::connect(server_addr));
1775 iotry!(client.send_to(&d[..]));
1776 iotry!(client.close());
1777 });
1778
1779 let mut buf = vec![0u8; BUF_SIZE];
1780 iotry!(server.recv(&mut buf));
1782
1783 let data_packet = match server.socket.recv_from(&mut buf).await {
1785 Ok((read, _src)) => Packet::try_from(&buf[..read]).unwrap(),
1786 Err(e) => panic!("{}", e),
1787 };
1788 assert_eq!(data_packet.get_type(), PacketType::Data);
1789 assert_eq!(&data_packet.payload(), &data.as_slice());
1790 assert_eq!(data_packet.payload().len(), data.len());
1791
1792 let mut packet = Packet::new();
1794 packet.set_wnd_size(BUF_SIZE as u32);
1795 packet.set_type(PacketType::State);
1796 packet.set_seq_nr(server.seq_nr);
1797 packet.set_ack_nr(data_packet.seq_nr() - 1);
1798 packet.set_connection_id(server.sender_connection_id);
1799
1800 for _ in 0..3u8 {
1801 iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
1802 }
1803
1804 let client_addr = server.connected_to;
1806 match server.socket.recv_from(&mut buf).await {
1807 Ok((0, _)) => panic!("Received 0 bytes from socket"),
1808 Ok((read, _src)) => {
1809 let packet = Packet::try_from(&buf[..read]).unwrap();
1810 assert_eq!(packet.get_type(), PacketType::Data);
1811 assert_eq!(packet.seq_nr(), data_packet.seq_nr());
1812 assert_eq!(packet.payload(), data_packet.payload());
1813 let response = server.handle_packet(&packet, client_addr).await;
1814 assert!(response.is_ok());
1815 let response = response.unwrap();
1816 assert!(response.is_some());
1817 let response = response.unwrap();
1818 iotry!(server
1819 .socket
1820 .send_to(response.as_ref(), server.connected_to));
1821 }
1822 Err(e) => panic!("{}", e),
1823 }
1824
1825 iotry!(server.recv_from(&mut buf));
1827 child.await;
1828 }
1829
1830 #[async_std::test]
1831 async fn test_socket_timeout_request() {
1832 let (server_addr, client_addr) = (
1833 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1834 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
1835 );
1836
1837 let client = iotry!(UtpSocket::bind(client_addr));
1838 let mut server = iotry!(UtpSocket::bind(server_addr));
1839 const LEN: usize = 512;
1840 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
1841 let d = data.clone();
1842
1843 assert_eq!(server.state, SocketState::New);
1844 assert_eq!(client.state, SocketState::New);
1845
1846 assert_eq!(
1848 client.sender_connection_id,
1849 client.receiver_connection_id + 1
1850 );
1851
1852 let child = task::spawn(async move {
1853 let mut client = iotry!(UtpSocket::connect(server_addr));
1854 assert_eq!(client.state, SocketState::Connected);
1855 assert_eq!(client.connected_to, server_addr);
1856 iotry!(client.send_to(&d[..]));
1857 drop(client);
1858 });
1859
1860 let mut buf = vec![0u8; BUF_SIZE];
1861 server.recv(&mut buf).await.unwrap();
1862 assert_eq!(
1864 server.receiver_connection_id,
1865 server.sender_connection_id + 1
1866 );
1867
1868 assert_eq!(server.state, SocketState::Connected);
1869
1870 iotry!(server.socket.recv_from(&mut buf));
1874
1875 server.congestion_timeout = 50;
1877
1878 loop {
1880 let response = server.recv_from(&mut buf).await;
1881 match response {
1882 Ok((0, _)) => continue,
1883 Ok(_) => break,
1884 Err(e) => panic!("{}", e),
1885 }
1886 }
1887
1888 drop(server);
1889 child.await;
1890 }
1891
1892 #[async_std::test]
1893 async fn test_sorted_buffer_insertion() {
1894 let server_addr = next_test_ip4();
1895 let mut socket = iotry!(UtpSocket::bind(server_addr));
1896
1897 let mut packet = Packet::new();
1898 packet.set_seq_nr(1);
1899
1900 assert!(socket.incoming_buffer.is_empty());
1901
1902 socket.insert_into_buffer(packet.clone());
1903 assert_eq!(socket.incoming_buffer.len(), 1);
1904
1905 packet.set_seq_nr(2);
1906 packet.set_timestamp(128.into());
1907
1908 socket.insert_into_buffer(packet.clone());
1909 assert_eq!(socket.incoming_buffer.len(), 2);
1910 assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
1911 assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
1912
1913 packet.set_seq_nr(3);
1914 packet.set_timestamp(256.into());
1915
1916 socket.insert_into_buffer(packet.clone());
1917 assert_eq!(socket.incoming_buffer.len(), 3);
1918 assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
1919 assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
1920
1921 packet.set_seq_nr(2);
1923 packet.set_timestamp(456.into());
1924
1925 socket.insert_into_buffer(packet.clone());
1926 assert_eq!(socket.incoming_buffer.len(), 3);
1927 assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
1928 assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
1929 }
1930
1931 #[async_std::test]
1932 async fn test_duplicate_packet_handling() {
1933 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
1934
1935 let client = iotry!(UtpSocket::bind(client_addr));
1936 let mut server = iotry!(UtpSocket::bind(server_addr));
1937
1938 assert_eq!(server.state, SocketState::New);
1939 assert_eq!(client.state, SocketState::New);
1940
1941 assert_eq!(
1943 client.sender_connection_id,
1944 client.receiver_connection_id + 1
1945 );
1946
1947 let child = task::spawn(async move {
1948 let mut client = iotry!(UtpSocket::connect(server_addr));
1949 assert_eq!(client.state, SocketState::Connected);
1950
1951 let mut packet = Packet::with_payload(&[1, 2, 3]);
1952 packet.set_wnd_size(BUF_SIZE as u32);
1953 packet.set_connection_id(client.sender_connection_id);
1954 packet.set_seq_nr(client.seq_nr);
1955 packet.set_ack_nr(client.ack_nr);
1956
1957 for _ in 0..2 {
1959 packet.set_timestamp(now_microseconds());
1960 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
1961 }
1962 client.seq_nr += 1;
1963
1964 for _ in 0..1 {
1966 let mut buf = vec![0u8; BUF_SIZE];
1967 iotry!(client.socket.recv_from(&mut buf));
1968 }
1969
1970 iotry!(client.close());
1971 });
1972
1973 let mut buf = vec![0u8; BUF_SIZE];
1974 iotry!(server.recv(&mut buf));
1975 assert_eq!(
1977 server.receiver_connection_id,
1978 server.sender_connection_id + 1
1979 );
1980
1981 assert_eq!(server.state, SocketState::Connected);
1982
1983 let expected: Vec<u8> = vec![1, 2, 3];
1984 let mut received: Vec<u8> = vec![];
1985 loop {
1986 match server.recv_from(&mut buf).await {
1987 Ok((0, _src)) => break,
1988 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
1989 Err(e) => panic!("{:?}", e),
1990 }
1991 }
1992 assert_eq!(received.len(), expected.len());
1993 assert_eq!(received, expected);
1994
1995 child.await;
1996 }
1997
1998 #[async_std::test]
1999 async fn test_correct_packet_loss() {
2000 let server_addr = next_test_ip4();
2001
2002 let mut server = iotry!(UtpSocket::bind(server_addr));
2003 const LEN: usize = 1024 * 10;
2004 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2005 let to_send = data.clone();
2006
2007 let child = task::spawn(async move {
2008 let mut client = iotry!(UtpSocket::connect(server_addr));
2009
2010 let chunks = to_send[..].chunks(BUF_SIZE);
2012 let dst = client.connected_to;
2013 for (index, chunk) in chunks.enumerate() {
2014 let mut packet = Packet::with_payload(chunk);
2015 packet.set_seq_nr(client.seq_nr);
2016 packet.set_ack_nr(client.ack_nr);
2017 packet.set_connection_id(client.sender_connection_id);
2018 packet.set_timestamp(now_microseconds());
2019
2020 if index % 2 == 0 {
2021 iotry!(client.socket.send_to(packet.as_ref(), dst));
2022 }
2023
2024 client.curr_window += packet.len() as u32;
2025 client.send_window.push(packet);
2026 client.seq_nr += 1;
2027 }
2028
2029 iotry!(client.close());
2030 });
2031
2032 let mut buf = vec![0u8; BUF_SIZE];
2033 let mut received: Vec<u8> = vec![];
2034 loop {
2035 match server.recv_from(&mut buf).await {
2036 Ok((0, _src)) => break,
2037 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2038 Err(e) => panic!("{}", e),
2039 }
2040 }
2041 assert_eq!(received.len(), data.len());
2042 assert_eq!(received, data);
2043
2044 child.await;
2045 }
2046
2047 #[async_std::test]
2048 async fn test_tolerance_to_small_buffers() {
2049 let server_addr = next_test_ip4();
2050 let mut server = iotry!(UtpSocket::bind(server_addr));
2051 const LEN: usize = 1024;
2052 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2053 let to_send = data.clone();
2054
2055 let child = task::spawn(async move {
2056 let mut client = iotry!(UtpSocket::connect(server_addr));
2057 iotry!(client.send_to(&to_send[..]));
2058 iotry!(client.close());
2059 });
2060
2061 let mut read = Vec::new();
2062 while server.state != SocketState::Closed {
2063 let mut small_buffer = vec![0; 512];
2064 match server.recv_from(&mut small_buffer).await {
2065 Ok((0, _src)) => break,
2066 Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
2067 Err(e) => panic!("{}", e),
2068 }
2069 }
2070
2071 assert_eq!(read.len(), data.len());
2072 assert_eq!(read, data);
2073
2074 child.await;
2075 }
2076
2077 #[async_std::test]
2078 async fn test_sequence_number_rollover() {
2079 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
2080
2081 let mut server = iotry!(UtpSocket::bind(server_addr));
2082
2083 const LEN: usize = BUF_SIZE * 4;
2084 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2085 let to_send = data.clone();
2086
2087 let child = task::spawn(async move {
2088 let mut client = iotry!(UtpSocket::bind(client_addr));
2089
2090 client.seq_nr = ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
2092
2093 let mut client = iotry!(UtpSocket::connect(server_addr));
2094 iotry!(client.send_to(&to_send[..]));
2096 assert!(client.seq_nr < 50);
2098 iotry!(client.close());
2100 });
2101
2102 let mut buf = vec![0u8; BUF_SIZE];
2103 let mut received: Vec<u8> = vec![];
2104 loop {
2105 match server.recv_from(&mut buf).await {
2106 Ok((0, _src)) => break,
2107 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2108 Err(e) => panic!("{}", e),
2109 }
2110 }
2111 assert_eq!(received.len(), data.len());
2112 assert_eq!(received, data);
2113
2114 child.await;
2115 }
2116
2117 #[async_std::test]
2118 async fn test_drop_unused_socket() {
2119 let server_addr = next_test_ip4();
2120 let server = iotry!(UtpSocket::bind(server_addr));
2121
2122 drop(server);
2124 }
2125
2126 #[async_std::test]
2127 async fn test_invalid_packet_on_connect() {
2128 use async_std::net::UdpSocket;
2129 let server_addr = next_test_ip4();
2130 let server = iotry!(UdpSocket::bind(server_addr));
2131
2132 let child = task::spawn(async move {
2133 let mut buf = vec![0u8; BUF_SIZE];
2134 match server.recv_from(&mut buf).await {
2135 Ok((_len, client_addr)) => {
2136 iotry!(server.send_to(&[], client_addr));
2137 }
2138 _ => panic!(),
2139 }
2140 });
2141
2142 match UtpSocket::connect(server_addr).await {
2143 Err(ref e) if e.kind() == ErrorKind::Other => (), Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
2145 Ok(_) => panic!("Expected Err, got Ok"),
2146 }
2147
2148 child.await;
2149 }
2150
2151 #[async_std::test]
2152 async fn test_receive_unexpected_reply_type_on_connect() {
2153 use async_std::net::UdpSocket;
2154 let server_addr = next_test_ip4();
2155 let server = iotry!(UdpSocket::bind(server_addr));
2156
2157 let child = task::spawn(async move {
2158 let mut buf = vec![0u8; BUF_SIZE];
2159 let mut packet = Packet::new();
2160 packet.set_type(PacketType::Data);
2161
2162 match server.recv_from(&mut buf).await {
2163 Ok((_len, client_addr)) => {
2164 iotry!(server.send_to(packet.as_ref(), client_addr));
2165 }
2166 _ => panic!(),
2167 }
2168 });
2169
2170 match UtpSocket::connect(server_addr).await {
2171 Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), Err(e) => panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e),
2173 Ok(_) => panic!("Expected Err, got Ok"),
2174 }
2175
2176 child.await;
2177 }
2178
2179 #[async_std::test]
2180 async fn test_receiving_syn_on_established_connection() {
2181 let server_addr = next_test_ip4();
2183 let mut server = iotry!(UtpSocket::bind(server_addr));
2184
2185 let child = task::spawn(async move {
2186 let mut buf = vec![0; BUF_SIZE];
2187 loop {
2188 match server.recv_from(&mut buf).await {
2189 Ok((0, _src)) => break,
2190 Ok(_) => (),
2191 Err(e) => panic!("{:?}", e),
2192 }
2193 }
2194 });
2195
2196 let mut client = iotry!(UtpSocket::connect(server_addr));
2197 let mut packet = Packet::new();
2198 packet.set_wnd_size(BUF_SIZE as u32);
2199 packet.set_type(PacketType::Syn);
2200 packet.set_connection_id(client.sender_connection_id);
2201 packet.set_seq_nr(client.seq_nr);
2202 packet.set_ack_nr(client.ack_nr);
2203 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
2204 let mut buf = vec![0u8; BUF_SIZE];
2205 match client.socket.recv_from(&mut buf).await {
2206 Ok((len, _src)) => {
2207 let reply = Packet::try_from(&buf[..len]).ok().unwrap();
2208 assert_eq!(reply.get_type(), PacketType::Reset);
2209 }
2210 Err(e) => panic!("{:?}", e),
2211 }
2212 iotry!(client.close());
2213
2214 child.await;
2215 }
2216
2217 #[async_std::test]
2218 async fn test_receiving_reset_on_established_connection() {
2219 let server_addr = next_test_ip4();
2221 let mut server = iotry!(UtpSocket::bind(server_addr));
2222
2223 let child = task::spawn(async move {
2224 let client = iotry!(UtpSocket::connect(server_addr));
2225 let mut packet = Packet::new();
2226 packet.set_wnd_size(BUF_SIZE as u32);
2227 packet.set_type(PacketType::Reset);
2228 packet.set_connection_id(client.sender_connection_id);
2229 packet.set_seq_nr(client.seq_nr);
2230 packet.set_ack_nr(client.ack_nr);
2231 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
2232 let mut buf = vec![0u8; BUF_SIZE];
2233 match client.socket.recv_from(&mut buf).await {
2234 Ok((_len, _src)) => (),
2235 Err(e) => panic!("{:?}", e),
2236 }
2237 });
2238
2239 let mut buf = vec![0u8; BUF_SIZE];
2240 loop {
2241 match server.recv_from(&mut buf).await {
2242 Ok((0, _src)) => break,
2243 Ok(_) => (),
2244 Err(ref e) if e.kind() == ErrorKind::ConnectionReset => return,
2245 Err(e) => panic!("{:?}", e),
2246 }
2247 }
2248 child.await;
2249 panic!("Should have received Reset");
2250 }
2251
2252 #[cfg(not(windows))]
2253 #[async_std::test]
2254 async fn test_premature_fin() {
2255 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
2256 let mut server = iotry!(UtpSocket::bind(server_addr));
2257
2258 const LEN: usize = BUF_SIZE * 4;
2259 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2260 let to_send = data.clone();
2261
2262 let child = task::spawn(async move {
2263 let mut client = iotry!(UtpSocket::connect(server_addr));
2264 iotry!(client.send_to(&to_send[..]));
2265 iotry!(client.close());
2266 });
2267
2268 let mut buf = vec![0u8; BUF_SIZE];
2269
2270 iotry!(server.recv(&mut buf));
2272
2273 let mut packet = Packet::new();
2275 packet.set_connection_id(server.sender_connection_id);
2276 packet.set_seq_nr(server.seq_nr);
2277 packet.set_ack_nr(server.ack_nr);
2278 packet.set_timestamp(now_microseconds());
2279 packet.set_type(PacketType::Fin);
2280 iotry!(server.socket.send_to(packet.as_ref(), client_addr));
2281
2282 let mut received: Vec<u8> = vec![];
2284 loop {
2285 match server.recv_from(&mut buf).await {
2286 Ok((0, _src)) => break,
2287 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
2288 Err(e) => panic!("{}", e),
2289 }
2290 }
2291 assert_eq!(received.len(), data.len());
2292 assert_eq!(received, data);
2293
2294 child.await;
2295 }
2296
2297 #[async_std::test]
2298 async fn test_base_delay_calculation() {
2299 let minute_in_microseconds = 60 * 10i64.pow(6);
2300 let samples = vec![
2301 (0, 10),
2302 (1, 8),
2303 (2, 12),
2304 (3, 7),
2305 (minute_in_microseconds + 1, 11),
2306 (minute_in_microseconds + 2, 19),
2307 (minute_in_microseconds + 3, 9),
2308 ];
2309 let addr = next_test_ip4();
2310 let mut socket = UtpSocket::bind(addr).await.unwrap();
2311
2312 for (timestamp, delay) in samples {
2313 socket.update_base_delay(delay.into(), ((timestamp + delay) as u32).into());
2314 }
2315
2316 let expected = vec![7i64, 9i64]
2317 .into_iter()
2318 .map(Into::into)
2319 .collect::<Vec<_>>();
2320 let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
2321 assert_eq!(expected, actual);
2322 assert_eq!(
2323 socket.min_base_delay(),
2324 expected.iter().min().cloned().unwrap_or_default()
2325 );
2326 }
2327
2328 #[async_std::test]
2329 async fn test_local_addr() {
2330 let addr = next_test_ip4();
2331 let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2332 let socket = UtpSocket::bind(addr).await.unwrap();
2333
2334 assert!(socket.local_addr().is_ok());
2335 assert_eq!(socket.local_addr().unwrap(), addr);
2336 }
2337
2338 #[async_std::test]
2339 async fn test_listener_local_addr() {
2340 let addr = next_test_ip4();
2341 let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2342 let listener = UtpListener::bind(addr).await.unwrap();
2343
2344 assert!(listener.local_addr().is_ok());
2345 assert_eq!(listener.local_addr().unwrap(), addr);
2346 }
2347
2348 #[async_std::test]
2349 async fn test_listener_listener_clone() {
2350 let addr = next_test_ip4();
2351 let addr = addr.to_socket_addrs().unwrap().next().unwrap();
2352
2353 let listener1 = UtpListener::bind(addr).await.unwrap();
2355 let listener2 = listener1.clone();
2356
2357 task::spawn(async move { listener1.accept().await.unwrap() });
2358 task::spawn(async move { listener2.accept().await.unwrap() });
2359
2360 task::spawn(async move {
2362 UtpSocket::connect(addr).await.unwrap();
2363 UtpSocket::connect(addr).await.unwrap();
2364 })
2365 .await;
2366 }
2367
2368 #[async_std::test]
2369 async fn test_peer_addr() {
2370 use std::sync::mpsc::channel;
2371 let addr = next_test_ip4();
2372 let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
2373 let mut server = UtpSocket::bind(server_addr).await.unwrap();
2374 let (tx, rx) = channel();
2375
2376 assert!(server.peer_addr().is_err());
2378
2379 let child = task::spawn(async move {
2380 let mut client = iotry!(UtpSocket::connect(server_addr));
2381 let mut buf = vec![0; 1024];
2382 tx.send(client.local_addr()).unwrap();
2383 iotry!(client.recv_from(&mut buf));
2384 });
2385
2386 let mut buf = vec![0; 1024];
2388 iotry!(server.recv(&mut buf));
2389
2390 assert!(server.peer_addr().is_ok());
2392 let client_addr = rx.recv().unwrap().unwrap();
2395 assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
2396
2397 iotry!(server.close());
2399
2400 assert!(server.peer_addr().is_err());
2402
2403 child.await;
2404 }
2405
2406 #[async_std::test]
2407 async fn test_take_address() {
2408 assert!(take_address("0.0.0.0:0").await.is_ok());
2410 assert!(take_address("[::]:0").await.is_ok());
2411 assert!(take_address(("0.0.0.0", 0)).await.is_ok());
2412 assert!(take_address(("::", 0)).await.is_ok());
2413 assert!(take_address(("1.2.3.4", 5)).await.is_ok());
2414
2415 assert!(take_address("999.0.0.0:0").await.is_err());
2417 assert!(take_address("1.2.3.4:70000").await.is_err());
2418 assert!(take_address("").await.is_err());
2419 assert!(take_address("this is not an address").await.is_err());
2420 assert!(take_address("no.dns.resolution.com").await.is_err());
2421 }
2422
2423 #[async_std::test]
2425 async fn test_connection_loss_data() {
2426 let server_addr = next_test_ip4();
2427 let mut server = iotry!(UtpSocket::bind(server_addr));
2428 server.congestion_timeout = 1;
2430 let attempts = server.max_retransmission_retries;
2431
2432 let child = task::spawn(async move {
2433 let mut client = iotry!(UtpSocket::connect(server_addr));
2434 iotry!(client.send_to(&[0]));
2435 client.state = SocketState::Closed;
2437 let ref socket = client.socket;
2438 let mut buf = vec![0u8; BUF_SIZE];
2439 iotry!(socket.recv_from(&mut buf));
2440 for _ in 0..attempts {
2441 match socket.recv_from(&mut buf).await {
2442 Ok((len, _src)) => assert_eq!(
2443 Packet::try_from(&buf[..len]).unwrap().get_type(),
2444 PacketType::Data
2445 ),
2446 Err(e) => panic!("{}", e),
2447 }
2448 }
2449 });
2450
2451 let mut buf = vec![0u8; BUF_SIZE];
2453 iotry!(server.recv_from(&mut buf));
2454
2455 iotry!(server.send_to(&[0]));
2456
2457 let mut buf = vec![0u8; BUF_SIZE];
2459 match server.recv(&mut buf).await {
2460 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2461 x => panic!("Expected Err(TimedOut), got {:?}", x),
2462 }
2463
2464 child.await;
2465 }
2466
2467 #[async_std::test]
2469 async fn test_connection_loss_fin() {
2470 let server_addr = next_test_ip4();
2471 let mut server = iotry!(UtpSocket::bind(server_addr));
2472 server.congestion_timeout = 1;
2474 let attempts = server.max_retransmission_retries;
2475
2476 let child = task::spawn(async move {
2477 let mut client = iotry!(UtpSocket::connect(server_addr));
2478 iotry!(client.send_to(&[0]));
2479 client.state = SocketState::Closed;
2481 let ref socket = client.socket;
2482 let mut buf = vec![0u8; BUF_SIZE];
2483 iotry!(socket.recv_from(&mut buf));
2484 for _ in 0..attempts {
2485 match socket.recv_from(&mut buf).await {
2486 Ok((len, _src)) => assert_eq!(
2487 Packet::try_from(&buf[..len]).unwrap().get_type(),
2488 PacketType::Fin
2489 ),
2490 Err(e) => panic!("{}", e),
2491 }
2492 }
2493 });
2494
2495 let mut buf = vec![0u8; BUF_SIZE];
2497 iotry!(server.recv_from(&mut buf));
2498
2499 match server.close().await {
2501 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2502 x => panic!("Expected Err(TimedOut), got {:?}", x),
2503 }
2504 child.await;
2505 }
2506
2507 #[async_std::test]
2509 async fn test_connection_loss_waiting() {
2510 let server_addr = next_test_ip4();
2511 let mut server = iotry!(UtpSocket::bind(server_addr));
2512 server.congestion_timeout = 1;
2514 let attempts = server.max_retransmission_retries;
2515
2516 task::spawn(async move {
2517 let mut client = iotry!(UtpSocket::connect(server_addr));
2518 iotry!(client.send_to(&[0]));
2519 client.state = SocketState::Closed;
2521 let ref socket = client.socket;
2522 let seq_nr = client.seq_nr;
2523 let mut buf = vec![0u8; BUF_SIZE];
2524 for _ in 0..(3 * attempts) {
2525 match socket.recv_from(&mut buf).await {
2526 Ok((len, _src)) => {
2527 let packet = Packet::try_from(&buf[..len]).unwrap();
2528 assert_eq!(packet.get_type(), PacketType::State);
2529 assert_eq!(packet.ack_nr(), seq_nr - 1);
2530 }
2531 Err(e) => panic!("{}", e),
2532 }
2533 }
2534 });
2535
2536 let mut buf = vec![0; BUF_SIZE];
2538 iotry!(server.recv_from(&mut buf));
2539
2540 let mut buf = vec![0; BUF_SIZE];
2542 match server.recv_from(&mut buf).await {
2543 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
2544 x => panic!("Expected Err(TimedOut), got {:?}", x),
2545 }
2546 }
2547}