Skip to main content

ferro_rs/session/
store.rs

1//! Session storage abstraction
2
3use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::collections::HashMap;
6
7use crate::error::FrameworkError;
8
9/// Session data container
10///
11/// Holds all session data including user authentication state and CSRF token.
12#[derive(Clone, Debug, Default)]
13pub struct SessionData {
14    /// Unique session identifier
15    pub id: String,
16    /// Key-value data stored in the session
17    pub data: HashMap<String, serde_json::Value>,
18    /// Authenticated user ID (if any)
19    pub user_id: Option<i64>,
20    /// CSRF token for this session
21    pub csrf_token: String,
22    /// Whether the session has been modified
23    pub dirty: bool,
24}
25
26impl SessionData {
27    /// Create a new session with the given ID
28    pub fn new(id: String, csrf_token: String) -> Self {
29        Self {
30            id,
31            data: HashMap::new(),
32            user_id: None,
33            csrf_token,
34            dirty: false,
35        }
36    }
37
38    /// Get a value from the session
39    ///
40    /// # Example
41    ///
42    /// ```rust,ignore
43    /// let name: Option<String> = session.get("name");
44    /// ```
45    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
46        self.data
47            .get(key)
48            .and_then(|v| serde_json::from_value(v.clone()).ok())
49    }
50
51    /// Put a value into the session
52    ///
53    /// # Example
54    ///
55    /// ```rust,ignore
56    /// session.put("name", "John");
57    /// session.put("count", 42);
58    /// ```
59    pub fn put<T: Serialize>(&mut self, key: &str, value: T) {
60        if let Ok(v) = serde_json::to_value(value) {
61            self.data.insert(key.to_string(), v);
62            self.dirty = true;
63        }
64    }
65
66    /// Remove a value from the session
67    ///
68    /// Returns the removed value if it existed.
69    pub fn forget(&mut self, key: &str) -> Option<serde_json::Value> {
70        self.dirty = true;
71        self.data.remove(key)
72    }
73
74    /// Check if the session has a key
75    pub fn has(&self, key: &str) -> bool {
76        self.data.contains_key(key)
77    }
78
79    /// Flash a value to the session (available only for next request)
80    ///
81    /// # Example
82    ///
83    /// ```rust,ignore
84    /// session.flash("success", "Item saved successfully!");
85    /// ```
86    pub fn flash<T: Serialize>(&mut self, key: &str, value: T) {
87        self.put(&format!("_flash.new.{key}"), value);
88    }
89
90    /// Get a flashed value (clears it after reading)
91    pub fn get_flash<T: DeserializeOwned>(&mut self, key: &str) -> Option<T> {
92        let flash_key = format!("_flash.old.{key}");
93        let value = self.get(&flash_key);
94        if value.is_some() {
95            self.forget(&flash_key);
96        }
97        value
98    }
99
100    /// Age flash data (move new flash to old, clear old)
101    pub fn age_flash_data(&mut self) {
102        // Remove old flash data
103        let old_keys: Vec<String> = self
104            .data
105            .keys()
106            .filter(|k| k.starts_with("_flash.old."))
107            .cloned()
108            .collect();
109        let had_old = !old_keys.is_empty();
110        for key in old_keys {
111            self.data.remove(&key);
112        }
113
114        // Move new flash data to old
115        let new_keys: Vec<String> = self
116            .data
117            .keys()
118            .filter(|k| k.starts_with("_flash.new."))
119            .cloned()
120            .collect();
121        let had_new = !new_keys.is_empty();
122        for key in new_keys {
123            if let Some(value) = self.data.remove(&key) {
124                let old_key = key.replace("_flash.new.", "_flash.old.");
125                self.data.insert(old_key, value);
126            }
127        }
128
129        if had_new || had_old {
130            self.dirty = true;
131        }
132    }
133
134    /// Clear all session data (keeps ID and regenerates CSRF)
135    pub fn flush(&mut self) {
136        self.data.clear();
137        self.user_id = None;
138        self.dirty = true;
139    }
140
141    /// Check if the session has been modified
142    pub fn is_dirty(&self) -> bool {
143        self.dirty
144    }
145
146    /// Mark the session as clean (after saving)
147    pub fn mark_clean(&mut self) {
148        self.dirty = false;
149    }
150}
151
152/// Session store trait for different backends
153///
154/// Implement this trait to create custom session storage backends.
155#[async_trait]
156pub trait SessionStore: Send + Sync {
157    /// Read a session by its ID
158    ///
159    /// Returns None if the session doesn't exist or has expired.
160    async fn read(&self, id: &str) -> Result<Option<SessionData>, FrameworkError>;
161
162    /// Write a session to storage
163    ///
164    /// Creates a new session if it doesn't exist, updates if it does.
165    async fn write(&self, session: &SessionData) -> Result<(), FrameworkError>;
166
167    /// Destroy a session by its ID
168    async fn destroy(&self, id: &str) -> Result<(), FrameworkError>;
169
170    /// Garbage collect expired sessions
171    ///
172    /// Returns the number of sessions cleaned up.
173    async fn gc(&self) -> Result<u64, FrameworkError>;
174
175    /// Destroy all sessions for a user, optionally keeping one session.
176    ///
177    /// Used for "logout other devices" (pass current session ID to keep) or
178    /// "logout everywhere" (pass None). Returns the number of destroyed sessions.
179    ///
180    /// Default implementation returns an error; override in drivers that support it.
181    async fn destroy_for_user(
182        &self,
183        _user_id: i64,
184        _except_session_id: Option<&str>,
185    ) -> Result<u64, FrameworkError> {
186        Err(FrameworkError::internal(
187            "destroy_for_user not supported by this session driver".to_string(),
188        ))
189    }
190}