poem/middleware/
force_https.rs1use std::{borrow::Cow, sync::Arc};
2
3use http::{Uri, header, uri::Scheme};
4
5use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result, web::Redirect};
6
7type FilterFn = Arc<dyn Fn(&Request) -> bool + Send + Sync>;
8
9#[derive(Default)]
11pub struct ForceHttps {
12 https_port: Option<u16>,
13 filter_fn: Option<FilterFn>,
14}
15
16impl ForceHttps {
17 pub fn new() -> Self {
19 Default::default()
20 }
21
22 #[must_use]
24 pub fn https_port(self, port: u16) -> Self {
25 Self {
26 https_port: Some(port),
27 ..self
28 }
29 }
30
31 #[must_use]
33 pub fn filter(self, predicate: impl Fn(&Request) -> bool + Send + Sync + 'static) -> Self {
34 Self {
35 filter_fn: Some(Arc::new(predicate)),
36 ..self
37 }
38 }
39}
40
41impl<E> Middleware<E> for ForceHttps
42where
43 E: Endpoint,
44{
45 type Output = ForceHttpsEndpoint<E>;
46
47 fn transform(&self, ep: E) -> Self::Output {
48 ForceHttpsEndpoint {
49 inner: ep,
50 https_port: self.https_port,
51 filter_fn: self.filter_fn.clone(),
52 }
53 }
54}
55
56pub struct ForceHttpsEndpoint<E> {
58 inner: E,
59 https_port: Option<u16>,
60 filter_fn: Option<FilterFn>,
61}
62
63impl<E> Endpoint for ForceHttpsEndpoint<E>
64where
65 E: Endpoint,
66{
67 type Output = Response;
68
69 async fn call(&self, mut req: Request) -> Result<Self::Output> {
70 if req.scheme() == &Scheme::HTTP && self.filter_fn.as_ref().map(|f| f(&req)).unwrap_or(true)
71 {
72 if let Some(host) = req.headers().get(header::HOST).cloned() {
73 if let Ok(host) = host.to_str() {
74 let host = redirect_host(host, self.https_port);
75 let uri_parts = std::mem::take(req.uri_mut()).into_parts();
76 let mut builder = Uri::builder().scheme(Scheme::HTTPS).authority(&*host);
77 if let Some(path_and_query) = uri_parts.path_and_query {
78 builder = builder.path_and_query(path_and_query);
79 }
80 if let Ok(uri) = builder.build() {
81 return Ok(Redirect::permanent(uri).into_response());
82 }
83 }
84 }
85 }
86
87 self.inner.call(req).await.map(IntoResponse::into_response)
88 }
89}
90
91fn redirect_host(host: &str, https_port: Option<u16>) -> Cow<'_, str> {
92 match (host.split_once(':'), https_port) {
93 (Some((host, _)), Some(port)) => Cow::Owned(format!("{host}:{port}")),
94 (None, Some(port)) => Cow::Owned(format!("{host}:{port}")),
95 (_, None) => Cow::Borrowed(host),
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn test_redirect_host() {
105 assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
106 assert_eq!(
107 redirect_host("example.com:5678", Some(1234)),
108 "example.com:1234"
109 );
110 assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
111 assert_eq!(redirect_host("example.com:1234", None), "example.com:1234");
112 assert_eq!(redirect_host("example.com", None), "example.com");
113 }
114}