async_web_server/
ws.rs

1use crate::{HttpRequest, IsTls, TcpOrTlsIncoming, TcpOrTlsStream};
2use async_http_codec::internal::buffer_write::BufferWrite;
3use async_http_codec::{RequestHead, ResponseHead};
4use async_ws::connection::WsConfig;
5use async_ws::http::{is_upgrade_request, upgrade_response};
6use futures::prelude::*;
7use futures::stream::FusedStream;
8use http::{HeaderMap, Method, Request, Uri, Version};
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub type WsConnection<IO = TcpOrTlsStream> = async_ws::connection::WsConnection<IO>;
14pub type WsMessageKind = async_ws::message::WsMessageKind;
15pub type WsSend<IO = TcpOrTlsStream> = async_ws::connection::WsSend<IO>;
16pub type WsConnectionError = async_ws::connection::WsConnectionError;
17pub type WsMessageReader<IO = TcpOrTlsStream> = async_ws::connection::WsMessageReader<IO>;
18pub type WsMessageWriter<IO = TcpOrTlsStream> = async_ws::connection::WsMessageWriter<IO>;
19
20pub enum HttpOrWs<IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream> {
21    Http(HttpRequest<IO>),
22    Ws(WsUpgradeRequest<IO>),
23}
24
25impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for HttpOrWs<IO> {
26    fn is_tls(&self) -> bool {
27        match self {
28            HttpOrWs::Http(http) => http.is_tls(),
29            HttpOrWs::Ws(ws) => ws.is_tls(),
30        }
31    }
32}
33
34pub struct HttpOrWsIncoming<
35    IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream,
36    T: Stream<Item = HttpRequest<IO>> + Unpin = TcpOrTlsIncoming,
37> {
38    incoming: Option<T>,
39}
40
41impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin>
42    HttpOrWsIncoming<IO, T>
43{
44    pub fn new(http_incoming: T) -> Self {
45        Self {
46            incoming: Some(http_incoming),
47        }
48    }
49}
50
51impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin> Stream
52    for HttpOrWsIncoming<IO, T>
53{
54    type Item = HttpOrWs<IO>;
55
56    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57        let incoming = match &mut self.incoming {
58            None => return Poll::Ready(None),
59            Some(incoming) => incoming,
60        };
61
62        let request = match incoming.poll_next_unpin(cx) {
63            Poll::Pending => return Poll::Pending,
64            Poll::Ready(None) => {
65                drop(self.incoming.take());
66                return Poll::Ready(None);
67            }
68            Poll::Ready(Some(request)) => request,
69        };
70
71        let request = request.into_inner();
72        if !is_upgrade_request(&request) {
73            return Poll::Ready(Some(HttpOrWs::Http(HttpRequest::from_inner(request))));
74        }
75
76        let response = upgrade_response(&request).unwrap();
77        let (request_head, request_body) = request.into_parts();
78        let request_head = RequestHead::from(request_head);
79        let (_, transport) = request_body.into_inner();
80        let response_head = ResponseHead::from(response);
81        Poll::Ready(Some(HttpOrWs::Ws(WsUpgradeRequest {
82            request_head,
83            response_head,
84            transport,
85        })))
86    }
87}
88
89impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin> FusedStream
90    for HttpOrWsIncoming<IO, T>
91{
92    fn is_terminated(&self) -> bool {
93        self.incoming.is_none()
94    }
95}
96
97pub struct WsUpgradeRequest<IO: AsyncRead + AsyncWrite + Unpin> {
98    pub(crate) request_head: RequestHead<'static>,
99    pub(crate) response_head: ResponseHead<'static>,
100    pub(crate) transport: IO,
101}
102
103impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for WsUpgradeRequest<IO> {
104    fn is_tls(&self) -> bool {
105        self.transport.is_tls()
106    }
107}
108
109impl<IO: AsyncRead + AsyncWrite + Unpin> WsUpgradeRequest<IO> {
110    /// Direct access to the request as [http::Request] and underlying transport.
111    /// The transport may be extracted using
112    /// ```no_run
113    /// # use futures::io::Cursor;
114    /// # use async_web_server::WsUpgradeRequest;
115    /// # let request: WsUpgradeRequest<Cursor<&mut [u8]>> = todo!();
116    /// let transport = request.into_inner();
117    /// ```
118    pub fn into_inner(self) -> Request<IO> {
119        Request::from_parts(self.request_head.into(), self.transport)
120    }
121    /// Access the original requests headers as [http::HeaderMap].
122    pub fn request_headers(&self) -> &HeaderMap {
123        self.request_head.headers()
124    }
125    /// Access the original requests URI as [http::Uri].
126    pub fn uri(&self) -> &Uri {
127        &self.request_head.uri()
128    }
129    /// Return the original requests method as [http::Method].
130    pub fn method(&self) -> Method {
131        self.request_head.method().clone()
132    }
133    /// Return the HTTP version as [http::Version].
134    pub fn version(&self) -> Version {
135        self.request_head.version()
136    }
137    /// Upgrade to a websocket connection.
138    pub fn upgrade(self) -> WsAccept<IO> {
139        WsAccept {
140            response: self.response_head.encode(self.transport),
141        }
142    }
143}
144
145pub struct WsAccept<IO: AsyncRead + AsyncWrite + Unpin> {
146    response: BufferWrite<IO>,
147}
148
149impl<IO: AsyncRead + AsyncWrite + Unpin> Future for WsAccept<IO> {
150    type Output = io::Result<WsConnection<IO>>;
151
152    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153        match self.response.poll_unpin(cx) {
154            Poll::Ready(Ok(transport)) => {
155                Poll::Ready(Ok(WsConnection::with_config(transport, WsConfig::server())))
156            }
157            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
158            Poll::Pending => Poll::Pending,
159        }
160    }
161}