Skip to main content

reinhardt_auth/
rest_authentication.rs

1//! REST API Authentication
2//!
3//! Provides REST API-compatible authentication wrappers and combinators.
4
5#[cfg(feature = "argon2-hasher")]
6use crate::DefaultUser;
7use crate::sessions::{Session, backends::SessionBackend};
8use crate::{AuthenticationBackend, AuthenticationError, SimpleUser, User};
9use reinhardt_http::Request;
10use std::sync::Arc;
11
12/// REST API authentication trait wrapper
13///
14/// Provides a REST API-compatible interface for authentication.
15#[async_trait::async_trait]
16pub trait RestAuthentication: Send + Sync {
17	/// Authenticate a request and return a user if successful
18	async fn authenticate(
19		&self,
20		request: &Request,
21	) -> Result<Option<Box<dyn User>>, AuthenticationError>;
22}
23
24/// Basic authentication configuration
25#[derive(Debug, Clone)]
26pub struct BasicAuthConfig {
27	/// Realm for WWW-Authenticate header
28	pub realm: String,
29}
30
31impl Default for BasicAuthConfig {
32	fn default() -> Self {
33		Self {
34			realm: "api".to_string(),
35		}
36	}
37}
38
39/// Session authentication configuration
40#[derive(Debug, Clone)]
41pub struct SessionAuthConfig {
42	/// Session cookie name
43	pub cookie_name: String,
44	/// Whether to enforce CSRF protection
45	pub enforce_csrf: bool,
46}
47
48impl Default for SessionAuthConfig {
49	fn default() -> Self {
50		Self {
51			cookie_name: "sessionid".to_string(),
52			enforce_csrf: true,
53		}
54	}
55}
56
57/// Token authentication configuration
58#[derive(Debug, Clone)]
59pub struct TokenAuthConfig {
60	/// Token header name (default: "Authorization")
61	pub header_name: String,
62	/// Token prefix (default: "Token")
63	pub prefix: String,
64}
65
66impl Default for TokenAuthConfig {
67	fn default() -> Self {
68		Self {
69			header_name: "Authorization".to_string(),
70			prefix: "Token".to_string(),
71		}
72	}
73}
74
75/// Composite authentication backend
76///
77/// Tries multiple authentication methods in sequence, similar to Django REST Framework.
78///
79/// # Examples
80///
81/// ```
82/// use reinhardt_auth::{CompositeAuthentication, SessionAuthentication, TokenAuthentication};
83/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
84///
85/// let session_backend = InMemorySessionBackend::new();
86/// let auth = CompositeAuthentication::new()
87///     .with_backend(SessionAuthentication::new(session_backend))
88///     .with_backend(TokenAuthentication::new());
89/// ```
90pub struct CompositeAuthentication {
91	backends: Vec<Arc<dyn AuthenticationBackend>>,
92}
93
94impl CompositeAuthentication {
95	/// Create a new composite authentication backend
96	///
97	/// # Examples
98	///
99	/// ```
100	/// use reinhardt_auth::CompositeAuthentication;
101	///
102	/// let auth = CompositeAuthentication::new();
103	/// ```
104	pub fn new() -> Self {
105		Self {
106			backends: Vec::new(),
107		}
108	}
109
110	/// Add an authentication backend (chainable)
111	///
112	/// Backends are tried in the order they are added.
113	/// The backend will be wrapped in an Arc internally.
114	///
115	/// # Examples
116	///
117	/// ```
118	/// use reinhardt_auth::{CompositeAuthentication, TokenAuthentication};
119	///
120	/// let auth = CompositeAuthentication::new()
121	///     .with_backend(TokenAuthentication::new());
122	/// ```
123	pub fn with_backend<B: AuthenticationBackend + 'static>(mut self, backend: B) -> Self {
124		self.backends.push(Arc::new(backend));
125		self
126	}
127
128	/// Add multiple backends at once (chainable)
129	pub fn with_backends(mut self, backends: Vec<Arc<dyn AuthenticationBackend>>) -> Self {
130		self.backends.extend(backends);
131		self
132	}
133}
134
135impl Default for CompositeAuthentication {
136	fn default() -> Self {
137		Self::new()
138	}
139}
140
141#[async_trait::async_trait]
142impl RestAuthentication for CompositeAuthentication {
143	async fn authenticate(
144		&self,
145		request: &Request,
146	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
147		// Try each backend in order
148		for backend in &self.backends {
149			match backend.authenticate(request).await {
150				Ok(Some(user)) => return Ok(Some(user)),
151				Ok(None) => continue,
152				Err(e) => {
153					// Log error but continue to next backend
154					tracing::warn!("Authentication backend error occurred");
155					tracing::debug!(error = %e, "Authentication backend error details");
156					continue;
157				}
158			}
159		}
160		Ok(None)
161	}
162}
163
164#[async_trait::async_trait]
165impl AuthenticationBackend for CompositeAuthentication {
166	async fn authenticate(
167		&self,
168		request: &Request,
169	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
170		<Self as RestAuthentication>::authenticate(self, request).await
171	}
172
173	async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
174		// Try each backend in order until one succeeds
175		// This is a fallback approach since we don't track which backend authenticated the user
176		for backend in &self.backends {
177			match backend.get_user(user_id).await {
178				Ok(Some(user)) => return Ok(Some(user)),
179				Ok(None) => continue,
180				Err(e) => {
181					// Log error but continue to next backend
182					tracing::warn!("get_user backend error occurred");
183					tracing::debug!(error = %e, "get_user backend error details");
184					continue;
185				}
186			}
187		}
188		Ok(None)
189	}
190}
191
192/// Token authentication using custom tokens
193pub struct TokenAuthentication {
194	/// Token store (token -> user_id)
195	tokens: std::collections::HashMap<String, String>,
196	/// Configuration
197	config: TokenAuthConfig,
198}
199
200impl TokenAuthentication {
201	/// Create a new token authentication backend
202	///
203	/// # Examples
204	///
205	/// ```
206	/// use reinhardt_auth::TokenAuthentication;
207	///
208	/// let auth = TokenAuthentication::new();
209	/// ```
210	pub fn new() -> Self {
211		Self {
212			tokens: std::collections::HashMap::new(),
213			config: TokenAuthConfig::default(),
214		}
215	}
216
217	/// Create with custom configuration
218	pub fn with_config(config: TokenAuthConfig) -> Self {
219		Self {
220			tokens: std::collections::HashMap::new(),
221			config,
222		}
223	}
224
225	/// Add a token for a user
226	pub fn add_token(&mut self, token: impl Into<String>, user_id: impl Into<String>) {
227		self.tokens.insert(token.into(), user_id.into());
228	}
229}
230
231impl Default for TokenAuthentication {
232	fn default() -> Self {
233		Self::new()
234	}
235}
236
237#[async_trait::async_trait]
238impl RestAuthentication for TokenAuthentication {
239	async fn authenticate(
240		&self,
241		request: &Request,
242	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
243		let auth_header = request
244			.headers
245			.get(&self.config.header_name)
246			.and_then(|h| h.to_str().ok());
247
248		if let Some(header) = auth_header {
249			let prefix = format!("{} ", self.config.prefix);
250			if let Some(token) = header.strip_prefix(&prefix)
251				&& let Some(user_id) = self.tokens.get(token)
252			{
253				// Try to parse user_id as UUID, or generate a new one if it fails
254				let id = uuid::Uuid::parse_str(user_id).unwrap_or_else(|_| uuid::Uuid::new_v4());
255				return Ok(Some(Box::new(SimpleUser {
256					id,
257					username: user_id.clone(),
258					email: format!("{}@example.com", user_id),
259					is_active: true,
260					is_admin: false,
261					is_staff: false,
262					is_superuser: false,
263				})));
264			}
265		}
266
267		Ok(None)
268	}
269}
270
271#[async_trait::async_trait]
272impl AuthenticationBackend for TokenAuthentication {
273	async fn authenticate(
274		&self,
275		request: &Request,
276	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
277		<Self as RestAuthentication>::authenticate(self, request).await
278	}
279
280	async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
281		if self.tokens.values().any(|id| id == user_id) {
282			// Try to parse user_id as UUID, or generate a new one if it fails
283			let id = uuid::Uuid::parse_str(user_id).unwrap_or_else(|_| uuid::Uuid::new_v4());
284			Ok(Some(Box::new(SimpleUser {
285				id,
286				username: user_id.to_string(),
287				email: format!("{}@example.com", user_id),
288				is_active: true,
289				is_admin: false,
290				is_staff: false,
291				is_superuser: false,
292			})))
293		} else {
294			Ok(None)
295		}
296	}
297}
298
299/// Remote user authentication (from upstream proxy)
300pub struct RemoteUserAuthentication {
301	/// Header name to check
302	header_name: String,
303}
304
305impl RemoteUserAuthentication {
306	/// Create a new remote user authentication backend
307	pub fn new() -> Self {
308		Self {
309			header_name: "REMOTE_USER".to_string(),
310		}
311	}
312
313	/// Set custom header name
314	pub fn with_header(mut self, header: impl Into<String>) -> Self {
315		self.header_name = header.into();
316		self
317	}
318}
319
320impl Default for RemoteUserAuthentication {
321	fn default() -> Self {
322		Self::new()
323	}
324}
325
326#[async_trait::async_trait]
327impl RestAuthentication for RemoteUserAuthentication {
328	async fn authenticate(
329		&self,
330		request: &Request,
331	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
332		let header_value = request
333			.headers
334			.get(&self.header_name)
335			.and_then(|v| v.to_str().ok());
336
337		if let Some(username) = header_value
338			&& !username.is_empty()
339		{
340			return Ok(Some(Box::new(SimpleUser {
341				id: uuid::Uuid::new_v4(),
342				username: username.to_string(),
343				email: format!("{}@example.com", username),
344				is_active: true,
345				is_admin: false,
346				is_staff: false,
347				is_superuser: false,
348			})));
349		}
350
351		Ok(None)
352	}
353}
354
355#[async_trait::async_trait]
356impl AuthenticationBackend for RemoteUserAuthentication {
357	async fn authenticate(
358		&self,
359		request: &Request,
360	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
361		<Self as RestAuthentication>::authenticate(self, request).await
362	}
363
364	async fn get_user(&self, _user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
365		Ok(None)
366	}
367}
368
369/// Session-based authentication
370#[derive(Clone)]
371pub struct SessionAuthentication<B: SessionBackend> {
372	/// Configuration
373	config: SessionAuthConfig,
374	/// Session backend for loading session data
375	session_backend: B,
376}
377
378impl<B: SessionBackend> SessionAuthentication<B> {
379	/// Create a new session authentication backend
380	pub fn new(session_backend: B) -> Self {
381		Self {
382			config: SessionAuthConfig::default(),
383			session_backend,
384		}
385	}
386
387	/// Create with custom configuration
388	pub fn with_config(config: SessionAuthConfig, session_backend: B) -> Self {
389		Self {
390			config,
391			session_backend,
392		}
393	}
394}
395
396impl<B: SessionBackend + Default> Default for SessionAuthentication<B> {
397	fn default() -> Self {
398		Self::new(B::default())
399	}
400}
401
402#[async_trait::async_trait]
403impl<B: SessionBackend> RestAuthentication for SessionAuthentication<B> {
404	async fn authenticate(
405		&self,
406		request: &Request,
407	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
408		// Check for session cookie
409		let cookie_header = request.headers.get("Cookie").and_then(|h| h.to_str().ok());
410
411		if let Some(cookies) = cookie_header {
412			for cookie in cookies.split(';') {
413				let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
414				if parts.len() == 2 && parts[0] == self.config.cookie_name {
415					let session_key = parts[1];
416
417					// Load session from backend
418					let mut session =
419						Session::from_key(self.session_backend.clone(), session_key.to_string())
420							.await
421							.map_err(|_| AuthenticationError::SessionExpired)?;
422
423					// Get user ID from session
424					let user_id: String = match session.get("_auth_user_id") {
425						Ok(Some(id)) => id,
426						Ok(None) => return Ok(None), // No user in session
427						Err(_) => return Err(AuthenticationError::SessionExpired),
428					};
429
430					// Get additional user fields from session
431					let username: String = session
432						.get("_auth_user_name")
433						.ok()
434						.flatten()
435						.unwrap_or_else(|| user_id.clone());
436					let email: String = session
437						.get("_auth_user_email")
438						.ok()
439						.flatten()
440						.unwrap_or_default();
441					let is_active: bool = session
442						.get("_auth_user_is_active")
443						.ok()
444						.flatten()
445						.unwrap_or(true);
446					let is_admin: bool = session
447						.get("_auth_user_is_admin")
448						.ok()
449						.flatten()
450						.unwrap_or(false);
451					let is_staff: bool = session
452						.get("_auth_user_is_staff")
453						.ok()
454						.flatten()
455						.unwrap_or(false);
456					let is_superuser: bool = session
457						.get("_auth_user_is_superuser")
458						.ok()
459						.flatten()
460						.unwrap_or(false);
461
462					// Create user from session data
463					let user = SimpleUser {
464						id: uuid::Uuid::parse_str(&user_id)
465							.map_err(|_| AuthenticationError::InvalidCredentials)?,
466						username,
467						email,
468						is_active,
469						is_admin,
470						is_staff,
471						is_superuser,
472					};
473
474					return Ok(Some(Box::new(user)));
475				}
476			}
477		}
478
479		Ok(None)
480	}
481}
482
483#[async_trait::async_trait]
484impl<B: SessionBackend> AuthenticationBackend for SessionAuthentication<B> {
485	async fn authenticate(
486		&self,
487		request: &Request,
488	) -> Result<Option<Box<dyn User>>, AuthenticationError> {
489		<Self as RestAuthentication>::authenticate(self, request).await
490	}
491
492	#[cfg(feature = "argon2-hasher")]
493	async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
494		// Parse user_id as UUID
495		let id =
496			uuid::Uuid::parse_str(user_id).map_err(|_| AuthenticationError::InvalidCredentials)?;
497
498		// Get database connection
499		let conn = reinhardt_db::orm::manager::get_connection()
500			.await
501			.map_err(|e| AuthenticationError::DatabaseError(e.to_string()))?;
502
503		// Build SQL query using reinhardt-query for type-safe query construction
504		use reinhardt_db::orm::{
505			Alias, DatabaseBackend, Expr, ExprTrait, Model, MySqlQueryBuilder,
506			PostgresQueryBuilder, Query, QueryStatementBuilder, SqliteQueryBuilder,
507		};
508
509		let table_name = DefaultUser::table_name();
510
511		// Build SELECT query using reinhardt-query
512		let stmt = Query::select()
513			.columns([
514				Alias::new("id"),
515				Alias::new("username"),
516				Alias::new("email"),
517				Alias::new("first_name"),
518				Alias::new("last_name"),
519				Alias::new("password_hash"),
520				Alias::new("last_login"),
521				Alias::new("is_active"),
522				Alias::new("is_staff"),
523				Alias::new("is_superuser"),
524				Alias::new("date_joined"),
525				Alias::new("user_permissions"),
526				Alias::new("groups"),
527			])
528			.from(Alias::new(table_name))
529			.and_where(Expr::col(Alias::new("id")).eq(Expr::value(id.to_string())))
530			.to_owned();
531
532		let sql = match conn.backend() {
533			DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
534			DatabaseBackend::MySql => stmt.to_string(MySqlQueryBuilder),
535			DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
536		};
537
538		// Execute query
539		let row = conn
540			.query_one(&sql, vec![])
541			.await
542			.map_err(|e| AuthenticationError::DatabaseError(e.to_string()))?;
543
544		// Deserialize to DefaultUser
545		let user: DefaultUser = serde_json::from_value(row.data).map_err(|e| {
546			AuthenticationError::DatabaseError(format!("Deserialization failed: {}", e))
547		})?;
548
549		// Return as trait object
550		Ok(Some(Box::new(user)))
551	}
552
553	#[cfg(not(feature = "argon2-hasher"))]
554	async fn get_user(&self, _user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
555		// When argon2-hasher feature is disabled, DefaultUser is not available
556		// Return None to indicate user retrieval is not supported
557		Ok(None)
558	}
559}
560
561#[cfg(test)]
562mod tests {
563	use super::*;
564	#[cfg(feature = "jwt")]
565	use crate::basic::BasicAuthentication;
566	use bytes::Bytes;
567	use hyper::{HeaderMap, Method};
568
569	#[tokio::test]
570	#[cfg(feature = "jwt")]
571	async fn test_composite_authentication() {
572		let mut basic = BasicAuthentication::new();
573		basic.add_user("user1", "pass1");
574
575		let composite = CompositeAuthentication::new().with_backend(basic);
576
577		// Test with basic auth
578		let mut headers = HeaderMap::new();
579		headers.insert(
580			"Authorization",
581			"Basic dXNlcjE6cGFzczE=".parse().unwrap(), // user1:pass1
582		);
583
584		let request = Request::builder()
585			.method(Method::GET)
586			.uri("/")
587			.headers(headers)
588			.body(Bytes::new())
589			.build()
590			.unwrap();
591
592		let result = RestAuthentication::authenticate(&composite, &request)
593			.await
594			.unwrap();
595		assert!(result.is_some());
596		assert_eq!(result.unwrap().get_username(), "user1");
597	}
598
599	#[tokio::test]
600	async fn test_token_authentication() {
601		let mut auth = TokenAuthentication::new();
602		auth.add_token("secret_token", "alice");
603
604		let mut headers = HeaderMap::new();
605		headers.insert("Authorization", "Token secret_token".parse().unwrap());
606
607		let request = Request::builder()
608			.method(Method::GET)
609			.uri("/")
610			.headers(headers)
611			.body(Bytes::new())
612			.build()
613			.unwrap();
614
615		let result = RestAuthentication::authenticate(&auth, &request)
616			.await
617			.unwrap();
618		assert!(result.is_some());
619		assert_eq!(result.unwrap().get_username(), "alice");
620	}
621
622	#[tokio::test]
623	async fn test_remote_user_authentication() {
624		let auth = RemoteUserAuthentication::new();
625
626		let mut headers = HeaderMap::new();
627		headers.insert("REMOTE_USER", "bob".parse().unwrap());
628
629		let request = Request::builder()
630			.method(Method::GET)
631			.uri("/")
632			.headers(headers)
633			.body(Bytes::new())
634			.build()
635			.unwrap();
636
637		let result = RestAuthentication::authenticate(&auth, &request)
638			.await
639			.unwrap();
640		assert!(result.is_some());
641		assert_eq!(result.unwrap().get_username(), "bob");
642	}
643
644	#[tokio::test]
645	async fn test_session_authentication() {
646		use crate::sessions::InMemorySessionBackend;
647		use crate::sessions::Session;
648
649		let session_backend = InMemorySessionBackend::new();
650
651		// Create a session with user data
652		let mut session = Session::new(session_backend.clone());
653		session
654			.set("_auth_user_id", "550e8400-e29b-41d4-a716-446655440000")
655			.unwrap();
656		session.set("_auth_user_name", "testuser").unwrap();
657		session.set("_auth_user_email", "test@example.com").unwrap();
658		session.set("_auth_user_is_active", true).unwrap();
659		session.save().await.unwrap();
660
661		// Get the generated session key
662		let session_key = session.get_or_create_key().to_string();
663
664		let auth = SessionAuthentication::new(session_backend);
665
666		let mut headers = HeaderMap::new();
667		let cookie_value = format!("sessionid={}", session_key);
668		headers.insert("Cookie", cookie_value.parse().unwrap());
669
670		let request = Request::builder()
671			.method(Method::GET)
672			.uri("/")
673			.headers(headers)
674			.body(Bytes::new())
675			.build()
676			.unwrap();
677
678		let result = RestAuthentication::authenticate(&auth, &request)
679			.await
680			.unwrap();
681		assert!(result.is_some());
682
683		// Verify the authenticated user
684		let user = result.unwrap();
685		assert_eq!(user.get_username(), "testuser");
686	}
687
688	#[tokio::test]
689	async fn test_custom_token_config() {
690		let config = TokenAuthConfig {
691			header_name: "X-API-Key".to_string(),
692			prefix: "Bearer".to_string(),
693		};
694
695		let mut auth = TokenAuthentication::with_config(config);
696		auth.add_token("my_token", "charlie");
697
698		let mut headers = HeaderMap::new();
699		headers.insert("X-API-Key", "Bearer my_token".parse().unwrap());
700
701		let request = Request::builder()
702			.method(Method::GET)
703			.uri("/")
704			.headers(headers)
705			.body(Bytes::new())
706			.build()
707			.unwrap();
708
709		let result = RestAuthentication::authenticate(&auth, &request)
710			.await
711			.unwrap();
712		assert!(result.is_some());
713		assert_eq!(result.unwrap().get_username(), "charlie");
714	}
715}