1use std::cmp::{max, min};
2use std::collections::VecDeque;
3use std::future::Future;
4use std::io::{ErrorKind, Result};
5use std::iter::Iterator;
6use std::mem;
7use std::net::SocketAddr;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13use crate::error::SocketError;
14use crate::packet::*;
15use crate::time::*;
16use crate::util::*;
17
18use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
19use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket};
20use tokio::sync::mpsc::{
21 unbounded_channel, UnboundedReceiver, UnboundedSender,
22};
23use tokio::sync::Mutex;
24use tokio::time::{sleep, timeout, Instant as TokioInstant, Sleep};
25
26use tracing::debug;
27
28const BUF_SIZE: usize = 1500;
31const GAIN: f64 = 1.0;
32const ALLOWED_INCREASE: u32 = 1;
33const TARGET: f64 = 100_000.0; const MSS: u32 = 1400;
35const MIN_CWND: u32 = 2;
36const INIT_CWND: u32 = 2;
37const INITIAL_CONGESTION_TIMEOUT: u64 = 1000; const MIN_CONGESTION_TIMEOUT: u64 = 500; const MAX_CONGESTION_TIMEOUT: u64 = 60_000; const BASE_HISTORY: usize = 10; const MAX_SYN_RETRIES: u32 = 5; const MAX_RETRANSMISSION_RETRIES: u32 = 5; const WINDOW_SIZE: u32 = 1024 * 1024; const MAX_BASE_DELAY_AGE: Delay = Delay(60_000_000);
47
48#[derive(PartialEq, Eq, Debug, Copy, Clone)]
49enum SocketState {
50 New,
51 Connected,
52 SynSent,
53 FinSent,
54 ResetReceived,
55 Closed,
56}
57
58struct DelayDifferenceSample {
59 received_at: Timestamp,
60 difference: Delay,
61}
62
63struct UtpSocket {
75 socket: UdpSocket,
77
78 connected_to: SocketAddr,
80
81 sender_connection_id: u16,
83
84 receiver_connection_id: u16,
86
87 seq_nr: u16,
89
90 ack_nr: u16,
92
93 state: SocketState,
95
96 incoming_buffer: Vec<Packet>,
98
99 send_window: Vec<Packet>,
101
102 unsent_queue: VecDeque<Packet>,
104
105 duplicate_ack_count: u32,
108
109 last_acked: u16,
111
112 last_acked_timestamp: Timestamp,
114
115 last_dropped: u16,
117
118 rtt: i32,
120
121 rtt_variance: i32,
123
124 pending_data: Vec<u8>,
126
127 curr_window: u32,
129
130 remote_wnd_size: u32,
132
133 base_delays: VecDeque<Delay>,
135
136 current_delays: Vec<DelayDifferenceSample>,
139
140 their_delay: Delay,
143
144 last_rollover: Timestamp,
146
147 congestion_timeout: u64,
149
150 cwnd: u32,
152
153 pub max_retransmission_retries: u32,
155}
156
157impl UtpSocket {
158 fn from_raw_parts(s: UdpSocket, src: SocketAddr) -> UtpSocket {
163 let (receiver_id, sender_id) = generate_sequential_identifiers();
164
165 UtpSocket {
166 socket: s,
167 connected_to: src,
168 receiver_connection_id: receiver_id,
169 sender_connection_id: sender_id,
170 seq_nr: 1,
171 ack_nr: 0,
172 state: SocketState::New,
173 incoming_buffer: Vec::new(),
174 send_window: Vec::new(),
175 unsent_queue: VecDeque::new(),
176 duplicate_ack_count: 0,
177 last_acked: 0,
178 last_acked_timestamp: Timestamp::default(),
179 last_dropped: 0,
180 rtt: 0,
181 rtt_variance: 0,
182 pending_data: Vec::new(),
183 curr_window: 0,
184 remote_wnd_size: 0,
185 current_delays: Vec::new(),
186 base_delays: VecDeque::with_capacity(BASE_HISTORY),
187 their_delay: Delay::default(),
188 last_rollover: Timestamp::default(),
189 congestion_timeout: INITIAL_CONGESTION_TIMEOUT,
190 cwnd: INIT_CWND * MSS,
191 max_retransmission_retries: MAX_RETRANSMISSION_RETRIES,
192 }
193 }
194
195 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpSocket> {
197 let src = lookup_host(&addr)
198 .await?
199 .last()
200 .ok_or(ErrorKind::AddrNotAvailable)?;
201 let socket = UdpSocket::bind(addr).await?;
202
203 Ok(UtpSocket::from_raw_parts(socket, src))
204 }
205
206 pub fn local_addr(&self) -> Result<SocketAddr> {
208 self.socket.local_addr()
209 }
210
211 pub fn peer_addr(&self) -> Result<SocketAddr> {
213 if self.state == SocketState::Connected
214 || self.state == SocketState::FinSent
215 {
216 Ok(self.connected_to)
217 } else {
218 Err(SocketError::NotConnected.into())
219 }
220 }
221
222 pub async fn connect(addr: SocketAddr) -> Result<UtpSocket> {
224 let mut socket = UtpSocket::bind(addr).await?;
225 socket.connected_to = addr;
226
227 let mut packet = Packet::new();
228 packet.set_type(PacketType::Syn);
229 packet.set_connection_id(socket.receiver_connection_id);
230 packet.set_seq_nr(socket.seq_nr);
231
232 let mut len = 0;
233 let mut buf = [0; BUF_SIZE];
234
235 let mut syn_timeout = socket.congestion_timeout;
236 for _ in 0..MAX_SYN_RETRIES {
237 packet.set_timestamp(now_microseconds());
238
239 debug!("connecting to {}", socket.connected_to);
241 socket
242 .socket
243 .send_to(packet.as_ref(), socket.connected_to)
244 .await?;
245 socket.state = SocketState::SynSent;
246 debug!("sent {:?}", packet);
247
248 let to = Duration::from_millis(syn_timeout);
250
251 match timeout(to, socket.socket.recv_from(&mut buf)).await {
252 Ok(Ok((read, src))) => {
253 socket.connected_to = src;
254 len = read;
255 break;
256 }
257 Ok(Err(e)) => return Err(e),
258 Err(_) => {
259 debug!("timed out, retrying");
260 syn_timeout *= 2;
261 continue;
262 }
263 };
264 }
265
266 let addr = socket.connected_to;
267 let packet = Packet::try_from(&buf[..len])?;
268 debug!("received {:?}", packet);
269 socket.handle_packet(&packet, addr)?;
270
271 debug!("connected to: {}", socket.connected_to);
272
273 Ok(socket)
274 }
275
276 pub async fn close(&mut self) -> Result<()> {
281 if self.state == SocketState::Closed
283 || self.state == SocketState::New
284 || self.state == SocketState::SynSent
285 {
286 return Ok(());
287 }
288
289 let local = self.socket.local_addr()?;
290
291 debug!("closing {} -> {}", local, self.connected_to);
292
293 self.flush().await?;
295
296 debug!("close flush completed");
297
298 let mut packet = Packet::new();
299 packet.set_connection_id(self.sender_connection_id);
300 packet.set_seq_nr(self.seq_nr);
301 packet.set_ack_nr(self.ack_nr);
302 packet.set_timestamp(now_microseconds());
303 packet.set_type(PacketType::Fin);
304
305 self.socket
307 .send_to(packet.as_ref(), self.connected_to)
308 .await?;
309 debug!("sent {:?}", packet);
310 self.state = SocketState::FinSent;
311
312 let mut jbuf = [0; BUF_SIZE];
314 let mut buf: ReadBuf<'_> = ReadBuf::new(&mut jbuf);
315
316 while self.state != SocketState::Closed {
317 self.recv(&mut buf).await?;
318 }
319
320 debug!("closed {} -> {}", local, self.connected_to);
321
322 Ok(())
323 }
324
325 pub async fn recv_from(
331 &mut self,
332 buf: &mut [u8],
333 ) -> Result<(usize, SocketAddr)> {
334 let mut buf = ReadBuf::new(buf);
335 let read = self.flush_incoming_buffer(&mut buf);
336
337 if read > 0 {
338 Ok((read, self.connected_to))
339 } else {
340 if self.state == SocketState::ResetReceived {
343 return Err(SocketError::ConnectionReset.into());
344 }
345
346 loop {
347 if self.state == SocketState::Closed {
350 return Ok((0, self.connected_to));
351 }
352
353 match self.recv(&mut buf).await {
354 Ok((0, _src)) => continue,
355 Ok(x) => return Ok(x),
356 Err(e) => return Err(e),
357 }
358 }
359 }
360 }
361
362 async fn recv(
363 &mut self,
364 buf: &mut ReadBuf<'_>,
365 ) -> Result<(usize, SocketAddr)> {
366 let mut b = [0; BUF_SIZE + HEADER_SIZE];
367 let start = Instant::now();
368 let read;
369 let src;
370 let mut retries = 0;
371
372 loop {
374 if retries >= self.max_retransmission_retries {
377 self.state = SocketState::Closed;
378 return Err(SocketError::ConnectionTimedOut.into());
379 }
380
381 if self.state != SocketState::New {
382 let to = Duration::from_millis(self.congestion_timeout);
383 debug!(
384 "setting read timeout of {} ms",
385 self.congestion_timeout
386 );
387
388 match timeout(to, self.socket.recv_from(&mut b)).await {
389 Ok(Ok((r, s))) => {
390 read = r;
391 src = s;
392 break;
393 }
394 Ok(Err(e)) => return Err(e),
395 Err(_) => {
396 debug!("recv_from timed out");
397 self.handle_receive_timeout().await?;
398 }
399 };
400 } else {
401 match self.socket.recv_from(&mut b).await {
402 Ok((r, s)) => {
403 read = r;
404 src = s;
405 break;
406 }
407 Err(e) => return Err(e),
408 }
409 };
410
411 let elapsed = start.elapsed();
412 let elapsed_ms = elapsed.as_secs() * 1000
413 + (elapsed.subsec_millis() / 1_000_000) as u64;
414 debug!("{} ms elapsed", elapsed_ms);
415 retries += 1;
416 }
417
418 let packet = match Packet::try_from(&b[..read]) {
420 Ok(packet) => packet,
421 Err(e) => {
422 debug!("{}", e);
423 debug!("Ignoring invalid packet");
424 return Ok((0, self.connected_to));
425 }
426 };
427 debug!("received {:?}", packet);
428
429 if let Some(mut pkt) = self.handle_packet(&packet, src)? {
431 pkt.set_wnd_size(WINDOW_SIZE);
432 self.socket.send_to(pkt.as_ref(), src).await?;
433 debug!("sent {:?}", pkt);
434 }
435
436 if packet.get_type() == PacketType::Data
439 && packet.seq_nr().wrapping_sub(self.last_dropped) > 0
440 {
441 self.insert_into_buffer(packet);
442 }
443
444 let read = self.flush_incoming_buffer(buf);
446
447 Ok((read, src))
448 }
449
450 async fn handle_receive_timeout(&mut self) -> Result<()> {
451 self.congestion_timeout *= 2;
452 self.cwnd = MSS;
453
454 debug!(
464 "self.send_window: {:?}",
465 self.send_window
466 .iter()
467 .map(Packet::seq_nr)
468 .collect::<Vec<u16>>()
469 );
470
471 if self.send_window.is_empty() {
472 if self.state == SocketState::FinSent {
475 let mut packet = Packet::new();
476 packet.set_connection_id(self.sender_connection_id);
477 packet.set_seq_nr(self.seq_nr);
478 packet.set_ack_nr(self.ack_nr);
479 packet.set_timestamp(now_microseconds());
480 packet.set_type(PacketType::Fin);
481
482 self.socket
484 .send_to(packet.as_ref(), self.connected_to)
485 .await?;
486 debug!("resent FIN: {:?}", packet);
487 } else if self.state != SocketState::New {
488 debug!("sending fast resend request");
491 self.send_fast_resend_request();
492 }
493 } else {
494 let packet = &mut self.send_window[0];
498 packet.set_timestamp(now_microseconds());
499 self.socket
500 .send_to(packet.as_ref(), self.connected_to)
501 .await?;
502 debug!("resent {:?}", packet);
503 }
504
505 Ok(())
506 }
507
508 fn prepare_reply(&self, original: &Packet, t: PacketType) -> Packet {
509 let mut resp = Packet::new();
510 resp.set_type(t);
511 let self_t_micro = now_microseconds();
512 let other_t_micro = original.timestamp();
513 let time_difference: Delay = abs_diff(self_t_micro, other_t_micro);
514 resp.set_timestamp(self_t_micro);
515 resp.set_timestamp_difference(time_difference);
516 resp.set_connection_id(self.sender_connection_id);
517 resp.set_seq_nr(self.seq_nr);
518 resp.set_ack_nr(self.ack_nr);
519
520 resp
521 }
522
523 fn advance_incoming_buffer(&mut self) -> Option<Packet> {
526 if !self.incoming_buffer.is_empty() {
527 let packet = self.incoming_buffer.remove(0);
528 debug!("Removed packet from incoming buffer: {:?}", packet);
529 self.ack_nr = packet.seq_nr();
530 self.last_dropped = self.ack_nr;
531 Some(packet)
532 } else {
533 None
534 }
535 }
536
537 fn flush_incoming_buffer(&mut self, buf: &mut ReadBuf) -> usize {
543 fn copy(src: &[u8], dst: &mut ReadBuf) -> usize {
544 let to_copy = min(src.len(), dst.capacity());
545
546 dst.put_slice(&src[..to_copy]);
547
548 to_copy
549 }
550 if !self.pending_data.is_empty() {
552 let flushed = copy(&self.pending_data[..], buf);
553
554 if flushed == self.pending_data.len() {
555 self.pending_data.clear();
556 self.advance_incoming_buffer();
557 } else {
558 self.pending_data = self.pending_data[flushed..].to_vec();
559 }
560
561 return flushed;
562 }
563
564 if !self.incoming_buffer.is_empty()
566 && (self.ack_nr == self.incoming_buffer[0].seq_nr()
567 || self.ack_nr.wrapping_sub(self.incoming_buffer[0].seq_nr())
568 >= 1)
569 {
570 let flushed = copy(&self.incoming_buffer[0].payload(), buf);
571
572 if flushed == self.incoming_buffer[0].payload().len() {
573 self.advance_incoming_buffer();
574 } else {
575 self.pending_data =
576 self.incoming_buffer[0].payload()[flushed..].to_vec();
577 }
578
579 return flushed;
580 } else if !self.incoming_buffer.is_empty() {
581 debug!(
582 "not flushing out of order data, acked={} != cached={}",
583 self.ack_nr,
584 self.incoming_buffer[0].seq_nr()
585 );
586 }
587
588 0
589 }
590
591 pub(crate) fn should_read(&self) -> bool {
593 self.incoming_buffer.is_empty() && self.pending_data.is_empty()
594 }
595
596 pub async fn send_to(&mut self, buf: &[u8]) -> Result<usize> {
608 if self.state == SocketState::Closed {
609 return Err(SocketError::ConnectionClosed.into());
610 }
611
612 let total_length = buf.len();
613
614 for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
615 let mut packet = Packet::with_payload(chunk);
616 packet.set_seq_nr(self.seq_nr);
617 packet.set_ack_nr(self.ack_nr);
618 packet.set_connection_id(self.sender_connection_id);
619
620 self.unsent_queue.push_back(packet);
621
622 self.seq_nr = self.seq_nr.wrapping_add(1);
624 }
625
626 self.send().await?;
628
629 Ok(total_length)
630 }
631
632 pub async fn flush(&mut self) -> Result<()> {
634 let mut buf = [0u8; BUF_SIZE];
635 let mut buf = ReadBuf::new(&mut buf);
636
637 while !self.send_window.is_empty() {
638 debug!("packets in send window: {}", self.send_window.len());
639 self.recv(&mut buf).await?;
640 }
641
642 Ok(())
643 }
644
645 async fn send(&mut self) -> Result<()> {
647 while let Some(mut packet) = self.unsent_queue.pop_front() {
648 self.send_packet(&mut packet).await?;
649 self.curr_window += packet.len() as u32;
650 self.send_window.push(packet);
651 }
652 Ok(())
653 }
654
655 fn max_inflight(&self) -> u32 {
656 let max_inflight = min(self.cwnd, self.remote_wnd_size);
657 max(MIN_CWND * MSS, max_inflight)
658 }
659
660 #[inline]
662 async fn send_packet(&mut self, packet: &mut Packet) -> Result<()> {
663 debug!("current window: {}", self.send_window.len());
664
665 packet.set_timestamp(now_microseconds());
666 packet.set_timestamp_difference(self.their_delay);
667
668 self.socket
669 .send_to(packet.as_ref(), self.connected_to)
670 .await?;
671
672 debug!("sent {:?}", packet);
673
674 Ok(())
675 }
676
677 fn update_base_delay(&mut self, base_delay: Delay, now: Timestamp) {
682 if self.base_delays.is_empty()
683 || now - self.last_rollover > MAX_BASE_DELAY_AGE
684 {
685 self.last_rollover = now;
687
688 if self.base_delays.len() == BASE_HISTORY {
690 self.base_delays.pop_front();
691 }
692
693 self.base_delays.push_back(base_delay);
695 } else {
696 let last_idx = self.base_delays.len() - 1;
698 if base_delay < self.base_delays[last_idx] {
699 self.base_delays[last_idx] = base_delay;
700 }
701 }
702 }
703
704 fn update_current_delay(&mut self, v: Delay, now: Timestamp) {
707 let rtt = (self.rtt as i64 * 100).into();
709 while !self.current_delays.is_empty()
710 && now - self.current_delays[0].received_at > rtt
711 {
712 self.current_delays.remove(0);
713 }
714
715 self.current_delays.push(DelayDifferenceSample {
717 received_at: now,
718 difference: v,
719 });
720 }
721
722 fn update_congestion_timeout(&mut self, current_delay: i32) {
723 let delta = self.rtt - current_delay;
724 self.rtt_variance += (delta.abs() - self.rtt_variance) / 4;
725 self.rtt += (current_delay - self.rtt) / 8;
726 self.congestion_timeout = max(
727 (self.rtt + self.rtt_variance * 4) as u64,
728 MIN_CONGESTION_TIMEOUT,
729 );
730 self.congestion_timeout =
731 min(self.congestion_timeout, MAX_CONGESTION_TIMEOUT);
732
733 debug!("current_delay: {}", current_delay);
734 debug!("delta: {}", delta);
735 debug!("self.rtt_variance: {}", self.rtt_variance);
736 debug!("self.rtt: {}", self.rtt);
737 debug!("self.congestion_timeout: {}", self.congestion_timeout);
738 }
739
740 fn filtered_current_delay(&self) -> Delay {
746 let input = self.current_delays.iter().map(|delay| &delay.difference);
747 (ewma(input, 0.333) as i64).into()
748 }
749
750 fn min_base_delay(&self) -> Delay {
752 self.base_delays.iter().min().cloned().unwrap_or_default()
753 }
754
755 fn build_selective_ack(&self) -> Vec<u8> {
757 let stashed = self
758 .incoming_buffer
759 .iter()
760 .filter(|pkt| pkt.seq_nr() > self.ack_nr + 1)
761 .map(|pkt| (pkt.seq_nr() - self.ack_nr - 2) as usize)
762 .map(|diff| (diff / 8, diff % 8));
763
764 let mut sack = Vec::new();
765 for (byte, bit) in stashed {
766 while byte >= sack.len() || sack.len() % 4 != 0 {
769 sack.push(0u8);
770 }
771
772 sack[byte] |= 1 << bit;
773 }
774
775 sack
776 }
777
778 fn send_fast_resend_request(&mut self) {
783 for _ in 0..3usize {
784 let mut packet = Packet::new();
785 packet.set_type(PacketType::State);
786 let self_t_micro = now_microseconds();
787 packet.set_timestamp(self_t_micro);
788 packet.set_timestamp_difference(self.their_delay);
789 packet.set_connection_id(self.sender_connection_id);
790 packet.set_seq_nr(self.seq_nr);
791 packet.set_ack_nr(self.ack_nr);
792 self.unsent_queue.push_back(packet);
793 }
794 }
795
796 fn resend_lost_packet(&mut self, lost_packet_nr: u16) {
797 debug!("---> resend_lost_packet({}) <---", lost_packet_nr);
798 match self
799 .send_window
800 .iter()
801 .position(|pkt| pkt.seq_nr() == lost_packet_nr)
802 {
803 None => debug!("Packet {} not found", lost_packet_nr),
804 Some(position) => {
805 debug!("self.send_window.len(): {}", self.send_window.len());
806 debug!("position: {}", position);
807 let packet = self.send_window[position].clone();
808
809 self.unsent_queue.push_back(packet);
810
811 }
814 }
815 debug!("---> END resend_lost_packet <---");
816 }
817
818 fn advance_send_window(&mut self) {
820 if let Some(position) = self
832 .send_window
833 .iter()
834 .position(|packet| packet.seq_nr() == self.last_acked)
835 {
836 for _ in 0..=position {
837 let packet = self.send_window.remove(0);
838 debug!("removing {} bytes from send window", packet.len());
839 debug!(
840 "{} packets left in send window",
841 self.send_window.len()
842 );
843 self.curr_window -= packet.len() as u32;
844 }
845 }
846 debug!("self.curr_window: {}", self.curr_window);
847 }
848
849 fn handle_fin_packet(
850 &mut self,
851 packet: &Packet,
852 src: SocketAddr,
853 ) -> Packet {
854 if packet.ack_nr() < self.seq_nr {
855 debug!("FIN received but there are missing acknowledgements for sent packets");
856 }
857 let mut reply = self.prepare_reply(packet, PacketType::State);
858 if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
859 debug!(
860 "current ack_nr({}) is behind received packet seq_nr ({})",
861 self.ack_nr,
862 packet.seq_nr()
863 );
864
865 let sack = self.build_selective_ack();
867
868 if !sack.is_empty() {
869 debug!("sending SACK to peer");
870 reply.set_sack(sack);
871 }
872 }
873
874 debug!("received FIN from {}, connection is closed", src);
875
876 self.state = SocketState::Closed;
878 reply
879 }
880
881 fn handle_packet(
885 &mut self,
886 packet: &Packet,
887 src: SocketAddr,
888 ) -> Result<Option<Packet>> {
889 debug!("({:?}, {:?})", self.state, packet.get_type());
890
891 if packet.seq_nr().wrapping_sub(self.ack_nr) == 1 {
893 self.ack_nr = packet.seq_nr();
894 }
895
896 if packet.get_type() != PacketType::Syn
898 && self.state != SocketState::SynSent
899 && !(packet.connection_id() == self.sender_connection_id
900 || packet.connection_id() == self.receiver_connection_id)
901 {
902 return Ok(Some(self.prepare_reply(packet, PacketType::Reset)));
903 }
904
905 self.remote_wnd_size = packet.wnd_size();
907 debug!("self.remote_wnd_size: {}", self.remote_wnd_size);
908
909 let now = now_microseconds();
912 self.their_delay = abs_diff(now, packet.timestamp());
913 debug!("self.their_delay: {}", self.their_delay);
914
915 match (self.state, packet.get_type()) {
916 (SocketState::New, PacketType::Syn) => {
917 self.connected_to = src;
918 self.ack_nr = packet.seq_nr();
919 self.seq_nr = rand::random();
920 self.receiver_connection_id = packet.connection_id() + 1;
921 self.sender_connection_id = packet.connection_id();
922 self.state = SocketState::Connected;
923 self.last_dropped = self.ack_nr;
924
925 Ok(Some(self.prepare_reply(packet, PacketType::State)))
926 }
927 (_, PacketType::Syn) => {
928 Ok(Some(self.prepare_reply(packet, PacketType::Reset)))
929 }
930 (SocketState::SynSent, PacketType::State) => {
931 self.connected_to = src;
932 self.ack_nr = packet.seq_nr();
933 self.seq_nr += 1;
934 self.state = SocketState::Connected;
935 self.last_acked = packet.ack_nr();
936 self.last_acked_timestamp = now_microseconds();
937 Ok(None)
938 }
939 (SocketState::SynSent, _) => Err(SocketError::InvalidReply.into()),
940 (SocketState::Connected, PacketType::Data)
941 | (SocketState::FinSent, PacketType::Data) => {
942 Ok(self.handle_data_packet(packet))
943 }
944 (SocketState::Connected, PacketType::State) => {
945 self.handle_state_packet(packet);
946 Ok(None)
947 }
948 (SocketState::Connected, PacketType::Fin)
949 | (SocketState::FinSent, PacketType::Fin) => {
950 Ok(Some(self.handle_fin_packet(packet, src)))
951 }
952 (SocketState::Closed, PacketType::Fin) => {
953 Ok(Some(self.prepare_reply(packet, PacketType::State)))
954 }
955 (SocketState::FinSent, PacketType::State) => {
956 if packet.ack_nr() == self.seq_nr {
957 debug!("connection closed succesfully");
958 self.state = SocketState::Closed;
959 } else {
960 self.handle_state_packet(packet);
961 }
962 Ok(None)
963 }
964 (_, PacketType::Reset) => {
965 self.state = SocketState::ResetReceived;
966 Err(SocketError::ConnectionReset.into())
967 }
968 (state, ty) => {
969 let message = format!(
970 "Unimplemented handling for ({:?},{:?})",
971 state, ty
972 );
973 debug!("{}", message);
974 Err(SocketError::Other(message).into())
975 }
976 }
977 }
978
979 fn handle_data_packet(&mut self, packet: &Packet) -> Option<Packet> {
980 let packet_type = if self.state == SocketState::FinSent {
983 PacketType::Fin
984 } else {
985 PacketType::State
986 };
987 let mut reply = self.prepare_reply(packet, packet_type);
988
989 if packet.seq_nr().wrapping_sub(self.ack_nr) > 1 {
990 debug!(
991 "current ack_nr ({}) is behind received packet seq_nr ({})",
992 self.ack_nr,
993 packet.seq_nr()
994 );
995
996 let sack = self.build_selective_ack();
998
999 if !sack.is_empty() {
1000 debug!("sending SACK packet");
1001 reply.set_sack(sack);
1002 }
1003 }
1004
1005 Some(reply)
1006 }
1007
1008 fn queuing_delay(&self) -> Delay {
1009 let filtered_current_delay = self.filtered_current_delay();
1010 let min_base_delay = self.min_base_delay();
1011 let queuing_delay = filtered_current_delay - min_base_delay;
1012
1013 debug!("filtered_current_delay: {}", filtered_current_delay);
1014 debug!("min_base_delay: {}", min_base_delay);
1015 debug!("queuing_delay: {}", queuing_delay);
1016
1017 queuing_delay
1018 }
1019
1020 fn update_congestion_window(
1042 &mut self,
1043 off_target: f64,
1044 bytes_newly_acked: u32,
1045 ) {
1046 let flightsize = self.curr_window;
1047
1048 let cwnd_increase =
1049 GAIN * off_target * bytes_newly_acked as f64 * MSS as f64;
1050 let cwnd_increase = cwnd_increase / self.cwnd as f64;
1051 debug!("cwnd_increase: {}", cwnd_increase);
1052
1053 self.cwnd = (self.cwnd as f64 + cwnd_increase) as u32;
1054 let max_allowed_cwnd = flightsize + ALLOWED_INCREASE * MSS;
1055 self.cwnd = min(self.cwnd, max_allowed_cwnd);
1056 self.cwnd = max(self.cwnd, MIN_CWND * MSS);
1057
1058 debug!("cwnd: {}", self.cwnd);
1059 debug!("max_allowed_cwnd: {}", max_allowed_cwnd);
1060 }
1061
1062 fn handle_packet_extension(
1063 &mut self,
1064 packet: &Packet,
1065 packet_loss_detected: &mut bool,
1066 ) {
1067 for extension in packet.extensions() {
1069 if extension.get_type() == ExtensionType::SelectiveAck {
1070 if extension.iter().count_ones() >= 3 {
1073 self.resend_lost_packet(packet.ack_nr() + 1);
1074 *packet_loss_detected = true;
1075 }
1076
1077 if let Some(last_seq_nr) =
1078 self.send_window.last().map(Packet::seq_nr)
1079 {
1080 let lost_packets = extension
1081 .iter()
1082 .enumerate()
1083 .filter(|&(_, received)| !received)
1084 .map(|(idx, _)| packet.ack_nr() + 2 + idx as u16)
1085 .take_while(|&seq_nr| seq_nr < last_seq_nr);
1086
1087 for seq_nr in lost_packets {
1088 debug!("SACK: packet {} lost", seq_nr);
1089 self.resend_lost_packet(seq_nr);
1090 *packet_loss_detected = true;
1091 }
1092 }
1093 } else {
1094 debug!(
1095 "Unknown extension {:?}, ignoring",
1096 extension.get_type()
1097 );
1098 }
1099 }
1100 }
1101
1102 fn handle_state_packet(&mut self, packet: &Packet) {
1103 if packet.ack_nr() == self.last_acked {
1104 self.duplicate_ack_count += 1;
1105 } else {
1106 self.last_acked = packet.ack_nr();
1107 self.last_acked_timestamp = now_microseconds();
1108 self.duplicate_ack_count = 1;
1109 }
1110
1111 if let Some(index) = self
1113 .send_window
1114 .iter()
1115 .position(|p| packet.ack_nr() == p.seq_nr())
1116 {
1117 let bytes_newly_acked = self
1123 .send_window
1124 .iter()
1125 .take(index + 1)
1126 .fold(0, |acc, p| acc + p.len());
1127
1128 let now = now_microseconds();
1130 let our_delay = now - self.send_window[index].timestamp();
1131 debug!("our_delay: {}", our_delay);
1132 self.update_base_delay(our_delay, now);
1133 self.update_current_delay(our_delay, now);
1134
1135 let off_target: f64 =
1136 (TARGET - u32::from(self.queuing_delay()) as f64) / TARGET;
1137 debug!("off_target: {}", off_target);
1138
1139 self.update_congestion_window(off_target, bytes_newly_acked as u32);
1140
1141 let rtt = u32::from(our_delay - self.queuing_delay()) / 1000;
1143 self.update_congestion_timeout(rtt as i32);
1144 }
1145
1146 let mut packet_loss_detected: bool =
1147 !self.send_window.is_empty() && self.duplicate_ack_count == 3;
1148
1149 self.handle_packet_extension(packet, &mut packet_loss_detected);
1150
1151 if !self.send_window.is_empty()
1155 && self.duplicate_ack_count == 3
1156 && !packet
1157 .extensions()
1158 .any(|ext| ext.get_type() == ExtensionType::SelectiveAck)
1159 {
1160 self.resend_lost_packet(packet.ack_nr() + 1);
1161 }
1162
1163 if packet_loss_detected {
1165 debug!("packet loss detected, halving congestion window");
1166 self.cwnd = max(self.cwnd / 2, MIN_CWND * MSS);
1167 debug!("cwnd: {}", self.cwnd);
1168 }
1169
1170 self.advance_send_window();
1172 }
1173
1174 fn insert_into_buffer(&mut self, packet: Packet) {
1183 if self
1186 .incoming_buffer
1187 .last()
1188 .map_or(false, |p| packet.seq_nr() > p.seq_nr())
1189 {
1190 self.incoming_buffer.push(packet);
1191 } else {
1192 let i = self
1195 .incoming_buffer
1196 .iter()
1197 .filter(|p| p.seq_nr() < packet.seq_nr())
1198 .count();
1199
1200 if self
1201 .incoming_buffer
1202 .get(i)
1203 .map_or(true, |p| p.seq_nr() != packet.seq_nr())
1204 {
1205 self.incoming_buffer.insert(i, packet);
1206 }
1207 }
1208 }
1209}
1210
1211macro_rules! ready_unpin {
1214 ($data:expr, $cx:expr) => {
1215 match unsafe { Pin::new_unchecked(&mut $data) }.poll($cx) {
1216 Poll::Ready(v) => v,
1217 Poll::Pending => return Poll::Pending,
1218 }
1219 };
1220}
1221
1222macro_rules! ready_try_unpin {
1225 ($data:expr, $cx:expr) => {
1226 match ready_unpin!($data, $cx) {
1227 Ok(v) => v,
1228 Err(e) => return Poll::Ready(Err(e)),
1229 }
1230 };
1231}
1232
1233macro_rules! poll_unpin {
1235 ($data:expr, $cx:expr) => {{
1236 #[allow()]
1237 let x = unsafe { Pin::new_unchecked(&mut $data) }.poll($cx);
1238 x
1239 }};
1240}
1241
1242macro_rules! ready_try {
1243 ($data:expr) => {{
1244 match ($data) {
1245 Poll::Pending => return Poll::Pending,
1246 Poll::Ready(Ok(v)) => v,
1247 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1248 }
1249 }};
1250}
1251
1252pub struct UtpSocketRef(Arc<Mutex<UtpSocket>>, SocketAddr);
1256
1257impl UtpSocketRef {
1258 fn new(socket: Arc<Mutex<UtpSocket>>, local: SocketAddr) -> Self {
1259 Self(socket, local)
1260 }
1261
1262 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
1264 let udp = UdpSocket::bind(addr).await?;
1265 let resolved = udp.local_addr()?;
1266 let socket = UtpSocket::from_raw_parts(udp, resolved);
1267 let lock = Arc::new(Mutex::new(socket));
1268
1269 debug!("bound utp socket on {}", resolved);
1270
1271 Ok(Self::new(lock, resolved))
1272 }
1273
1274 pub async fn connect(
1276 self,
1277 dst: SocketAddr,
1278 ) -> Result<(UtpStream, UtpStreamDriver)> {
1279 let mut socket = self.0.lock().await;
1280
1281 socket.connected_to = dst;
1282
1283 let mut packet = Packet::new();
1284 packet.set_type(PacketType::Syn);
1285 packet.set_connection_id(socket.receiver_connection_id);
1286 packet.set_seq_nr(socket.seq_nr);
1287
1288 let mut len = 0;
1289 let mut buf = [0; BUF_SIZE];
1290
1291 let mut syn_timeout = socket.congestion_timeout;
1292 for _ in 0..MAX_SYN_RETRIES {
1293 packet.set_timestamp(now_microseconds());
1294
1295 debug!("connecting to {}", socket.connected_to);
1296 let dst = socket.connected_to;
1297
1298 socket.socket.send_to(packet.as_ref(), dst).await?;
1299 socket.state = SocketState::SynSent;
1300 debug!("sent {:?}", packet);
1301
1302 let to = Duration::from_millis(syn_timeout);
1303
1304 match timeout(to, socket.socket.recv_from(&mut buf)).await {
1305 Ok(Ok((read, src))) => {
1306 socket.connected_to = src;
1307 len = read;
1308 break;
1309 }
1310 Ok(Err(e)) => return Err(e),
1311 Err(_) => {
1312 debug!("timed out, retrying");
1313 syn_timeout *= 2;
1314 continue;
1315 }
1316 };
1317 }
1318
1319 let remote = socket.connected_to;
1320 let packet = Packet::try_from(&buf[..len])?;
1321 debug!("received {:?}", packet);
1322 socket.handle_packet(&packet, remote)?;
1323
1324 debug!("connected to: {}", socket.connected_to);
1325
1326 let (tx, rx) = unbounded_channel();
1327
1328 let local = socket.local_addr()?;
1329
1330 mem::drop(socket);
1331
1332 let driver = UtpStreamDriver::new(self.0.clone(), tx);
1333 let stream = UtpStream::new(self.0, rx, local, remote);
1334
1335 Ok((stream, driver))
1336 }
1337
1338 pub async fn accept(self) -> Result<(UtpStream, UtpStreamDriver)> {
1343 let (src, dst);
1344
1345 loop {
1346 let mut socket = self.0.lock().await;
1347 let mut buf = [0u8; BUF_SIZE];
1348
1349 let (read, remote) = socket.socket.recv_from(&mut buf).await?;
1350
1351 let packet = Packet::try_from(&buf[..read])?;
1352
1353 debug!("accept receive {:?}", packet);
1354
1355 if let Ok(Some(reply)) = socket.handle_packet(&packet, remote) {
1356 src = socket.socket.local_addr()?;
1357 dst = socket.connected_to;
1358
1359 socket.socket.send_to(reply.as_ref(), dst).await?;
1360
1361 debug!("sent {:?} to {}", reply, dst);
1362 debug!("accepted connection {} -> {}", dst, src);
1363 break;
1364 }
1365 }
1366
1367 let (tx, rx) = unbounded_channel();
1368 let socket = self.0;
1369 let stream = UtpStream::new(socket.clone(), rx, src, dst);
1370 let driver = UtpStreamDriver::new(socket, tx);
1371
1372 Ok((stream, driver))
1373 }
1374
1375 pub fn local_addr(&self) -> SocketAddr {
1377 self.1
1378 }
1379}
1380
1381pub struct UtpStream {
1384 socket: Arc<Mutex<UtpSocket>>,
1385 receiver: UnboundedReceiver<Result<()>>,
1386 local: SocketAddr,
1387 remote: SocketAddr,
1388}
1389
1390impl UtpStream {
1391 fn new(
1392 socket: Arc<Mutex<UtpSocket>>,
1393 receiver: UnboundedReceiver<Result<()>>,
1394 local: SocketAddr,
1395 remote: SocketAddr,
1396 ) -> Self {
1397 Self {
1398 socket,
1399 receiver,
1400 local,
1401 remote,
1402 }
1403 }
1404
1405 pub fn local_addr(&self) -> SocketAddr {
1407 self.local
1408 }
1409
1410 pub fn peer_addr(&self) -> SocketAddr {
1412 self.remote
1413 }
1414
1415 fn handle_driver_notification(
1416 mut self: Pin<&mut Self>,
1417 cx: &mut Context<'_>,
1418 buf: &mut ReadBuf,
1419 ) -> Poll<Result<()>> {
1420 match poll_unpin!(self.receiver.recv(), cx) {
1421 Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
1423 debug!("connection driver has died");
1424 Poll::Ready(Ok(()))
1425 }
1426 Poll::Ready(Some(Ok(()))) => {
1427 debug!("notification from driver");
1428 self.poll_read(cx, buf)
1429 }
1430 Poll::Pending => {
1431 debug!("waiting for notification from driver");
1432 Poll::Pending
1433 }
1434 }
1435 }
1436
1437 fn prepare_packet(socket: &mut UtpSocket, chunk: &[u8]) -> Packet {
1438 let mut packet = Packet::with_payload(chunk);
1439
1440 packet.set_seq_nr(socket.seq_nr);
1441 packet.set_ack_nr(socket.ack_nr);
1442 packet.set_connection_id(socket.sender_connection_id);
1443
1444 packet
1445 }
1446
1447 fn handle_driver_message(
1448 msg: Poll<Option<Result<()>>>,
1449 ) -> Poll<Result<()>> {
1450 match msg {
1451 Poll::Ready(None) => {
1452 debug!("driver is dead, closing success");
1453 Poll::Ready(Ok(()))
1454 }
1455 Poll::Ready(Some(Ok(()))) => {
1456 debug!("driver sent closing notice");
1457 Poll::Ready(Ok(()))
1458 }
1459 Poll::Ready(Some(Err(e)))
1460 if e.kind() == ErrorKind::NotConnected =>
1461 {
1462 debug!("connection closed by err");
1463 Poll::Ready(Ok(()))
1464 }
1465 Poll::Ready(Some(Err(e))) => {
1466 debug!("failed to close correctly");
1467 Poll::Ready(Err(e))
1468 }
1469 Poll::Pending => {
1470 debug!("waiting for driver to complete closing");
1471 Poll::Pending
1472 }
1473 }
1474 }
1475
1476 fn wait_acks(
1477 socket: &mut UtpSocket,
1478 cx: &mut Context<'_>,
1479 ) -> Poll<Result<()>> {
1480 let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
1481
1482 debug!("waiting for ACKs for {} packets", socket.send_window.len());
1483
1484 while !socket.send_window.is_empty()
1485 && socket.state != SocketState::Closed
1486 {
1487 let (read, src) = {
1488 match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
1489 Poll::Ready(Ok((read, src))) => (read, src),
1490 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1491 Poll::Pending => return Poll::Pending,
1492 }
1493 };
1494
1495 let packet = Packet::try_from(&buf[..read])?;
1496
1497 if let Some(reply) = socket.handle_packet(&packet, src)? {
1498 if poll_unpin!(socket.socket.send_to(reply.as_ref(), src), cx)
1499 .is_pending()
1500 {
1501 socket.unsent_queue.push_back(reply);
1502 return Poll::Pending;
1503 }
1504 }
1505 }
1506
1507 Poll::Ready(Ok(()))
1508 }
1509
1510 fn flush_unsent(
1511 socket: &mut UtpSocket,
1512 cx: &mut Context<'_>,
1513 ) -> Poll<Result<()>> {
1514 while let Some(mut packet) = socket.unsent_queue.pop_front() {
1515 if poll_unpin!(socket.send_packet(&mut packet), cx).is_pending() {
1516 debug!("too many in flight packets, waiting for ack");
1517 return Poll::Pending;
1518 }
1519
1520 let result = {
1521 let dst = socket.connected_to;
1522 poll_unpin!(socket.socket.send_to(packet.as_ref(), dst), cx)
1523 };
1524
1525 match result {
1526 Poll::Pending => {
1527 socket.unsent_queue.push_front(packet);
1528 return Poll::Pending;
1529 }
1530 Poll::Ready(Ok(_)) => socket.send_window.push(packet),
1531 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1532 }
1533 }
1534
1535 Poll::Ready(Ok(()))
1536 }
1537}
1538
1539impl AsyncRead for UtpStream
1540where
1541 Self: Unpin,
1542{
1543 fn poll_read(
1544 self: Pin<&mut Self>,
1545 cx: &mut Context,
1546 buf: &mut ReadBuf,
1547 ) -> Poll<Result<()>> {
1548 debug!("read poll for {} bytes", buf.capacity());
1549
1550 let (read, state) = {
1551 let mut socket = ready_unpin!(self.socket.lock(), cx);
1552
1553 (socket.flush_incoming_buffer(buf), socket.state)
1554 };
1555
1556 if read > 0 {
1557 debug!("flushed {} bytes of received data", read);
1558 Poll::Ready(Ok(()))
1559 } else if state == SocketState::Closed {
1560 debug!("read on closed connection");
1561 Poll::Ready(Ok(()))
1562 } else if state == SocketState::ResetReceived {
1563 debug!("read on reset connection");
1564 Poll::Ready(Err(SocketError::ConnectionReset.into()))
1565 } else {
1566 self.handle_driver_notification(cx, buf)
1567 }
1568 }
1569}
1570
1571impl AsyncWrite for UtpStream
1572where
1573 Self: Unpin,
1574{
1575 fn poll_write(
1576 mut self: Pin<&mut Self>,
1577 cx: &mut Context,
1578 buf: &[u8],
1579 ) -> Poll<Result<usize>> {
1580 let mut socket = ready_unpin!(self.socket.lock(), cx);
1581
1582 if socket.state == SocketState::Closed {
1583 debug!("tried to write on closed connection");
1584 return Poll::Ready(Err(SocketError::ConnectionClosed.into()));
1585 }
1586
1587 let mut sent: usize = 0;
1588
1589 debug!("trying to send {} bytes", buf.len());
1590
1591 for chunk in buf.chunks(MSS as usize - HEADER_SIZE) {
1592 if socket.curr_window >= socket.max_inflight() {
1593 debug!("send window is full, waiting for ACKs");
1594 mem::drop(socket);
1595
1596 while poll_unpin!(self.receiver.recv(), cx).is_ready() {}
1597
1598 return Poll::Pending;
1599 }
1600
1601 debug!("attempting to send chunk of {} byte", chunk.len());
1602
1603 let mut packet = Self::prepare_packet(&mut socket, chunk);
1604
1605 match poll_unpin!(socket.send_packet(&mut packet), cx) {
1606 Poll::Pending if sent == 0 => {
1607 debug!("socket send buffer is full, waiting..");
1608 return Poll::Pending;
1609 }
1610 Poll::Ready(Err(e)) if sent == 0 => {
1611 debug!("os error reading data: {}", e);
1612 return Poll::Ready(Err(e));
1613 }
1614 Poll::Pending | Poll::Ready(Err(_)) => {
1615 debug!("successfully sent {} bytes, sleeping...", sent);
1616 return Poll::Ready(Ok(sent));
1617 }
1618
1619 Poll::Ready(Ok(())) => {
1620 let written = packet.len();
1621
1622 socket.curr_window += written as u32;
1623 socket.send_window.push(packet);
1624
1625 sent += written;
1626 socket.seq_nr = socket.seq_nr.wrapping_add(1);
1627
1628 debug!(
1629 "poll_write sent seq {}, curr_window: {}",
1630 socket.seq_nr - 1,
1631 socket.curr_window
1632 );
1633 }
1634 }
1635 }
1636
1637 Poll::Ready(Ok(buf.len()))
1638 }
1639
1640 fn poll_flush(
1641 mut self: Pin<&mut Self>,
1642 cx: &mut Context,
1643 ) -> Poll<Result<()>> {
1644 debug!("attempting flush");
1645
1646 match poll_unpin!(self.receiver.recv(), cx) {
1647 Poll::Ready(Some(Err(e))) => {
1648 debug!("driver signaled error over channel");
1649 return Poll::Ready(Err(e));
1650 }
1651 Poll::Ready(None) => {
1652 debug!("connection driver disconnected");
1653 return Poll::Ready(Ok(()));
1654 }
1655 _ => debug!("no message from driver"),
1656 }
1657
1658 let mut socket = ready_unpin!(self.socket.lock(), cx);
1659
1660 if socket.state == SocketState::Closed {
1661 return Poll::Ready(Err(SocketError::NotConnected.into()));
1662 }
1663
1664 ready_try!(Self::flush_unsent(&mut socket, cx));
1665
1666 ready_try!(Self::wait_acks(&mut socket, cx));
1667
1668 debug!("sucessfully flushed");
1669
1670 Poll::Ready(Ok(()))
1671 }
1672
1673 fn poll_shutdown(
1674 mut self: Pin<&mut Self>,
1675 cx: &mut Context,
1676 ) -> Poll<Result<()>> {
1677 debug!("poll_shutdown connection...");
1678
1679 {
1680 let socket = ready_unpin!(self.socket.lock(), cx);
1681
1682 if socket.state == SocketState::Closed {
1683 debug!("socket closed by driver");
1684 return Poll::Ready(Ok(()));
1685 }
1686 }
1687
1688 match self.as_mut().poll_flush(cx) {
1689 Poll::Pending => Poll::Pending,
1690 Poll::Ready(Ok(())) => {
1691 {
1692 let mut socket = ready_unpin!(self.socket.lock(), cx);
1693
1694 if socket.state != SocketState::FinSent {
1695 if let Poll::Ready(Ok(())) =
1696 poll_unpin!(socket.close(), cx)
1697 {
1698 return Poll::Ready(Ok(()));
1699 } else {
1700 mem::drop(socket);
1701 ready_unpin!(self.receiver.recv(), cx);
1702 }
1703 }
1704 }
1705
1706 let msg = poll_unpin!(self.receiver.recv(), cx);
1707
1708 Self::handle_driver_message(msg)
1709 }
1710 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1711 }
1712 }
1713}
1714
1715#[must_use = "stream drivers must be spawned for the stream to work"]
1716pub struct UtpStreamDriver {
1720 socket: Arc<Mutex<UtpSocket>>,
1721 sender: UnboundedSender<Result<()>>,
1722 timer: Pin<Box<Sleep>>,
1723 timeout_nr: u32,
1724}
1725
1726impl UtpStreamDriver {
1727 fn new(
1728 socket: Arc<Mutex<UtpSocket>>,
1729 sender: UnboundedSender<Result<()>>,
1730 ) -> Self {
1731 Self {
1732 socket,
1733 sender,
1734 timer: Box::pin(sleep(Duration::from_millis(
1735 INITIAL_CONGESTION_TIMEOUT,
1736 ))),
1737 timeout_nr: 0,
1738 }
1739 }
1740
1741 async fn handle_timeout(&mut self, next_timeout: u64) -> Result<()> {
1742 self.timeout_nr += 1;
1743 debug!(
1744 "timed out {} times out of {} max, retrying in {} ms",
1745 self.timeout_nr, MAX_RETRANSMISSION_RETRIES, next_timeout
1746 );
1747
1748 if self.timeout_nr > MAX_RETRANSMISSION_RETRIES {
1749 let mut socket = self.socket.lock().await;
1750 socket.state = SocketState::Closed;
1751
1752 return Err(SocketError::ConnectionTimedOut.into());
1753 }
1754
1755 let ret = {
1756 let mut socket = self.socket.lock().await;
1757 socket.handle_receive_timeout().await
1758 };
1759
1760 self.reset_timer(Duration::from_millis(next_timeout));
1761
1762 ret
1763 }
1764
1765 fn notify_close(&mut self) {
1766 if self
1767 .sender
1768 .send(Err(SocketError::NotConnected.into()))
1769 .is_err()
1770 {
1771 error!("failed to notify socket of termination");
1772 } else {
1773 debug!("notified socket of closing");
1774 }
1775 }
1776
1777 fn send_reply(
1778 socket: &mut UtpSocket,
1779 mut reply: Packet,
1780 cx: &mut Context<'_>,
1781 ) -> Poll<Result<()>> {
1782 match poll_unpin!(socket.send_packet(&mut reply), cx) {
1783 Poll::Pending => {
1784 socket.unsent_queue.push_back(reply);
1785 Poll::Pending
1786 }
1787 Poll::Ready(Err(e)) => {
1788 error!("driver failed to send packet: {}", e);
1789 Poll::Ready(Err(e))
1790 }
1791 _ => Poll::Ready(Ok(())),
1792 }
1793 }
1794
1795 fn reset_timer(&mut self, next_timeout: Duration) {
1796 let now = TokioInstant::from_std(Instant::now());
1797
1798 self.timer.as_mut().reset(now + next_timeout);
1799 }
1800
1801 fn check_timeout(
1802 &mut self,
1803 cx: &mut Context<'_>,
1804 next_timeout: u64,
1805 ) -> Poll<Result<()>> {
1806 if self.timer.is_elapsed() {
1807 debug!("receive timeout detected");
1808
1809 match poll_unpin!(self.handle_timeout(next_timeout), cx) {
1810 Poll::Pending => todo!("socket buffer full"),
1811 Poll::Ready(Ok(())) => {
1812 self.reset_timer(Duration::from_millis(next_timeout));
1813
1814 ready_unpin!(self.timer, cx);
1815
1816 Poll::Pending
1817 }
1818 Poll::Ready(Err(e)) => {
1819 debug!("remote peer timed out too many times");
1820 self.sender
1821 .send(Err(e.kind().into()))
1822 .expect("failed to propagate");
1823
1824 Poll::Ready(Err(e))
1825 }
1826 }
1827 } else {
1828 ready_unpin!(self.timer, cx);
1829 Poll::Pending
1830 }
1831 }
1832}
1833
1834impl Future for UtpStreamDriver {
1835 type Output = Result<()>;
1836
1837 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
1838 let sender = self.sender.clone();
1839 let mut socket = ready_unpin!(self.socket.lock(), cx);
1840 let mut buf = [0u8; BUF_SIZE + HEADER_SIZE];
1841
1842 loop {
1843 debug!("stream driver poll attempt");
1844
1845 if socket.state == SocketState::Closed {
1846 debug!("socket is closed when attempting poll, killing driver");
1847
1848 mem::drop(socket);
1849
1850 self.notify_close();
1851
1852 return Poll::Ready(Ok(()));
1853 }
1854
1855 match poll_unpin!(socket.socket.recv_from(&mut buf), cx) {
1856 Poll::Ready(Ok((read, src))) => {
1857 if let Ok(packet) = Packet::try_from(&buf[..read]) {
1858 debug!("received packet {:?}", packet);
1859
1860 match socket.handle_packet(&packet, src) {
1861 Ok(Some(reply)) => {
1862 if let PacketType::Data = packet.get_type() {
1863 socket.insert_into_buffer(packet);
1864
1865 if sender.send(Ok(())).is_err() {
1867 debug!(
1868 "dropped socket, killing driver"
1869 );
1870 return Poll::Ready(Ok(()));
1871 }
1872 }
1873
1874 if Self::send_reply(&mut socket, reply, cx)
1875 .is_pending()
1876 {
1877 return Poll::Pending;
1878 }
1879 }
1880 Ok(None) => ready_try_unpin!(socket.send(), cx),
1881 Err(e) => return Poll::Ready(Err(e)),
1882 }
1883 }
1884 }
1885 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1886 Poll::Pending => {
1887 let next_timeout = socket.congestion_timeout * 2;
1888
1889 mem::drop(socket);
1890
1891 return self.check_timeout(cx, next_timeout);
1892 }
1893 }
1894 }
1895 }
1896}
1897
1898impl Drop for UtpSocket {
1899 fn drop(&mut self) {
1900 let _ = self.close();
1901 }
1902}
1903
1904const MTU: usize = 1500;
1905
1906pub struct BufferedUtpStream {
1909 stream: BufReader<UtpStream>,
1910}
1911
1912impl BufferedUtpStream {
1913 pub fn new(stream: UtpStream) -> Self {
1915 Self {
1916 stream: BufReader::with_capacity(MTU, stream),
1917 }
1918 }
1919
1920 fn get_stream(self: Pin<&mut Self>) -> Pin<&mut BufReader<UtpStream>> {
1921 unsafe { self.map_unchecked_mut(|s| &mut s.stream) }
1922 }
1923
1924 pub fn local_addr(&self) -> Result<SocketAddr> {
1926 Ok(self.stream.get_ref().local_addr())
1927 }
1928
1929 pub fn peer_addr(&self) -> Result<SocketAddr> {
1931 Ok(self.stream.get_ref().peer_addr())
1932 }
1933}
1934
1935impl AsyncRead for BufferedUtpStream {
1936 fn poll_read(
1937 self: Pin<&mut Self>,
1938 cx: &mut Context,
1939 buf: &mut ReadBuf,
1940 ) -> Poll<Result<()>> {
1941 self.get_stream().get_pin_mut().poll_read(cx, buf)
1943 }
1944}
1945
1946impl AsyncWrite for BufferedUtpStream {
1947 fn poll_write(
1948 self: Pin<&mut Self>,
1949 cx: &mut Context,
1950 buf: &[u8],
1951 ) -> Poll<Result<usize>> {
1952 self.get_stream().poll_write(cx, buf)
1953 }
1954
1955 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
1956 self.get_stream().poll_flush(cx)
1957 }
1958
1959 fn poll_shutdown(
1960 self: Pin<&mut Self>,
1961 cx: &mut Context,
1962 ) -> Poll<Result<()>> {
1963 self.get_stream().poll_shutdown(cx)
1964 }
1965}
1966
1967#[cfg(test)]
1968mod test {
1969 use std::env;
1970 use std::io::ErrorKind;
1971 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
1972
1973 use std::sync::atomic::Ordering;
1974
1975 use super::*;
1976 use crate::socket::{SocketState, UtpSocket, BUF_SIZE};
1977 use crate::time::now_microseconds;
1978
1979 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1980 use tokio::task;
1981 use tokio::time::interval;
1982
1983 use tracing::debug_span;
1984 use tracing_futures::Instrument;
1985 use tracing_subscriber::FmtSubscriber;
1986
1987 macro_rules! iotry {
1988 ($e:expr) => {
1989 match $e.await {
1990 Ok(e) => e,
1991 Err(e) => panic!("{:?}", e),
1992 }
1993 };
1994 }
1995
1996 fn init_logger() {
1997 if let Some(level) = env::var("RUST_LOG").ok().map(|x| x.parse().ok()) {
1998 let subscriber =
1999 FmtSubscriber::builder().with_max_level(level).finish();
2000
2001 let _ = tracing::subscriber::set_global_default(subscriber);
2002 }
2003 }
2004
2005 fn next_test_port() -> u16 {
2006 use std::sync::atomic::AtomicUsize;
2007 static NEXT_OFFSET: AtomicUsize = AtomicUsize::new(0);
2008 const BASE_PORT: u16 = 9600;
2009 BASE_PORT + NEXT_OFFSET.fetch_add(1, Ordering::Relaxed) as u16
2010 }
2011
2012 fn next_test_ip4() -> SocketAddr {
2013 ("127.0.0.1".parse::<Ipv4Addr>().unwrap(), next_test_port()).into()
2014 }
2015
2016 fn next_test_ip6() -> SocketAddr {
2017 ("::1".parse::<Ipv6Addr>().unwrap(), next_test_port()).into()
2018 }
2019
2020 async fn stream_accept(server_addr: SocketAddr) -> UtpStream {
2021 let (stream, driver) = UtpSocketRef::bind(server_addr)
2022 .await
2023 .expect("failed to bind")
2024 .accept()
2025 .await
2026 .expect("failed to accept");
2027
2028 task::spawn(driver.instrument(debug_span!("stream_driver")));
2029
2030 stream
2031 }
2032
2033 async fn stream_connect(local: SocketAddr, peer: SocketAddr) -> UtpStream {
2034 let socket = UtpSocketRef::bind(local).await.expect("failed to bind");
2035 let (stream, driver) =
2036 socket.connect(peer).await.expect("failed to connect");
2037
2038 task::spawn(driver.instrument(debug_span!("stream_driver")));
2039
2040 stream
2041 }
2042
2043 #[tokio::test]
2044 async fn stream_fast_resend_active() {
2045 init_logger();
2046 let server_addr = next_test_ip4();
2047 let client_addr = next_test_ip4();
2048 const DATA: u8 = 2;
2049 const LEN: usize = 345;
2050
2051 let socket =
2052 UtpSocketRef::bind(server_addr).await.expect("bind failed");
2053
2054 let handle = task::spawn(async {
2055 let buf = [DATA; LEN];
2056 let (mut stream, driver) =
2057 socket.accept().await.expect("accept failed");
2058
2059 task::spawn(driver);
2060
2061 stream.write_all(&buf).await.expect("write failed");
2062 stream.shutdown().await.expect("shutdown failed");
2063 });
2064
2065 let (mut stream, driver) = UtpSocketRef::bind(client_addr)
2066 .await
2067 .expect("bind failed")
2068 .connect(server_addr)
2069 .await
2070 .expect("connect failed");
2071
2072 {
2073 let mut lock = stream.socket.lock().await;
2074 let mut buf = [0u8; LEN];
2075
2076 lock.recv_from(&mut buf).await.expect("read failed");
2078 }
2079
2080 task::spawn(driver);
2081
2082 stream.shutdown().await.expect("close failed");
2083
2084 handle.await.expect("task failure");
2085 }
2086
2087 #[tokio::test]
2088 async fn stream_connect_disconnect() {
2089 init_logger();
2090 let server_addr = next_test_ip4();
2091 let client_addr = next_test_ip4();
2092
2093 let handle = task::spawn(async move {
2094 let mut stream = stream_accept(server_addr).await;
2095
2096 stream.shutdown().await.expect("failed to close");
2097 });
2098
2099 let mut stream = stream_connect(client_addr, server_addr).await;
2100
2101 stream.shutdown().await.expect("failed to close connection");
2102
2103 handle.await.expect("task failure");
2104 }
2105
2106 #[tokio::test]
2107 #[ignore]
2108 async fn stream_packet_split() {
2109 init_logger();
2110
2111 let server_addr = next_test_ip4();
2112 let client_addr = next_test_ip4();
2113 const LEN: usize = 2000;
2114 const DATA: u8 = 1;
2115
2116 let handle = task::spawn(async move {
2117 let mut stream = stream_accept(server_addr)
2118 .instrument(debug_span!("server"))
2119 .await;
2120
2121 let mut buf = [0u8; LEN];
2122
2123 stream
2124 .read_exact(&mut buf)
2125 .instrument(debug_span!("server_read_exact"))
2126 .await
2127 .expect("read failed");
2128
2129 for b in &buf[..] {
2130 assert_eq!(*b, DATA, "data was altered");
2131 }
2132
2133 stream
2134 .shutdown()
2135 .instrument(debug_span!("server_shutdown"))
2136 .await
2137 .expect("flush failed");
2138 });
2139
2140 let mut stream = stream_connect(client_addr, server_addr)
2141 .instrument(debug_span!("client"))
2142 .await;
2143 let buf = [DATA; LEN];
2144
2145 stream
2146 .write_all(&buf)
2147 .instrument(debug_span!("client_write_all"))
2148 .await
2149 .expect("write failed");
2150
2151 stream
2152 .shutdown()
2153 .instrument(debug_span!("client_shutdown"))
2154 .await
2155 .expect("close failed");
2156
2157 handle.await.expect("task failure")
2158 }
2159
2160 #[tokio::test]
2161 #[ignore]
2162 async fn stream_closed_write() {
2163 init_logger();
2164 let server_addr = next_test_ip4();
2165 let client_addr = next_test_ip4();
2166
2167 const LEN: usize = 1240;
2168 const DATA: u8 = 12;
2169
2170 let handle = task::spawn(async move {
2171 let mut stream = stream_accept(server_addr)
2172 .instrument(debug_span!("server"))
2173 .await;
2174 let mut buf = [0u8; LEN];
2175
2176 stream
2177 .read_exact(&mut buf)
2178 .instrument(debug_span!("server_read_exact"))
2179 .await
2180 .expect("read failed");
2181
2182 stream
2183 .shutdown()
2184 .instrument(debug_span!("server_shutdown"))
2185 .await
2186 .expect("shutdown failed");
2187
2188 stream
2189 .read_exact(&mut buf)
2190 .instrument(debug_span!("server_closed_read"))
2191 .await
2192 .expect_err("read on closed stream");
2193 });
2194
2195 let mut stream = stream_connect(client_addr, server_addr)
2196 .instrument(debug_span!("client"))
2197 .await;
2198 let buf = [DATA; LEN];
2199
2200 stream
2201 .write_all(&buf)
2202 .instrument(debug_span!("client_write_all"))
2203 .await
2204 .expect("write failed");
2205 stream
2206 .shutdown()
2207 .instrument(debug_span!("client_shutdown"))
2208 .await
2209 .expect("shutdown failed");
2210
2211 stream
2212 .write_all(&buf)
2213 .instrument(debug_span!("client_closed_write"))
2214 .await
2215 .expect_err("wrote on closed stream");
2216
2217 handle.await.expect("execution failure");
2218 }
2219
2220 #[tokio::test]
2221 async fn stream_fast_resend_idle() {
2222 init_logger();
2223 let server = next_test_ip4();
2224 let client = next_test_ip4();
2225
2226 let handle = task::spawn(async move {
2227 let mut stream = stream_accept(server).await;
2228
2229 let mut timer = interval(Duration::from_secs(3));
2230
2231 timer.tick().await;
2232
2233 stream.shutdown().await.expect("close failed");
2234 });
2235
2236 let mut stream = stream_connect(client, server).await;
2237 let mut timer = interval(Duration::from_secs(4));
2238
2239 timer.tick().await;
2240
2241 stream.shutdown().await.expect("close failed");
2242
2243 handle.await.expect("task failed");
2244 }
2245
2246 #[tokio::test]
2247 async fn stream_clean_close() {
2248 init_logger();
2249 let server_addr = next_test_ip4();
2250 let client_addr = next_test_ip4();
2251
2252 const DATA: u8 = 1;
2253 const LEN: usize = 1024;
2254
2255 let handle = task::spawn(async move {
2256 let mut stream = stream_accept(server_addr)
2257 .instrument(debug_span!("stream_accept"))
2258 .await;
2259 let buf = [DATA; LEN];
2260
2261 stream
2262 .write_all(&buf)
2263 .instrument(debug_span!("server_write"))
2264 .await
2265 .expect("write failed");
2266
2267 stream
2268 .shutdown()
2269 .instrument(debug_span!("server_shutdown"))
2270 .await
2271 .expect("shutdown failed");
2272 });
2273
2274 let mut socket = stream_connect(client_addr, server_addr)
2275 .instrument(debug_span!("stream_connect"))
2276 .await;
2277 let mut buf = [0u8; LEN];
2278
2279 socket
2280 .read_exact(&mut buf)
2281 .instrument(debug_span!("client_read"))
2282 .await
2283 .expect("read failed");
2284
2285 socket
2286 .shutdown()
2287 .instrument(debug_span!("client_shutdown"))
2288 .await
2289 .expect("shutdown failed");
2290
2291 handle.await.expect("task panic");
2292 }
2293
2294 #[tokio::test]
2295 async fn stream_connect_timeout() {
2296 init_logger();
2297 let server_addr = next_test_ip4();
2298 let client_addr = next_test_ip4();
2299
2300 let socket =
2301 UtpSocketRef::bind(client_addr).await.expect("bind failed");
2302
2303 socket.0.lock().await.congestion_timeout = 100;
2304
2305 assert!(
2306 socket.connect(server_addr).await.is_err(),
2307 "connected to void"
2308 );
2309 }
2310
2311 #[tokio::test]
2312 async fn stream_read_timeout() {
2313 init_logger();
2314 let server_addr = next_test_ip4();
2315 let client_addr = next_test_ip4();
2316
2317 let handle = task::spawn(async move {
2318 let sock =
2319 UtpSocketRef::bind(server_addr).await.expect("bind failed");
2320
2321 let _ = sock.accept().await.expect("accept failed");
2323 });
2324
2325 let mut socket = stream_connect(client_addr, server_addr).await;
2326 let mut buf = [0u8; 1024];
2327
2328 socket.socket.lock().await.congestion_timeout = 100;
2329
2330 socket
2331 .read_exact(&mut buf)
2332 .await
2333 .expect_err("read from non responding peer");
2334
2335 handle.await.expect("task panic");
2336 }
2337
2338 #[tokio::test]
2339 async fn stream_write_timeout() {
2340 init_logger();
2341
2342 let (server, client) = (next_test_ip4(), next_test_ip4());
2343 const DATA: u8 = 45;
2344 const LEN: usize = 123;
2345
2346 let handle = task::spawn(async move {
2347 let mut stream = stream_accept(server).await;
2348 let buf = [DATA; LEN];
2349
2350 stream.socket.lock().await.congestion_timeout = 100;
2351
2352 stream
2353 .write_all(&buf)
2354 .await
2355 .expect("packets weren't buffered");
2356 stream
2357 .flush()
2358 .await
2359 .expect_err("flush succeeded without ack");
2360 });
2361
2362 let sock = UtpSocketRef::bind(client).await.expect("bind failed");
2363 let _ = sock.connect(server).await.expect("connect failed");
2364
2365 handle.await.expect("execution failure");
2366 }
2367
2368 #[tokio::test]
2369 #[ignore]
2370 async fn stream_flush_then_send() {
2371 init_logger();
2372 let server_addr = next_test_ip4();
2373 let client_addr = next_test_ip4();
2374
2375 const LEN: usize = 1240;
2376 const DATA: u8 = 25;
2377
2378 let handle = task::spawn(async move {
2379 let mut stream = stream_accept(server_addr).await;
2380 let mut buf = [0u8; 2 * LEN];
2381
2382 stream.read_exact(&mut buf).await.expect("failed to read");
2383
2384 for b in buf.iter() {
2385 assert_eq!(*b, DATA, "data corrupted");
2386 }
2387
2388 stream.flush().await.expect("flush failed");
2389 stream.shutdown().await.expect("shutdown failed");
2390 });
2391
2392 let mut stream = stream_connect(client_addr, server_addr).await;
2393 let buf = [DATA; LEN];
2394
2395 stream.write_all(&buf).await.expect("write failed");
2396
2397 stream.flush().await.expect("flush failed");
2398
2399 stream.write_all(&buf).await.expect("write failed");
2400 stream.shutdown().await.expect("shutdown failed");
2401
2402 handle.await.expect("task failure");
2403 }
2404
2405 #[tokio::test]
2406 async fn test_socket_ipv4() {
2407 let server_addr = next_test_ip4();
2408
2409 let handle = task::spawn(async move {
2410 let mut server = iotry!(UtpSocket::bind(server_addr));
2411 assert_eq!(server.state, SocketState::New);
2412
2413 let mut buf = [0u8; BUF_SIZE];
2414 match server.recv_from(&mut buf).await {
2415 e => println!("{:?}", e),
2416 }
2417 assert_eq!(
2420 server.receiver_connection_id,
2421 server.sender_connection_id + 1
2422 );
2423
2424 assert_eq!(server.state, SocketState::Closed);
2425 drop(server);
2426 });
2427
2428 let mut client = iotry!(UtpSocket::connect(server_addr));
2429 assert_eq!(client.state, SocketState::Connected);
2430 assert_eq!(
2433 client.sender_connection_id,
2434 client.receiver_connection_id + 1
2435 );
2436 assert_eq!(
2437 client.connected_to,
2438 server_addr.to_socket_addrs().unwrap().next().unwrap()
2439 );
2440 iotry!(client.close());
2441
2442 handle.await.expect("task failure");
2443 }
2444
2445 #[ignore]
2446 #[tokio::test]
2447 async fn test_socket_ipv6() {
2448 let server_addr = next_test_ip6();
2449
2450 let mut server = iotry!(UtpSocket::bind(server_addr));
2451 assert_eq!(server.state, SocketState::New);
2452
2453 task::spawn(async move {
2454 let mut client = iotry!(UtpSocket::connect(server_addr));
2455 assert_eq!(client.state, SocketState::Connected);
2456 assert_eq!(
2459 client.sender_connection_id,
2460 client.receiver_connection_id + 1
2461 );
2462 assert_eq!(
2463 client.connected_to,
2464 server_addr.to_socket_addrs().unwrap().next().unwrap()
2465 );
2466 iotry!(client.close());
2467 drop(client);
2468 });
2469
2470 let mut buf = [0u8; BUF_SIZE];
2471 match server.recv_from(&mut buf).await {
2472 e => println!("{:?}", e),
2473 }
2474 assert_eq!(
2477 server.receiver_connection_id,
2478 server.sender_connection_id + 1
2479 );
2480
2481 assert_eq!(server.state, SocketState::Closed);
2482 drop(server);
2483 }
2484
2485 #[tokio::test]
2486 async fn test_recvfrom_on_closed_socket() {
2487 let server_addr = next_test_ip4();
2488
2489 let mut server = iotry!(UtpSocket::bind(server_addr));
2490 assert_eq!(server.state, SocketState::New);
2491
2492 let handle = task::spawn(async move {
2493 let mut client = iotry!(UtpSocket::connect(server_addr));
2494 assert_eq!(client.state, SocketState::Connected);
2495 assert!(client.close().await.is_ok());
2496 });
2497
2498 let mut buf = [0u8; BUF_SIZE];
2501 let _resp = server.recv_from(&mut buf).await;
2502 assert_eq!(server.state, SocketState::Closed);
2503
2504 match server.recv_from(&mut buf).await {
2507 Ok((0, _src)) => {}
2508 e => panic!("Expected Ok(0), got {:?}", e),
2509 }
2510 assert_eq!(server.state, SocketState::Closed);
2511
2512 handle.await.expect("task failure");
2513 }
2514
2515 #[tokio::test]
2516 async fn test_sendto_on_closed_socket() {
2517 init_logger();
2518 let server_addr = next_test_ip4();
2519
2520 let mut server = iotry!(UtpSocket::bind(server_addr));
2521 assert_eq!(server.state, SocketState::New);
2522
2523 let handle = task::spawn(async move {
2524 let mut client = iotry!(UtpSocket::connect(server_addr));
2525 assert_eq!(client.state, SocketState::Connected);
2526 iotry!(client.close());
2527 });
2528
2529 let mut buf = [0u8; BUF_SIZE];
2531 let (_read, _src) = iotry!(server.recv_from(&mut buf));
2532 assert_eq!(server.state, SocketState::Closed);
2533
2534 match server.send_to(&buf).await {
2536 Err(ref e) if e.kind() == ErrorKind::NotConnected => (),
2537 v => panic!("expected {:?}, got {:?}", ErrorKind::NotConnected, v),
2538 }
2539
2540 handle.await.expect("task failure");
2541 }
2542
2543 #[tokio::test]
2544 async fn test_acks_on_socket() {
2545 use tokio::sync::mpsc::channel;
2546
2547 init_logger();
2548 let server_addr = next_test_ip4();
2549 let (tx, mut rx) = channel(1);
2550
2551 let mut server = iotry!(UtpSocket::bind(server_addr));
2552
2553 let handle = task::spawn(async move {
2554 let mut buf = [0u8; BUF_SIZE];
2556 let mut buf = ReadBuf::new(&mut buf);
2557 let _resp = server.recv(&mut buf).await.unwrap();
2558 tx.send(server.seq_nr).await.expect("channel closed");
2559
2560 let mut buf = [0; 1500];
2562 iotry!(server.recv_from(&mut buf));
2563 });
2564
2565 let mut client = iotry!(UtpSocket::connect(server_addr));
2566 assert_eq!(client.state, SocketState::Connected);
2567 let sender_seq_nr = rx.recv().await.expect("channel closed");
2568 let ack_nr = client.ack_nr;
2569 assert_eq!(ack_nr, sender_seq_nr);
2570 assert!(client.close().await.is_ok());
2571
2572 assert_eq!(client.ack_nr, ack_nr);
2576 drop(client);
2577
2578 handle.await.expect("task failure");
2579 }
2580
2581 #[tokio::test]
2582 async fn test_handle_packet() {
2583 let initial_connection_id: u16 = rand::random();
2585 let sender_connection_id = initial_connection_id + 1;
2586 let (server_addr, client_addr) = (
2587 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2588 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2589 );
2590 let mut socket = iotry!(UtpSocket::bind(server_addr));
2591
2592 let mut packet = Packet::new();
2593 packet.set_wnd_size(BUF_SIZE as u32);
2594 packet.set_type(PacketType::Syn);
2595 packet.set_connection_id(initial_connection_id);
2596
2597 let response = socket.handle_packet(&packet, client_addr);
2599 assert!(response.is_ok());
2600 let response = response.unwrap();
2601 assert!(response.is_some());
2602
2603 let response = response.unwrap();
2605 assert_eq!(response.get_type(), PacketType::State);
2606
2607 assert_eq!(response.connection_id(), packet.connection_id());
2609
2610 assert_eq!(response.ack_nr(), packet.seq_nr());
2612
2613 assert!(response.payload().is_empty());
2615 let old_packet = packet;
2621 let old_response = response;
2622
2623 let mut packet = Packet::new();
2624 packet.set_type(PacketType::Data);
2625 packet.set_connection_id(sender_connection_id);
2626 packet.set_seq_nr(old_packet.seq_nr() + 1);
2627 packet.set_ack_nr(old_response.seq_nr());
2628
2629 let response = socket.handle_packet(&packet, client_addr);
2630 assert!(response.is_ok());
2631 let response = response.unwrap();
2632 assert!(response.is_some());
2633
2634 let response = response.unwrap();
2635 assert_eq!(response.get_type(), PacketType::State);
2636
2637 assert_eq!(response.connection_id(), initial_connection_id);
2642 assert_eq!(response.connection_id(), packet.connection_id() - 1);
2643
2644 assert_eq!(response.ack_nr(), packet.seq_nr());
2646
2647 assert!(response.payload().is_empty());
2649 assert_eq!(response.seq_nr(), old_response.seq_nr());
2650 let old_packet = packet;
2654 let old_response = response;
2655
2656 let mut packet = Packet::new();
2657 packet.set_type(PacketType::Fin);
2658 packet.set_connection_id(sender_connection_id);
2659 packet.set_seq_nr(old_packet.seq_nr() + 1);
2660 packet.set_ack_nr(old_response.seq_nr());
2661
2662 let response = socket.handle_packet(&packet, client_addr);
2663 assert!(response.is_ok());
2664 let response = response.unwrap();
2665 assert!(response.is_some());
2666
2667 let response = response.unwrap();
2668
2669 assert_eq!(response.get_type(), PacketType::State);
2670
2671 assert_eq!(packet.seq_nr(), old_packet.seq_nr() + 1);
2673
2674 assert_eq!(response.seq_nr(), old_response.seq_nr());
2676
2677 assert_eq!(response.ack_nr(), packet.seq_nr());
2679 }
2680
2681 #[tokio::test]
2682 async fn test_response_to_keepalive_ack() {
2683 let initial_connection_id: u16 = rand::random();
2685 let (server_addr, client_addr) = (
2686 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2687 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2688 );
2689 let mut socket = iotry!(UtpSocket::bind(server_addr));
2690
2691 let mut packet = Packet::new();
2693 packet.set_wnd_size(BUF_SIZE as u32);
2694 packet.set_type(PacketType::Syn);
2695 packet.set_connection_id(initial_connection_id);
2696
2697 let response = socket.handle_packet(&packet, client_addr);
2698 assert!(response.is_ok());
2699 let response = response.unwrap();
2700 assert!(response.is_some());
2701 let response = response.unwrap();
2702 assert_eq!(response.get_type(), PacketType::State);
2703
2704 let old_packet = packet;
2705 let old_response = response;
2706
2707 let mut packet = Packet::new();
2709 packet.set_wnd_size(BUF_SIZE as u32);
2710 packet.set_type(PacketType::State);
2711 packet.set_connection_id(initial_connection_id);
2712 packet.set_seq_nr(old_packet.seq_nr() + 1);
2713 packet.set_ack_nr(old_response.seq_nr());
2714
2715 let response = socket.handle_packet(&packet, client_addr);
2716 assert!(response.is_ok());
2717 let response = response.unwrap();
2718 assert!(response.is_none());
2719
2720 let response = socket.handle_packet(&packet, client_addr);
2722 assert!(response.is_ok());
2723 let response = response.unwrap();
2724 assert!(response.is_none());
2725
2726 socket.state = SocketState::Closed;
2728 }
2729
2730 #[tokio::test]
2731 async fn test_response_to_wrong_connection_id() {
2732 let initial_connection_id: u16 = rand::random();
2734 let (server_addr, client_addr) = (
2735 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2736 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2737 );
2738 let mut socket = iotry!(UtpSocket::bind(server_addr));
2739
2740 let mut packet = Packet::new();
2742 packet.set_wnd_size(BUF_SIZE as u32);
2743 packet.set_type(PacketType::Syn);
2744 packet.set_connection_id(initial_connection_id);
2745
2746 let response = socket.handle_packet(&packet, client_addr);
2747 assert!(response.is_ok());
2748 let response = response.unwrap();
2749 assert!(response.is_some());
2750 assert_eq!(response.unwrap().get_type(), PacketType::State);
2751
2752 let new_connection_id = initial_connection_id.wrapping_mul(2);
2754
2755 let mut packet = Packet::new();
2756 packet.set_wnd_size(BUF_SIZE as u32);
2757 packet.set_type(PacketType::State);
2758 packet.set_connection_id(new_connection_id);
2759
2760 let response = socket.handle_packet(&packet, client_addr);
2761 assert!(response.is_ok());
2762 let response = response.unwrap();
2763 assert!(response.is_some());
2764
2765 let response = response.unwrap();
2766 assert_eq!(response.get_type(), PacketType::Reset);
2767 assert_eq!(response.ack_nr(), packet.seq_nr());
2768
2769 socket.state = SocketState::Closed;
2771 }
2772
2773 #[tokio::test]
2774 async fn test_unordered_packets() {
2775 let initial_connection_id: u16 = rand::random();
2777 let (server_addr, client_addr) = (
2778 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2779 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2780 );
2781 let mut socket = iotry!(UtpSocket::bind(server_addr));
2782
2783 let mut packet = Packet::new();
2785 packet.set_wnd_size(BUF_SIZE as u32);
2786 packet.set_type(PacketType::Syn);
2787 packet.set_connection_id(initial_connection_id);
2788
2789 let response = socket.handle_packet(&packet, client_addr);
2790 assert!(response.is_ok());
2791 let response = response.unwrap();
2792 assert!(response.is_some());
2793 let response = response.unwrap();
2794 assert_eq!(response.get_type(), PacketType::State);
2795
2796 let old_packet = packet;
2797 let old_response = response;
2798
2799 let mut window: Vec<Packet> = Vec::new();
2800
2801 let mut packet = Packet::with_payload(&[1, 2, 3]);
2803 packet.set_wnd_size(BUF_SIZE as u32);
2804 packet.set_connection_id(initial_connection_id);
2805 packet.set_seq_nr(old_packet.seq_nr() + 1);
2806 packet.set_ack_nr(old_response.seq_nr());
2807 window.push(packet);
2808
2809 let mut packet = Packet::with_payload(&[4, 5, 6]);
2810 packet.set_wnd_size(BUF_SIZE as u32);
2811 packet.set_connection_id(initial_connection_id);
2812 packet.set_seq_nr(old_packet.seq_nr() + 2);
2813 packet.set_ack_nr(old_response.seq_nr());
2814 window.push(packet);
2815
2816 let response = socket.handle_packet(&window[1], client_addr);
2818 assert!(response.is_ok());
2819 let response = response.unwrap();
2820 assert!(response.is_some());
2821 let response = response.unwrap();
2822 assert!(response.ack_nr() != window[1].seq_nr());
2823
2824 let response = socket.handle_packet(&window[0], client_addr);
2825 assert!(response.is_ok());
2826 let response = response.unwrap();
2827 assert!(response.is_some());
2828
2829 socket.state = SocketState::Closed;
2831 }
2832
2833 #[tokio::test]
2834 async fn test_response_to_triple_ack() {
2835 let server_addr = next_test_ip4();
2836 let mut server = iotry!(UtpSocket::bind(server_addr));
2837
2838 const LEN: usize = 1024;
2840 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2841 let d = data.clone();
2842 assert_eq!(LEN, data.len());
2843
2844 let handle = task::spawn(async move {
2845 let mut client = iotry!(UtpSocket::connect(server_addr));
2846 iotry!(client.send_to(&d[..]));
2847 iotry!(client.close());
2848 });
2849
2850 let mut buf = [0; BUF_SIZE];
2851 let mut buf = ReadBuf::new(&mut buf);
2852 iotry!(server.recv(&mut buf));
2854
2855 let data_packet =
2857 match server.socket.recv_from(buf.initialized_mut()).await {
2858 Ok((_, _src)) => Packet::try_from(buf.filled()).unwrap(),
2859 Err(e) => panic!("{}", e),
2860 };
2861 assert_eq!(data_packet.get_type(), PacketType::Data);
2862 assert_eq!(&data_packet.payload(), &data.as_slice());
2863 assert_eq!(data_packet.payload().len(), data.len());
2864
2865 let mut packet = Packet::new();
2867 packet.set_wnd_size(BUF_SIZE as u32);
2868 packet.set_type(PacketType::State);
2869 packet.set_seq_nr(server.seq_nr);
2870 packet.set_ack_nr(data_packet.seq_nr() - 1);
2871 packet.set_connection_id(server.sender_connection_id);
2872
2873 for _ in 0..3usize {
2874 iotry!(server.socket.send_to(packet.as_ref(), server.connected_to));
2875 }
2876
2877 let client_addr = server.connected_to;
2879
2880 let mut buf = [0; BUF_SIZE];
2881
2882 match server.socket.recv_from(&mut buf).await {
2883 Ok((0, _)) => panic!("Received 0 bytes from socket"),
2884 Ok((read, _src)) => {
2885 let packet = Packet::try_from(&buf[..read]).unwrap();
2886 assert_eq!(packet.get_type(), PacketType::Data);
2887 assert_eq!(packet.seq_nr(), data_packet.seq_nr());
2888 assert_eq!(packet.payload(), data_packet.payload());
2889 let response = server.handle_packet(&packet, client_addr);
2890 assert!(response.is_ok());
2891 let response = response.unwrap();
2892 assert!(response.is_some());
2893 let response = response.unwrap();
2894 iotry!(server
2895 .socket
2896 .send_to(response.as_ref(), server.connected_to));
2897 }
2898 Err(e) => panic!("{}", e),
2899 }
2900
2901 let mut buf = [0; 1500];
2903 iotry!(server.recv_from(&mut buf));
2904
2905 handle.await.expect("task failure");
2906 }
2907
2908 #[ignore]
2909 #[tokio::test]
2910 async fn test_socket_timeout_request() {
2911 let (server_addr, client_addr) = (
2912 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2913 next_test_ip4().to_socket_addrs().unwrap().next().unwrap(),
2914 );
2915
2916 let client = iotry!(UtpSocket::bind(client_addr));
2917 let mut server = iotry!(UtpSocket::bind(server_addr));
2918 const LEN: usize = 512;
2919 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
2920 let d = data.clone();
2921
2922 assert_eq!(server.state, SocketState::New);
2923 assert_eq!(client.state, SocketState::New);
2924
2925 assert_eq!(
2928 client.sender_connection_id,
2929 client.receiver_connection_id + 1
2930 );
2931
2932 let handle = task::spawn(async move {
2933 let mut client = iotry!(UtpSocket::connect(server_addr));
2934 assert_eq!(client.state, SocketState::Connected);
2935 assert_eq!(client.connected_to, server_addr);
2936 iotry!(client.send_to(&d[..]));
2937 drop(client);
2938 });
2939
2940 let mut buf = [0u8; BUF_SIZE];
2941 let mut buf = ReadBuf::new(&mut buf);
2942 server.recv(&mut buf).await.unwrap();
2943 assert_eq!(
2946 server.receiver_connection_id,
2947 server.sender_connection_id + 1
2948 );
2949
2950 assert_eq!(server.state, SocketState::Connected);
2951
2952 let mut buf = [0; 1500];
2956 iotry!(server.socket.recv_from(&mut buf));
2957
2958 server.congestion_timeout = 50;
2960
2961 loop {
2963 match server.recv_from(&mut buf).await {
2964 Ok((0, _)) => continue,
2965 Ok(_) => break,
2966 Err(e) => panic!("{}", e),
2967 }
2968 }
2969
2970 drop(server);
2971
2972 handle.await.expect("task failure");
2973 }
2974
2975 #[tokio::test]
2976 async fn test_sorted_buffer_insertion() {
2977 let server_addr = next_test_ip4();
2978 let mut socket = iotry!(UtpSocket::bind(server_addr));
2979
2980 let mut packet = Packet::new();
2981 packet.set_seq_nr(1);
2982
2983 assert!(socket.incoming_buffer.is_empty());
2984
2985 socket.insert_into_buffer(packet.clone());
2986 assert_eq!(socket.incoming_buffer.len(), 1);
2987
2988 packet.set_seq_nr(2);
2989 packet.set_timestamp(128.into());
2990
2991 socket.insert_into_buffer(packet.clone());
2992 assert_eq!(socket.incoming_buffer.len(), 2);
2993 assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
2994 assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
2995
2996 packet.set_seq_nr(3);
2997 packet.set_timestamp(256.into());
2998
2999 socket.insert_into_buffer(packet.clone());
3000 assert_eq!(socket.incoming_buffer.len(), 3);
3001 assert_eq!(socket.incoming_buffer[2].seq_nr(), 3);
3002 assert_eq!(socket.incoming_buffer[2].timestamp(), 256.into());
3003
3004 packet.set_seq_nr(2);
3006 packet.set_timestamp(456.into());
3007
3008 socket.insert_into_buffer(packet);
3009 assert_eq!(socket.incoming_buffer.len(), 3);
3010 assert_eq!(socket.incoming_buffer[1].seq_nr(), 2);
3011 assert_eq!(socket.incoming_buffer[1].timestamp(), 128.into());
3012 }
3013
3014 #[tokio::test]
3015 async fn test_duplicate_packet_handling() {
3016 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3017
3018 let client = iotry!(UtpSocket::bind(client_addr));
3019 let mut server = iotry!(UtpSocket::bind(server_addr));
3020
3021 assert_eq!(server.state, SocketState::New);
3022 assert_eq!(client.state, SocketState::New);
3023
3024 assert_eq!(
3027 client.sender_connection_id,
3028 client.receiver_connection_id + 1
3029 );
3030
3031 let handle = task::spawn(async move {
3032 let mut client = iotry!(UtpSocket::connect(server_addr));
3033 assert_eq!(client.state, SocketState::Connected);
3034
3035 let mut packet = Packet::with_payload(&[1, 2, 3]);
3036 packet.set_wnd_size(BUF_SIZE as u32);
3037 packet.set_connection_id(client.sender_connection_id);
3038 packet.set_seq_nr(client.seq_nr);
3039 packet.set_ack_nr(client.ack_nr);
3040
3041 for _ in 0..2usize {
3043 packet.set_timestamp(now_microseconds());
3044 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3045 }
3046 client.seq_nr += 1;
3047
3048 for _ in 0..1usize {
3050 let mut buf = [0; BUF_SIZE];
3051 iotry!(client.socket.recv_from(&mut buf));
3052 }
3053
3054 iotry!(client.close());
3055 });
3056 let mut buf = [0u8; BUF_SIZE];
3057 let mut buf = ReadBuf::new(&mut buf);
3058 iotry!(server.recv(&mut buf));
3059 assert_eq!(
3062 server.receiver_connection_id,
3063 server.sender_connection_id + 1
3064 );
3065
3066 assert_eq!(server.state, SocketState::Connected);
3067
3068 let expected: Vec<u8> = vec![1, 2, 3];
3069 let mut received: Vec<u8> = vec![];
3070 loop {
3071 match server.recv_from(buf.initialized_mut()).await {
3072 Ok((0, _src)) => break,
3073 Ok((_, _src)) => received.extend(buf.filled().to_vec()),
3074 Err(e) => panic!("{:?}", e),
3075 }
3076 }
3077 assert_eq!(received.len(), expected.len());
3078 assert_eq!(received, expected);
3079
3080 handle.await.expect("task failure");
3081 }
3082
3083 #[tokio::test]
3084 async fn test_correct_packet_loss() {
3085 init_logger();
3086 let server_addr = next_test_ip4();
3087
3088 let mut server = iotry!(UtpSocket::bind(server_addr));
3089 const LEN: usize = 1024 * 10;
3090 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3091 let to_send = data.clone();
3092
3093 let handle = task::spawn(
3094 async move {
3095 let mut client = iotry!(UtpSocket::connect(server_addr));
3096
3097 let chunks = to_send[..].chunks(BUF_SIZE);
3099 let dst = client.connected_to;
3100 for (index, chunk) in chunks.enumerate() {
3101 let mut packet = Packet::with_payload(chunk);
3102 packet.set_seq_nr(client.seq_nr);
3103 packet.set_ack_nr(client.ack_nr);
3104 packet.set_connection_id(client.sender_connection_id);
3105 packet.set_timestamp(now_microseconds());
3106
3107 if index % 2 == 0 {
3108 iotry!(client.socket.send_to(packet.as_ref(), dst));
3109 }
3110
3111 client.curr_window += packet.len() as u32;
3112 client.send_window.push(packet);
3113 client.seq_nr += 1;
3114 }
3115
3116 iotry!(client.close());
3117 }
3118 .instrument(debug_span!("sender")),
3119 );
3120
3121 let mut buf = [0; BUF_SIZE];
3122 let mut received: Vec<u8> = vec![];
3123 loop {
3124 match server.recv_from(&mut buf).await {
3125 Ok((0, _src)) => break,
3126 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3127 Err(e) => panic!("{}", e),
3128 }
3129 }
3130 assert_eq!(
3131 received.len(),
3132 data.len(),
3133 "wrong number of bytes received"
3134 );
3135 assert_eq!(received, data, "incorrect data received");
3136 handle.await.expect("task failure");
3137 }
3138
3139 #[tokio::test]
3140 async fn test_tolerance_to_small_buffers() {
3141 let server_addr = next_test_ip4();
3142 let mut server = iotry!(UtpSocket::bind(server_addr));
3143 const LEN: usize = 1024;
3144 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3145 let to_send = data.clone();
3146
3147 let handle = task::spawn(async move {
3148 let mut client = iotry!(UtpSocket::connect(server_addr));
3149 iotry!(client.send_to(&to_send[..]));
3150 iotry!(client.close());
3151 });
3152
3153 let mut read = Vec::new();
3154 while server.state != SocketState::Closed {
3155 let mut small_buffer = [0; 512];
3156 match server.recv_from(&mut small_buffer).await {
3157 Ok((0, _src)) => break,
3158 Ok((len, _src)) => read.extend(small_buffer[..len].to_vec()),
3159 Err(e) => panic!("{}", e),
3160 }
3161 }
3162
3163 assert_eq!(read.len(), data.len());
3164 assert_eq!(read, data);
3165 handle.await.expect("task failure");
3166 }
3167
3168 #[tokio::test]
3169 async fn test_sequence_number_rollover() {
3170 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3171
3172 let mut server = iotry!(UtpSocket::bind(server_addr));
3173
3174 const LEN: usize = BUF_SIZE * 4;
3175 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3176 let to_send = data.clone();
3177
3178 let mut client = iotry!(UtpSocket::bind(client_addr));
3179
3180 client.seq_nr =
3182 ::std::u16::MAX - (to_send.len() / (BUF_SIZE * 2)) as u16;
3183
3184 let handle = task::spawn(async move {
3185 let mut client = iotry!(UtpSocket::connect(server_addr));
3186 iotry!(client.send_to(&to_send[..]));
3188 assert!(client.seq_nr < 50);
3190 iotry!(client.close());
3192 });
3193
3194 let mut buf = [0; BUF_SIZE];
3195 let mut received: Vec<u8> = vec![];
3196 loop {
3197 match server.recv_from(&mut buf).await {
3198 Ok((0, _src)) => break,
3199 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3200 Err(e) => panic!("{}", e),
3201 }
3202 }
3203 assert_eq!(received.len(), data.len());
3204 assert_eq!(received, data);
3205 handle.await.expect("task failure");
3206 }
3207
3208 #[tokio::test]
3209 async fn test_drop_unused_socket() {
3210 let server_addr = next_test_ip4();
3211 let server = iotry!(UtpSocket::bind(server_addr));
3212
3213 drop(server);
3215 }
3216
3217 #[tokio::test]
3218 async fn test_invalid_packet_on_connect() {
3219 use tokio::net::UdpSocket;
3220 let server_addr = next_test_ip4();
3221 let server = iotry!(UdpSocket::bind(server_addr));
3222
3223 let handle = task::spawn(async move {
3224 match UtpSocket::connect(server_addr).await {
3225 Err(ref e) if e.kind() == ErrorKind::Other => (), Err(e) => panic!("Expected ErrorKind::Other, got {:?}", e),
3227 Ok(_) => panic!("Expected Err, got Ok"),
3228 }
3229 });
3230
3231 let mut buf = [0; BUF_SIZE];
3232 match server.recv_from(&mut buf).await {
3233 Ok((_len, client_addr)) => {
3234 iotry!(server.send_to(&[], client_addr));
3235 }
3236 _ => panic!(),
3237 }
3238
3239 handle.await.expect("task failure");
3240 }
3241
3242 #[tokio::test]
3243 async fn test_receive_unexpected_reply_type_on_connect() {
3244 use tokio::net::UdpSocket;
3245 let server_addr = next_test_ip4();
3246 let server = iotry!(UdpSocket::bind(server_addr));
3247
3248 let mut buf = [0; BUF_SIZE];
3249 let mut packet = Packet::new();
3250 packet.set_type(PacketType::Data);
3251
3252 let handle = task::spawn(async move {
3253 match server.recv_from(&mut buf).await {
3254 Ok((_len, client_addr)) => {
3255 iotry!(server.send_to(packet.as_ref(), client_addr));
3256 }
3257 _ => panic!(),
3258 }
3259 });
3260
3261 match UtpSocket::connect(server_addr).await {
3262 Err(ref e) if e.kind() == ErrorKind::ConnectionRefused => (), Err(e) => {
3264 panic!("Expected ErrorKind::ConnectionRefused, got {:?}", e)
3265 }
3266 Ok(_) => panic!("Expected Err, got Ok"),
3267 }
3268
3269 handle.await.expect("task failure");
3270 }
3271
3272 #[tokio::test]
3273 async fn test_receiving_syn_on_established_connection() {
3274 let server_addr = next_test_ip4();
3276 let mut server = iotry!(UtpSocket::bind(server_addr));
3277
3278 let handle = task::spawn(async move {
3279 let mut buf = [0; BUF_SIZE];
3280 loop {
3281 match server.recv_from(&mut buf).await {
3282 Ok((0, _src)) => break,
3283 Ok(_) => (),
3284 Err(e) => panic!("{:?}", e),
3285 }
3286 }
3287 });
3288
3289 let mut client = iotry!(UtpSocket::connect(server_addr));
3290 let mut packet = Packet::new();
3291 packet.set_wnd_size(BUF_SIZE as u32);
3292 packet.set_type(PacketType::Syn);
3293 packet.set_connection_id(client.sender_connection_id);
3294 packet.set_seq_nr(client.seq_nr);
3295 packet.set_ack_nr(client.ack_nr);
3296 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3297 let mut buf = [0; BUF_SIZE];
3298
3299 let (len, _) = client
3300 .socket
3301 .recv_from(&mut buf)
3302 .await
3303 .expect("recv failed");
3304 let reply = Packet::try_from(&buf[..len]).ok().unwrap();
3305 assert_eq!(reply.get_type(), PacketType::Reset);
3306 iotry!(client.close());
3307 handle.await.expect("task failure");
3308 }
3309
3310 #[tokio::test]
3311 #[ignore]
3312 async fn test_receiving_reset_on_established_connection() {
3313 let server_addr = next_test_ip4();
3315 let mut server = iotry!(UtpSocket::bind(server_addr));
3316
3317 let handle = task::spawn(async move {
3318 let client = iotry!(UtpSocket::connect(server_addr));
3319 let mut packet = Packet::new();
3320 packet.set_wnd_size(BUF_SIZE as u32);
3321 packet.set_type(PacketType::Reset);
3322 packet.set_connection_id(client.sender_connection_id);
3323 packet.set_seq_nr(client.seq_nr);
3324 packet.set_ack_nr(client.ack_nr);
3325 iotry!(client.socket.send_to(packet.as_ref(), server_addr));
3326
3327 let mut buf = [0; BUF_SIZE];
3328
3329 client
3330 .socket
3331 .recv_from(&mut buf)
3332 .await
3333 .expect("recv failed");
3334 });
3335
3336 let mut buf = [0; BUF_SIZE];
3337
3338 loop {
3339 match server.recv_from(&mut buf).await {
3340 Ok((0, _src)) => break,
3341 Ok(_) => (),
3342 Err(ref e) if e.kind() == ErrorKind::ConnectionReset => {
3343 handle.await.expect("task failure");
3344 return;
3345 }
3346 Err(e) => panic!("{:?}", e),
3347 }
3348 }
3349 panic!("Should have received Reset");
3350 }
3351
3352 #[cfg(not(windows))]
3353 #[tokio::test]
3354 async fn test_premature_fin() {
3355 let (server_addr, client_addr) = (next_test_ip4(), next_test_ip4());
3356 let mut server = iotry!(UtpSocket::bind(server_addr));
3357
3358 const LEN: usize = BUF_SIZE * 4;
3359 let data = (0..LEN).map(|idx| idx as u8).collect::<Vec<u8>>();
3360 let to_send = data.clone();
3361
3362 task::spawn(async move {
3363 let mut client = iotry!(UtpSocket::connect(server_addr));
3364 iotry!(client.send_to(&to_send[..]));
3365 iotry!(client.close());
3366 });
3367
3368 let mut buf = [0; BUF_SIZE];
3369 let mut buf = ReadBuf::new(&mut buf);
3370
3371 iotry!(server.recv(&mut buf));
3373
3374 let mut packet = Packet::new();
3376 packet.set_connection_id(server.sender_connection_id);
3377 packet.set_seq_nr(server.seq_nr);
3378 packet.set_ack_nr(server.ack_nr);
3379 packet.set_timestamp(now_microseconds());
3380 packet.set_type(PacketType::Fin);
3381 iotry!(server.socket.send_to(packet.as_ref(), client_addr));
3382
3383 let mut received: Vec<u8> = vec![];
3385 loop {
3386 let mut buf = [0; BUF_SIZE];
3387
3388 match server.recv_from(&mut buf).await {
3389 Ok((0, _src)) => break,
3390 Ok((len, _src)) => received.extend(buf[..len].to_vec()),
3391 Err(e) => panic!("{}", e),
3392 }
3393 }
3394 assert_eq!(received.len(), data.len());
3395 assert_eq!(received, data);
3396 }
3397
3398 #[tokio::test]
3399 async fn test_base_delay_calculation() {
3400 let minute_in_microseconds = 60 * 10i64.pow(6);
3401 let samples = vec![
3402 (0, 10),
3403 (1, 8),
3404 (2, 12),
3405 (3, 7),
3406 (minute_in_microseconds + 1, 11),
3407 (minute_in_microseconds + 2, 19),
3408 (minute_in_microseconds + 3, 9),
3409 ];
3410 let addr = next_test_ip4();
3411 let mut socket = UtpSocket::bind(addr).await.unwrap();
3412
3413 for (timestamp, delay) in samples {
3414 socket.update_base_delay(
3415 delay.into(),
3416 ((timestamp + delay) as u32).into(),
3417 );
3418 }
3419
3420 let expected = vec![7i64, 9i64]
3421 .into_iter()
3422 .map(Into::into)
3423 .collect::<Vec<_>>();
3424 let actual = socket.base_delays.iter().cloned().collect::<Vec<_>>();
3425 assert_eq!(expected, actual);
3426 assert_eq!(
3427 socket.min_base_delay(),
3428 expected.iter().min().cloned().unwrap_or_default()
3429 );
3430 }
3431
3432 #[tokio::test]
3433 async fn test_local_addr() {
3434 let addr = next_test_ip4();
3435 let addr = addr.to_socket_addrs().unwrap().next().unwrap();
3436 let socket = UtpSocket::bind(addr).await.unwrap();
3437
3438 assert!(socket.local_addr().is_ok());
3439 assert_eq!(socket.local_addr().unwrap(), addr);
3440 }
3441
3442 #[tokio::test]
3443 async fn test_peer_addr() {
3444 use std::sync::mpsc::channel;
3445 let addr = next_test_ip4();
3446 let server_addr = addr.to_socket_addrs().unwrap().next().unwrap();
3447 let mut server = UtpSocket::bind(server_addr).await.unwrap();
3448 let (tx, rx) = channel();
3449
3450 assert!(server.peer_addr().is_err());
3453
3454 task::spawn(async move {
3455 let mut client = iotry!(UtpSocket::connect(server_addr));
3456 let mut buf = [0; 1024];
3457 tx.send(client.local_addr())
3458 .expect("failed to send on channel");
3459 iotry!(client.recv_from(&mut buf));
3460
3461 let mut buf = [0; 1024];
3463 let mut buf = ReadBuf::new(&mut buf);
3464 iotry!(server.recv(&mut buf));
3465
3466 assert!(server.peer_addr().is_ok());
3468 let client_addr = rx.recv().unwrap().unwrap();
3471 assert_eq!(server.peer_addr().unwrap().port(), client_addr.port());
3472
3473 iotry!(server.close());
3475
3476 assert!(server.peer_addr().is_err());
3478 });
3479 }
3480
3481 #[ignore]
3483 #[tokio::test]
3484 async fn test_connection_loss_data() {
3485 let server_addr = next_test_ip4();
3486 let mut server = iotry!(UtpSocket::bind(server_addr));
3487 server.congestion_timeout = 1;
3489 let attempts = server.max_retransmission_retries;
3490
3491 let mut client = iotry!(UtpSocket::connect(server_addr));
3492 iotry!(client.send_to(&[0]));
3493 client.state = SocketState::Closed;
3495
3496 let mut buf = [0; BUF_SIZE];
3497 iotry!(client.socket.recv_from(&mut buf));
3498
3499 for _ in 0..attempts {
3500 match client.socket.recv_from(&mut buf).await {
3501 Ok((len, _src)) => assert_eq!(
3502 Packet::try_from(&buf[..len]).unwrap().get_type(),
3503 PacketType::Data
3504 ),
3505 Err(e) => panic!("{}", e),
3506 }
3507 }
3508
3509 let mut buf = [0; BUF_SIZE];
3511 iotry!(server.recv_from(&mut buf));
3512
3513 iotry!(server.send_to(&[0]));
3514
3515 let mut buf = [0; BUF_SIZE];
3518 let mut buf = ReadBuf::new(&mut buf);
3519 match server.recv(&mut buf).await {
3520 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3521 x => panic!("Expected Err(TimedOut), got {:?}", x),
3522 }
3523 }
3524
3525 #[ignore]
3527 #[tokio::test]
3528 async fn test_connection_loss_fin() {
3529 let server_addr = next_test_ip4();
3530 let mut server = iotry!(UtpSocket::bind(server_addr));
3531 server.congestion_timeout = 1;
3533 let attempts = server.max_retransmission_retries;
3534
3535 let mut client = iotry!(UtpSocket::connect(server_addr));
3536 iotry!(client.send_to(&[0]));
3537 client.state = SocketState::Closed;
3539 let mut buf = [0; BUF_SIZE];
3540 iotry!(client.socket.recv_from(&mut buf));
3541 for _ in 0..attempts {
3542 match client.socket.recv_from(&mut buf).await {
3543 Ok((len, _src)) => assert_eq!(
3544 Packet::try_from(&buf[..len]).unwrap().get_type(),
3545 PacketType::Fin
3546 ),
3547 Err(e) => panic!("{}", e),
3548 }
3549 }
3550
3551 let mut buf = [0; BUF_SIZE];
3553 iotry!(server.recv_from(&mut buf));
3554
3555 match server.close().await {
3557 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3558 x => panic!("Expected Err(TimedOut), got {:?}", x),
3559 }
3560 }
3561
3562 #[ignore]
3564 #[tokio::test]
3565 async fn test_connection_loss_waiting() {
3566 let server_addr = next_test_ip4();
3567 let mut server = iotry!(UtpSocket::bind(server_addr));
3568 server.congestion_timeout = 1;
3570 let attempts = server.max_retransmission_retries;
3571
3572 let mut client = iotry!(UtpSocket::connect(server_addr));
3573 iotry!(client.send_to(&[0]));
3574 client.state = SocketState::Closed;
3576 let seq_nr = client.seq_nr;
3577 let mut buf = [0; BUF_SIZE];
3578 for _ in 0..(3 * attempts) {
3579 match client.socket.recv_from(&mut buf).await {
3580 Ok((len, _src)) => {
3581 let packet = Packet::try_from(&buf[..len]).unwrap();
3582 assert_eq!(packet.get_type(), PacketType::State);
3583 assert_eq!(packet.ack_nr(), seq_nr - 1);
3584 }
3585 Err(e) => panic!("{}", e),
3586 }
3587 }
3588
3589 let mut buf = [0; BUF_SIZE];
3591 iotry!(server.recv_from(&mut buf));
3592
3593 let mut buf = [0; BUF_SIZE];
3595 match server.recv_from(&mut buf).await {
3596 Err(ref e) if e.kind() == ErrorKind::TimedOut => (),
3597 x => panic!("Expected Err(TimedOut), got {:?}", x),
3598 }
3599 }
3600}