Skip to main content

reinhardt_middleware/
remote_user.rs

1// The `User` trait is deprecated in favour of the new `#[model]`-based user macro system.
2// Downstream crates still reference it during the transition period.
3#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_auth::{AuthenticationBackend, User};
12#[cfg(feature = "sessions")]
13use reinhardt_http::{
14	AuthState, Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
15};
16
17/// Default HTTP header name for the remote user.
18#[cfg(feature = "sessions")]
19pub const REMOTE_USER_HEADER: &str = "REMOTE_USER";
20
21/// Remote user authentication middleware.
22///
23/// Authenticates users based on the `REMOTE_USER` header set by a
24/// reverse proxy (Apache, Nginx, etc.). This is the Rust equivalent
25/// of Django's [`RemoteUserMiddleware`](https://docs.djangoproject.com/en/5.1/ref/middleware/#django.contrib.auth.middleware.RemoteUserMiddleware).
26///
27/// When the configured header is present, the middleware uses the
28/// provided [`AuthenticationBackend`] to look up the user. When the
29/// header is absent, the request proceeds as anonymous, clearing any
30/// previously authenticated state.
31///
32/// # Security Warning
33///
34/// This middleware should **only** be used behind a trusted reverse
35/// proxy that controls the `REMOTE_USER` header. If the proxy does
36/// not strip or override this header from client requests, an
37/// attacker can impersonate any user by sending a crafted header.
38///
39/// # Examples
40///
41/// ```rust,no_run
42/// use std::sync::Arc;
43/// use reinhardt_middleware::RemoteUserMiddleware;
44/// use reinhardt_http::MiddlewareChain;
45/// # use reinhardt_http::{Handler, Request, Response, Result};
46/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User};
47/// # use async_trait::async_trait;
48/// # struct MyHandler;
49/// # #[async_trait]
50/// # impl Handler for MyHandler {
51/// #     async fn handle(&self, _request: Request) -> Result<Response> {
52/// #         Ok(Response::ok())
53/// #     }
54/// # }
55/// # struct MyAuthBackend;
56/// # #[async_trait]
57/// # impl AuthenticationBackend for MyAuthBackend {
58/// #     async fn authenticate(&self, _req: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
59/// #     async fn get_user(&self, _uid: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
60/// # }
61/// # let handler = Arc::new(MyHandler);
62///
63/// let auth_backend = Arc::new(MyAuthBackend);
64/// let middleware = RemoteUserMiddleware::new(auth_backend);
65///
66/// let app = MiddlewareChain::new(handler)
67///     .with_middleware(Arc::new(middleware));
68/// ```
69#[cfg(feature = "sessions")]
70pub struct RemoteUserMiddleware<A: AuthenticationBackend> {
71	auth_backend: Arc<A>,
72	header_name: String,
73	/// When `true`, absence of the remote user header forces logout
74	/// (anonymous state). When `false`, the existing session auth
75	/// is preserved even without the header.
76	force_logout_if_no_header: bool,
77}
78
79#[cfg(feature = "sessions")]
80impl<A: AuthenticationBackend> RemoteUserMiddleware<A> {
81	/// Creates a new remote user middleware with the default `REMOTE_USER` header.
82	///
83	/// When the header is absent, the request proceeds as anonymous.
84	///
85	/// # Arguments
86	///
87	/// * `auth_backend` - Authentication backend for user lookup
88	pub fn new(auth_backend: Arc<A>) -> Self {
89		Self {
90			auth_backend,
91			header_name: REMOTE_USER_HEADER.to_string(),
92			force_logout_if_no_header: true,
93		}
94	}
95
96	/// Sets a custom header name for remote user identification.
97	///
98	/// # Arguments
99	///
100	/// * `header_name` - The HTTP header name containing the remote username
101	///
102	/// # Examples
103	///
104	/// ```rust,no_run
105	/// # use std::sync::Arc;
106	/// # use reinhardt_middleware::RemoteUserMiddleware;
107	/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User};
108	/// # use reinhardt_http::Request;
109	/// # use async_trait::async_trait;
110	/// # struct MyAuth;
111	/// # #[async_trait]
112	/// # impl AuthenticationBackend for MyAuth {
113	/// #     async fn authenticate(&self, _req: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
114	/// #     async fn get_user(&self, _uid: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
115	/// # }
116	///
117	/// let backend = Arc::new(MyAuth);
118	/// let middleware = RemoteUserMiddleware::new(backend)
119	///     .with_header("X-Forwarded-User");
120	/// ```
121	pub fn with_header(mut self, header_name: &str) -> Self {
122		self.header_name = header_name.to_string();
123		self
124	}
125
126	/// Looks up a user by username via the authentication backend.
127	async fn get_user_by_name(&self, username: &str) -> Option<Box<dyn User>> {
128		self.auth_backend.get_user(username).await.ok().flatten()
129	}
130
131	/// Inserts user information into request extensions.
132	fn insert_user_extensions(request: &Request, user: &dyn User) {
133		let is_authenticated = user.is_authenticated();
134		let is_admin = user.is_admin();
135		let is_active = user.is_active();
136		let user_id = user.id();
137
138		// Insert individual values for backward compatibility
139		request.extensions.insert(user_id.clone());
140		request.extensions.insert(IsAuthenticated(is_authenticated));
141		request.extensions.insert(IsAdmin(is_admin));
142		request.extensions.insert(IsActive(is_active));
143
144		// Insert AuthState object
145		let auth_state = if is_authenticated {
146			AuthState::authenticated(user_id, is_admin, is_active)
147		} else {
148			AuthState::anonymous()
149		};
150		request.extensions.insert(auth_state);
151	}
152}
153
154#[cfg(feature = "sessions")]
155#[async_trait]
156impl<A: AuthenticationBackend + 'static> Middleware for RemoteUserMiddleware<A> {
157	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
158		let remote_user = request
159			.headers
160			.get(&self.header_name)
161			.and_then(|v| v.to_str().ok())
162			.map(|s| s.to_string());
163
164		if let Some(username) = remote_user {
165			if let Some(user) = self.get_user_by_name(&username).await {
166				Self::insert_user_extensions(&request, user.as_ref());
167			} else {
168				request.extensions.insert(AuthState::anonymous());
169			}
170		} else if self.force_logout_if_no_header {
171			// No header and force logout: set anonymous
172			request.extensions.insert(AuthState::anonymous());
173		}
174		// If !force_logout_if_no_header and no header: don't touch
175		// extensions, preserve existing auth state from upstream middleware.
176
177		next.handle(request).await
178	}
179}
180
181/// Persistent remote user authentication middleware.
182///
183/// A variant of [`RemoteUserMiddleware`] that preserves the existing
184/// session authentication when the remote user header is absent. This
185/// is the Rust equivalent of Django's
186/// [`PersistentRemoteUserMiddleware`](https://docs.djangoproject.com/en/5.1/ref/middleware/#django.contrib.auth.middleware.PersistentRemoteUserMiddleware).
187///
188/// Use this middleware when the reverse proxy may not always set the
189/// header (e.g., only on initial login pages) and you want to keep
190/// the user authenticated via session for subsequent requests.
191///
192/// # Examples
193///
194/// ```rust,no_run
195/// use std::sync::Arc;
196/// use reinhardt_middleware::PersistentRemoteUserMiddleware;
197/// use reinhardt_http::MiddlewareChain;
198/// # use reinhardt_http::{Handler, Request, Response, Result};
199/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User};
200/// # use async_trait::async_trait;
201/// # struct MyHandler;
202/// # #[async_trait]
203/// # impl Handler for MyHandler {
204/// #     async fn handle(&self, _request: Request) -> Result<Response> {
205/// #         Ok(Response::ok())
206/// #     }
207/// # }
208/// # struct MyAuthBackend;
209/// # #[async_trait]
210/// # impl AuthenticationBackend for MyAuthBackend {
211/// #     async fn authenticate(&self, _req: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
212/// #     async fn get_user(&self, _uid: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> { Ok(None) }
213/// # }
214/// # let handler = Arc::new(MyHandler);
215///
216/// let auth_backend = Arc::new(MyAuthBackend);
217/// let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
218///
219/// let app = MiddlewareChain::new(handler)
220///     .with_middleware(Arc::new(middleware));
221/// ```
222#[cfg(feature = "sessions")]
223pub struct PersistentRemoteUserMiddleware<A: AuthenticationBackend> {
224	inner: RemoteUserMiddleware<A>,
225}
226
227#[cfg(feature = "sessions")]
228impl<A: AuthenticationBackend> PersistentRemoteUserMiddleware<A> {
229	/// Creates a new persistent remote user middleware.
230	///
231	/// Unlike [`RemoteUserMiddleware`], this middleware does not clear
232	/// authentication when the header is absent.
233	///
234	/// # Arguments
235	///
236	/// * `auth_backend` - Authentication backend for user lookup
237	pub fn new(auth_backend: Arc<A>) -> Self {
238		Self {
239			inner: RemoteUserMiddleware {
240				auth_backend,
241				header_name: REMOTE_USER_HEADER.to_string(),
242				force_logout_if_no_header: false,
243			},
244		}
245	}
246
247	/// Sets a custom header name for remote user identification.
248	pub fn with_header(mut self, header_name: &str) -> Self {
249		self.inner.header_name = header_name.to_string();
250		self
251	}
252}
253
254#[cfg(feature = "sessions")]
255#[async_trait]
256impl<A: AuthenticationBackend + 'static> Middleware for PersistentRemoteUserMiddleware<A> {
257	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
258		self.inner.process(request, next).await
259	}
260}
261
262#[cfg(all(test, feature = "sessions"))]
263mod tests {
264	use super::*;
265	use bytes::Bytes;
266	use hyper::{HeaderMap, Method, Version};
267	use reinhardt_auth::{AuthenticationError, SimpleUser};
268	use reinhardt_http::{AuthState, Handler, Middleware, Request, Response};
269	use rstest::rstest;
270	use uuid::Uuid;
271
272	struct TestHandler;
273
274	#[async_trait::async_trait]
275	impl Handler for TestHandler {
276		async fn handle(&self, request: Request) -> Result<Response> {
277			let auth_state = request.extensions.get::<AuthState>();
278			Ok(Response::ok().with_json(&serde_json::json!({
279				"is_authenticated": auth_state.as_ref().map(|s| s.is_authenticated()).unwrap_or(false),
280				"user_id": auth_state.as_ref().map(|s| s.user_id().to_string()).unwrap_or_default(),
281			}))?)
282		}
283	}
284
285	struct TestAuthBackend {
286		user: Option<SimpleUser>,
287	}
288
289	#[async_trait::async_trait]
290	impl AuthenticationBackend for TestAuthBackend {
291		async fn authenticate(
292			&self,
293			_request: &Request,
294		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
295			Ok(self
296				.user
297				.as_ref()
298				.map(|u| Box::new(u.clone()) as Box<dyn User>))
299		}
300
301		async fn get_user(
302			&self,
303			_user_id: &str,
304		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
305			Ok(self
306				.user
307				.as_ref()
308				.map(|u| Box::new(u.clone()) as Box<dyn User>))
309		}
310	}
311
312	fn test_user() -> SimpleUser {
313		SimpleUser {
314			id: Uuid::now_v7(),
315			username: "proxy-user".to_string(),
316			email: "proxy@example.com".to_string(),
317			is_active: true,
318			is_admin: false,
319			is_staff: false,
320			is_superuser: false,
321		}
322	}
323
324	fn create_request_with_header(name: &'static str, value: &str) -> Request {
325		let mut headers = HeaderMap::new();
326		headers.insert(name, value.parse().unwrap());
327		Request::builder()
328			.method(Method::GET)
329			.uri("/test")
330			.version(Version::HTTP_11)
331			.headers(headers)
332			.body(Bytes::new())
333			.build()
334			.unwrap()
335	}
336
337	fn create_request_without_header() -> Request {
338		Request::builder()
339			.method(Method::GET)
340			.uri("/test")
341			.version(Version::HTTP_11)
342			.headers(HeaderMap::new())
343			.body(Bytes::new())
344			.build()
345			.unwrap()
346	}
347
348	#[rstest]
349	#[tokio::test]
350	async fn test_remote_user_header_authenticates_user() {
351		// Arrange
352		let user = test_user();
353		let expected_id = user.id.to_string();
354		let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
355		let middleware = RemoteUserMiddleware::new(auth_backend);
356		let handler = Arc::new(TestHandler);
357		let request = create_request_with_header("REMOTE_USER", "proxy-user");
358
359		// Act
360		let response = middleware.process(request, handler).await.unwrap();
361
362		// Assert
363		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
364		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
365		assert_eq!(body["is_authenticated"], true);
366		assert_eq!(body["user_id"], expected_id);
367	}
368
369	#[rstest]
370	#[tokio::test]
371	async fn test_missing_header_produces_anonymous() {
372		// Arrange
373		let auth_backend = Arc::new(TestAuthBackend {
374			user: Some(test_user()),
375		});
376		let middleware = RemoteUserMiddleware::new(auth_backend);
377		let handler = Arc::new(TestHandler);
378		let request = create_request_without_header();
379
380		// Act
381		let response = middleware.process(request, handler).await.unwrap();
382
383		// Assert
384		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
385		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
386		assert_eq!(body["is_authenticated"], false);
387	}
388
389	#[rstest]
390	#[tokio::test]
391	async fn test_unknown_user_produces_anonymous() {
392		// Arrange
393		let auth_backend = Arc::new(TestAuthBackend { user: None });
394		let middleware = RemoteUserMiddleware::new(auth_backend);
395		let handler = Arc::new(TestHandler);
396		let request = create_request_with_header("REMOTE_USER", "unknown-user");
397
398		// Act
399		let response = middleware.process(request, handler).await.unwrap();
400
401		// Assert
402		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
403		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
404		assert_eq!(body["is_authenticated"], false);
405	}
406
407	#[rstest]
408	#[tokio::test]
409	async fn test_custom_header_name() {
410		// Arrange
411		let user = test_user();
412		let expected_id = user.id.to_string();
413		let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
414		let middleware = RemoteUserMiddleware::new(auth_backend).with_header("X-Forwarded-User");
415		let handler = Arc::new(TestHandler);
416		let request = create_request_with_header("X-Forwarded-User", "proxy-user");
417
418		// Act
419		let response = middleware.process(request, handler).await.unwrap();
420
421		// Assert
422		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
423		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
424		assert_eq!(body["is_authenticated"], true);
425		assert_eq!(body["user_id"], expected_id);
426	}
427
428	#[rstest]
429	#[tokio::test]
430	async fn test_persistent_middleware_preserves_auth_when_no_header() {
431		// Arrange
432		let auth_backend = Arc::new(TestAuthBackend {
433			user: Some(test_user()),
434		});
435		let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
436		let handler = Arc::new(TestHandler);
437
438		// Pre-insert an authenticated state to simulate upstream auth
439		let request = create_request_without_header();
440		request
441			.extensions
442			.insert(AuthState::authenticated("existing-user", false, true));
443
444		// Act
445		let response = middleware.process(request, handler).await.unwrap();
446
447		// Assert - the existing auth state should be preserved
448		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
449		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
450		assert_eq!(body["is_authenticated"], true);
451		assert_eq!(body["user_id"], "existing-user");
452	}
453
454	#[rstest]
455	#[tokio::test]
456	async fn test_persistent_middleware_authenticates_when_header_present() {
457		// Arrange
458		let user = test_user();
459		let expected_id = user.id.to_string();
460		let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
461		let middleware = PersistentRemoteUserMiddleware::new(auth_backend);
462		let handler = Arc::new(TestHandler);
463		let request = create_request_with_header("REMOTE_USER", "proxy-user");
464
465		// Act
466		let response = middleware.process(request, handler).await.unwrap();
467
468		// Assert
469		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
470		let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
471		assert_eq!(body["is_authenticated"], true);
472		assert_eq!(body["user_id"], expected_id);
473	}
474}