reinhardt_middleware/
remote_user.rs1#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_auth::{AuthenticationBackend, User};
12#[cfg(feature = "sessions")]
13use reinhardt_http::{
14 AuthState, Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
15};
16
17#[cfg(feature = "sessions")]
19pub const REMOTE_USER_HEADER: &str = "REMOTE_USER";
20
21#[cfg(feature = "sessions")]
70pub struct RemoteUserMiddleware<A: AuthenticationBackend> {
71 auth_backend: Arc<A>,
72 header_name: String,
73 force_logout_if_no_header: bool,
77}
78
79#[cfg(feature = "sessions")]
80impl<A: AuthenticationBackend> RemoteUserMiddleware<A> {
81 pub fn new(auth_backend: Arc<A>) -> Self {
89 Self {
90 auth_backend,
91 header_name: REMOTE_USER_HEADER.to_string(),
92 force_logout_if_no_header: true,
93 }
94 }
95
96 pub fn with_header(mut self, header_name: &str) -> Self {
122 self.header_name = header_name.to_string();
123 self
124 }
125
126 async fn get_user_by_name(&self, username: &str) -> Option<Box<dyn User>> {
128 self.auth_backend.get_user(username).await.ok().flatten()
129 }
130
131 fn insert_user_extensions(request: &Request, user: &dyn User) {
133 let is_authenticated = user.is_authenticated();
134 let is_admin = user.is_admin();
135 let is_active = user.is_active();
136 let user_id = user.id();
137
138 request.extensions.insert(user_id.clone());
140 request.extensions.insert(IsAuthenticated(is_authenticated));
141 request.extensions.insert(IsAdmin(is_admin));
142 request.extensions.insert(IsActive(is_active));
143
144 let auth_state = if is_authenticated {
146 AuthState::authenticated(user_id, is_admin, is_active)
147 } else {
148 AuthState::anonymous()
149 };
150 request.extensions.insert(auth_state);
151 }
152}
153
154#[cfg(feature = "sessions")]
155#[async_trait]
156impl<A: AuthenticationBackend + 'static> Middleware for RemoteUserMiddleware<A> {
157 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
158 let remote_user = request
159 .headers
160 .get(&self.header_name)
161 .and_then(|v| v.to_str().ok())
162 .map(|s| s.to_string());
163
164 if let Some(username) = remote_user {
165 if let Some(user) = self.get_user_by_name(&username).await {
166 Self::insert_user_extensions(&request, user.as_ref());
167 } else {
168 request.extensions.insert(AuthState::anonymous());
169 }
170 } else if self.force_logout_if_no_header {
171 request.extensions.insert(AuthState::anonymous());
173 }
174 next.handle(request).await
178 }
179}
180
181#[cfg(feature = "sessions")]
223pub struct PersistentRemoteUserMiddleware<A: AuthenticationBackend> {
224 inner: RemoteUserMiddleware<A>,
225}
226
227#[cfg(feature = "sessions")]
228impl<A: AuthenticationBackend> PersistentRemoteUserMiddleware<A> {
229 pub fn new(auth_backend: Arc<A>) -> Self {
238 Self {
239 inner: RemoteUserMiddleware {
240 auth_backend,
241 header_name: REMOTE_USER_HEADER.to_string(),
242 force_logout_if_no_header: false,
243 },
244 }
245 }
246
247 pub fn with_header(mut self, header_name: &str) -> Self {
249 self.inner.header_name = header_name.to_string();
250 self
251 }
252}
253
254#[cfg(feature = "sessions")]
255#[async_trait]
256impl<A: AuthenticationBackend + 'static> Middleware for PersistentRemoteUserMiddleware<A> {
257 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
258 self.inner.process(request, next).await
259 }
260}
261
262#[cfg(all(test, feature = "sessions"))]
263mod tests {
264 use super::*;
265 use bytes::Bytes;
266 use hyper::{HeaderMap, Method, Version};
267 use reinhardt_auth::{AuthenticationError, SimpleUser};
268 use reinhardt_http::{AuthState, Handler, Middleware, Request, Response};
269 use rstest::rstest;
270 use uuid::Uuid;
271
272 struct TestHandler;
273
274 #[async_trait::async_trait]
275 impl Handler for TestHandler {
276 async fn handle(&self, request: Request) -> Result<Response> {
277 let auth_state = request.extensions.get::<AuthState>();
278 Ok(Response::ok().with_json(&serde_json::json!({
279 "is_authenticated": auth_state.as_ref().map(|s| s.is_authenticated()).unwrap_or(false),
280 "user_id": auth_state.as_ref().map(|s| s.user_id().to_string()).unwrap_or_default(),
281 }))?)
282 }
283 }
284
285 struct TestAuthBackend {
286 user: Option<SimpleUser>,
287 }
288
289 #[async_trait::async_trait]
290 impl AuthenticationBackend for TestAuthBackend {
291 async fn authenticate(
292 &self,
293 _request: &Request,
294 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
295 Ok(self
296 .user
297 .as_ref()
298 .map(|u| Box::new(u.clone()) as Box<dyn User>))
299 }
300
301 async fn get_user(
302 &self,
303 _user_id: &str,
304 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
305 Ok(self
306 .user
307 .as_ref()
308 .map(|u| Box::new(u.clone()) as Box<dyn User>))
309 }
310 }
311
312 fn test_user() -> SimpleUser {
313 SimpleUser {
314 id: Uuid::now_v7(),
315 username: "proxy-user".to_string(),
316 email: "proxy@example.com".to_string(),
317 is_active: true,
318 is_admin: false,
319 is_staff: false,
320 is_superuser: false,
321 }
322 }
323
324 fn create_request_with_header(name: &'static str, value: &str) -> Request {
325 let mut headers = HeaderMap::new();
326 headers.insert(name, value.parse().unwrap());
327 Request::builder()
328 .method(Method::GET)
329 .uri("/test")
330 .version(Version::HTTP_11)
331 .headers(headers)
332 .body(Bytes::new())
333 .build()
334 .unwrap()
335 }
336
337 fn create_request_without_header() -> Request {
338 Request::builder()
339 .method(Method::GET)
340 .uri("/test")
341 .version(Version::HTTP_11)
342 .headers(HeaderMap::new())
343 .body(Bytes::new())
344 .build()
345 .unwrap()
346 }
347
348 #[rstest]
349 #[tokio::test]
350 async fn test_remote_user_header_authenticates_user() {
351 let user = test_user();
353 let expected_id = user.id.to_string();
354 let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
355 let middleware = RemoteUserMiddleware::new(auth_backend);
356 let handler = Arc::new(TestHandler);
357 let request = create_request_with_header("REMOTE_USER", "proxy-user");
358
359 let response = middleware.process(request, handler).await.unwrap();
361
362 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
364 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
365 assert_eq!(body["is_authenticated"], true);
366 assert_eq!(body["user_id"], expected_id);
367 }
368
369 #[rstest]
370 #[tokio::test]
371 async fn test_missing_header_produces_anonymous() {
372 let auth_backend = Arc::new(TestAuthBackend {
374 user: Some(test_user()),
375 });
376 let middleware = RemoteUserMiddleware::new(auth_backend);
377 let handler = Arc::new(TestHandler);
378 let request = create_request_without_header();
379
380 let response = middleware.process(request, handler).await.unwrap();
382
383 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
385 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
386 assert_eq!(body["is_authenticated"], false);
387 }
388
389 #[rstest]
390 #[tokio::test]
391 async fn test_unknown_user_produces_anonymous() {
392 let auth_backend = Arc::new(TestAuthBackend { user: None });
394 let middleware = RemoteUserMiddleware::new(auth_backend);
395 let handler = Arc::new(TestHandler);
396 let request = create_request_with_header("REMOTE_USER", "unknown-user");
397
398 let response = middleware.process(request, handler).await.unwrap();
400
401 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
403 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
404 assert_eq!(body["is_authenticated"], false);
405 }
406
407 #[rstest]
408 #[tokio::test]
409 async fn test_custom_header_name() {
410 let user = test_user();
412 let expected_id = user.id.to_string();
413 let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
414 let middleware = RemoteUserMiddleware::new(auth_backend).with_header("X-Forwarded-User");
415 let handler = Arc::new(TestHandler);
416 let request = create_request_with_header("X-Forwarded-User", "proxy-user");
417
418 let response = middleware.process(request, handler).await.unwrap();
420
421 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
423 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
424 assert_eq!(body["is_authenticated"], true);
425 assert_eq!(body["user_id"], expected_id);
426 }
427
428 #[rstest]
429 #[tokio::test]
430 async fn test_persistent_middleware_preserves_auth_when_no_header() {
431 let auth_backend = Arc::new(TestAuthBackend {
433 user: Some(test_user()),
434 });
435 let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
436 let handler = Arc::new(TestHandler);
437
438 let request = create_request_without_header();
440 request
441 .extensions
442 .insert(AuthState::authenticated("existing-user", false, true));
443
444 let response = middleware.process(request, handler).await.unwrap();
446
447 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
449 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
450 assert_eq!(body["is_authenticated"], true);
451 assert_eq!(body["user_id"], "existing-user");
452 }
453
454 #[rstest]
455 #[tokio::test]
456 async fn test_persistent_middleware_authenticates_when_header_present() {
457 let user = test_user();
459 let expected_id = user.id.to_string();
460 let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
461 let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
462 let handler = Arc::new(TestHandler);
463 let request = create_request_with_header("REMOTE_USER", "proxy-user");
464
465 let response = middleware.process(request, handler).await.unwrap();
467
468 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
470 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
471 assert_eq!(body["is_authenticated"], true);
472 assert_eq!(body["user_id"], expected_id);
473 }
474}