async_std_utp/
stream.rs

1use async_std::{io, net::ToSocketAddrs, sync::RwLock};
2use futures::{future::BoxFuture, ready, AsyncRead, AsyncWrite, FutureExt};
3use std::{fmt::Debug, io::Result, net::SocketAddr, sync::Arc, task::Poll};
4
5use crate::socket::UtpSocket;
6
7/// A structure that represents a uTP (Micro Transport Protocol) stream between a local socket and a
8/// remote socket.
9///
10/// The connection will be closed when the value is dropped (either explicitly or when it goes out
11/// of scope).
12///
13/// The default maximum retransmission retries is 5, which translates to about 16 seconds. It can be
14/// changed by calling `set_max_retransmission_retries`. Notice that the initial congestion timeout
15/// is 500 ms and doubles with each timeout.
16///
17/// # Examples
18///
19/// ```no_run
20/// # fn main() { async_std::task::block_on(async {
21/// use async_std_utp::UtpStream;
22/// use async_std::prelude::*;
23///
24/// let mut stream = UtpStream::bind("127.0.0.1:1234").await.expect("Error binding stream");
25/// let _ = stream.write(&[1]).await;
26/// let _ = stream.read(&mut [0; 1000]).await;
27/// # }); }
28/// ```
29#[derive(Clone, Debug)]
30pub struct UtpStream {
31    socket: Arc<RwLock<UtpSocket>>,
32    futures: Arc<UtpStreamFutures>,
33}
34
35unsafe impl Send for UtpStream {}
36type OptionIoFuture<T> = RwLock<Option<BoxFuture<'static, io::Result<T>>>>;
37
38#[derive(Default)]
39struct UtpStreamFutures {
40    read: OptionIoFuture<(Vec<u8>, usize)>,
41    write: OptionIoFuture<usize>,
42    flush: OptionIoFuture<()>,
43    close: OptionIoFuture<()>,
44}
45
46impl std::fmt::Debug for UtpStreamFutures {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "UtpStreamFutures state")
49    }
50}
51
52impl UtpStream {
53    /// Creates a uTP stream listening on the given address.
54    ///
55    /// The address type can be any implementer of the `ToSocketAddr` trait. See its documentation
56    /// for concrete examples.
57    ///
58    /// If more than one valid address is specified, only the first will be used.
59    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpStream> {
60        let socket = UtpSocket::bind(addr).await?;
61        Ok(UtpStream {
62            socket: Arc::new(RwLock::new(socket)),
63            futures: UtpStreamFutures::default().into(),
64        })
65    }
66
67    /// Opens a uTP connection to a remote host by hostname or IP address.
68    ///
69    /// The address type can be any implementer of the `ToSocketAddr` trait. See its documentation
70    /// for concrete examples.
71    ///
72    /// If more than one valid address is specified, only the first will be used.
73    pub async fn connect<A: ToSocketAddrs>(dst: A) -> Result<UtpStream> {
74        // Port 0 means the operating system gets to choose it
75        let socket = UtpSocket::connect(dst).await?;
76        Ok(UtpStream {
77            socket: Arc::new(RwLock::new(socket)),
78            futures: UtpStreamFutures::default().into(),
79        })
80    }
81
82    /// Gracefully closes connection to peer.
83    ///
84    /// This method allows both peers to receive all packets still in
85    /// flight.
86    pub async fn close(&mut self) -> Result<()> {
87        self.socket.write().await.close().await
88    }
89
90    /// Returns the socket address of the local half of this uTP connection.
91    pub fn local_addr(&self) -> Result<SocketAddr> {
92        self.socket.try_read().unwrap().local_addr()
93    }
94
95    /// Returns the socket address of the remote half of this uTP connection.
96    pub fn peer_addr(&self) -> Result<SocketAddr> {
97        self.socket.try_read().unwrap().peer_addr()
98    }
99
100    /// Changes the maximum number of retransmission retries on the underlying socket.
101    pub async fn set_max_retransmission_retries(&mut self, n: u32) {
102        self.socket.write().await.max_retransmission_retries = n;
103    }
104}
105
106impl AsyncRead for UtpStream {
107    fn poll_read(
108        self: std::pin::Pin<&mut Self>,
109        cx: &mut std::task::Context<'_>,
110        buf: &mut [u8],
111    ) -> std::task::Poll<Result<usize>> {
112        if self.futures.read.try_read().unwrap().is_none() {
113            let socket = self.socket.clone();
114            let mut vec = Vec::from(&buf[..]);
115            *self.futures.read.try_write().unwrap() = async move {
116                let (nread, _) = socket.write().await.recv_from(&mut vec).await?;
117                Ok((vec, nread))
118            }
119            .boxed()
120            .into();
121        }
122
123        let (bytes, nread) = {
124            let mut fut = self.futures.read.try_write().unwrap();
125            ready!(fut.as_mut().unwrap().poll_unpin(cx))?
126        };
127        buf.copy_from_slice(&bytes);
128        *self.futures.read.try_write().unwrap() = None;
129        Poll::Ready(Ok(nread))
130    }
131}
132
133impl AsyncWrite for UtpStream {
134    fn poll_write(
135        self: std::pin::Pin<&mut Self>,
136        cx: &mut std::task::Context<'_>,
137        buf: &[u8],
138    ) -> std::task::Poll<Result<usize>> {
139        if self.futures.write.try_read().unwrap().is_none() {
140            let socket = self.socket.clone();
141            let vec = Vec::from(buf);
142            *self.futures.write.try_write().unwrap() = async move {
143                let nread = socket.write().await.send_to(&vec).await?;
144                Ok(nread)
145            }
146            .boxed()
147            .into();
148        }
149
150        let nread = {
151            let mut fut = self.futures.write.try_write().unwrap();
152            ready!(fut.as_mut().unwrap().poll_unpin(cx))?
153        };
154        *self.futures.write.try_write().unwrap() = None;
155        Poll::Ready(Ok(nread))
156    }
157
158    fn poll_flush(
159        self: std::pin::Pin<&mut Self>,
160        cx: &mut std::task::Context<'_>,
161    ) -> std::task::Poll<Result<()>> {
162        if self.futures.flush.try_read().unwrap().is_none() {
163            let socket = self.socket.clone();
164            *self.futures.flush.try_write().unwrap() =
165                async move { socket.write().await.flush().await }
166                    .boxed()
167                    .into();
168        }
169
170        let result = {
171            let mut fut = self.futures.flush.try_write().unwrap();
172            ready!(fut.as_mut().unwrap().poll_unpin(cx))
173        };
174        *self.futures.flush.try_write().unwrap() = None;
175        Poll::Ready(result)
176    }
177
178    fn poll_close(
179        self: std::pin::Pin<&mut Self>,
180        cx: &mut std::task::Context<'_>,
181    ) -> std::task::Poll<Result<()>> {
182        if self.futures.close.try_read().is_none() {
183            let socket = self.socket.clone();
184            *self.futures.close.try_write().unwrap() =
185                async move { socket.write().await.flush().await }
186                    .boxed()
187                    .into();
188        }
189
190        let result = {
191            let mut fut = self.futures.close.try_write().unwrap();
192            ready!(fut.as_mut().unwrap().poll_unpin(cx))
193        };
194        *self.futures.close.try_write().unwrap() = None;
195        Poll::Ready(result)
196    }
197}
198
199impl From<UtpSocket> for UtpStream {
200    fn from(socket: UtpSocket) -> Self {
201        UtpStream {
202            socket: Arc::new(RwLock::new(socket)),
203            futures: UtpStreamFutures::default().into(),
204        }
205    }
206}