#![allow(deprecated)]
#[cfg(feature = "sessions")]
use async_trait::async_trait;
#[cfg(feature = "sessions")]
use std::sync::Arc;
#[cfg(feature = "sessions")]
use reinhardt_auth::{AuthenticationBackend, User};
#[cfg(feature = "sessions")]
use reinhardt_http::{
AuthState, Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
};
#[cfg(feature = "sessions")]
pub const REMOTE_USER_HEADER: &str = "REMOTE_USER";
#[cfg(feature = "sessions")]
pub struct RemoteUserMiddleware<A: AuthenticationBackend> {
auth_backend: Arc<A>,
header_name: String,
force_logout_if_no_header: bool,
}
#[cfg(feature = "sessions")]
impl<A: AuthenticationBackend> RemoteUserMiddleware<A> {
pub fn new(auth_backend: Arc<A>) -> Self {
Self {
auth_backend,
header_name: REMOTE_USER_HEADER.to_string(),
force_logout_if_no_header: true,
}
}
pub fn with_header(mut self, header_name: &str) -> Self {
self.header_name = header_name.to_string();
self
}
async fn get_user_by_name(&self, username: &str) -> Option<Box<dyn User>> {
self.auth_backend.get_user(username).await.ok().flatten()
}
fn insert_user_extensions(request: &Request, user: &dyn User) {
let is_authenticated = user.is_authenticated();
let is_admin = user.is_admin();
let is_active = user.is_active();
let user_id = user.id();
request.extensions.insert(user_id.clone());
request.extensions.insert(IsAuthenticated(is_authenticated));
request.extensions.insert(IsAdmin(is_admin));
request.extensions.insert(IsActive(is_active));
let auth_state = if is_authenticated {
AuthState::authenticated(user_id, is_admin, is_active)
} else {
AuthState::anonymous()
};
request.extensions.insert(auth_state);
}
}
#[cfg(feature = "sessions")]
#[async_trait]
impl<A: AuthenticationBackend + 'static> Middleware for RemoteUserMiddleware<A> {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let remote_user = request
.headers
.get(&self.header_name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(username) = remote_user {
if let Some(user) = self.get_user_by_name(&username).await {
Self::insert_user_extensions(&request, user.as_ref());
} else {
request.extensions.insert(AuthState::anonymous());
}
} else if self.force_logout_if_no_header {
request.extensions.insert(AuthState::anonymous());
}
next.handle(request).await
}
}
#[cfg(feature = "sessions")]
pub struct PersistentRemoteUserMiddleware<A: AuthenticationBackend> {
inner: RemoteUserMiddleware<A>,
}
#[cfg(feature = "sessions")]
impl<A: AuthenticationBackend> PersistentRemoteUserMiddleware<A> {
pub fn new(auth_backend: Arc<A>) -> Self {
Self {
inner: RemoteUserMiddleware {
auth_backend,
header_name: REMOTE_USER_HEADER.to_string(),
force_logout_if_no_header: false,
},
}
}
pub fn with_header(mut self, header_name: &str) -> Self {
self.inner.header_name = header_name.to_string();
self
}
}
#[cfg(feature = "sessions")]
#[async_trait]
impl<A: AuthenticationBackend + 'static> Middleware for PersistentRemoteUserMiddleware<A> {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
self.inner.process(request, next).await
}
}
#[cfg(all(test, feature = "sessions"))]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_auth::{AuthenticationError, SimpleUser};
use reinhardt_http::{AuthState, Handler, Middleware, Request, Response};
use rstest::rstest;
use uuid::Uuid;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let auth_state = request.extensions.get::<AuthState>();
Ok(Response::ok().with_json(&serde_json::json!({
"is_authenticated": auth_state.as_ref().map(|s| s.is_authenticated()).unwrap_or(false),
"user_id": auth_state.as_ref().map(|s| s.user_id().to_string()).unwrap_or_default(),
}))?)
}
}
struct TestAuthBackend {
user: Option<SimpleUser>,
}
#[async_trait::async_trait]
impl AuthenticationBackend for TestAuthBackend {
async fn authenticate(
&self,
_request: &Request,
) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
Ok(self
.user
.as_ref()
.map(|u| Box::new(u.clone()) as Box<dyn User>))
}
async fn get_user(
&self,
_user_id: &str,
) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
Ok(self
.user
.as_ref()
.map(|u| Box::new(u.clone()) as Box<dyn User>))
}
}
fn test_user() -> SimpleUser {
SimpleUser {
id: Uuid::now_v7(),
username: "proxy-user".to_string(),
email: "proxy@example.com".to_string(),
is_active: true,
is_admin: false,
is_staff: false,
is_superuser: false,
}
}
fn create_request_with_header(name: &'static str, value: &str) -> Request {
let mut headers = HeaderMap::new();
headers.insert(name, value.parse().unwrap());
Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap()
}
fn create_request_without_header() -> Request {
Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[rstest]
#[tokio::test]
async fn test_remote_user_header_authenticates_user() {
let user = test_user();
let expected_id = user.id.to_string();
let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
let middleware = RemoteUserMiddleware::new(auth_backend);
let handler = Arc::new(TestHandler);
let request = create_request_with_header("REMOTE_USER", "proxy-user");
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], true);
assert_eq!(body["user_id"], expected_id);
}
#[rstest]
#[tokio::test]
async fn test_missing_header_produces_anonymous() {
let auth_backend = Arc::new(TestAuthBackend {
user: Some(test_user()),
});
let middleware = RemoteUserMiddleware::new(auth_backend);
let handler = Arc::new(TestHandler);
let request = create_request_without_header();
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], false);
}
#[rstest]
#[tokio::test]
async fn test_unknown_user_produces_anonymous() {
let auth_backend = Arc::new(TestAuthBackend { user: None });
let middleware = RemoteUserMiddleware::new(auth_backend);
let handler = Arc::new(TestHandler);
let request = create_request_with_header("REMOTE_USER", "unknown-user");
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], false);
}
#[rstest]
#[tokio::test]
async fn test_custom_header_name() {
let user = test_user();
let expected_id = user.id.to_string();
let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
let middleware = RemoteUserMiddleware::new(auth_backend).with_header("X-Forwarded-User");
let handler = Arc::new(TestHandler);
let request = create_request_with_header("X-Forwarded-User", "proxy-user");
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], true);
assert_eq!(body["user_id"], expected_id);
}
#[rstest]
#[tokio::test]
async fn test_persistent_middleware_preserves_auth_when_no_header() {
let auth_backend = Arc::new(TestAuthBackend {
user: Some(test_user()),
});
let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
let handler = Arc::new(TestHandler);
let request = create_request_without_header();
request
.extensions
.insert(AuthState::authenticated("existing-user", false, true));
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], true);
assert_eq!(body["user_id"], "existing-user");
}
#[rstest]
#[tokio::test]
async fn test_persistent_middleware_authenticates_when_header_present() {
let user = test_user();
let expected_id = user.id.to_string();
let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
let handler = Arc::new(TestHandler);
let request = create_request_with_header("REMOTE_USER", "proxy-user");
let response = middleware.process(request, handler).await.unwrap();
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(body["is_authenticated"], true);
assert_eq!(body["user_id"], expected_id);
}
}