poem/middleware/
force_https.rs

1use std::{borrow::Cow, sync::Arc};
2
3use http::{Uri, header, uri::Scheme};
4
5use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result, web::Redirect};
6
7type FilterFn = Arc<dyn Fn(&Request) -> bool + Send + Sync>;
8
9/// Middleware which forces redirects to a HTTPS uri.
10#[derive(Default)]
11pub struct ForceHttps {
12    https_port: Option<u16>,
13    filter_fn: Option<FilterFn>,
14}
15
16impl ForceHttps {
17    /// Create a new `ForceHttps` middleware.
18    pub fn new() -> Self {
19        Default::default()
20    }
21
22    /// Specify the https port.
23    #[must_use]
24    pub fn https_port(self, port: u16) -> Self {
25        Self {
26            https_port: Some(port),
27            ..self
28        }
29    }
30
31    /// Uses a closure to determine if a request should be redirect.
32    #[must_use]
33    pub fn filter(self, predicate: impl Fn(&Request) -> bool + Send + Sync + 'static) -> Self {
34        Self {
35            filter_fn: Some(Arc::new(predicate)),
36            ..self
37        }
38    }
39}
40
41impl<E> Middleware<E> for ForceHttps
42where
43    E: Endpoint,
44{
45    type Output = ForceHttpsEndpoint<E>;
46
47    fn transform(&self, ep: E) -> Self::Output {
48        ForceHttpsEndpoint {
49            inner: ep,
50            https_port: self.https_port,
51            filter_fn: self.filter_fn.clone(),
52        }
53    }
54}
55
56/// Endpoint for the ForceHttps middleware.
57pub struct ForceHttpsEndpoint<E> {
58    inner: E,
59    https_port: Option<u16>,
60    filter_fn: Option<FilterFn>,
61}
62
63impl<E> Endpoint for ForceHttpsEndpoint<E>
64where
65    E: Endpoint,
66{
67    type Output = Response;
68
69    async fn call(&self, mut req: Request) -> Result<Self::Output> {
70        if req.scheme() == &Scheme::HTTP && self.filter_fn.as_ref().map(|f| f(&req)).unwrap_or(true)
71        {
72            if let Some(host) = req.headers().get(header::HOST).cloned() {
73                if let Ok(host) = host.to_str() {
74                    let host = redirect_host(host, self.https_port);
75                    let uri_parts = std::mem::take(req.uri_mut()).into_parts();
76                    let mut builder = Uri::builder().scheme(Scheme::HTTPS).authority(&*host);
77                    if let Some(path_and_query) = uri_parts.path_and_query {
78                        builder = builder.path_and_query(path_and_query);
79                    }
80                    if let Ok(uri) = builder.build() {
81                        return Ok(Redirect::permanent(uri).into_response());
82                    }
83                }
84            }
85        }
86
87        self.inner.call(req).await.map(IntoResponse::into_response)
88    }
89}
90
91fn redirect_host(host: &str, https_port: Option<u16>) -> Cow<'_, str> {
92    match (host.split_once(':'), https_port) {
93        (Some((host, _)), Some(port)) => Cow::Owned(format!("{host}:{port}")),
94        (None, Some(port)) => Cow::Owned(format!("{host}:{port}")),
95        (_, None) => Cow::Borrowed(host),
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_redirect_host() {
105        assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
106        assert_eq!(
107            redirect_host("example.com:5678", Some(1234)),
108            "example.com:1234"
109        );
110        assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
111        assert_eq!(redirect_host("example.com:1234", None), "example.com:1234");
112        assert_eq!(redirect_host("example.com", None), "example.com");
113    }
114}