use std::future::Future;
use std::pin::Pin;
use axum::Router;
use axum::body::Body;
use axum::http::{Method, Request, Response, StatusCode, Uri, header};
use tower::Service;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SlashRedirect {
#[default]
Off,
Append,
Strip,
}
impl SlashRedirect {
pub fn alternate_path(&self, path: &str) -> Option<String> {
match self {
SlashRedirect::Off => None,
SlashRedirect::Append => {
if path == "/" || path.ends_with('/') {
None
} else {
Some(format!("{path}/"))
}
}
SlashRedirect::Strip => {
if path == "/" || !path.ends_with('/') {
None
} else {
Some(path.trim_end_matches('/').to_string())
}
}
}
}
}
pub fn slash_redirect_fallback(
snapshot: Router,
policy: SlashRedirect,
not_found_template: Option<String>,
) -> impl Fn(Request<Body>) -> Pin<Box<dyn Future<Output = Response<Body>> + Send>>
+ Clone
+ Send
+ Sync
+ 'static {
move |req: Request<Body>| {
let snapshot = snapshot.clone();
let policy = policy;
let template = not_found_template.clone();
Box::pin(async move {
let original_path = req.uri().path().to_owned();
let query = req
.uri()
.query()
.map(|q| format!("?{q}"))
.unwrap_or_default();
let default_404 =
|| crate::errors::render_not_found(template.as_deref(), &original_path);
let Some(alt) = policy.alternate_path(&original_path) else {
return default_404();
};
let alt_uri: Uri = match format!("{alt}{query}").parse() {
Ok(u) => u,
Err(_) => return default_404(),
};
let probe_req = match Request::builder()
.method(Method::GET)
.uri(alt_uri.clone())
.body(Body::empty())
{
Ok(r) => r,
Err(_) => return default_404(),
};
let mut probe_service = snapshot.clone();
std::future::poll_fn(|cx| {
<Router as Service<Request<Body>>>::poll_ready(&mut probe_service, cx)
})
.await
.ok();
let probe_resp =
match <Router as Service<Request<Body>>>::call(&mut probe_service, probe_req).await
{
Ok(r) => r,
Err(_) => return default_404(),
};
if probe_resp.status() == StatusCode::NOT_FOUND {
return default_404();
}
let mut redirect = Response::new(Body::empty());
*redirect.status_mut() = StatusCode::PERMANENT_REDIRECT;
let location = format!("{alt}{query}");
if let Ok(value) = location.parse() {
redirect.headers_mut().insert(header::LOCATION, value);
}
redirect
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alternate_path_off_never_returns_anything() {
assert_eq!(SlashRedirect::Off.alternate_path("/foo"), None);
assert_eq!(SlashRedirect::Off.alternate_path("/foo/"), None);
assert_eq!(SlashRedirect::Off.alternate_path("/"), None);
}
#[test]
fn alternate_path_append_adds_trailing_slash() {
assert_eq!(
SlashRedirect::Append.alternate_path("/foo"),
Some("/foo/".to_string())
);
assert_eq!(
SlashRedirect::Append.alternate_path("/api/articles"),
Some("/api/articles/".to_string())
);
}
#[test]
fn alternate_path_append_skips_already_slashed() {
assert_eq!(SlashRedirect::Append.alternate_path("/foo/"), None);
assert_eq!(SlashRedirect::Append.alternate_path("/"), None);
}
#[test]
fn alternate_path_strip_removes_trailing_slash() {
assert_eq!(
SlashRedirect::Strip.alternate_path("/foo/"),
Some("/foo".to_string())
);
assert_eq!(
SlashRedirect::Strip.alternate_path("/api/articles/"),
Some("/api/articles".to_string())
);
}
#[test]
fn alternate_path_strip_skips_slashless_and_root() {
assert_eq!(SlashRedirect::Strip.alternate_path("/foo"), None);
assert_eq!(SlashRedirect::Strip.alternate_path("/"), None);
}
}