1use async_std::{io, net::ToSocketAddrs, sync::RwLock};
2use futures::{future::BoxFuture, ready, AsyncRead, AsyncWrite, FutureExt};
3use std::{fmt::Debug, io::Result, net::SocketAddr, sync::Arc, task::Poll};
4
5use crate::socket::UtpSocket;
6
7#[derive(Clone, Debug)]
30pub struct UtpStream {
31 socket: Arc<RwLock<UtpSocket>>,
32 futures: Arc<UtpStreamFutures>,
33}
34
35unsafe impl Send for UtpStream {}
36type OptionIoFuture<T> = RwLock<Option<BoxFuture<'static, io::Result<T>>>>;
37
38#[derive(Default)]
39struct UtpStreamFutures {
40 read: OptionIoFuture<(Vec<u8>, usize)>,
41 write: OptionIoFuture<usize>,
42 flush: OptionIoFuture<()>,
43 close: OptionIoFuture<()>,
44}
45
46impl std::fmt::Debug for UtpStreamFutures {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "UtpStreamFutures state")
49 }
50}
51
52impl UtpStream {
53 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UtpStream> {
60 let socket = UtpSocket::bind(addr).await?;
61 Ok(UtpStream {
62 socket: Arc::new(RwLock::new(socket)),
63 futures: UtpStreamFutures::default().into(),
64 })
65 }
66
67 pub async fn connect<A: ToSocketAddrs>(dst: A) -> Result<UtpStream> {
74 let socket = UtpSocket::connect(dst).await?;
76 Ok(UtpStream {
77 socket: Arc::new(RwLock::new(socket)),
78 futures: UtpStreamFutures::default().into(),
79 })
80 }
81
82 pub async fn close(&mut self) -> Result<()> {
87 self.socket.write().await.close().await
88 }
89
90 pub fn local_addr(&self) -> Result<SocketAddr> {
92 self.socket.try_read().unwrap().local_addr()
93 }
94
95 pub fn peer_addr(&self) -> Result<SocketAddr> {
97 self.socket.try_read().unwrap().peer_addr()
98 }
99
100 pub async fn set_max_retransmission_retries(&mut self, n: u32) {
102 self.socket.write().await.max_retransmission_retries = n;
103 }
104}
105
106impl AsyncRead for UtpStream {
107 fn poll_read(
108 self: std::pin::Pin<&mut Self>,
109 cx: &mut std::task::Context<'_>,
110 buf: &mut [u8],
111 ) -> std::task::Poll<Result<usize>> {
112 if self.futures.read.try_read().unwrap().is_none() {
113 let socket = self.socket.clone();
114 let mut vec = Vec::from(&buf[..]);
115 *self.futures.read.try_write().unwrap() = async move {
116 let (nread, _) = socket.write().await.recv_from(&mut vec).await?;
117 Ok((vec, nread))
118 }
119 .boxed()
120 .into();
121 }
122
123 let (bytes, nread) = {
124 let mut fut = self.futures.read.try_write().unwrap();
125 ready!(fut.as_mut().unwrap().poll_unpin(cx))?
126 };
127 buf.copy_from_slice(&bytes);
128 *self.futures.read.try_write().unwrap() = None;
129 Poll::Ready(Ok(nread))
130 }
131}
132
133impl AsyncWrite for UtpStream {
134 fn poll_write(
135 self: std::pin::Pin<&mut Self>,
136 cx: &mut std::task::Context<'_>,
137 buf: &[u8],
138 ) -> std::task::Poll<Result<usize>> {
139 if self.futures.write.try_read().unwrap().is_none() {
140 let socket = self.socket.clone();
141 let vec = Vec::from(buf);
142 *self.futures.write.try_write().unwrap() = async move {
143 let nread = socket.write().await.send_to(&vec).await?;
144 Ok(nread)
145 }
146 .boxed()
147 .into();
148 }
149
150 let nread = {
151 let mut fut = self.futures.write.try_write().unwrap();
152 ready!(fut.as_mut().unwrap().poll_unpin(cx))?
153 };
154 *self.futures.write.try_write().unwrap() = None;
155 Poll::Ready(Ok(nread))
156 }
157
158 fn poll_flush(
159 self: std::pin::Pin<&mut Self>,
160 cx: &mut std::task::Context<'_>,
161 ) -> std::task::Poll<Result<()>> {
162 if self.futures.flush.try_read().unwrap().is_none() {
163 let socket = self.socket.clone();
164 *self.futures.flush.try_write().unwrap() =
165 async move { socket.write().await.flush().await }
166 .boxed()
167 .into();
168 }
169
170 let result = {
171 let mut fut = self.futures.flush.try_write().unwrap();
172 ready!(fut.as_mut().unwrap().poll_unpin(cx))
173 };
174 *self.futures.flush.try_write().unwrap() = None;
175 Poll::Ready(result)
176 }
177
178 fn poll_close(
179 self: std::pin::Pin<&mut Self>,
180 cx: &mut std::task::Context<'_>,
181 ) -> std::task::Poll<Result<()>> {
182 if self.futures.close.try_read().is_none() {
183 let socket = self.socket.clone();
184 *self.futures.close.try_write().unwrap() =
185 async move { socket.write().await.flush().await }
186 .boxed()
187 .into();
188 }
189
190 let result = {
191 let mut fut = self.futures.close.try_write().unwrap();
192 ready!(fut.as_mut().unwrap().poll_unpin(cx))
193 };
194 *self.futures.close.try_write().unwrap() = None;
195 Poll::Ready(result)
196 }
197}
198
199impl From<UtpSocket> for UtpStream {
200 fn from(socket: UtpSocket) -> Self {
201 UtpStream {
202 socket: Arc::new(RwLock::new(socket)),
203 futures: UtpStreamFutures::default().into(),
204 }
205 }
206}