1use crate::{HttpRequest, IsTls, TcpOrTlsIncoming, TcpOrTlsStream};
2use async_http_codec::internal::buffer_write::BufferWrite;
3use async_http_codec::{RequestHead, ResponseHead};
4use async_ws::connection::WsConfig;
5use async_ws::http::{is_upgrade_request, upgrade_response};
6use futures::prelude::*;
7use futures::stream::FusedStream;
8use http::{HeaderMap, Method, Request, Uri, Version};
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub type WsConnection<IO = TcpOrTlsStream> = async_ws::connection::WsConnection<IO>;
14pub type WsMessageKind = async_ws::message::WsMessageKind;
15pub type WsSend<IO = TcpOrTlsStream> = async_ws::connection::WsSend<IO>;
16pub type WsConnectionError = async_ws::connection::WsConnectionError;
17pub type WsMessageReader<IO = TcpOrTlsStream> = async_ws::connection::WsMessageReader<IO>;
18pub type WsMessageWriter<IO = TcpOrTlsStream> = async_ws::connection::WsMessageWriter<IO>;
19
20pub enum HttpOrWs<IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream> {
21 Http(HttpRequest<IO>),
22 Ws(WsUpgradeRequest<IO>),
23}
24
25impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for HttpOrWs<IO> {
26 fn is_tls(&self) -> bool {
27 match self {
28 HttpOrWs::Http(http) => http.is_tls(),
29 HttpOrWs::Ws(ws) => ws.is_tls(),
30 }
31 }
32}
33
34pub struct HttpOrWsIncoming<
35 IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream,
36 T: Stream<Item = HttpRequest<IO>> + Unpin = TcpOrTlsIncoming,
37> {
38 incoming: Option<T>,
39}
40
41impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin>
42 HttpOrWsIncoming<IO, T>
43{
44 pub fn new(http_incoming: T) -> Self {
45 Self {
46 incoming: Some(http_incoming),
47 }
48 }
49}
50
51impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin> Stream
52 for HttpOrWsIncoming<IO, T>
53{
54 type Item = HttpOrWs<IO>;
55
56 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57 let incoming = match &mut self.incoming {
58 None => return Poll::Ready(None),
59 Some(incoming) => incoming,
60 };
61
62 let request = match incoming.poll_next_unpin(cx) {
63 Poll::Pending => return Poll::Pending,
64 Poll::Ready(None) => {
65 drop(self.incoming.take());
66 return Poll::Ready(None);
67 }
68 Poll::Ready(Some(request)) => request,
69 };
70
71 let request = request.into_inner();
72 if !is_upgrade_request(&request) {
73 return Poll::Ready(Some(HttpOrWs::Http(HttpRequest::from_inner(request))));
74 }
75
76 let response = upgrade_response(&request).unwrap();
77 let (request_head, request_body) = request.into_parts();
78 let request_head = RequestHead::from(request_head);
79 let (_, transport) = request_body.into_inner();
80 let response_head = ResponseHead::from(response);
81 Poll::Ready(Some(HttpOrWs::Ws(WsUpgradeRequest {
82 request_head,
83 response_head,
84 transport,
85 })))
86 }
87}
88
89impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = HttpRequest<IO>> + Unpin> FusedStream
90 for HttpOrWsIncoming<IO, T>
91{
92 fn is_terminated(&self) -> bool {
93 self.incoming.is_none()
94 }
95}
96
97pub struct WsUpgradeRequest<IO: AsyncRead + AsyncWrite + Unpin> {
98 pub(crate) request_head: RequestHead<'static>,
99 pub(crate) response_head: ResponseHead<'static>,
100 pub(crate) transport: IO,
101}
102
103impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for WsUpgradeRequest<IO> {
104 fn is_tls(&self) -> bool {
105 self.transport.is_tls()
106 }
107}
108
109impl<IO: AsyncRead + AsyncWrite + Unpin> WsUpgradeRequest<IO> {
110 pub fn into_inner(self) -> Request<IO> {
119 Request::from_parts(self.request_head.into(), self.transport)
120 }
121 pub fn request_headers(&self) -> &HeaderMap {
123 self.request_head.headers()
124 }
125 pub fn uri(&self) -> &Uri {
127 &self.request_head.uri()
128 }
129 pub fn method(&self) -> Method {
131 self.request_head.method().clone()
132 }
133 pub fn version(&self) -> Version {
135 self.request_head.version()
136 }
137 pub fn upgrade(self) -> WsAccept<IO> {
139 WsAccept {
140 response: self.response_head.encode(self.transport),
141 }
142 }
143}
144
145pub struct WsAccept<IO: AsyncRead + AsyncWrite + Unpin> {
146 response: BufferWrite<IO>,
147}
148
149impl<IO: AsyncRead + AsyncWrite + Unpin> Future for WsAccept<IO> {
150 type Output = io::Result<WsConnection<IO>>;
151
152 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 match self.response.poll_unpin(cx) {
154 Poll::Ready(Ok(transport)) => {
155 Poll::Ready(Ok(WsConnection::with_config(transport, WsConfig::server())))
156 }
157 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
158 Poll::Pending => Poll::Pending,
159 }
160 }
161}