Skip to main content

product_os_router/
dual_protocol.rs

1//! HTTP to HTTPS upgrade layer
2//!
3//! This module provides middleware for automatically redirecting HTTP requests to HTTPS.
4
5use std::prelude::v1::*;
6
7use std::fmt::{self, Debug, Formatter};
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use product_os_http_body::Bytes;
13use product_os_http::header::{HOST, LOCATION, UPGRADE};
14use product_os_http::uri::{Authority, Scheme};
15use product_os_http::{HeaderValue, Request, Response, StatusCode, Uri};
16use product_os_http_body::{Either, Empty};
17use pin_project::pin_project;
18use tower_layer::Layer;
19use tower_service::Service as TowerService;
20
21
22
23/// Protocol indicator for HTTP/HTTPS connections
24///
25/// Used to identify whether a connection is encrypted (TLS) or plain (HTTP).
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum Protocol {
28    /// This connection is encrypted with TLS.
29    Tls,
30    /// This connection is unencrypted.
31    Plain,
32}
33
34
35/// Layer for upgrading HTTP connections to HTTPS
36///
37/// Automatically redirects HTTP requests to HTTPS using a 301 redirect.
38/// WebSocket connections (`ws://`) are redirected to `wss://`.
39///
40/// # Examples
41///
42/// ```rust
43/// use product_os_router::UpgradeHttpLayer;
44/// use tower::Layer;
45///
46/// let layer = UpgradeHttpLayer;
47/// // Apply to your service
48/// ```
49#[derive(Clone, Copy, Debug)]
50pub struct UpgradeHttpLayer;
51
52impl<Service> Layer<Service> for UpgradeHttpLayer {
53    type Service = UpgradeHttp<Service>;
54
55    fn layer(&self, inner: Service) -> Self::Service {
56        UpgradeHttp::new(inner)
57    }
58}
59
60
61/// [`Service`](TowerService) upgrading HTTP requests to HTTPS by using a
62/// [301 "Moved Permanently"](https://tools.ietf.org/html/rfc7231#section-6.4.2)
63/// status code.
64///
65/// Note that this [`Service`](TowerService) always redirects with the given
66/// path and query. Depending on how you apply this [`Service`](TowerService) it
67/// will redirect even in the case of a resulting 404 "Not Found" status code at
68/// the destination.
69#[derive(Clone, Debug)]
70pub struct UpgradeHttp<Service> {
71    /// Wrapped user-provided [`Service`](TowerService).
72    service: Service,
73}
74
75impl<Service> UpgradeHttp<Service> {
76    /// Creates a new [`UpgradeHttp`].
77    pub const fn new(service: Service) -> Self {
78        Self { service }
79    }
80
81    /// Consumes the [`UpgradeHttp`], returning the wrapped
82    /// [`Service`](TowerService).
83    pub fn into_inner(self) -> Service {
84        self.service
85    }
86
87    /// Return a reference to the wrapped [`Service`](TowerService).
88    pub const fn get_ref(&self) -> &Service {
89        &self.service
90    }
91
92    /// Return a mutable reference to the wrapped [`Service`](TowerService).
93    pub fn get_mut(&mut self) -> &mut Service {
94        &mut self.service
95    }
96}
97
98impl<Service, RequestBody, ResponseBody> TowerService<Request<RequestBody>> for UpgradeHttp<Service>
99where
100    Service: TowerService<Request<RequestBody>, Response = Response<ResponseBody>>,
101{
102    type Response = Response<Either<ResponseBody, Empty<Bytes>>>;
103    type Error = Service::Error;
104    type Future = UpgradeHttpFuture<Service, Request<RequestBody>>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.service.poll_ready(cx)
108    }
109
110    fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
111        let protocol = req.extensions().get::<Protocol>().copied();
112
113        match protocol {
114            None => {
115                // Protocol extension was not set; return a 500 Internal Server Error
116                // instead of panicking. This can happen if the layer is used without
117                // `DualProtocolService` setting the extension.
118                let response = Response::builder()
119                    .status(StatusCode::INTERNAL_SERVER_ERROR)
120                    .body(Empty::new())
121                    .expect("building empty 500 response");
122                UpgradeHttpFuture::new_upgrade(response)
123            }
124            Some(Protocol::Tls) => UpgradeHttpFuture::new_service(self.service.call(req)),
125            Some(Protocol::Plain) => {
126                let response = Response::builder();
127
128                let response = if let Some((authority, scheme)) =
129                    extract_authority(&req).and_then(|authority| {
130                        let uri = req.uri();
131
132                        // Depending on the scheme we need a different scheme to redirect to.
133
134                        // WebSocket handshakes often don't send a scheme, so we check the "Upgrade"
135                        // header as well.
136                        if uri.scheme_str() == Some("ws")
137                            || req.headers().get(UPGRADE)
138                            == Some(&HeaderValue::from_static("websocket"))
139                        {
140                            Some((
141                                authority,
142                                Scheme::try_from("wss").expect("ASCII string is valid"),
143                            ))
144                        }
145                        // HTTP requests often don't send a scheme.
146                        else if uri.scheme() == Some(&Scheme::HTTP) || uri.scheme_str().is_none()
147                        {
148                            Some((authority, Scheme::HTTPS))
149                        }
150                        // Unknown scheme, abort.
151                        else {
152                            None
153                        }
154                    }) {
155                    // Build URI to redirect to.
156                    let mut uri = Uri::builder().scheme(scheme).authority(authority);
157
158                    if let Some(path_and_query) = req.uri().path_and_query() {
159                        uri = uri.path_and_query(path_and_query.clone());
160                    }
161
162                    let uri = uri.build().expect("invalid path and query");
163
164                    response
165                        .status(StatusCode::MOVED_PERMANENTLY)
166                        .header(LOCATION, uri.to_string())
167                } else {
168                    // If we can't extract the host or have an unknown scheme, tell the client there
169                    // is something wrong with their request.
170                    response.status(StatusCode::BAD_REQUEST)
171                }
172                    .body(Empty::new())
173                    .expect("invalid header or body");
174
175                UpgradeHttpFuture::new_upgrade(response)
176            }
177        }
178    }
179}
180
181/// [`Future`](TowerService::Future) type for [`UpgradeHttp`].
182#[pin_project]
183pub struct UpgradeHttpFuture<Service, Request>(#[pin] FutureServe<Service, Request>)
184where
185    Service: TowerService<Request>;
186
187/// Holds [`Future`] to serve for [`UpgradeHttpFuture`].
188#[derive(Debug)]
189#[pin_project(project = UpgradeHttpFutureProj)]
190enum FutureServe<Service, Request>
191where
192    Service: TowerService<Request>,
193{
194    /// The request was using the HTTPS protocol, so we
195    /// will pass-through the wrapped [`Service`](TowerService).
196    Service(#[pin] Service::Future),
197    /// The request was using the HTTP protocol, so we
198    /// will upgrade the connection.
199    Upgrade(Option<Response<Empty<Bytes>>>),
200}
201
202// Rust can't figure out the correct bounds.
203impl<Service, Request> Debug for UpgradeHttpFuture<Service, Request>
204where
205    Service: TowerService<Request>,
206    FutureServe<Service, Request>: Debug,
207{
208    fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
209        formatter
210            .debug_tuple("UpgradeHttpFuture")
211            .field(&self.0)
212            .finish()
213    }
214}
215
216impl<Service, Request> UpgradeHttpFuture<Service, Request>
217where
218    Service: TowerService<Request>,
219{
220    /// Create a [`UpgradeHttpFuture`] in the [`Service`](FutureServe::Service)
221    /// state.
222    const fn new_service(future: Service::Future) -> Self {
223        Self(FutureServe::Service(future))
224    }
225
226    /// Create a [`UpgradeHttpFuture`] in the [`Upgrade`](FutureServe::Upgrade)
227    /// state.
228    const fn new_upgrade(response: Response<Empty<Bytes>>) -> Self {
229        Self(FutureServe::Upgrade(Some(response)))
230    }
231}
232
233impl<Service, Request, ResponseBody> Future for UpgradeHttpFuture<Service, Request>
234where
235    Service: TowerService<Request, Response = Response<ResponseBody>>,
236{
237    type Output = Result<Response<Either<ResponseBody, Empty<Bytes>>>, Service::Error>;
238
239    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
240        match self.project().0.project() {
241            UpgradeHttpFutureProj::Service(future) => {
242                future.poll(cx).map_ok(|result| result.map(Either::Left))
243            }
244            UpgradeHttpFutureProj::Upgrade(response) => Poll::Ready(Ok(response
245                .take()
246                .expect("polled again after `Poll::Ready`")
247                .map(Either::Right))),
248        }
249    }
250}
251
252/// Extracts the host from a request, converting it to an [`Authority`].
253fn extract_authority<Body>(request: &Request<Body>) -> Option<Authority> {
254    /// `X-Forwarded-Host` header string.
255    const X_FORWARDED_HOST: &str = "x-forwarded-host";
256
257    let headers = request.headers();
258
259    headers
260        .get(X_FORWARDED_HOST)
261        .or_else(|| headers.get(HOST))
262        .and_then(|header| header.to_str().ok())
263        .or_else(|| request.uri().host())
264        .and_then(|host| Authority::try_from(host).ok())
265}