1use crate::error::{SecurityError, Result};
4use crate::config::{SessionConfig, SessionStoreType, SameSitePolicy};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SessionData {
13 pub session_id: String,
14 pub user_id: String,
15 pub roles: Vec<String>,
16 pub permissions: Vec<String>,
17 pub attributes: HashMap<String, serde_json::Value>,
18 pub created_at: chrono::DateTime<chrono::Utc>,
19 pub expires_at: chrono::DateTime<chrono::Utc>,
20 pub last_accessed_at: chrono::DateTime<chrono::Utc>,
21 pub ip_address: Option<String>,
22 pub user_agent: Option<String>,
23}
24
25impl SessionData {
26 pub fn new(
28 session_id: String,
29 user_id: String,
30 roles: Vec<String>,
31 permissions: Vec<String>,
32 max_age_seconds: Option<u64>,
33 ip_address: Option<String>,
34 user_agent: Option<String>,
35 ) -> Self {
36 let now = chrono::Utc::now();
37 let expires_at = max_age_seconds
38 .map(|secs| now + chrono::Duration::seconds(secs as i64))
39 .unwrap_or_else(|| now + chrono::Duration::hours(24)); Self {
42 session_id,
43 user_id,
44 roles,
45 permissions,
46 attributes: HashMap::new(),
47 created_at: now,
48 expires_at,
49 last_accessed_at: now,
50 ip_address,
51 user_agent,
52 }
53 }
54
55 pub fn is_expired(&self) -> bool {
57 chrono::Utc::now() > self.expires_at
58 }
59
60 pub fn touch(&mut self) {
62 self.last_accessed_at = chrono::Utc::now();
63 }
64
65 pub fn extend(&mut self, additional_seconds: i64) {
67 self.expires_at = chrono::Utc::now() + chrono::Duration::seconds(additional_seconds);
68 }
69
70 pub fn time_until_expiry(&self) -> i64 {
72 let now = chrono::Utc::now();
73 (self.expires_at - now).num_seconds()
74 }
75
76 pub fn has_role(&self, role: &str) -> bool {
78 self.roles.contains(&role.to_string())
79 }
80
81 pub fn has_permission(&self, permission: &str) -> bool {
83 self.permissions.contains(&permission.to_string())
84 }
85
86 pub fn set_attribute(&mut self, key: String, value: serde_json::Value) {
88 self.attributes.insert(key, value);
89 }
90
91 pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
93 self.attributes.get(key)
94 }
95
96 pub fn remove_attribute(&mut self, key: &str) -> Option<serde_json::Value> {
98 self.attributes.remove(key)
99 }
100}
101
102#[async_trait::async_trait]
104pub trait SessionStore: Send + Sync {
105 async fn store(&self, session: SessionData) -> Result<()>;
107
108 async fn get(&self, session_id: &str) -> Result<Option<SessionData>>;
110
111 async fn update(&self, session: SessionData) -> Result<()>;
113
114 async fn delete(&self, session_id: &str) -> Result<()>;
116
117 async fn delete_user_sessions(&self, user_id: &str) -> Result<usize>;
119
120 async fn cleanup_expired(&self) -> Result<usize>;
122
123 async fn count(&self) -> Result<usize>;
125}
126
127pub struct MemorySessionStore {
129 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
130}
131
132impl MemorySessionStore {
133 pub fn new() -> Self {
134 Self {
135 sessions: Arc::new(RwLock::new(HashMap::new())),
136 }
137 }
138}
139
140#[async_trait::async_trait]
141impl SessionStore for MemorySessionStore {
142 async fn store(&self, session: SessionData) -> Result<()> {
143 let mut sessions = self.sessions.write().await;
144 sessions.insert(session.session_id.clone(), session);
145 Ok(())
146 }
147
148 async fn get(&self, session_id: &str) -> Result<Option<SessionData>> {
149 let sessions = self.sessions.read().await;
150 Ok(sessions.get(session_id).cloned())
151 }
152
153 async fn update(&self, session: SessionData) -> Result<()> {
154 let mut sessions = self.sessions.write().await;
155 sessions.insert(session.session_id.clone(), session);
156 Ok(())
157 }
158
159 async fn delete(&self, session_id: &str) -> Result<()> {
160 let mut sessions = self.sessions.write().await;
161 sessions.remove(session_id);
162 Ok(())
163 }
164
165 async fn delete_user_sessions(&self, user_id: &str) -> Result<usize> {
166 let mut sessions = self.sessions.write().await;
167 let keys_to_remove: Vec<String> = sessions
168 .iter()
169 .filter(|(_, session)| session.user_id == user_id)
170 .map(|(key, _)| key.clone())
171 .collect();
172
173 let count = keys_to_remove.len();
174 for key in keys_to_remove {
175 sessions.remove(&key);
176 }
177
178 Ok(count)
179 }
180
181 async fn cleanup_expired(&self) -> Result<usize> {
182 let mut sessions = self.sessions.write().await;
183 let expired_keys: Vec<String> = sessions
184 .iter()
185 .filter(|(_, session)| session.is_expired())
186 .map(|(key, _)| key.clone())
187 .collect();
188
189 let count = expired_keys.len();
190 for key in expired_keys {
191 sessions.remove(&key);
192 }
193
194 Ok(count)
195 }
196
197 async fn count(&self) -> Result<usize> {
198 let sessions = self.sessions.read().await;
199 Ok(sessions.len())
200 }
201}
202
203pub struct SessionManager {
205 config: SessionConfig,
206 store: Box<dyn SessionStore>,
207}
208
209impl SessionManager {
210 pub fn new(config: SessionConfig) -> Self {
212 let store: Box<dyn SessionStore> = match config.store_type {
213 SessionStoreType::Memory => Box::new(MemorySessionStore::new()),
214 SessionStoreType::Redis => {
215 Box::new(MemorySessionStore::new())
217 }
218 SessionStoreType::Database => {
219 Box::new(MemorySessionStore::new())
221 }
222 };
223
224 Self { config, store }
225 }
226
227 pub async fn create_session(
229 &self,
230 user_id: &str,
231 roles: Vec<String>,
232 permissions: Vec<String>,
233 ip_address: Option<String>,
234 user_agent: Option<String>,
235 ) -> Result<SessionData> {
236 let session_id = self.generate_session_id();
237 let session = SessionData::new(
238 session_id,
239 user_id.to_string(),
240 roles,
241 permissions,
242 self.config.max_age_seconds,
243 ip_address,
244 user_agent,
245 );
246
247 self.store.store(session.clone()).await?;
248 Ok(session)
249 }
250
251 pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
253 let mut session = self.store.get(session_id).await?;
254
255 if let Some(ref mut session) = session {
256 if session.is_expired() {
258 self.store.delete(session_id).await?;
260 return Ok(None);
261 }
262
263 session.touch();
265 self.store.update(session.clone()).await?;
266 }
267
268 Ok(session)
269 }
270
271 pub async fn update_session(&self, session: SessionData) -> Result<()> {
273 if session.is_expired() {
274 return Err(SecurityError::SessionExpired);
275 }
276
277 self.store.update(session).await
278 }
279
280 pub async fn delete_session(&self, session_id: &str) -> Result<()> {
282 self.store.delete(session_id).await
283 }
284
285 pub async fn delete_user_sessions(&self, user_id: &str) -> Result<usize> {
287 self.store.delete_user_sessions(user_id).await
288 }
289
290 pub async fn extend_session(&self, session_id: &str, additional_seconds: i64) -> Result<()> {
292 let mut session = self.store.get(session_id).await?
293 .ok_or_else(|| SecurityError::SessionInvalid)?;
294
295 if session.is_expired() {
296 return Err(SecurityError::SessionExpired);
297 }
298
299 session.extend(additional_seconds);
300 self.store.update(session).await
301 }
302
303 pub async fn validate_session(&self, session_id: &str) -> Result<Option<SessionData>> {
305 self.get_session(session_id).await
306 }
307
308 pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
310 self.store.cleanup_expired().await
311 }
312
313 pub async fn get_stats(&self) -> Result<SessionStats> {
315 let count = self.store.count().await?;
316 Ok(SessionStats { total_sessions: count })
317 }
318
319 fn generate_session_id(&self) -> String {
321 use uuid::Uuid;
322 Uuid::new_v4().to_string()
323 }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct SessionStats {
329 pub total_sessions: usize,
330}
331
332#[derive(Debug, Clone)]
334pub struct CookieConfig {
335 pub name: String,
336 pub secure: bool,
337 pub http_only: bool,
338 pub same_site: SameSitePolicy,
339 pub domain: Option<String>,
340 pub path: Option<String>,
341}
342
343impl Default for CookieConfig {
344 fn default() -> Self {
345 Self {
346 name: "session_id".to_string(),
347 secure: true,
348 http_only: true,
349 same_site: SameSitePolicy::Lax,
350 domain: None,
351 path: Some("/".to_string()),
352 }
353 }
354}
355
356impl From<&SessionConfig> for CookieConfig {
357 fn from(config: &SessionConfig) -> Self {
358 Self {
359 name: config.cookie_name.clone(),
360 secure: config.cookie_secure,
361 http_only: config.cookie_http_only,
362 same_site: config.cookie_same_site.clone(),
363 domain: None,
364 path: Some("/".to_string()),
365 }
366 }
367}
368
369pub struct SessionCookie;
371
372impl SessionCookie {
373 pub fn generate_cookie_header(session_id: &str, config: &CookieConfig, max_age: Option<u64>) -> String {
375 let mut cookie = format!("{}={}", config.name, session_id);
376
377 if config.http_only {
378 cookie.push_str("; HttpOnly");
379 }
380
381 if config.secure {
382 cookie.push_str("; Secure");
383 }
384
385 match config.same_site {
386 SameSitePolicy::Strict => cookie.push_str("; SameSite=Strict"),
387 SameSitePolicy::Lax => cookie.push_str("; SameSite=Lax"),
388 SameSitePolicy::None => cookie.push_str("; SameSite=None"),
389 }
390
391 if let Some(domain) = &config.domain {
392 cookie.push_str(&format!("; Domain={}", domain));
393 }
394
395 if let Some(path) = &config.path {
396 cookie.push_str(&format!("; Path={}", path));
397 }
398
399 if let Some(max_age) = max_age {
400 cookie.push_str(&format!("; Max-Age={}", max_age));
401 }
402
403 cookie
404 }
405
406 pub fn parse_session_id(cookie_header: &str, cookie_name: &str) -> Option<String> {
408 for cookie in cookie_header.split(';') {
409 let cookie = cookie.trim();
410 if let Some(value) = cookie.strip_prefix(&format!("{}=", cookie_name)) {
411 return Some(value.to_string());
412 }
413 }
414 None
415 }
416
417 pub fn generate_delete_cookie_header(config: &CookieConfig) -> String {
419 let mut cookie = format!("{}=; Max-Age=0", config.name);
420
421 if config.http_only {
422 cookie.push_str("; HttpOnly");
423 }
424
425 if config.secure {
426 cookie.push_str("; Secure");
427 }
428
429 match config.same_site {
430 SameSitePolicy::Strict => cookie.push_str("; SameSite=Strict"),
431 SameSitePolicy::Lax => cookie.push_str("; SameSite=Lax"),
432 SameSitePolicy::None => cookie.push_str("; SameSite=None"),
433 }
434
435 if let Some(domain) = &config.domain {
436 cookie.push_str(&format!("; Domain={}", domain));
437 }
438
439 if let Some(path) = &config.path {
440 cookie.push_str(&format!("; Path={}", path));
441 }
442
443 cookie
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use std::time::Duration;
451 use tokio::time::sleep;
452
453 async fn create_test_manager() -> SessionManager {
454 let config = SessionConfig::default();
455 SessionManager::new(config)
456 }
457
458 #[tokio::test]
459 async fn test_session_creation() {
460 let manager = create_test_manager().await;
461
462 let session = manager.create_session(
463 "user123",
464 vec!["admin".to_string()],
465 vec!["read".to_string(), "write".to_string()],
466 Some("127.0.0.1".to_string()),
467 Some("Test Browser".to_string()),
468 ).await.unwrap();
469
470 assert_eq!(session.user_id, "user123");
471 assert!(session.has_role("admin"));
472 assert!(session.has_permission("read"));
473 assert!(!session.is_expired());
474 assert!(session.time_until_expiry() > 0);
475 }
476
477 #[tokio::test]
478 async fn test_session_retrieval() {
479 let manager = create_test_manager().await;
480
481 let session = manager.create_session(
482 "user123",
483 vec!["user".to_string()],
484 vec![],
485 None,
486 None,
487 ).await.unwrap();
488
489 let retrieved = manager.get_session(&session.session_id).await.unwrap().unwrap();
490 assert_eq!(retrieved.user_id, "user123");
491 assert_eq!(retrieved.session_id, session.session_id);
492 }
493
494 #[tokio::test]
495 async fn test_session_update() {
496 let manager = create_test_manager().await;
497
498 let mut session = manager.create_session(
499 "user123",
500 vec!["user".to_string()],
501 vec![],
502 None,
503 None,
504 ).await.unwrap();
505
506 session.set_attribute("theme".to_string(), serde_json::Value::String("dark".to_string()));
507 manager.update_session(session.clone()).await.unwrap();
508
509 let updated = manager.get_session(&session.session_id).await.unwrap().unwrap();
510 assert_eq!(updated.get_attribute("theme").unwrap().as_str().unwrap(), "dark");
511 }
512
513 #[tokio::test]
514 async fn test_session_deletion() {
515 let manager = create_test_manager().await;
516
517 let session = manager.create_session(
518 "user123",
519 vec!["user".to_string()],
520 vec![],
521 None,
522 None,
523 ).await.unwrap();
524
525 let retrieved = manager.get_session(&session.session_id).await.unwrap();
527 assert!(retrieved.is_some());
528
529 manager.delete_session(&session.session_id).await.unwrap();
531
532 let retrieved = manager.get_session(&session.session_id).await.unwrap();
534 assert!(retrieved.is_none());
535 }
536
537 #[tokio::test]
538 async fn test_user_session_deletion() {
539 let manager = create_test_manager().await;
540
541 let session1 = manager.create_session(
543 "user123",
544 vec!["user".to_string()],
545 vec![],
546 None,
547 None,
548 ).await.unwrap();
549
550 let session2 = manager.create_session(
551 "user123",
552 vec!["user".to_string()],
553 vec![],
554 None,
555 None,
556 ).await.unwrap();
557
558 let deleted_count = manager.delete_user_sessions("user123").await.unwrap();
560 assert_eq!(deleted_count, 2);
561
562 let retrieved1 = manager.get_session(&session1.session_id).await.unwrap();
564 let retrieved2 = manager.get_session(&session2.session_id).await.unwrap();
565 assert!(retrieved1.is_none());
566 assert!(retrieved2.is_none());
567 }
568
569 #[tokio::test]
570 async fn test_session_extension() {
571 let manager = create_test_manager().await;
572
573 let session = manager.create_session(
574 "user123",
575 vec!["user".to_string()],
576 vec![],
577 None,
578 None,
579 ).await.unwrap();
580
581 let original_expiry = session.time_until_expiry();
582
583 manager.extend_session(&session.session_id, 3600).await.unwrap();
585
586 let updated = manager.get_session(&session.session_id).await.unwrap().unwrap();
587 let new_expiry = updated.time_until_expiry();
588
589 assert!(new_expiry > original_expiry);
591 }
592
593 #[tokio::test]
594 async fn test_cookie_header_generation() {
595 let config = CookieConfig::default();
596 let session_id = "session123";
597
598 let cookie_header = SessionCookie::generate_cookie_header(session_id, &config, Some(3600));
599 assert!(cookie_header.contains("session_id=session123"));
600 assert!(cookie_header.contains("HttpOnly"));
601 assert!(cookie_header.contains("Secure"));
602 assert!(cookie_header.contains("SameSite=Lax"));
603 assert!(cookie_header.contains("Max-Age=3600"));
604 }
605
606 #[tokio::test]
607 async fn test_cookie_parsing() {
608 let cookie_header = "session_id=abc123; other=value; session_id=def456";
609 let session_id = SessionCookie::parse_session_id(cookie_header, "session_id");
610
611 assert_eq!(session_id, Some("abc123".to_string()));
612 }
613
614 #[tokio::test]
615 async fn test_delete_cookie_generation() {
616 let config = CookieConfig::default();
617 let delete_header = SessionCookie::generate_delete_cookie_header(&config);
618
619 assert!(delete_header.contains("session_id="));
620 assert!(delete_header.contains("Max-Age=0"));
621 assert!(delete_header.contains("HttpOnly"));
622 assert!(delete_header.contains("Secure"));
623 }
624
625 #[tokio::test]
626 async fn test_session_attributes() {
627 let session = SessionData::new(
628 "session123".to_string(),
629 "user123".to_string(),
630 vec!["user".to_string()],
631 vec!["read".to_string()],
632 Some(3600),
633 None,
634 None,
635 );
636
637 assert!(session.has_role("user"));
639 assert!(!session.has_role("admin"));
640 assert!(session.has_permission("read"));
641 assert!(!session.has_permission("write"));
642
643 let mut session = session;
645 session.set_attribute("theme".to_string(), serde_json::Value::String("dark".to_string()));
646 session.set_attribute("locale".to_string(), serde_json::Value::String("en".to_string()));
647
648 assert_eq!(session.get_attribute("theme").unwrap().as_str().unwrap(), "dark");
649 assert_eq!(session.get_attribute("locale").unwrap().as_str().unwrap(), "en");
650
651 let removed = session.remove_attribute("theme");
653 assert_eq!(removed.unwrap().as_str().unwrap(), "dark");
654 assert!(session.get_attribute("theme").is_none());
655 }
656}