1use std::io::Result;
2
3use std::net::{SocketAddr};
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use tokio::net::UdpSocket;
9use tokio::time::{sleep, Sleep, Instant};
10use tokio::io::{ReadBuf, AsyncRead, AsyncWrite};
11
12use crate::{get_timeout, new_udp_socket};
13
14pub struct UdpStreamLocal {
20 socket: UdpSocket,
21 timeout: Pin<Box<Sleep>>,
22}
23
24impl UdpStreamLocal {
25 #[inline]
27 pub(crate) async fn new(
28 local_addr: SocketAddr,
29 remote_addr: SocketAddr,
30 ) -> std::io::Result<Self> {
31 let socket = new_udp_socket(local_addr)?;
32 socket.connect(remote_addr).await?;
33 Ok(Self {
34 socket,
35 timeout: Box::pin(sleep(get_timeout())),
36 })
37 }
38
39 #[inline]
41 pub fn peer_addr(&self) -> SocketAddr { self.socket.peer_addr().unwrap() }
42
43 #[inline]
45 pub fn local_addr(&self) -> SocketAddr { self.socket.local_addr().unwrap() }
46
47 #[inline]
49 pub const fn inner_socket(&self) -> &UdpSocket { &self.socket }
50}
51
52impl AsyncRead for UdpStreamLocal {
53 fn poll_read(
54 self: Pin<&mut Self>,
55 cx: &mut Context<'_>,
56 buf: &mut ReadBuf<'_>,
57 ) -> Poll<Result<()>> {
58 let this = self.get_mut();
59
60 if let Poll::Ready(result) = this.socket.poll_recv(cx, buf) {
61 this.timeout.as_mut().reset(Instant::now() + get_timeout());
63
64 return match result {
65 Ok(_) => Poll::Ready(Ok(())),
66 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
67 Err(e) => Poll::Ready(Err(e)),
68 };
69 }
70
71 if this.timeout.as_mut().poll(cx).is_ready() {
73 buf.clear();
74 return Poll::Ready(Ok(()));
75 }
76
77 Poll::Pending
78 }
79}
80
81impl AsyncWrite for UdpStreamLocal {
82 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
83 let this = self.get_mut();
84 this.socket.poll_send(cx, buf)
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
88 Poll::Ready(Ok(()))
89 }
90
91 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
92 Poll::Ready(Ok(()))
93 }
94}