1use std::{
2 io::{self, Read, Write},
3 pin::Pin,
4 task::{Context, Poll, ready},
5};
6
7use rama_core::stream::Stream;
8use rama_core::{
9 error::OpaqueError,
10 extensions::{Extensions, ExtensionsMut, ExtensionsRef},
11 futures::{self, SinkExt, StreamExt},
12 telemetry::tracing::{debug, trace},
13};
14use rama_http::io::upgrade;
15
16use crate::{
17 Message, ProtocolError,
18 protocol::{CloseFrame, Role, WebSocket, WebSocketConfig},
19 runtime::{
20 compat::{self, AllowStd, ContextWaker},
21 handshake::without_handshake,
22 },
23};
24
25#[derive(Debug)]
33pub struct AsyncWebSocket<S = upgrade::Upgraded> {
34 inner: WebSocket<AllowStd<S>>,
35 closing: bool,
36 ended: bool,
37 ready: bool,
42}
43
44impl<S> AsyncWebSocket<S> {
45 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
48 where
49 S: Stream + Unpin + ExtensionsMut,
50 {
51 without_handshake(stream, move |allow_std| {
52 WebSocket::from_raw_socket(allow_std, role, config)
53 })
54 .await
55 }
56
57 pub async fn from_partially_read(
60 stream: S,
61 part: Vec<u8>,
62 role: Role,
63 config: Option<WebSocketConfig>,
64 ) -> Self
65 where
66 S: Stream + Unpin + ExtensionsMut,
67 {
68 without_handshake(stream, move |allow_std| {
69 WebSocket::from_partially_read(allow_std, part, role, config)
70 })
71 .await
72 }
73
74 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
75 Self {
76 inner: ws,
77 closing: false,
78 ended: false,
79 ready: true,
80 }
81 }
82
83 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
84 where
85 S: Unpin,
86 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
87 AllowStd<S>: Read + Write,
88 {
89 trace!("AsyncWebSocket.with_context");
90 if let Some((kind, ctx)) = ctx {
91 self.inner.get_mut().set_waker(kind, ctx.waker());
92 }
93 f(&mut self.inner)
94 }
95
96 pub fn into_inner(self) -> S {
98 self.inner.into_inner().into_inner()
99 }
100
101 pub fn get_ref(&self) -> &S
103 where
104 S: Stream + Unpin,
105 {
106 self.inner.get_ref().get_ref()
107 }
108
109 pub fn get_mut(&mut self) -> &mut S
111 where
112 S: Stream + Unpin,
113 {
114 self.inner.get_mut().get_mut()
115 }
116
117 pub fn get_config(&self) -> &WebSocketConfig {
119 self.inner.get_config()
120 }
121
122 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), ProtocolError>
124 where
125 S: Stream + Unpin,
126 {
127 self.send(Message::Close(msg)).await
128 }
129}
130
131impl<S: ExtensionsRef> ExtensionsRef for AsyncWebSocket<S> {
132 fn extensions(&self) -> &Extensions {
133 self.inner.extensions()
134 }
135}
136
137impl<S: ExtensionsMut> ExtensionsMut for AsyncWebSocket<S> {
138 fn extensions_mut(&mut self) -> &mut Extensions {
139 self.inner.extensions_mut()
140 }
141}
142
143impl<S: Stream + Unpin> AsyncWebSocket<S> {
144 #[inline]
145 pub fn send_message(
147 &mut self,
148 msg: Message,
149 ) -> impl Future<Output = Result<(), ProtocolError>> + Send + '_ {
150 self.send(msg)
151 }
152
153 pub async fn recv_message(&mut self) -> Result<Message, ProtocolError> {
154 self.next().await.ok_or_else(|| {
155 ProtocolError::Io(io::Error::new(
156 io::ErrorKind::ConnectionAborted,
157 OpaqueError::from_display(
158 "Connection closed: no messages to be received any longer",
159 ),
160 ))
161 })?
162 }
163}
164
165impl<T> futures::Stream for AsyncWebSocket<T>
166where
167 T: Stream + Unpin,
168{
169 type Item = Result<Message, ProtocolError>;
170
171 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172 trace!("Stream.poll_next");
173
174 if self.ended {
178 return Poll::Ready(None);
179 }
180
181 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
182 trace!("Stream.with_context poll_next -> read()");
183 compat::cvt(s.read())
184 })) {
185 Ok(v) => Poll::Ready(Some(Ok(v))),
186 Err(e) => {
187 self.ended = true;
188 if e.is_connection_error() {
189 Poll::Ready(None)
190 } else {
191 Poll::Ready(Some(Err(e)))
192 }
193 }
194 }
195 }
196}
197
198impl<T> futures::stream::FusedStream for AsyncWebSocket<T>
199where
200 T: Stream + Unpin,
201{
202 fn is_terminated(&self) -> bool {
203 self.ended
204 }
205}
206
207impl<T> futures::Sink<Message> for AsyncWebSocket<T>
208where
209 T: Stream + Unpin,
210{
211 type Error = ProtocolError;
212
213 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214 if self.ready {
215 Poll::Ready(Ok(()))
216 } else {
217 (*self)
219 .with_context(Some((ContextWaker::Write, cx)), |s| compat::cvt(s.flush()))
220 .map(|r| {
221 self.ready = true;
222 r
223 })
224 }
225 }
226
227 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
228 match (*self).with_context(None, |s| s.write(item)) {
229 Ok(()) => {
230 self.ready = true;
231 Ok(())
232 }
233 Err(ProtocolError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
234 self.ready = false;
237 Ok(())
238 }
239 Err(e) => {
240 self.ready = true;
241 debug!("websocket start_send error: {e}");
242 Err(e)
243 }
244 }
245 }
246
247 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 (*self)
249 .with_context(Some((ContextWaker::Write, cx)), |s| compat::cvt(s.flush()))
250 .map(|r| {
251 self.ready = true;
252 match r {
253 Err(err) if err.is_connection_error() => {
254 Ok(())
256 }
257 other => other,
258 }
259 })
260 }
261
262 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263 self.ready = true;
264 let res = if self.closing {
265 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
267 } else {
268 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
269 };
270
271 match res {
272 Ok(()) => Poll::Ready(Ok(())),
273 Err(ProtocolError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
274 trace!("WouldBlock");
275 self.closing = true;
276 Poll::Pending
277 }
278 Err(err) => {
279 if err.is_connection_error() {
280 Poll::Ready(Ok(()))
281 } else {
282 debug!("websocket close error: {}", err);
283 Poll::Ready(Err(err))
284 }
285 }
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use crate::runtime::{AsyncWebSocket, compat::AllowStd};
293 use std::io::{Read, Write};
294
295 fn is_read<T: Read>() {}
296 fn is_write<T: Write>() {}
297 fn is_unpin<T: Unpin>() {}
298
299 #[test]
300 fn web_socket_stream_has_traits() {
301 is_read::<AllowStd<tokio::net::TcpStream>>();
302 is_write::<AllowStd<tokio::net::TcpStream>>();
303 is_unpin::<AsyncWebSocket<tokio::net::TcpStream>>();
304 }
305}