1#[cfg(feature = "argon2-hasher")]
6use crate::DefaultUser;
7use crate::sessions::{Session, backends::SessionBackend};
8use crate::{AuthenticationBackend, AuthenticationError, SimpleUser, User};
9use reinhardt_http::Request;
10use std::sync::Arc;
11
12#[async_trait::async_trait]
16pub trait RestAuthentication: Send + Sync {
17 async fn authenticate(
19 &self,
20 request: &Request,
21 ) -> Result<Option<Box<dyn User>>, AuthenticationError>;
22}
23
24#[derive(Debug, Clone)]
26pub struct BasicAuthConfig {
27 pub realm: String,
29}
30
31impl Default for BasicAuthConfig {
32 fn default() -> Self {
33 Self {
34 realm: "api".to_string(),
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct SessionAuthConfig {
42 pub cookie_name: String,
44 pub enforce_csrf: bool,
46}
47
48impl Default for SessionAuthConfig {
49 fn default() -> Self {
50 Self {
51 cookie_name: "sessionid".to_string(),
52 enforce_csrf: true,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct TokenAuthConfig {
60 pub header_name: String,
62 pub prefix: String,
64}
65
66impl Default for TokenAuthConfig {
67 fn default() -> Self {
68 Self {
69 header_name: "Authorization".to_string(),
70 prefix: "Token".to_string(),
71 }
72 }
73}
74
75pub struct CompositeAuthentication {
91 backends: Vec<Arc<dyn AuthenticationBackend>>,
92}
93
94impl CompositeAuthentication {
95 pub fn new() -> Self {
105 Self {
106 backends: Vec::new(),
107 }
108 }
109
110 pub fn with_backend<B: AuthenticationBackend + 'static>(mut self, backend: B) -> Self {
124 self.backends.push(Arc::new(backend));
125 self
126 }
127
128 pub fn with_backends(mut self, backends: Vec<Arc<dyn AuthenticationBackend>>) -> Self {
130 self.backends.extend(backends);
131 self
132 }
133}
134
135impl Default for CompositeAuthentication {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141#[async_trait::async_trait]
142impl RestAuthentication for CompositeAuthentication {
143 async fn authenticate(
144 &self,
145 request: &Request,
146 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
147 for backend in &self.backends {
149 match backend.authenticate(request).await {
150 Ok(Some(user)) => return Ok(Some(user)),
151 Ok(None) => continue,
152 Err(e) => {
153 tracing::warn!("Authentication backend error occurred");
155 tracing::debug!(error = %e, "Authentication backend error details");
156 continue;
157 }
158 }
159 }
160 Ok(None)
161 }
162}
163
164#[async_trait::async_trait]
165impl AuthenticationBackend for CompositeAuthentication {
166 async fn authenticate(
167 &self,
168 request: &Request,
169 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
170 <Self as RestAuthentication>::authenticate(self, request).await
171 }
172
173 async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
174 for backend in &self.backends {
177 match backend.get_user(user_id).await {
178 Ok(Some(user)) => return Ok(Some(user)),
179 Ok(None) => continue,
180 Err(e) => {
181 tracing::warn!("get_user backend error occurred");
183 tracing::debug!(error = %e, "get_user backend error details");
184 continue;
185 }
186 }
187 }
188 Ok(None)
189 }
190}
191
192pub struct TokenAuthentication {
194 tokens: std::collections::HashMap<String, String>,
196 config: TokenAuthConfig,
198}
199
200impl TokenAuthentication {
201 pub fn new() -> Self {
211 Self {
212 tokens: std::collections::HashMap::new(),
213 config: TokenAuthConfig::default(),
214 }
215 }
216
217 pub fn with_config(config: TokenAuthConfig) -> Self {
219 Self {
220 tokens: std::collections::HashMap::new(),
221 config,
222 }
223 }
224
225 pub fn add_token(&mut self, token: impl Into<String>, user_id: impl Into<String>) {
227 self.tokens.insert(token.into(), user_id.into());
228 }
229}
230
231impl Default for TokenAuthentication {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[async_trait::async_trait]
238impl RestAuthentication for TokenAuthentication {
239 async fn authenticate(
240 &self,
241 request: &Request,
242 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
243 let auth_header = request
244 .headers
245 .get(&self.config.header_name)
246 .and_then(|h| h.to_str().ok());
247
248 if let Some(header) = auth_header {
249 let prefix = format!("{} ", self.config.prefix);
250 if let Some(token) = header.strip_prefix(&prefix)
251 && let Some(user_id) = self.tokens.get(token)
252 {
253 let id = uuid::Uuid::parse_str(user_id).unwrap_or_else(|_| uuid::Uuid::new_v4());
255 return Ok(Some(Box::new(SimpleUser {
256 id,
257 username: user_id.clone(),
258 email: format!("{}@example.com", user_id),
259 is_active: true,
260 is_admin: false,
261 is_staff: false,
262 is_superuser: false,
263 })));
264 }
265 }
266
267 Ok(None)
268 }
269}
270
271#[async_trait::async_trait]
272impl AuthenticationBackend for TokenAuthentication {
273 async fn authenticate(
274 &self,
275 request: &Request,
276 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
277 <Self as RestAuthentication>::authenticate(self, request).await
278 }
279
280 async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
281 if self.tokens.values().any(|id| id == user_id) {
282 let id = uuid::Uuid::parse_str(user_id).unwrap_or_else(|_| uuid::Uuid::new_v4());
284 Ok(Some(Box::new(SimpleUser {
285 id,
286 username: user_id.to_string(),
287 email: format!("{}@example.com", user_id),
288 is_active: true,
289 is_admin: false,
290 is_staff: false,
291 is_superuser: false,
292 })))
293 } else {
294 Ok(None)
295 }
296 }
297}
298
299pub struct RemoteUserAuthentication {
301 header_name: String,
303}
304
305impl RemoteUserAuthentication {
306 pub fn new() -> Self {
308 Self {
309 header_name: "REMOTE_USER".to_string(),
310 }
311 }
312
313 pub fn with_header(mut self, header: impl Into<String>) -> Self {
315 self.header_name = header.into();
316 self
317 }
318}
319
320impl Default for RemoteUserAuthentication {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326#[async_trait::async_trait]
327impl RestAuthentication for RemoteUserAuthentication {
328 async fn authenticate(
329 &self,
330 request: &Request,
331 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
332 let header_value = request
333 .headers
334 .get(&self.header_name)
335 .and_then(|v| v.to_str().ok());
336
337 if let Some(username) = header_value
338 && !username.is_empty()
339 {
340 return Ok(Some(Box::new(SimpleUser {
341 id: uuid::Uuid::new_v4(),
342 username: username.to_string(),
343 email: format!("{}@example.com", username),
344 is_active: true,
345 is_admin: false,
346 is_staff: false,
347 is_superuser: false,
348 })));
349 }
350
351 Ok(None)
352 }
353}
354
355#[async_trait::async_trait]
356impl AuthenticationBackend for RemoteUserAuthentication {
357 async fn authenticate(
358 &self,
359 request: &Request,
360 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
361 <Self as RestAuthentication>::authenticate(self, request).await
362 }
363
364 async fn get_user(&self, _user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
365 Ok(None)
366 }
367}
368
369#[derive(Clone)]
371pub struct SessionAuthentication<B: SessionBackend> {
372 config: SessionAuthConfig,
374 session_backend: B,
376}
377
378impl<B: SessionBackend> SessionAuthentication<B> {
379 pub fn new(session_backend: B) -> Self {
381 Self {
382 config: SessionAuthConfig::default(),
383 session_backend,
384 }
385 }
386
387 pub fn with_config(config: SessionAuthConfig, session_backend: B) -> Self {
389 Self {
390 config,
391 session_backend,
392 }
393 }
394}
395
396impl<B: SessionBackend + Default> Default for SessionAuthentication<B> {
397 fn default() -> Self {
398 Self::new(B::default())
399 }
400}
401
402#[async_trait::async_trait]
403impl<B: SessionBackend> RestAuthentication for SessionAuthentication<B> {
404 async fn authenticate(
405 &self,
406 request: &Request,
407 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
408 let cookie_header = request.headers.get("Cookie").and_then(|h| h.to_str().ok());
410
411 if let Some(cookies) = cookie_header {
412 for cookie in cookies.split(';') {
413 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
414 if parts.len() == 2 && parts[0] == self.config.cookie_name {
415 let session_key = parts[1];
416
417 let mut session =
419 Session::from_key(self.session_backend.clone(), session_key.to_string())
420 .await
421 .map_err(|_| AuthenticationError::SessionExpired)?;
422
423 let user_id: String = match session.get("_auth_user_id") {
425 Ok(Some(id)) => id,
426 Ok(None) => return Ok(None), Err(_) => return Err(AuthenticationError::SessionExpired),
428 };
429
430 let username: String = session
432 .get("_auth_user_name")
433 .ok()
434 .flatten()
435 .unwrap_or_else(|| user_id.clone());
436 let email: String = session
437 .get("_auth_user_email")
438 .ok()
439 .flatten()
440 .unwrap_or_default();
441 let is_active: bool = session
442 .get("_auth_user_is_active")
443 .ok()
444 .flatten()
445 .unwrap_or(true);
446 let is_admin: bool = session
447 .get("_auth_user_is_admin")
448 .ok()
449 .flatten()
450 .unwrap_or(false);
451 let is_staff: bool = session
452 .get("_auth_user_is_staff")
453 .ok()
454 .flatten()
455 .unwrap_or(false);
456 let is_superuser: bool = session
457 .get("_auth_user_is_superuser")
458 .ok()
459 .flatten()
460 .unwrap_or(false);
461
462 let user = SimpleUser {
464 id: uuid::Uuid::parse_str(&user_id)
465 .map_err(|_| AuthenticationError::InvalidCredentials)?,
466 username,
467 email,
468 is_active,
469 is_admin,
470 is_staff,
471 is_superuser,
472 };
473
474 return Ok(Some(Box::new(user)));
475 }
476 }
477 }
478
479 Ok(None)
480 }
481}
482
483#[async_trait::async_trait]
484impl<B: SessionBackend> AuthenticationBackend for SessionAuthentication<B> {
485 async fn authenticate(
486 &self,
487 request: &Request,
488 ) -> Result<Option<Box<dyn User>>, AuthenticationError> {
489 <Self as RestAuthentication>::authenticate(self, request).await
490 }
491
492 #[cfg(feature = "argon2-hasher")]
493 async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
494 let id =
496 uuid::Uuid::parse_str(user_id).map_err(|_| AuthenticationError::InvalidCredentials)?;
497
498 let conn = reinhardt_db::orm::manager::get_connection()
500 .await
501 .map_err(|e| AuthenticationError::DatabaseError(e.to_string()))?;
502
503 use reinhardt_db::orm::{
505 Alias, DatabaseBackend, Expr, ExprTrait, Model, MySqlQueryBuilder,
506 PostgresQueryBuilder, Query, QueryStatementBuilder, SqliteQueryBuilder,
507 };
508
509 let table_name = DefaultUser::table_name();
510
511 let stmt = Query::select()
513 .columns([
514 Alias::new("id"),
515 Alias::new("username"),
516 Alias::new("email"),
517 Alias::new("first_name"),
518 Alias::new("last_name"),
519 Alias::new("password_hash"),
520 Alias::new("last_login"),
521 Alias::new("is_active"),
522 Alias::new("is_staff"),
523 Alias::new("is_superuser"),
524 Alias::new("date_joined"),
525 Alias::new("user_permissions"),
526 Alias::new("groups"),
527 ])
528 .from(Alias::new(table_name))
529 .and_where(Expr::col(Alias::new("id")).eq(Expr::value(id.to_string())))
530 .to_owned();
531
532 let sql = match conn.backend() {
533 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
534 DatabaseBackend::MySql => stmt.to_string(MySqlQueryBuilder),
535 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
536 };
537
538 let row = conn
540 .query_one(&sql, vec![])
541 .await
542 .map_err(|e| AuthenticationError::DatabaseError(e.to_string()))?;
543
544 let user: DefaultUser = serde_json::from_value(row.data).map_err(|e| {
546 AuthenticationError::DatabaseError(format!("Deserialization failed: {}", e))
547 })?;
548
549 Ok(Some(Box::new(user)))
551 }
552
553 #[cfg(not(feature = "argon2-hasher"))]
554 async fn get_user(&self, _user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
555 Ok(None)
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 #[cfg(feature = "jwt")]
565 use crate::basic::BasicAuthentication;
566 use bytes::Bytes;
567 use hyper::{HeaderMap, Method};
568
569 #[tokio::test]
570 #[cfg(feature = "jwt")]
571 async fn test_composite_authentication() {
572 let mut basic = BasicAuthentication::new();
573 basic.add_user("user1", "pass1");
574
575 let composite = CompositeAuthentication::new().with_backend(basic);
576
577 let mut headers = HeaderMap::new();
579 headers.insert(
580 "Authorization",
581 "Basic dXNlcjE6cGFzczE=".parse().unwrap(), );
583
584 let request = Request::builder()
585 .method(Method::GET)
586 .uri("/")
587 .headers(headers)
588 .body(Bytes::new())
589 .build()
590 .unwrap();
591
592 let result = RestAuthentication::authenticate(&composite, &request)
593 .await
594 .unwrap();
595 assert!(result.is_some());
596 assert_eq!(result.unwrap().get_username(), "user1");
597 }
598
599 #[tokio::test]
600 async fn test_token_authentication() {
601 let mut auth = TokenAuthentication::new();
602 auth.add_token("secret_token", "alice");
603
604 let mut headers = HeaderMap::new();
605 headers.insert("Authorization", "Token secret_token".parse().unwrap());
606
607 let request = Request::builder()
608 .method(Method::GET)
609 .uri("/")
610 .headers(headers)
611 .body(Bytes::new())
612 .build()
613 .unwrap();
614
615 let result = RestAuthentication::authenticate(&auth, &request)
616 .await
617 .unwrap();
618 assert!(result.is_some());
619 assert_eq!(result.unwrap().get_username(), "alice");
620 }
621
622 #[tokio::test]
623 async fn test_remote_user_authentication() {
624 let auth = RemoteUserAuthentication::new();
625
626 let mut headers = HeaderMap::new();
627 headers.insert("REMOTE_USER", "bob".parse().unwrap());
628
629 let request = Request::builder()
630 .method(Method::GET)
631 .uri("/")
632 .headers(headers)
633 .body(Bytes::new())
634 .build()
635 .unwrap();
636
637 let result = RestAuthentication::authenticate(&auth, &request)
638 .await
639 .unwrap();
640 assert!(result.is_some());
641 assert_eq!(result.unwrap().get_username(), "bob");
642 }
643
644 #[tokio::test]
645 async fn test_session_authentication() {
646 use crate::sessions::InMemorySessionBackend;
647 use crate::sessions::Session;
648
649 let session_backend = InMemorySessionBackend::new();
650
651 let mut session = Session::new(session_backend.clone());
653 session
654 .set("_auth_user_id", "550e8400-e29b-41d4-a716-446655440000")
655 .unwrap();
656 session.set("_auth_user_name", "testuser").unwrap();
657 session.set("_auth_user_email", "test@example.com").unwrap();
658 session.set("_auth_user_is_active", true).unwrap();
659 session.save().await.unwrap();
660
661 let session_key = session.get_or_create_key().to_string();
663
664 let auth = SessionAuthentication::new(session_backend);
665
666 let mut headers = HeaderMap::new();
667 let cookie_value = format!("sessionid={}", session_key);
668 headers.insert("Cookie", cookie_value.parse().unwrap());
669
670 let request = Request::builder()
671 .method(Method::GET)
672 .uri("/")
673 .headers(headers)
674 .body(Bytes::new())
675 .build()
676 .unwrap();
677
678 let result = RestAuthentication::authenticate(&auth, &request)
679 .await
680 .unwrap();
681 assert!(result.is_some());
682
683 let user = result.unwrap();
685 assert_eq!(user.get_username(), "testuser");
686 }
687
688 #[tokio::test]
689 async fn test_custom_token_config() {
690 let config = TokenAuthConfig {
691 header_name: "X-API-Key".to_string(),
692 prefix: "Bearer".to_string(),
693 };
694
695 let mut auth = TokenAuthentication::with_config(config);
696 auth.add_token("my_token", "charlie");
697
698 let mut headers = HeaderMap::new();
699 headers.insert("X-API-Key", "Bearer my_token".parse().unwrap());
700
701 let request = Request::builder()
702 .method(Method::GET)
703 .uri("/")
704 .headers(headers)
705 .body(Bytes::new())
706 .build()
707 .unwrap();
708
709 let result = RestAuthentication::authenticate(&auth, &request)
710 .await
711 .unwrap();
712 assert!(result.is_some());
713 assert_eq!(result.unwrap().get_username(), "charlie");
714 }
715}