acton_htmx/auth/
session.rs

1//! Session types and data structures
2//!
3//! This module provides the core session types used throughout the framework.
4
5use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10/// Unique session identifier
11#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
12pub struct SessionId(String);
13
14impl SessionId {
15    /// Generate a new cryptographically secure session ID
16    #[must_use]
17    pub fn generate() -> Self {
18        Self(Uuid::new_v4().to_string())
19    }
20
21    /// Create from a string (validates format)
22    ///
23    /// # Errors
24    ///
25    /// Returns error if the string is not a valid UUID
26    pub fn try_from_string(s: String) -> Result<Self, SessionError> {
27        Uuid::parse_str(&s)
28            .map(|_| Self(s))
29            .map_err(|_| SessionError::InvalidSessionId)
30    }
31
32    /// Get the session ID as a string reference
33    #[must_use]
34    pub fn as_str(&self) -> &str {
35        &self.0
36    }
37}
38
39impl std::fmt::Display for SessionId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45impl std::str::FromStr for SessionId {
46    type Err = SessionError;
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        Self::try_from_string(s.to_string())
50    }
51}
52
53/// Session data stored per-session
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SessionData {
56    /// When this session was created
57    pub created_at: DateTime<Utc>,
58    /// When this session was last accessed
59    pub last_accessed: DateTime<Utc>,
60    /// When this session expires
61    pub expires_at: DateTime<Utc>,
62    /// User ID (if authenticated)
63    pub user_id: Option<i64>,
64    /// Custom session data (key-value store)
65    pub data: HashMap<String, serde_json::Value>,
66    /// Flash messages queued for next request
67    pub flash_messages: Vec<FlashMessage>,
68}
69
70impl SessionData {
71    /// Create new session data with default expiration (24 hours)
72    #[must_use]
73    pub fn new() -> Self {
74        let now = Utc::now();
75        Self {
76            created_at: now,
77            last_accessed: now,
78            expires_at: now + Duration::hours(24),
79            user_id: None,
80            data: HashMap::new(),
81            flash_messages: Vec::new(),
82        }
83    }
84
85    /// Create session with custom expiration duration
86    #[must_use]
87    pub fn with_expiration(duration: Duration) -> Self {
88        let now = Utc::now();
89        Self {
90            created_at: now,
91            last_accessed: now,
92            expires_at: now + duration,
93            user_id: None,
94            data: HashMap::new(),
95            flash_messages: Vec::new(),
96        }
97    }
98
99    /// Check if session is expired
100    #[must_use]
101    pub fn is_expired(&self) -> bool {
102        Utc::now() > self.expires_at
103    }
104
105    /// Update last accessed time and extend expiration
106    pub fn touch(&mut self, extend_by: Duration) {
107        self.last_accessed = Utc::now();
108        self.expires_at = self.last_accessed + extend_by;
109    }
110
111    /// Validate session is not expired and touch it if valid
112    ///
113    /// This method combines the common pattern of checking expiry and
114    /// extending the session lifetime in a single operation.
115    ///
116    /// # Returns
117    ///
118    /// - `true` if the session is valid (not expired) - session is touched
119    /// - `false` if the session has expired - session is not modified
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use acton_htmx::auth::session::SessionData;
125    /// use chrono::Duration;
126    ///
127    /// let mut session = SessionData::new();
128    /// assert!(session.validate_and_touch(Duration::hours(24)));
129    ///
130    /// // Expired session returns false
131    /// let mut expired = SessionData::with_expiration(Duration::seconds(-1));
132    /// assert!(!expired.validate_and_touch(Duration::hours(24)));
133    /// ```
134    pub fn validate_and_touch(&mut self, extend_by: Duration) -> bool {
135        if self.is_expired() {
136            false
137        } else {
138            self.touch(extend_by);
139            true
140        }
141    }
142
143    /// Get a value from session data
144    #[must_use]
145    pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
146        self.data
147            .get(key)
148            .and_then(|v| serde_json::from_value(v.clone()).ok())
149    }
150
151    /// Set a value in session data
152    ///
153    /// # Errors
154    ///
155    /// Returns error if value cannot be serialized to JSON
156    pub fn set<T: Serialize>(&mut self, key: String, value: T) -> Result<(), SessionError> {
157        let json_value = serde_json::to_value(value)?;
158        self.data.insert(key, json_value);
159        Ok(())
160    }
161
162    /// Remove a value from session data
163    pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
164        self.data.remove(key)
165    }
166
167    /// Clear all session data (keeps metadata)
168    pub fn clear(&mut self) {
169        self.data.clear();
170        self.flash_messages.clear();
171        self.user_id = None;
172    }
173}
174
175impl Default for SessionData {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Flash message for one-time display
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct FlashMessage {
184    /// Message level (success, info, warning, error)
185    pub level: FlashLevel,
186    /// Message text
187    pub message: String,
188    /// Optional title
189    pub title: Option<String>,
190}
191
192impl FlashMessage {
193    /// Create a success flash message
194    #[must_use]
195    pub fn success(message: impl Into<String>) -> Self {
196        Self {
197            level: FlashLevel::Success,
198            message: message.into(),
199            title: None,
200        }
201    }
202
203    /// Create an info flash message
204    #[must_use]
205    pub fn info(message: impl Into<String>) -> Self {
206        Self {
207            level: FlashLevel::Info,
208            message: message.into(),
209            title: None,
210        }
211    }
212
213    /// Create a warning flash message
214    #[must_use]
215    pub fn warning(message: impl Into<String>) -> Self {
216        Self {
217            level: FlashLevel::Warning,
218            message: message.into(),
219            title: None,
220        }
221    }
222
223    /// Create an error flash message
224    #[must_use]
225    pub fn error(message: impl Into<String>) -> Self {
226        Self {
227            level: FlashLevel::Error,
228            message: message.into(),
229            title: None,
230        }
231    }
232
233    /// Set the title for this flash message
234    #[must_use]
235    pub fn with_title(mut self, title: impl Into<String>) -> Self {
236        self.title = Some(title.into());
237        self
238    }
239
240    /// Get CSS class for this flash level
241    #[must_use]
242    pub const fn css_class(&self) -> &'static str {
243        self.level.css_class()
244    }
245}
246
247/// Flash message severity level
248#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
249#[serde(rename_all = "lowercase")]
250pub enum FlashLevel {
251    /// Success message (green)
252    Success,
253    /// Informational message (blue)
254    Info,
255    /// Warning message (yellow)
256    Warning,
257    /// Error message (red)
258    Error,
259}
260
261impl FlashLevel {
262    /// Get CSS class for this level
263    #[must_use]
264    pub const fn css_class(self) -> &'static str {
265        match self {
266            Self::Success => "flash-success",
267            Self::Info => "flash-info",
268            Self::Warning => "flash-warning",
269            Self::Error => "flash-error",
270        }
271    }
272}
273
274/// Session-related errors
275#[derive(Debug, thiserror::Error)]
276pub enum SessionError {
277    /// Invalid session ID format
278    #[error("Invalid session ID")]
279    InvalidSessionId,
280
281    /// Session not found
282    #[error("Session not found")]
283    NotFound,
284
285    /// Session expired
286    #[error("Session expired")]
287    Expired,
288
289    /// Serialization error
290    #[error("Serialization error: {0}")]
291    Serialization(#[from] serde_json::Error),
292
293    /// Redis error
294    #[cfg(feature = "redis")]
295    #[error("Redis error: {0}")]
296    Redis(String),
297
298    /// Agent communication error
299    #[error("Agent error: {0}")]
300    Agent(String),
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_session_id_generate() {
309        let id1 = SessionId::generate();
310        let id2 = SessionId::generate();
311        assert_ne!(id1, id2);
312    }
313
314    #[test]
315    fn test_session_id_from_string() {
316        let uuid_str = "550e8400-e29b-41d4-a716-446655440000";
317        let result = SessionId::try_from_string(uuid_str.to_string());
318        assert!(result.is_ok());
319    }
320
321    #[test]
322    fn test_session_id_invalid() {
323        let result = SessionId::try_from_string("not-a-uuid".to_string());
324        assert!(result.is_err());
325    }
326
327    #[test]
328    fn test_session_data_new() {
329        let data = SessionData::new();
330        assert!(!data.is_expired());
331        assert!(data.user_id.is_none());
332        assert!(data.data.is_empty());
333    }
334
335    #[test]
336    fn test_session_data_expiration() {
337        let data = SessionData::with_expiration(Duration::seconds(-1));
338        assert!(data.is_expired());
339    }
340
341    #[test]
342    fn test_session_data_touch() {
343        let mut data = SessionData::new();
344        let original_expiry = data.expires_at;
345        std::thread::sleep(std::time::Duration::from_millis(10));
346        data.touch(Duration::hours(24));
347        assert!(data.expires_at > original_expiry);
348    }
349
350    #[test]
351    fn test_session_data_validate_and_touch_valid() {
352        let mut data = SessionData::new();
353        let original_expiry = data.expires_at;
354        std::thread::sleep(std::time::Duration::from_millis(10));
355
356        // Valid session should return true and extend expiry
357        assert!(data.validate_and_touch(Duration::hours(24)));
358        assert!(data.expires_at > original_expiry);
359    }
360
361    #[test]
362    fn test_session_data_validate_and_touch_expired() {
363        let mut data = SessionData::with_expiration(Duration::seconds(-1));
364        let original_expiry = data.expires_at;
365
366        // Expired session should return false and not modify expiry
367        assert!(!data.validate_and_touch(Duration::hours(24)));
368        assert_eq!(data.expires_at, original_expiry);
369    }
370
371    #[test]
372    fn test_session_data_get_set() {
373        let mut data = SessionData::new();
374        data.set("key".to_string(), "value").unwrap();
375        let value: Option<String> = data.get("key");
376        assert_eq!(value, Some("value".to_string()));
377    }
378
379    #[test]
380    fn test_session_data_remove() {
381        let mut data = SessionData::new();
382        data.set("key".to_string(), "value").unwrap();
383        let removed = data.remove("key");
384        assert!(removed.is_some());
385        let value: Option<String> = data.get("key");
386        assert!(value.is_none());
387    }
388
389    #[test]
390    fn test_flash_message_creation() {
391        let flash = FlashMessage::success("Test").with_title("Success");
392        assert_eq!(flash.level, FlashLevel::Success);
393        assert_eq!(flash.message, "Test");
394        assert_eq!(flash.title, Some("Success".to_string()));
395    }
396
397    #[test]
398    fn test_flash_level_css_class() {
399        assert_eq!(FlashLevel::Success.css_class(), "flash-success");
400        assert_eq!(FlashLevel::Info.css_class(), "flash-info");
401        assert_eq!(FlashLevel::Warning.css_class(), "flash-warning");
402        assert_eq!(FlashLevel::Error.css_class(), "flash-error");
403    }
404}