Skip to main content

rune_axum_redirect_https/
layer.rs

1use http::{
2    header::{HeaderValue, HOST, LOCATION},
3    Request, Response, StatusCode,
4};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tower::{Layer, Service};
9
10/// Configuration for the HTTP-to-HTTPS redirect middleware.
11///
12/// Build with [`RedirectHttps::new()`] and chain methods to adjust the redirect
13/// status code or target HTTPS port, then pass to [`RedirectHttpsLayer::new()`].
14///
15/// # Examples
16///
17/// ```rust
18/// use http::StatusCode;
19/// use rune_axum_redirect_https::RedirectHttps;
20///
21/// // HTTP on :8080 → HTTPS on :8443, 301 for legacy compatibility
22/// let config = RedirectHttps::new()
23///     .status(StatusCode::MOVED_PERMANENTLY)
24///     .https_port(8443);
25/// ```
26#[derive(Clone, Debug)]
27pub struct RedirectHttps {
28    status: StatusCode,
29    https_port: Option<u16>,
30}
31
32impl Default for RedirectHttps {
33    fn default() -> Self {
34        Self {
35            status: StatusCode::PERMANENT_REDIRECT,
36            https_port: None,
37        }
38    }
39}
40
41impl RedirectHttps {
42    /// Creates a `RedirectHttps` with defaults: `308 Permanent Redirect`, standard HTTPS port.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the redirect status code.
48    ///
49    /// Defaults to `308 Permanent Redirect`, which preserves the HTTP method.
50    /// Use `301 Moved Permanently` when you need compatibility with older
51    /// clients that do not support 308.
52    ///
53    /// > [!WARNING]
54    /// > `301` converts POST to GET in many browsers and HTTP clients. Prefer
55    /// > `308` unless you have a specific reason to use `301`.
56    pub fn status(mut self, status: StatusCode) -> Self {
57        self.status = status;
58        self
59    }
60
61    /// Sets the HTTPS port in the redirect `Location` URL.
62    ///
63    /// When set, any port in the `Host` header is stripped and replaced with
64    /// this value. Useful when HTTP and HTTPS run on non-standard ports (e.g.
65    /// `8080` → `8443`). When unset the port is omitted from the URL, which
66    /// directs clients to the standard HTTPS port (443).
67    pub fn https_port(mut self, port: u16) -> Self {
68        self.https_port = Some(port);
69        self
70    }
71
72    fn is_http<B>(req: &Request<B>) -> bool {
73        if let Some(proto) = req.headers().get("x-forwarded-proto") {
74            return proto.as_bytes().eq_ignore_ascii_case(b"http");
75        }
76        req.uri().scheme() == Some(&http::uri::Scheme::HTTP)
77    }
78
79    fn location<B>(&self, req: &Request<B>) -> Option<HeaderValue> {
80        let host = req.headers().get(HOST)?.to_str().ok()?;
81
82        let hostname = host
83            .rsplit_once(':')
84            .filter(|(_, port)| port.parse::<u16>().is_ok())
85            .map_or(host, |(h, _)| h);
86
87        let authority = match self.https_port {
88            Some(port) => format!("{hostname}:{port}"),
89            None => hostname.to_owned(),
90        };
91
92        let path_and_query = req
93            .uri()
94            .path_and_query()
95            .map(|pq| pq.as_str())
96            .unwrap_or("/");
97
98        HeaderValue::from_str(&format!("https://{authority}{path_and_query}")).ok()
99    }
100}
101
102/// Tower [`Layer`] that redirects HTTP requests to HTTPS.
103///
104/// Apply with Axum's `.layer()` call. Use [`RedirectHttpsLayer::default()`] for
105/// a `308` redirect on standard ports, or [`RedirectHttpsLayer::new()`] to
106/// supply a custom [`RedirectHttps`] configuration.
107///
108/// # Examples
109///
110/// ```rust,no_run
111/// use axum::{routing::get, Router};
112/// use rune_axum_redirect_https::RedirectHttpsLayer;
113///
114/// let app: Router = Router::new()
115///     .route("/", get(|| async { "ok" }))
116///     .layer(RedirectHttpsLayer::default());
117/// ```
118#[derive(Clone, Debug, Default)]
119pub struct RedirectHttpsLayer {
120    config: RedirectHttps,
121}
122
123impl RedirectHttpsLayer {
124    /// Creates a `RedirectHttpsLayer` from a custom [`RedirectHttps`] configuration.
125    pub fn new(config: RedirectHttps) -> Self {
126        Self { config }
127    }
128}
129
130impl<S> Layer<S> for RedirectHttpsLayer {
131    type Service = RedirectHttpsService<S>;
132
133    fn layer(&self, inner: S) -> Self::Service {
134        RedirectHttpsService {
135            inner,
136            config: self.config.clone(),
137        }
138    }
139}
140
141/// Tower [`Service`] produced by [`RedirectHttpsLayer`].
142#[derive(Clone, Debug)]
143pub struct RedirectHttpsService<S> {
144    inner: S,
145    config: RedirectHttps,
146}
147
148impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RedirectHttpsService<S>
149where
150    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
151    S::Future: Send + 'static,
152    S::Error: Send + 'static,
153    ResBody: Default + Send + 'static,
154{
155    type Response = Response<ResBody>;
156    type Error = S::Error;
157    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
158
159    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        self.inner.poll_ready(cx)
161    }
162
163    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
164        if RedirectHttps::is_http(&req)
165            && let Some(location) = self.config.location(&req)
166        {
167            let status = self.config.status;
168            return Box::pin(async move {
169                let mut response = Response::builder()
170                    .status(status)
171                    .body(ResBody::default())
172                    .expect("redirect response is valid");
173                response.headers_mut().insert(LOCATION, location);
174                Ok(response)
175            });
176        }
177        Box::pin(self.inner.call(req))
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use axum::{body::Body, routing::get, Router};
185    use http::StatusCode;
186    use tower::ServiceExt;
187
188    fn build_app(config: RedirectHttps) -> Router {
189        Router::new()
190            .route("/", get(|| async { "ok" }))
191            .layer(RedirectHttpsLayer::new(config))
192    }
193
194    async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
195        app.oneshot(req).await.unwrap()
196    }
197
198    fn forwarded_request(proto: &str, uri: &str) -> http::Request<Body> {
199        http::Request::builder()
200            .uri(uri)
201            .header(HOST, "example.com")
202            .header("x-forwarded-proto", proto)
203            .body(Body::empty())
204            .unwrap()
205    }
206
207    #[tokio::test]
208    async fn redirects_on_x_forwarded_proto_http() {
209        let response = send(
210            build_app(RedirectHttps::new()),
211            forwarded_request("http", "/path?q=1"),
212        )
213        .await;
214
215        assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
216        assert_eq!(
217            response.headers()["location"],
218            "https://example.com/path?q=1"
219        );
220    }
221
222    #[tokio::test]
223    async fn passes_through_on_x_forwarded_proto_https() {
224        let response = send(
225            build_app(RedirectHttps::new()),
226            forwarded_request("https", "/"),
227        )
228        .await;
229        assert_eq!(response.status(), StatusCode::OK);
230    }
231
232    #[tokio::test]
233    async fn redirects_on_http_uri_scheme() {
234        let req = http::Request::builder()
235            .uri("http://example.com/page")
236            .header(HOST, "example.com")
237            .body(Body::empty())
238            .unwrap();
239        let response = send(build_app(RedirectHttps::new()), req).await;
240
241        assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
242        assert_eq!(response.headers()["location"], "https://example.com/page");
243    }
244
245    #[tokio::test]
246    async fn passes_through_when_no_scheme_indicator() {
247        let req = http::Request::builder()
248            .uri("/")
249            .header(HOST, "example.com")
250            .body(Body::empty())
251            .unwrap();
252        let response = send(build_app(RedirectHttps::new()), req).await;
253        assert_eq!(response.status(), StatusCode::OK);
254    }
255
256    #[tokio::test]
257    async fn passes_through_when_no_host_header() {
258        let req = http::Request::builder()
259            .uri("/")
260            .header("x-forwarded-proto", "http")
261            .body(Body::empty())
262            .unwrap();
263        let response = send(build_app(RedirectHttps::new()), req).await;
264        assert_eq!(response.status(), StatusCode::OK);
265    }
266
267    #[tokio::test]
268    async fn custom_status_301() {
269        let config = RedirectHttps::new().status(StatusCode::MOVED_PERMANENTLY);
270        let response = send(build_app(config), forwarded_request("http", "/")).await;
271        assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
272    }
273
274    #[tokio::test]
275    async fn strips_http_port_from_host() {
276        let req = http::Request::builder()
277            .uri("/path")
278            .header(HOST, "example.com:80")
279            .header("x-forwarded-proto", "http")
280            .body(Body::empty())
281            .unwrap();
282        let response = send(build_app(RedirectHttps::new()), req).await;
283        assert_eq!(response.headers()["location"], "https://example.com/path");
284    }
285
286    #[tokio::test]
287    async fn custom_https_port() {
288        let config = RedirectHttps::new().https_port(8443);
289        let req = http::Request::builder()
290            .uri("/path")
291            .header(HOST, "example.com:8080")
292            .header("x-forwarded-proto", "http")
293            .body(Body::empty())
294            .unwrap();
295        let response = send(build_app(config), req).await;
296        assert_eq!(
297            response.headers()["location"],
298            "https://example.com:8443/path"
299        );
300    }
301
302    #[tokio::test]
303    async fn default_layer_uses_308() {
304        let app = Router::new()
305            .route("/", get(|| async { "ok" }))
306            .layer(RedirectHttpsLayer::default());
307        let response = send(app, forwarded_request("http", "/")).await;
308        assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
309    }
310}