rune_axum_redirect_https/
layer.rs1use http::{
2 header::{HeaderValue, HOST, LOCATION},
3 Request, Response, StatusCode,
4};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tower::{Layer, Service};
9
10#[derive(Clone, Debug)]
27pub struct RedirectHttps {
28 status: StatusCode,
29 https_port: Option<u16>,
30}
31
32impl Default for RedirectHttps {
33 fn default() -> Self {
34 Self {
35 status: StatusCode::PERMANENT_REDIRECT,
36 https_port: None,
37 }
38 }
39}
40
41impl RedirectHttps {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn status(mut self, status: StatusCode) -> Self {
57 self.status = status;
58 self
59 }
60
61 pub fn https_port(mut self, port: u16) -> Self {
68 self.https_port = Some(port);
69 self
70 }
71
72 fn is_http<B>(req: &Request<B>) -> bool {
73 if let Some(proto) = req.headers().get("x-forwarded-proto") {
74 return proto.as_bytes().eq_ignore_ascii_case(b"http");
75 }
76 req.uri().scheme() == Some(&http::uri::Scheme::HTTP)
77 }
78
79 fn location<B>(&self, req: &Request<B>) -> Option<HeaderValue> {
80 let host = req.headers().get(HOST)?.to_str().ok()?;
81
82 let hostname = host
83 .rsplit_once(':')
84 .filter(|(_, port)| port.parse::<u16>().is_ok())
85 .map_or(host, |(h, _)| h);
86
87 let authority = match self.https_port {
88 Some(port) => format!("{hostname}:{port}"),
89 None => hostname.to_owned(),
90 };
91
92 let path_and_query = req
93 .uri()
94 .path_and_query()
95 .map(|pq| pq.as_str())
96 .unwrap_or("/");
97
98 HeaderValue::from_str(&format!("https://{authority}{path_and_query}")).ok()
99 }
100}
101
102#[derive(Clone, Debug, Default)]
119pub struct RedirectHttpsLayer {
120 config: RedirectHttps,
121}
122
123impl RedirectHttpsLayer {
124 pub fn new(config: RedirectHttps) -> Self {
126 Self { config }
127 }
128}
129
130impl<S> Layer<S> for RedirectHttpsLayer {
131 type Service = RedirectHttpsService<S>;
132
133 fn layer(&self, inner: S) -> Self::Service {
134 RedirectHttpsService {
135 inner,
136 config: self.config.clone(),
137 }
138 }
139}
140
141#[derive(Clone, Debug)]
143pub struct RedirectHttpsService<S> {
144 inner: S,
145 config: RedirectHttps,
146}
147
148impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RedirectHttpsService<S>
149where
150 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
151 S::Future: Send + 'static,
152 S::Error: Send + 'static,
153 ResBody: Default + Send + 'static,
154{
155 type Response = Response<ResBody>;
156 type Error = S::Error;
157 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
158
159 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 self.inner.poll_ready(cx)
161 }
162
163 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
164 if RedirectHttps::is_http(&req)
165 && let Some(location) = self.config.location(&req)
166 {
167 let status = self.config.status;
168 return Box::pin(async move {
169 let mut response = Response::builder()
170 .status(status)
171 .body(ResBody::default())
172 .expect("redirect response is valid");
173 response.headers_mut().insert(LOCATION, location);
174 Ok(response)
175 });
176 }
177 Box::pin(self.inner.call(req))
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use axum::{body::Body, routing::get, Router};
185 use http::StatusCode;
186 use tower::ServiceExt;
187
188 fn build_app(config: RedirectHttps) -> Router {
189 Router::new()
190 .route("/", get(|| async { "ok" }))
191 .layer(RedirectHttpsLayer::new(config))
192 }
193
194 async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
195 app.oneshot(req).await.unwrap()
196 }
197
198 fn forwarded_request(proto: &str, uri: &str) -> http::Request<Body> {
199 http::Request::builder()
200 .uri(uri)
201 .header(HOST, "example.com")
202 .header("x-forwarded-proto", proto)
203 .body(Body::empty())
204 .unwrap()
205 }
206
207 #[tokio::test]
208 async fn redirects_on_x_forwarded_proto_http() {
209 let response = send(
210 build_app(RedirectHttps::new()),
211 forwarded_request("http", "/path?q=1"),
212 )
213 .await;
214
215 assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
216 assert_eq!(
217 response.headers()["location"],
218 "https://example.com/path?q=1"
219 );
220 }
221
222 #[tokio::test]
223 async fn passes_through_on_x_forwarded_proto_https() {
224 let response = send(
225 build_app(RedirectHttps::new()),
226 forwarded_request("https", "/"),
227 )
228 .await;
229 assert_eq!(response.status(), StatusCode::OK);
230 }
231
232 #[tokio::test]
233 async fn redirects_on_http_uri_scheme() {
234 let req = http::Request::builder()
235 .uri("http://example.com/page")
236 .header(HOST, "example.com")
237 .body(Body::empty())
238 .unwrap();
239 let response = send(build_app(RedirectHttps::new()), req).await;
240
241 assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
242 assert_eq!(response.headers()["location"], "https://example.com/page");
243 }
244
245 #[tokio::test]
246 async fn passes_through_when_no_scheme_indicator() {
247 let req = http::Request::builder()
248 .uri("/")
249 .header(HOST, "example.com")
250 .body(Body::empty())
251 .unwrap();
252 let response = send(build_app(RedirectHttps::new()), req).await;
253 assert_eq!(response.status(), StatusCode::OK);
254 }
255
256 #[tokio::test]
257 async fn passes_through_when_no_host_header() {
258 let req = http::Request::builder()
259 .uri("/")
260 .header("x-forwarded-proto", "http")
261 .body(Body::empty())
262 .unwrap();
263 let response = send(build_app(RedirectHttps::new()), req).await;
264 assert_eq!(response.status(), StatusCode::OK);
265 }
266
267 #[tokio::test]
268 async fn custom_status_301() {
269 let config = RedirectHttps::new().status(StatusCode::MOVED_PERMANENTLY);
270 let response = send(build_app(config), forwarded_request("http", "/")).await;
271 assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
272 }
273
274 #[tokio::test]
275 async fn strips_http_port_from_host() {
276 let req = http::Request::builder()
277 .uri("/path")
278 .header(HOST, "example.com:80")
279 .header("x-forwarded-proto", "http")
280 .body(Body::empty())
281 .unwrap();
282 let response = send(build_app(RedirectHttps::new()), req).await;
283 assert_eq!(response.headers()["location"], "https://example.com/path");
284 }
285
286 #[tokio::test]
287 async fn custom_https_port() {
288 let config = RedirectHttps::new().https_port(8443);
289 let req = http::Request::builder()
290 .uri("/path")
291 .header(HOST, "example.com:8080")
292 .header("x-forwarded-proto", "http")
293 .body(Body::empty())
294 .unwrap();
295 let response = send(build_app(config), req).await;
296 assert_eq!(
297 response.headers()["location"],
298 "https://example.com:8443/path"
299 );
300 }
301
302 #[tokio::test]
303 async fn default_layer_uses_308() {
304 let app = Router::new()
305 .route("/", get(|| async { "ok" }))
306 .layer(RedirectHttpsLayer::default());
307 let response = send(app, forwarded_request("http", "/")).await;
308 assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
309 }
310}