1use async_trait::async_trait;
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use crate::{
13 traits::{AuthProvider, Authenticatable, AuthenticationResult, SessionStorage},
14 utils::CryptoUtils,
15 AuthError, AuthResult,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub struct SessionId(String);
21
22impl SessionId {
23 pub fn generate() -> Self {
25 Self(CryptoUtils::generate_token(32))
26 }
27
28 pub fn from_string(s: String) -> AuthResult<Self> {
30 if s.len() < 16 {
31 return Err(AuthError::token_error("Session ID too short"));
32 }
33 Ok(Self(s))
34 }
35
36 pub fn as_str(&self) -> &str {
38 &self.0
39 }
40}
41
42impl std::fmt::Display for SessionId {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self.0)
45 }
46}
47
48impl From<String> for SessionId {
49 fn from(s: String) -> Self {
50 Self(s)
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SessionData {
57 pub user_id: String,
59
60 pub username: String,
62
63 pub roles: Vec<String>,
65
66 pub permissions: Vec<String>,
68
69 pub created_at: DateTime<Utc>,
71
72 pub last_accessed: DateTime<Utc>,
74
75 pub expires_at: DateTime<Utc>,
77
78 pub csrf_token: Option<String>,
80
81 pub ip_address: Option<String>,
83
84 pub user_agent: Option<String>,
86
87 pub metadata: HashMap<String, serde_json::Value>,
89}
90
91impl SessionData {
92 pub fn new<U: Authenticatable>(
94 user: &U,
95 expires_in: Duration,
96 csrf_token: Option<String>,
97 ip_address: Option<String>,
98 user_agent: Option<String>,
99 ) -> Self {
100 let now = Utc::now();
101 Self {
102 user_id: format!("{:?}", user.id()),
103 username: user.username().to_string(),
104 roles: user.roles(),
105 permissions: user.permissions(),
106 created_at: now,
107 last_accessed: now,
108 expires_at: now + expires_in,
109 csrf_token,
110 ip_address,
111 user_agent,
112 metadata: user.additional_data(),
113 }
114 }
115
116 pub fn is_expired(&self) -> bool {
118 Utc::now() > self.expires_at
119 }
120
121 pub fn touch(&mut self) {
123 self.last_accessed = Utc::now();
124 }
125
126 pub fn extend(&mut self, duration: Duration) {
128 self.expires_at = Utc::now() + duration;
129 self.touch();
130 }
131}
132
133#[derive(Debug)]
138pub struct MemorySessionStorage {
139 sessions: Arc<RwLock<HashMap<SessionId, (SessionData, DateTime<Utc>)>>>,
140}
141
142impl MemorySessionStorage {
143 pub fn new() -> Self {
145 Self {
146 sessions: Arc::new(RwLock::new(HashMap::new())),
147 }
148 }
149}
150
151impl Default for MemorySessionStorage {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[async_trait]
158impl SessionStorage for MemorySessionStorage {
159 type SessionId = SessionId;
160 type SessionData = SessionData;
161
162 async fn create_session(
163 &self,
164 data: Self::SessionData,
165 expires_at: DateTime<Utc>,
166 ) -> AuthResult<Self::SessionId> {
167 let session_id = SessionId::generate();
168 let mut sessions = self.sessions.write().await;
169 sessions.insert(session_id.clone(), (data, expires_at));
170 Ok(session_id)
171 }
172
173 async fn get_session(&self, id: &Self::SessionId) -> AuthResult<Option<Self::SessionData>> {
174 let sessions = self.sessions.read().await;
175 if let Some((data, expires_at)) = sessions.get(id) {
176 if Utc::now() > *expires_at {
178 drop(sessions);
180 let mut sessions = self.sessions.write().await;
181 sessions.remove(id);
182 return Ok(None);
183 }
184 Ok(Some(data.clone()))
185 } else {
186 Ok(None)
187 }
188 }
189
190 async fn update_session(
191 &self,
192 id: &Self::SessionId,
193 data: Self::SessionData,
194 expires_at: DateTime<Utc>,
195 ) -> AuthResult<()> {
196 let mut sessions = self.sessions.write().await;
197 if sessions.contains_key(id) {
198 sessions.insert(id.clone(), (data, expires_at));
199 Ok(())
200 } else {
201 Err(AuthError::token_error("Session not found"))
202 }
203 }
204
205 async fn delete_session(&self, id: &Self::SessionId) -> AuthResult<()> {
206 let mut sessions = self.sessions.write().await;
207 sessions.remove(id);
208 Ok(())
209 }
210
211 async fn cleanup_expired_sessions(&self) -> AuthResult<u64> {
212 let mut sessions = self.sessions.write().await;
213 let now = Utc::now();
214 let initial_count = sessions.len();
215
216 sessions.retain(|_, (_, expires_at)| now <= *expires_at);
217
218 let cleaned = (initial_count - sessions.len()) as u64;
219 Ok(cleaned)
220 }
221
222 async fn get_session_expiry(&self, id: &Self::SessionId) -> AuthResult<Option<DateTime<Utc>>> {
223 let sessions = self.sessions.read().await;
224 Ok(sessions.get(id).map(|(_, expires_at)| *expires_at))
225 }
226
227 async fn extend_session(
228 &self,
229 id: &Self::SessionId,
230 expires_at: DateTime<Utc>,
231 ) -> AuthResult<()> {
232 let mut sessions = self.sessions.write().await;
233 if let Some((data, _)) = sessions.get(id).cloned() {
234 sessions.insert(id.clone(), (data, expires_at));
235 Ok(())
236 } else {
237 Err(AuthError::token_error("Session not found"))
238 }
239 }
240}
241#[derive(Debug)]
246pub struct SessionProvider<S, U>
247where
248 S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
249 U: Authenticatable,
250{
251 storage: Arc<S>,
252 session_duration: Duration,
253 cleanup_interval: Duration,
254 _phantom: std::marker::PhantomData<U>,
255}
256
257impl<S, U> SessionProvider<S, U>
258where
259 S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
260 U: Authenticatable + Clone,
261{
262 pub fn new(storage: S, session_duration: Duration, cleanup_interval: Duration) -> Self {
264 Self {
265 storage: Arc::new(storage),
266 session_duration,
267 cleanup_interval,
268 _phantom: std::marker::PhantomData,
269 }
270 }
271
272 pub fn cleanup_interval(&self) -> Duration {
273 self.cleanup_interval
274 }
275
276 pub fn with_default_config(storage: S) -> Self {
278 Self::new(
279 storage,
280 Duration::hours(24), Duration::hours(1), )
283 }
284
285 pub fn storage(&self) -> &S {
287 &self.storage
288 }
289
290 pub async fn create_session(
292 &self,
293 user: &U,
294 csrf_token: Option<String>,
295 ip_address: Option<String>,
296 user_agent: Option<String>,
297 ) -> AuthResult<SessionId> {
298 let session_data = SessionData::new(
299 user,
300 self.session_duration,
301 csrf_token,
302 ip_address,
303 user_agent,
304 );
305
306 self.storage
307 .create_session(session_data, Utc::now() + self.session_duration)
308 .await
309 }
310
311 pub async fn validate_session(&self, session_id: &SessionId) -> AuthResult<SessionData> {
313 match self.storage.get_session(session_id).await? {
314 Some(mut session_data) => {
315 if session_data.is_expired() {
316 let _ = self.storage.delete_session(session_id).await;
318 return Err(AuthError::token_error("Session expired"));
319 }
320
321 session_data.touch();
323 self.storage
324 .update_session(session_id, session_data.clone(), session_data.expires_at)
325 .await?;
326
327 Ok(session_data)
328 }
329 None => Err(AuthError::token_error("Session not found")),
330 }
331 }
332
333 pub async fn extend_session(&self, session_id: &SessionId) -> AuthResult<()> {
335 let new_expiry = Utc::now() + self.session_duration;
336 self.storage.extend_session(session_id, new_expiry).await
337 }
338
339 pub async fn destroy_session(&self, session_id: &SessionId) -> AuthResult<()> {
341 self.storage.delete_session(session_id).await
342 }
343
344 pub async fn cleanup_expired(&self) -> AuthResult<u64> {
346 self.storage.cleanup_expired_sessions().await
347 }
348}
349
350#[async_trait]
352pub trait UserFinder<U: Authenticatable>: Send + Sync {
353 async fn find_by_id(&self, id: &str) -> AuthResult<Option<U>>;
355}
356
357#[derive(Debug, Clone)]
359pub struct SessionCredentials {
360 pub session_id: SessionId,
361}
362
363#[async_trait]
364impl<S, U> AuthProvider<U> for SessionProvider<S, U>
365where
366 S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
367 U: Authenticatable + Clone,
368{
369 type Token = SessionId;
370 type Credentials = SessionCredentials;
371
372 async fn authenticate(
373 &self,
374 credentials: &Self::Credentials,
375 ) -> AuthResult<AuthenticationResult<U, Self::Token>> {
376 let _session_data = self.validate_session(&credentials.session_id).await?;
378
379 Err(AuthError::generic_error(
383 "Session authentication requires a UserFinder implementation",
384 ))
385 }
386
387 async fn validate_token(&self, _token: &Self::Token) -> AuthResult<U> {
388 Err(AuthError::generic_error(
390 "Token validation requires a UserFinder implementation",
391 ))
392 }
393
394 async fn revoke_token(&self, token: &Self::Token) -> AuthResult<()> {
395 self.destroy_session(token).await
396 }
397
398 fn provider_name(&self) -> &str {
399 "session"
400 }
401}
402#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[derive(Debug, Clone)]
410 struct MockUser {
411 id: String,
412 username: String,
413 roles: Vec<String>,
414 permissions: Vec<String>,
415 active: bool,
416 }
417
418 #[async_trait]
419 impl Authenticatable for MockUser {
420 type Id = String;
421 type Credentials = String;
422
423 fn id(&self) -> &Self::Id {
424 &self.id
425 }
426
427 fn username(&self) -> &str {
428 &self.username
429 }
430
431 fn is_active(&self) -> bool {
432 self.active
433 }
434
435 fn roles(&self) -> Vec<String> {
436 self.roles.clone()
437 }
438
439 fn permissions(&self) -> Vec<String> {
440 self.permissions.clone()
441 }
442
443 async fn verify_credentials(&self, _credentials: &Self::Credentials) -> AuthResult<bool> {
444 Ok(true)
445 }
446 }
447
448 #[tokio::test]
449 async fn test_session_id_generation() {
450 let session_id = SessionId::generate();
451 assert!(session_id.as_str().len() >= 32);
452
453 let session_id2 = SessionId::generate();
454 assert_ne!(session_id.as_str(), session_id2.as_str());
455 }
456
457 #[tokio::test]
458 async fn test_session_id_from_string() {
459 let valid_id = "a".repeat(32);
460 let session_id = SessionId::from_string(valid_id.clone()).unwrap();
461 assert_eq!(session_id.as_str(), &valid_id);
462
463 let short_id = "short";
464 let result = SessionId::from_string(short_id.to_string());
465 assert!(result.is_err());
466 }
467
468 #[tokio::test]
469 async fn test_session_data_creation() {
470 let user = MockUser {
471 id: "123".to_string(),
472 username: "test@example.com".to_string(),
473 roles: vec!["admin".to_string()],
474 permissions: vec!["read".to_string(), "write".to_string()],
475 active: true,
476 };
477
478 let session_data = SessionData::new(
479 &user,
480 Duration::hours(24),
481 Some("csrf_token".to_string()),
482 Some("192.168.1.1".to_string()),
483 Some("Mozilla/5.0".to_string()),
484 );
485
486 assert_eq!(session_data.user_id, "\"123\"");
487 assert_eq!(session_data.username, "test@example.com");
488 assert_eq!(session_data.roles, vec!["admin"]);
489 assert_eq!(session_data.permissions, vec!["read", "write"]);
490 assert_eq!(session_data.csrf_token, Some("csrf_token".to_string()));
491 assert!(!session_data.is_expired());
492 }
493
494 #[tokio::test]
495 async fn test_session_data_expiration() {
496 let user = MockUser {
497 id: "123".to_string(),
498 username: "test@example.com".to_string(),
499 roles: vec![],
500 permissions: vec![],
501 active: true,
502 };
503
504 let mut session_data = SessionData::new(
505 &user,
506 Duration::milliseconds(-1), None,
508 None,
509 None,
510 );
511
512 assert!(session_data.is_expired());
514
515 session_data.extend(Duration::hours(1));
517 assert!(!session_data.is_expired());
518 }
519
520 #[tokio::test]
521 async fn test_memory_session_storage() {
522 let storage = MemorySessionStorage::new();
523 let user = MockUser {
524 id: "123".to_string(),
525 username: "test@example.com".to_string(),
526 roles: vec![],
527 permissions: vec![],
528 active: true,
529 };
530
531 let session_data = SessionData::new(&user, Duration::hours(24), None, None, None);
532
533 let expires_at = Utc::now() + Duration::hours(24);
534
535 let session_id = storage
537 .create_session(session_data.clone(), expires_at)
538 .await
539 .unwrap();
540
541 let retrieved = storage.get_session(&session_id).await.unwrap();
543 assert!(retrieved.is_some());
544 assert_eq!(retrieved.unwrap().user_id, session_data.user_id);
545
546 storage.delete_session(&session_id).await.unwrap();
548
549 let retrieved = storage.get_session(&session_id).await.unwrap();
551 assert!(retrieved.is_none());
552 }
553
554 #[tokio::test]
555 async fn test_memory_session_storage_expired_cleanup() {
556 let storage = MemorySessionStorage::new();
557 let user = MockUser {
558 id: "123".to_string(),
559 username: "test@example.com".to_string(),
560 roles: vec![],
561 permissions: vec![],
562 active: true,
563 };
564
565 let session_data = SessionData::new(
566 &user,
567 Duration::milliseconds(-1), None,
569 None,
570 None,
571 );
572
573 let expires_at = Utc::now() - Duration::hours(1); let session_id = storage
577 .create_session(session_data, expires_at)
578 .await
579 .unwrap();
580
581 let retrieved = storage.get_session(&session_id).await.unwrap();
583 assert!(retrieved.is_none());
584 }
585
586 #[tokio::test]
587 async fn test_session_provider_creation() {
588 let storage = MemorySessionStorage::new();
589 let provider: SessionProvider<MemorySessionStorage, MockUser> =
590 SessionProvider::with_default_config(storage);
591
592 assert_eq!(provider.provider_name(), "session");
593 }
594
595 #[tokio::test]
596 async fn test_session_provider_session_lifecycle() {
597 let storage = MemorySessionStorage::new();
598 let provider = SessionProvider::with_default_config(storage);
599
600 let user = MockUser {
601 id: "123".to_string(),
602 username: "test@example.com".to_string(),
603 roles: vec!["admin".to_string()],
604 permissions: vec!["read".to_string()],
605 active: true,
606 };
607
608 let session_id = provider
610 .create_session(
611 &user,
612 Some("csrf_token".to_string()),
613 Some("192.168.1.1".to_string()),
614 Some("Mozilla/5.0".to_string()),
615 )
616 .await
617 .unwrap();
618
619 let session_data = provider.validate_session(&session_id).await.unwrap();
621 assert_eq!(session_data.username, "test@example.com");
622 assert_eq!(session_data.csrf_token, Some("csrf_token".to_string()));
623
624 provider.extend_session(&session_id).await.unwrap();
626
627 provider.destroy_session(&session_id).await.unwrap();
629
630 let result = provider.validate_session(&session_id).await;
632 assert!(result.is_err());
633 }
634
635 #[tokio::test]
636 async fn test_session_cleanup() {
637 let storage = MemorySessionStorage::new();
638 let provider: SessionProvider<MemorySessionStorage, MockUser> =
639 SessionProvider::with_default_config(storage);
640
641 let cleaned = provider.cleanup_expired().await.unwrap();
643 assert_eq!(cleaned, 0);
644 }
645}