use std::collections::HashMap;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::Request;
use axum::http::{header, HeaderValue, StatusCode};
use axum::middleware::Next;
use axum::response::Response;
#[derive(Debug, Clone)]
pub struct LoginRequiredConfig {
pub login_url: String,
pub redirect_field: String,
}
impl Default for LoginRequiredConfig {
fn default() -> Self {
Self {
login_url: "/login".into(),
redirect_field: "next".into(),
}
}
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn login_required(
login_url: impl Into<String>,
) -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
let cfg = Arc::new(LoginRequiredConfig {
login_url: login_url.into(),
..Default::default()
});
axum::middleware::from_fn(move |req: Request<Body>, next: Next| {
let cfg = cfg.clone();
async move { handle_login_required(cfg, req, next).await }
})
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn user_passes_test<F>(
login_url: impl Into<String>,
predicate: F,
) -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone
where
F: Fn(&crate::tenancy::auth::User) -> bool + Send + Sync + 'static,
{
let cfg = Arc::new(LoginRequiredConfig {
login_url: login_url.into(),
..Default::default()
});
let pred = Arc::new(predicate);
axum::middleware::from_fn(move |req: Request<Body>, next: Next| {
let cfg = cfg.clone();
let pred = pred.clone();
async move { handle_user_passes_test(cfg, pred, req, next).await }
})
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
async fn handle_user_passes_test<F>(
cfg: Arc<LoginRequiredConfig>,
pred: Arc<F>,
req: Request<Body>,
next: Next,
) -> Response
where
F: Fn(&crate::tenancy::auth::User) -> bool + Send + Sync + 'static,
{
use axum::extract::FromRequestParts as _;
let (mut parts, body) = req.into_parts();
let user = crate::extractors::SessionUser::from_request_parts(&mut parts, &())
.await
.unwrap_or(crate::extractors::SessionUser(None));
if let Some(u) = user.0.as_ref() {
if pred(u) {
let req = Request::from_parts(parts, body);
return next.run(req).await;
}
}
let original = parts
.uri
.path_and_query()
.map(|p| p.as_str().to_owned())
.unwrap_or_else(|| "/".to_owned());
redirect_to_login(&cfg.login_url, &cfg.redirect_field, &original)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn user_passes_test_or_403<F>(
predicate: F,
) -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone
where
F: Fn(&crate::tenancy::auth::User) -> bool + Send + Sync + 'static,
{
let pred = Arc::new(predicate);
axum::middleware::from_fn(move |req: Request<Body>, next: Next| {
let pred = pred.clone();
async move { handle_user_passes_test_or_403(pred, req, next).await }
})
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
async fn handle_user_passes_test_or_403<F>(pred: Arc<F>, req: Request<Body>, next: Next) -> Response
where
F: Fn(&crate::tenancy::auth::User) -> bool + Send + Sync + 'static,
{
use axum::extract::FromRequestParts as _;
let (mut parts, body) = req.into_parts();
let user = crate::extractors::SessionUser::from_request_parts(&mut parts, &())
.await
.unwrap_or(crate::extractors::SessionUser(None));
match user.0.as_ref() {
None => Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.expect("401 + empty body is always valid"),
Some(u) if pred(u) => {
let req = Request::from_parts(parts, body);
next.run(req).await
}
Some(_) => Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::empty())
.expect("403 + empty body is always valid"),
}
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn login_required_or_401() -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
user_passes_test_or_403(|_| true)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn superuser_required(
login_url: impl Into<String>,
) -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
user_passes_test(login_url, |u| u.is_superuser && u.active)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn superuser_required_or_403() -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
user_passes_test_or_403(|u| u.is_superuser && u.active)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn active_required(
login_url: impl Into<String>,
) -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
user_passes_test(login_url, |u| u.active)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
pub fn active_required_or_403() -> impl tower::Layer<
axum::routing::Route,
Service = impl tower::Service<
Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Send + 'static,
> + Clone
+ Send
+ Sync
+ 'static,
> + Clone {
user_passes_test_or_403(|u| u.active)
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
async fn handle_login_required(
cfg: Arc<LoginRequiredConfig>,
req: Request<Body>,
next: Next,
) -> Response {
use axum::extract::FromRequestParts as _;
let (mut parts, body) = req.into_parts();
let user = crate::extractors::SessionUser::from_request_parts(&mut parts, &())
.await
.unwrap_or(crate::extractors::SessionUser(None));
if user.0.is_some() {
let req = Request::from_parts(parts, body);
return next.run(req).await;
}
let original = parts
.uri
.path_and_query()
.map(|p| p.as_str().to_owned())
.unwrap_or_else(|| "/".to_owned());
redirect_to_login(&cfg.login_url, &cfg.redirect_field, &original)
}
#[must_use]
pub fn redirect_to_login(login_url: &str, redirect_field: &str, original: &str) -> Response {
let target = build_login_url(login_url, redirect_field, original);
let mut res = Response::builder()
.status(StatusCode::FOUND)
.body(Body::empty())
.expect("302 + empty body is always valid");
if let Ok(v) = HeaderValue::from_str(&target) {
res.headers_mut().insert(header::LOCATION, v);
}
res
}
fn build_login_url(login_url: &str, redirect_field: &str, original: &str) -> String {
let encoded = crate::url_codec::url_encode(original);
let sep = if login_url.contains('?') { '&' } else { '?' };
format!("{login_url}{sep}{redirect_field}={encoded}")
}
#[must_use]
pub fn extract_next(query: &HashMap<String, String>) -> Option<String> {
extract_next_named(query, "next")
}
#[must_use]
pub fn extract_next_named(query: &HashMap<String, String>, field: &str) -> Option<String> {
query
.get(field)
.map(|s| s.trim().to_owned())
.filter(|s| !s.is_empty())
}
#[must_use]
pub fn safe_next(next: &str) -> Option<String> {
let trimmed = next.trim();
if trimmed.is_empty() {
return None;
}
let decoded = crate::url_codec::url_decode(trimmed);
if !is_safe_path(&decoded) {
return None;
}
if !is_safe_path(trimmed) {
return None;
}
Some(trimmed.to_owned())
}
fn is_safe_path(s: &str) -> bool {
if !s.starts_with('/') {
return false;
}
if s.starts_with("//") {
return false;
}
if s.starts_with("/\\") {
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_login_url_appends_next_as_query() {
let url = build_login_url("/login", "next", "/protected");
assert_eq!(url, "/login?next=%2Fprotected");
}
#[test]
fn build_login_url_uses_ampersand_when_login_already_has_query() {
let url = build_login_url("/login?foo=1", "next", "/protected");
assert_eq!(url, "/login?foo=1&next=%2Fprotected");
}
#[test]
fn build_login_url_url_encodes_path_with_special_chars() {
let url = build_login_url("/login", "next", "/posts/hello world?page=2");
assert!(url.contains("%20"));
assert!(url.contains("%3F") || url.contains("%26") || url.contains("page%3D2"));
}
#[test]
fn redirect_to_login_returns_302_with_location() {
let res = redirect_to_login("/login", "next", "/profile");
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(loc, "/login?next=%2Fprofile");
}
#[test]
fn redirect_to_login_drops_location_on_crlf_attempt() {
let res = redirect_to_login("/login", "next", "/profile\r\nSet-Cookie: pwned=1");
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert!(!loc.contains('\r'), "raw \\r in Location header: {loc}");
assert!(!loc.contains('\n'), "raw \\n in Location header: {loc}");
}
#[test]
fn extract_next_returns_value_when_present() {
let mut q = HashMap::new();
q.insert("next".to_owned(), "/profile".to_owned());
assert_eq!(extract_next(&q), Some("/profile".to_owned()));
}
#[test]
fn extract_next_returns_none_when_absent_or_empty() {
assert_eq!(extract_next(&HashMap::new()), None);
let mut q = HashMap::new();
q.insert("next".to_owned(), "".to_owned());
assert_eq!(extract_next(&q), None);
q.insert("next".to_owned(), " ".to_owned());
assert_eq!(extract_next(&q), None);
}
#[test]
fn extract_next_named_uses_custom_field() {
let mut q = HashMap::new();
q.insert("redirect_to".to_owned(), "/profile".to_owned());
assert_eq!(extract_next(&q), None);
assert_eq!(
extract_next_named(&q, "redirect_to"),
Some("/profile".to_owned())
);
}
#[test]
fn safe_next_accepts_root_relative_paths() {
assert_eq!(safe_next("/profile"), Some("/profile".to_owned()));
assert_eq!(safe_next("/posts/42"), Some("/posts/42".to_owned()));
assert_eq!(
safe_next("/search?q=hello"),
Some("/search?q=hello".to_owned())
);
}
#[test]
fn safe_next_rejects_absolute_urls() {
assert_eq!(safe_next("http://evil.example/x"), None);
assert_eq!(safe_next("https://evil.example/x"), None);
assert_eq!(safe_next("ftp://evil.example/"), None);
assert_eq!(safe_next("//evil.example/x"), None);
}
#[test]
fn safe_next_rejects_backslash_variant() {
assert_eq!(safe_next("/\\evil.example/x"), None);
}
#[test]
fn safe_next_rejects_percent_encoded_bypass() {
assert_eq!(safe_next("%2F%2Fevil.example/x"), None);
assert_eq!(safe_next("%2f%2fevil.example/x"), None);
assert_eq!(safe_next("/%2Fevil.example/x"), None);
assert_eq!(safe_next("/%5Cevil.example/x"), None);
}
#[test]
fn safe_next_accepts_legitimate_percent_encodes_in_path() {
assert_eq!(
safe_next("/profile/hello%20world"),
Some("/profile/hello%20world".to_owned())
);
}
#[test]
fn safe_next_rejects_empty_and_whitespace() {
assert_eq!(safe_next(""), None);
assert_eq!(safe_next(" "), None);
assert_eq!(safe_next("\t\n"), None);
}
#[test]
fn safe_next_strips_surrounding_whitespace() {
assert_eq!(safe_next(" /profile "), Some("/profile".to_owned()));
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn login_required_layer_redirects_anonymous_to_login_url() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn protected() -> &'static str {
"secret"
}
let app = Router::new()
.route("/profile", get(protected))
.layer(login_required("/login"));
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/profile")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(loc, "/login?next=%2Fprofile");
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn user_passes_test_redirects_anonymous_to_login_url() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn staff_only() -> &'static str {
"staff zone"
}
let app = Router::new()
.route("/admin/dashboard", get(staff_only))
.layer(user_passes_test("/login", |u| u.is_superuser));
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/admin/dashboard")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(loc, "/login?next=%2Fadmin%2Fdashboard");
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[test]
fn user_passes_test_signature_compiles_with_closure_predicate() {
let _ = || {
let _layer = user_passes_test("/login", |u: &crate::tenancy::auth::User| {
u.is_superuser && u.active
});
};
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn user_passes_test_or_403_returns_401_for_anonymous() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn admin_only() -> &'static str {
"ok"
}
let app = Router::new()
.route("/api/admin/stats", get(admin_only))
.layer(user_passes_test_or_403(|u| u.is_superuser));
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/api/admin/stats")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[test]
fn login_required_or_401_signature_compiles() {
let _ = || {
let _layer = login_required_or_401();
};
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn superuser_required_redirects_anonymous_to_login() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn admin_only() -> &'static str {
"admin zone"
}
let app = Router::new()
.route("/admin/dashboard", get(admin_only))
.layer(superuser_required("/login"));
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/admin/dashboard")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(loc, "/login?next=%2Fadmin%2Fdashboard");
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn superuser_required_or_403_returns_401_for_anonymous() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn admin_api() -> &'static str {
"ok"
}
let app = Router::new()
.route("/api/admin", get(admin_api))
.layer(superuser_required_or_403());
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/api/admin")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn active_required_redirects_anonymous_to_login() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn dashboard() -> &'static str {
"dashboard"
}
let app = Router::new()
.route("/dashboard", get(dashboard))
.layer(active_required("/login"));
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/dashboard")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::FOUND);
let loc = res
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap();
assert_eq!(loc, "/login?next=%2Fdashboard");
}
#[cfg(all(feature = "tenancy", feature = "postgres"))]
#[tokio::test]
async fn active_required_or_403_returns_401_for_anonymous() {
use axum::body::Body;
use axum::routing::get;
use axum::Router;
use tower::ServiceExt as _;
async fn me_api() -> &'static str {
"ok"
}
let app = Router::new()
.route("/api/me", get(me_api))
.layer(active_required_or_403());
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/api/me")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
}