1use crate::{
7 AuthContext,
8 jwt::{JwtConfig, JwtError, JwtManager},
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use thiserror::Error;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info};
16use uuid::Uuid;
17
18#[derive(Debug, Error)]
20pub enum SessionError {
21 #[error("Session not found: {session_id}")]
22 SessionNotFound { session_id: String },
23
24 #[error("Session expired: {session_id}")]
25 SessionExpired { session_id: String },
26
27 #[error("Session invalid: {reason}")]
28 SessionInvalid { reason: String },
29
30 #[error("Maximum sessions exceeded for user: {user_id}")]
31 MaxSessionsExceeded { user_id: String },
32
33 #[error("Session creation failed: {reason}")]
34 CreationFailed { reason: String },
35
36 #[error("JWT error: {0}")]
37 JwtError(#[from] JwtError),
38
39 #[error("Storage error: {0}")]
40 StorageError(String),
41
42 #[error("Invalid session token")]
43 InvalidToken,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Session {
49 pub session_id: String,
51
52 pub user_id: String,
54
55 pub auth_context: AuthContext,
57
58 pub created_at: chrono::DateTime<chrono::Utc>,
60
61 pub last_accessed: chrono::DateTime<chrono::Utc>,
63
64 pub expires_at: chrono::DateTime<chrono::Utc>,
66
67 pub client_ip: Option<String>,
69
70 pub user_agent: Option<String>,
72
73 pub metadata: HashMap<String, String>,
75
76 pub is_active: bool,
78
79 pub refresh_token: Option<String>,
81}
82
83impl Session {
84 pub fn new(user_id: String, auth_context: AuthContext, duration: chrono::Duration) -> Self {
86 let now = chrono::Utc::now();
87 let session_id = Uuid::new_v4().to_string();
88
89 Self {
90 session_id,
91 user_id,
92 auth_context,
93 created_at: now,
94 last_accessed: now,
95 expires_at: now + duration,
96 client_ip: None,
97 user_agent: None,
98 metadata: HashMap::new(),
99 is_active: true,
100 refresh_token: None,
101 }
102 }
103
104 pub fn is_expired(&self) -> bool {
106 chrono::Utc::now() > self.expires_at
107 }
108
109 pub fn touch(&mut self) {
111 self.last_accessed = chrono::Utc::now();
112 }
113
114 pub fn with_metadata(mut self, key: String, value: String) -> Self {
116 self.metadata.insert(key, value);
117 self
118 }
119
120 pub fn with_client_info(
122 mut self,
123 client_ip: Option<String>,
124 user_agent: Option<String>,
125 ) -> Self {
126 self.client_ip = client_ip;
127 self.user_agent = user_agent;
128 self
129 }
130}
131
132#[async_trait::async_trait]
134pub trait SessionStorage: Send + Sync {
135 async fn store_session(&self, session: &Session) -> Result<(), SessionError>;
137
138 async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionError>;
140
141 async fn update_session(&self, session: &Session) -> Result<(), SessionError>;
143
144 async fn delete_session(&self, session_id: &str) -> Result<(), SessionError>;
146
147 async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError>;
149
150 async fn cleanup_expired(&self) -> Result<u64, SessionError>;
152
153 async fn get_session_count(&self, user_id: &str) -> Result<usize, SessionError>;
155}
156
157pub struct MemorySessionStorage {
159 sessions: Arc<RwLock<HashMap<String, Session>>>,
160 user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
161}
162
163impl MemorySessionStorage {
164 pub fn new() -> Self {
165 Self {
166 sessions: Arc::new(RwLock::new(HashMap::new())),
167 user_sessions: Arc::new(RwLock::new(HashMap::new())),
168 }
169 }
170}
171
172impl Default for MemorySessionStorage {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178#[async_trait::async_trait]
179impl SessionStorage for MemorySessionStorage {
180 async fn store_session(&self, session: &Session) -> Result<(), SessionError> {
181 let mut sessions = self.sessions.write().await;
182 let mut user_sessions = self.user_sessions.write().await;
183
184 sessions.insert(session.session_id.clone(), session.clone());
185
186 user_sessions
187 .entry(session.user_id.clone())
188 .or_insert_with(Vec::new)
189 .push(session.session_id.clone());
190
191 debug!(
192 "Stored session {} for user {}",
193 session.session_id, session.user_id
194 );
195 Ok(())
196 }
197
198 async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionError> {
199 let sessions = self.sessions.read().await;
200 Ok(sessions.get(session_id).cloned())
201 }
202
203 async fn update_session(&self, session: &Session) -> Result<(), SessionError> {
204 let mut sessions = self.sessions.write().await;
205 if sessions.contains_key(&session.session_id) {
206 sessions.insert(session.session_id.clone(), session.clone());
207 debug!("Updated session {}", session.session_id);
208 Ok(())
209 } else {
210 Err(SessionError::SessionNotFound {
211 session_id: session.session_id.clone(),
212 })
213 }
214 }
215
216 async fn delete_session(&self, session_id: &str) -> Result<(), SessionError> {
217 let mut sessions = self.sessions.write().await;
218 let mut user_sessions = self.user_sessions.write().await;
219
220 if let Some(session) = sessions.remove(session_id) {
221 if let Some(user_session_list) = user_sessions.get_mut(&session.user_id) {
222 user_session_list.retain(|id| id != session_id);
223 if user_session_list.is_empty() {
224 user_sessions.remove(&session.user_id);
225 }
226 }
227 debug!("Deleted session {}", session_id);
228 Ok(())
229 } else {
230 Err(SessionError::SessionNotFound {
231 session_id: session_id.to_string(),
232 })
233 }
234 }
235
236 async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError> {
237 let sessions = self.sessions.read().await;
238 let user_sessions = self.user_sessions.read().await;
239
240 let mut result = Vec::new();
241 if let Some(session_ids) = user_sessions.get(user_id) {
242 for session_id in session_ids {
243 if let Some(session) = sessions.get(session_id) {
244 result.push(session.clone());
245 }
246 }
247 }
248
249 Ok(result)
250 }
251
252 async fn cleanup_expired(&self) -> Result<u64, SessionError> {
253 let mut sessions = self.sessions.write().await;
254 let mut user_sessions = self.user_sessions.write().await;
255 let mut removed_count = 0u64;
256
257 let now = chrono::Utc::now();
258 let expired_sessions: Vec<String> = sessions
259 .iter()
260 .filter(|(_, session)| session.expires_at < now)
261 .map(|(id, _)| id.clone())
262 .collect();
263
264 for session_id in expired_sessions {
265 if let Some(session) = sessions.remove(&session_id) {
266 if let Some(user_session_list) = user_sessions.get_mut(&session.user_id) {
267 user_session_list.retain(|id| id != &session_id);
268 if user_session_list.is_empty() {
269 user_sessions.remove(&session.user_id);
270 }
271 }
272 removed_count += 1;
273 }
274 }
275
276 if removed_count > 0 {
277 info!("Cleaned up {} expired sessions", removed_count);
278 }
279
280 Ok(removed_count)
281 }
282
283 async fn get_session_count(&self, user_id: &str) -> Result<usize, SessionError> {
284 let user_sessions = self.user_sessions.read().await;
285 Ok(user_sessions.get(user_id).map(|v| v.len()).unwrap_or(0))
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct SessionConfig {
292 pub default_duration: chrono::Duration,
294
295 pub max_duration: chrono::Duration,
297
298 pub max_sessions_per_user: usize,
300
301 pub enable_jwt: bool,
303
304 pub jwt_config: JwtConfig,
306
307 pub enable_refresh: bool,
309
310 pub refresh_duration: chrono::Duration,
312
313 pub cleanup_interval: chrono::Duration,
315
316 pub extend_on_access: bool,
318
319 pub extension_duration: chrono::Duration,
321}
322
323impl Default for SessionConfig {
324 fn default() -> Self {
325 Self {
326 default_duration: chrono::Duration::hours(24),
327 max_duration: chrono::Duration::days(7),
328 max_sessions_per_user: 10,
329 enable_jwt: true,
330 jwt_config: JwtConfig::default(),
331 enable_refresh: true,
332 refresh_duration: chrono::Duration::days(30),
333 cleanup_interval: chrono::Duration::hours(1),
334 extend_on_access: true,
335 extension_duration: chrono::Duration::hours(1),
336 }
337 }
338}
339
340pub struct SessionManager {
342 config: SessionConfig,
343 storage: Arc<dyn SessionStorage>,
344 jwt_manager: Option<Arc<JwtManager>>,
345}
346
347impl SessionManager {
348 pub fn new(config: SessionConfig, storage: Arc<dyn SessionStorage>) -> Self {
350 let jwt_manager = if config.enable_jwt {
351 match JwtManager::new(config.jwt_config.clone()) {
352 Ok(manager) => Some(Arc::new(manager)),
353 Err(e) => {
354 error!("Failed to create JWT manager: {}", e);
355 None
356 }
357 }
358 } else {
359 None
360 };
361
362 Self {
363 config,
364 storage,
365 jwt_manager,
366 }
367 }
368
369 pub fn with_default_config() -> Self {
371 Self::new(
372 SessionConfig::default(),
373 Arc::new(MemorySessionStorage::new()),
374 )
375 }
376
377 pub async fn create_session(
379 &self,
380 user_id: String,
381 auth_context: AuthContext,
382 duration: Option<chrono::Duration>,
383 client_ip: Option<String>,
384 user_agent: Option<String>,
385 ) -> Result<(Session, Option<String>), SessionError> {
386 let session_count = self.storage.get_session_count(&user_id).await?;
388 if session_count >= self.config.max_sessions_per_user {
389 return Err(SessionError::MaxSessionsExceeded { user_id });
390 }
391
392 let session_duration = duration.unwrap_or(self.config.default_duration);
394
395 let final_duration = std::cmp::min(session_duration, self.config.max_duration);
397
398 let mut session = Session::new(user_id.clone(), auth_context, final_duration)
400 .with_client_info(client_ip, user_agent);
401
402 let jwt_token = if let Some(jwt_manager) = &self.jwt_manager {
404 let token = jwt_manager
405 .generate_access_token(
406 session
407 .auth_context
408 .user_id
409 .clone()
410 .unwrap_or_else(|| user_id.clone()),
411 session.auth_context.roles.clone(),
412 session.auth_context.api_key_id.clone(),
413 session.client_ip.clone(),
414 Some(session.session_id.clone()),
415 vec!["api".to_string()],
416 )
417 .await?;
418 Some(token)
419 } else {
420 None
421 };
422
423 if self.config.enable_refresh {
425 session.refresh_token = Some(Uuid::new_v4().to_string());
426 }
427
428 self.storage.store_session(&session).await?;
430
431 info!(
432 "Created session {} for user {} (duration: {} hours)",
433 session.session_id,
434 user_id,
435 final_duration.num_hours()
436 );
437
438 Ok((session, jwt_token))
439 }
440
441 pub async fn get_session(&self, session_id: &str) -> Result<Session, SessionError> {
443 let session = self.storage.get_session(session_id).await?.ok_or_else(|| {
444 SessionError::SessionNotFound {
445 session_id: session_id.to_string(),
446 }
447 })?;
448
449 if session.is_expired() {
450 let _ = self.storage.delete_session(session_id).await;
452 return Err(SessionError::SessionExpired {
453 session_id: session_id.to_string(),
454 });
455 }
456
457 if !session.is_active {
458 return Err(SessionError::SessionInvalid {
459 reason: "Session is inactive".to_string(),
460 });
461 }
462
463 Ok(session)
464 }
465
466 pub async fn validate_session(&self, session_id: &str) -> Result<Session, SessionError> {
468 let mut session = self.get_session(session_id).await?;
469
470 session.touch();
472
473 if self.config.extend_on_access {
475 let new_expiry = chrono::Utc::now() + self.config.extension_duration;
476 if new_expiry < session.expires_at + self.config.max_duration {
477 session.expires_at = new_expiry;
478 }
479 }
480
481 self.storage.update_session(&session).await?;
483
484 debug!("Validated and updated session {}", session_id);
485 Ok(session)
486 }
487
488 pub async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext, SessionError> {
490 let jwt_manager =
491 self.jwt_manager
492 .as_ref()
493 .ok_or_else(|| SessionError::SessionInvalid {
494 reason: "JWT not enabled".to_string(),
495 })?;
496
497 let auth_context = jwt_manager.token_to_auth_context(token).await?;
498 Ok(auth_context)
499 }
500
501 pub async fn refresh_session(
503 &self,
504 session_id: &str,
505 refresh_token: &str,
506 ) -> Result<(Session, Option<String>), SessionError> {
507 let session = self.get_session(session_id).await?;
508
509 if !self.config.enable_refresh {
511 return Err(SessionError::SessionInvalid {
512 reason: "Session refresh not enabled".to_string(),
513 });
514 }
515
516 let stored_refresh_token =
517 session
518 .refresh_token
519 .as_ref()
520 .ok_or_else(|| SessionError::SessionInvalid {
521 reason: "No refresh token available".to_string(),
522 })?;
523
524 if stored_refresh_token != refresh_token {
525 return Err(SessionError::InvalidToken);
526 }
527
528 self.create_session(
530 session.user_id.clone(),
531 session.auth_context.clone(),
532 Some(self.config.default_duration),
533 session.client_ip.clone(),
534 session.user_agent.clone(),
535 )
536 .await
537 }
538
539 pub async fn terminate_session(&self, session_id: &str) -> Result<(), SessionError> {
541 self.storage.delete_session(session_id).await?;
542 info!("Terminated session {}", session_id);
543 Ok(())
544 }
545
546 pub async fn terminate_user_sessions(&self, user_id: &str) -> Result<u64, SessionError> {
548 let sessions = self.storage.get_user_sessions(user_id).await?;
549 let mut terminated_count = 0u64;
550
551 for session in sessions {
552 if self
553 .storage
554 .delete_session(&session.session_id)
555 .await
556 .is_ok()
557 {
558 terminated_count += 1;
559 }
560 }
561
562 info!(
563 "Terminated {} sessions for user {}",
564 terminated_count, user_id
565 );
566 Ok(terminated_count)
567 }
568
569 pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError> {
571 let sessions = self.storage.get_user_sessions(user_id).await?;
572 let active_sessions = sessions
573 .into_iter()
574 .filter(|s| !s.is_expired() && s.is_active)
575 .collect();
576
577 Ok(active_sessions)
578 }
579
580 pub async fn cleanup_expired_sessions(&self) -> Result<u64, SessionError> {
582 self.storage.cleanup_expired().await
583 }
584
585 pub async fn start_cleanup_task(&self) -> tokio::task::JoinHandle<()> {
587 let storage = Arc::clone(&self.storage);
588 let interval = self.config.cleanup_interval;
589
590 tokio::spawn(async move {
591 let mut cleanup_interval = tokio::time::interval(
592 interval
593 .to_std()
594 .unwrap_or(std::time::Duration::from_secs(3600)),
595 );
596
597 loop {
598 cleanup_interval.tick().await;
599
600 match storage.cleanup_expired().await {
601 Ok(count) => {
602 if count > 0 {
603 debug!("Cleanup task removed {} expired sessions", count);
604 }
605 }
606 Err(e) => {
607 error!("Session cleanup failed: {}", e);
608 }
609 }
610 }
611 })
612 }
613
614 pub async fn get_session_stats(&self) -> Result<SessionStats, SessionError> {
616 Ok(SessionStats {
619 total_sessions: 0, active_sessions: 0, expired_sessions: 0, })
623 }
624}
625
626#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct SessionStats {
629 pub total_sessions: u64,
630 pub active_sessions: u64,
631 pub expired_sessions: u64,
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use crate::models::Role;
638
639 fn create_test_auth_context() -> AuthContext {
640 AuthContext {
641 user_id: Some("test_user".to_string()),
642 roles: vec![Role::Operator],
643 api_key_id: Some("test_key".to_string()),
644 permissions: vec!["read".to_string(), "write".to_string()],
645 }
646 }
647
648 #[tokio::test]
649 async fn test_session_creation() {
650 let manager = SessionManager::with_default_config();
651 let auth_context = create_test_auth_context();
652
653 let result = manager
654 .create_session(
655 "test_user".to_string(),
656 auth_context,
657 None,
658 Some("127.0.0.1".to_string()),
659 Some("TestAgent/1.0".to_string()),
660 )
661 .await;
662
663 assert!(result.is_ok());
664 let (session, jwt_token) = result.unwrap();
665 assert_eq!(session.user_id, "test_user");
666 assert!(!session.is_expired());
667 assert!(jwt_token.is_some()); }
669
670 #[tokio::test]
671 async fn test_session_validation() {
672 let manager = SessionManager::with_default_config();
673 let auth_context = create_test_auth_context();
674
675 let (session, _) = manager
676 .create_session("test_user".to_string(), auth_context, None, None, None)
677 .await
678 .unwrap();
679
680 let validated_session = manager.validate_session(&session.session_id).await;
681 assert!(validated_session.is_ok());
682
683 let validated = validated_session.unwrap();
684 assert!(validated.last_accessed > session.last_accessed);
685 }
686
687 #[tokio::test]
688 async fn test_session_expiration() {
689 let manager = SessionManager::with_default_config();
690 let auth_context = create_test_auth_context();
691
692 let (session, _) = manager
694 .create_session(
695 "test_user".to_string(),
696 auth_context,
697 Some(chrono::Duration::milliseconds(1)),
698 None,
699 None,
700 )
701 .await
702 .unwrap();
703
704 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
706
707 let result = manager.get_session(&session.session_id).await;
708 assert!(matches!(result, Err(SessionError::SessionExpired { .. })));
709 }
710
711 #[tokio::test]
712 async fn test_session_limits() {
713 let config = SessionConfig {
714 max_sessions_per_user: 2,
715 ..Default::default()
716 };
717 let manager = SessionManager::new(config, Arc::new(MemorySessionStorage::new()));
718 let auth_context = create_test_auth_context();
719
720 let result1 = manager
722 .create_session(
723 "test_user".to_string(),
724 auth_context.clone(),
725 None,
726 None,
727 None,
728 )
729 .await;
730 assert!(result1.is_ok());
731
732 let result2 = manager
734 .create_session(
735 "test_user".to_string(),
736 auth_context.clone(),
737 None,
738 None,
739 None,
740 )
741 .await;
742 assert!(result2.is_ok());
743
744 let result3 = manager
746 .create_session("test_user".to_string(), auth_context, None, None, None)
747 .await;
748 assert!(matches!(
749 result3,
750 Err(SessionError::MaxSessionsExceeded { .. })
751 ));
752 }
753
754 #[tokio::test]
755 async fn test_session_termination() {
756 let manager = SessionManager::with_default_config();
757 let auth_context = create_test_auth_context();
758
759 let (session, _) = manager
760 .create_session("test_user".to_string(), auth_context, None, None, None)
761 .await
762 .unwrap();
763
764 assert!(manager.get_session(&session.session_id).await.is_ok());
766
767 assert!(manager.terminate_session(&session.session_id).await.is_ok());
769
770 assert!(matches!(
772 manager.get_session(&session.session_id).await,
773 Err(SessionError::SessionNotFound { .. })
774 ));
775 }
776
777 #[tokio::test]
778 async fn test_cleanup_expired_sessions() {
779 let manager = SessionManager::with_default_config();
780 let auth_context = create_test_auth_context();
781
782 let (_, _) = manager
784 .create_session(
785 "test_user".to_string(),
786 auth_context,
787 Some(chrono::Duration::milliseconds(1)),
788 None,
789 None,
790 )
791 .await
792 .unwrap();
793
794 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
796
797 let cleanup_result = manager.cleanup_expired_sessions().await;
799 assert!(cleanup_result.is_ok());
800 assert!(cleanup_result.unwrap() > 0);
801 }
802}