product_os_router/
dual_protocol.rs1use std::prelude::v1::*;
6
7use std::fmt::{self, Debug, Formatter};
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use product_os_http_body::Bytes;
13use product_os_http::header::{HOST, LOCATION, UPGRADE};
14use product_os_http::uri::{Authority, Scheme};
15use product_os_http::{HeaderValue, Request, Response, StatusCode, Uri};
16use product_os_http_body::{Either, Empty};
17use pin_project::pin_project;
18use tower_layer::Layer;
19use tower_service::Service as TowerService;
20
21
22
23#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum Protocol {
28 Tls,
30 Plain,
32}
33
34
35#[derive(Clone, Copy, Debug)]
50pub struct UpgradeHttpLayer;
51
52impl<Service> Layer<Service> for UpgradeHttpLayer {
53 type Service = UpgradeHttp<Service>;
54
55 fn layer(&self, inner: Service) -> Self::Service {
56 UpgradeHttp::new(inner)
57 }
58}
59
60
61#[derive(Clone, Debug)]
70pub struct UpgradeHttp<Service> {
71 service: Service,
73}
74
75impl<Service> UpgradeHttp<Service> {
76 pub const fn new(service: Service) -> Self {
78 Self { service }
79 }
80
81 pub fn into_inner(self) -> Service {
84 self.service
85 }
86
87 pub const fn get_ref(&self) -> &Service {
89 &self.service
90 }
91
92 pub fn get_mut(&mut self) -> &mut Service {
94 &mut self.service
95 }
96}
97
98impl<Service, RequestBody, ResponseBody> TowerService<Request<RequestBody>> for UpgradeHttp<Service>
99where
100 Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
101{
102 type Response = Response<Either<ResponseBody, Empty<Bytes>>>;
103 type Error = Service::Error;
104 type Future = UpgradeHttpFuture<Service, Request<RequestBody>>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.service.poll_ready(cx)
108 }
109
110 fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
111 let protocol = req.extensions().get::<Protocol>().copied();
112
113 match protocol {
114 None => {
115 let response = Response::builder()
119 .status(StatusCode::INTERNAL_SERVER_ERROR)
120 .body(Empty::new())
121 .expect("building empty 500 response");
122 UpgradeHttpFuture::new_upgrade(response)
123 }
124 Some(Protocol::Tls) => UpgradeHttpFuture::new_service(self.service.call(req)),
125 Some(Protocol::Plain) => {
126 let response = Response::builder();
127
128 let response = if let Some((authority, scheme)) =
129 extract_authority(&req).and_then(|authority| {
130 let uri = req.uri();
131
132 if uri.scheme_str() == Some("ws")
137 || req.headers().get(UPGRADE)
138 == Some(&HeaderValue::from_static("websocket"))
139 {
140 Some((
141 authority,
142 Scheme::try_from("wss").expect("ASCII string is valid"),
143 ))
144 }
145 else if uri.scheme() == Some(&Scheme::HTTP) || uri.scheme_str().is_none()
147 {
148 Some((authority, Scheme::HTTPS))
149 }
150 else {
152 None
153 }
154 }) {
155 let mut uri = Uri::builder().scheme(scheme).authority(authority);
157
158 if let Some(path_and_query) = req.uri().path_and_query() {
159 uri = uri.path_and_query(path_and_query.clone());
160 }
161
162 let uri = uri.build().expect("invalid path and query");
163
164 response
165 .status(StatusCode::MOVED_PERMANENTLY)
166 .header(LOCATION, uri.to_string())
167 } else {
168 response.status(StatusCode::BAD_REQUEST)
171 }
172 .body(Empty::new())
173 .expect("invalid header or body");
174
175 UpgradeHttpFuture::new_upgrade(response)
176 }
177 }
178 }
179}
180
181#[pin_project]
183pub struct UpgradeHttpFuture<Service, Request>(#[pin] FutureServe<Service, Request>)
184where
185 Service: TowerService<Request>;
186
187#[derive(Debug)]
189#[pin_project(project = UpgradeHttpFutureProj)]
190enum FutureServe<Service, Request>
191where
192 Service: TowerService<Request>,
193{
194 Service(#[pin] Service::Future),
197 Upgrade(Option<Response<Empty<Bytes>>>),
200}
201
202impl<Service, Request> Debug for UpgradeHttpFuture<Service, Request>
204where
205 Service: TowerService<Request>,
206 FutureServe<Service, Request>: Debug,
207{
208 fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
209 formatter
210 .debug_tuple("UpgradeHttpFuture")
211 .field(&self.0)
212 .finish()
213 }
214}
215
216impl<Service, Request> UpgradeHttpFuture<Service, Request>
217where
218 Service: TowerService<Request>,
219{
220 const fn new_service(future: Service::Future) -> Self {
223 Self(FutureServe::Service(future))
224 }
225
226 const fn new_upgrade(response: Response<Empty<Bytes>>) -> Self {
229 Self(FutureServe::Upgrade(Some(response)))
230 }
231}
232
233impl<Service, Request, ResponseBody> Future for UpgradeHttpFuture<Service, Request>
234where
235 Service: TowerService<Request, Response = Response<ResponseBody>>,
236{
237 type Output = Result<Response<Either<ResponseBody, Empty<Bytes>>>, Service::Error>;
238
239 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
240 match self.project().0.project() {
241 UpgradeHttpFutureProj::Service(future) => {
242 future.poll(cx).map_ok(|result| result.map(Either::Left))
243 }
244 UpgradeHttpFutureProj::Upgrade(response) => Poll::Ready(Ok(response
245 .take()
246 .expect("polled again after `Poll::Ready`")
247 .map(Either::Right))),
248 }
249 }
250}
251
252fn extract_authority<Body>(request: &Request<Body>) -> Option<Authority> {
254 const X_FORWARDED_HOST: &str = "x-forwarded-host";
256
257 let headers = request.headers();
258
259 headers
260 .get(X_FORWARDED_HOST)
261 .or_else(|| headers.get(HOST))
262 .and_then(|header| header.to_str().ok())
263 .or_else(|| request.uri().host())
264 .and_then(|host| Authority::try_from(host).ok())
265}