use async_trait::async_trait;
use std::sync::Arc;
use reinhardt_http::{AuthState, Handler, Middleware, Request, Response, Result};
pub const DEFAULT_LOGIN_URL: &str = "/accounts/login/";
pub const DEFAULT_REDIRECT_FIELD_NAME: &str = "next";
#[derive(Clone, Debug)]
pub struct LoginRequiredConfig {
pub login_url: String,
pub redirect_field_name: String,
pub exempt_paths: Vec<String>,
}
impl Default for LoginRequiredConfig {
fn default() -> Self {
Self {
login_url: DEFAULT_LOGIN_URL.to_string(),
redirect_field_name: DEFAULT_REDIRECT_FIELD_NAME.to_string(),
exempt_paths: Vec::new(),
}
}
}
impl LoginRequiredConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_login_url(mut self, url: &str) -> Self {
self.login_url = url.to_string();
self
}
pub fn with_redirect_field_name(mut self, name: &str) -> Self {
self.redirect_field_name = name.to_string();
self
}
pub fn with_exempt_path(mut self, path: &str) -> Self {
self.exempt_paths.push(path.to_string());
self
}
pub fn with_exempt_paths(mut self, paths: &[&str]) -> Self {
self.exempt_paths
.extend(paths.iter().map(|p| p.to_string()));
self
}
fn is_exempt(&self, path: &str) -> bool {
if path == self.login_url || path.starts_with(&self.login_url) {
return true;
}
self.exempt_paths.iter().any(|exempt| {
if exempt.ends_with('/') {
path.starts_with(exempt.as_str())
} else {
path == exempt
}
})
}
}
pub struct LoginRequiredMiddleware {
config: LoginRequiredConfig,
}
impl LoginRequiredMiddleware {
pub fn new(config: LoginRequiredConfig) -> Self {
Self { config }
}
fn build_redirect_url(&self, original_path: &str) -> String {
format!(
"{}?{}={}",
self.config.login_url, self.config.redirect_field_name, original_path
)
}
}
impl Default for LoginRequiredMiddleware {
fn default() -> Self {
Self::new(LoginRequiredConfig::default())
}
}
#[async_trait]
impl Middleware for LoginRequiredMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let path = request.uri.path().to_string();
if self.config.is_exempt(&path) {
return next.handle(request).await;
}
let is_authenticated = request
.extensions
.get::<AuthState>()
.map(|s| s.is_authenticated())
.unwrap_or(false);
if !is_authenticated {
let redirect_url = self.build_redirect_url(&path);
return Ok(Response::temporary_redirect(&redirect_url));
}
next.handle(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
use reinhardt_http::{AuthState, Handler, Middleware, Request, Response};
use rstest::rstest;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::ok())
}
}
fn create_request(path: &str, auth_state: Option<AuthState>) -> Request {
let request = Request::builder()
.method(Method::GET)
.uri(path)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
if let Some(state) = auth_state {
request.extensions.insert(state);
}
request
}
#[rstest]
#[tokio::test]
async fn test_authenticated_user_passes_through() {
let middleware = LoginRequiredMiddleware::default();
let handler = Arc::new(TestHandler);
let request = create_request(
"/dashboard",
Some(AuthState::authenticated("user-1", false, true)),
);
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_unauthenticated_user_gets_redirected() {
let middleware = LoginRequiredMiddleware::default();
let handler = Arc::new(TestHandler);
let request = create_request("/dashboard", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::FOUND);
let location = response.headers.get("Location").unwrap().to_str().unwrap();
assert_eq!(location, "/accounts/login/?next=/dashboard");
}
#[rstest]
#[tokio::test]
async fn test_no_auth_state_gets_redirected() {
let middleware = LoginRequiredMiddleware::default();
let handler = Arc::new(TestHandler);
let request = create_request("/dashboard", None);
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::FOUND);
}
#[rstest]
#[tokio::test]
async fn test_login_url_is_exempt() {
let middleware = LoginRequiredMiddleware::default();
let handler = Arc::new(TestHandler);
let request = create_request("/accounts/login/", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_custom_exempt_path_prefix_match() {
let config = LoginRequiredConfig::new().with_exempt_path("/api/public/");
let middleware = LoginRequiredMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = create_request("/api/public/health", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_custom_exempt_path_exact_match() {
let config = LoginRequiredConfig::new().with_exempt_path("/health");
let middleware = LoginRequiredMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = create_request("/health", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_custom_exempt_path_exact_no_prefix() {
let config = LoginRequiredConfig::new().with_exempt_path("/health");
let middleware = LoginRequiredMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = create_request("/health/detail", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::FOUND);
}
#[rstest]
#[tokio::test]
async fn test_custom_login_url_and_redirect_field() {
let config = LoginRequiredConfig::new()
.with_login_url("/auth/signin/")
.with_redirect_field_name("return_to");
let middleware = LoginRequiredMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = create_request("/dashboard", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::FOUND);
let location = response.headers.get("Location").unwrap().to_str().unwrap();
assert_eq!(location, "/auth/signin/?return_to=/dashboard");
}
#[rstest]
#[tokio::test]
async fn test_multiple_exempt_paths() {
let config =
LoginRequiredConfig::new().with_exempt_paths(&["/api/", "/health", "/static/"]);
let middleware = LoginRequiredMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = create_request("/api/users", Some(AuthState::anonymous()));
let response = middleware.process(request, handler.clone()).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let request = create_request("/health", Some(AuthState::anonymous()));
let response = middleware.process(request, handler.clone()).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let request = create_request("/static/css/main.css", Some(AuthState::anonymous()));
let response = middleware.process(request, handler.clone()).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
let request = create_request("/dashboard", Some(AuthState::anonymous()));
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::FOUND);
}
}