1use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
12pub struct SessionId(String);
13
14impl SessionId {
15 #[must_use]
17 pub fn generate() -> Self {
18 Self(Uuid::new_v4().to_string())
19 }
20
21 pub fn try_from_string(s: String) -> Result<Self, SessionError> {
27 Uuid::parse_str(&s)
28 .map(|_| Self(s))
29 .map_err(|_| SessionError::InvalidSessionId)
30 }
31
32 #[must_use]
34 pub fn as_str(&self) -> &str {
35 &self.0
36 }
37}
38
39impl std::fmt::Display for SessionId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45impl std::str::FromStr for SessionId {
46 type Err = SessionError;
47
48 fn from_str(s: &str) -> Result<Self, Self::Err> {
49 Self::try_from_string(s.to_string())
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SessionData {
56 pub created_at: DateTime<Utc>,
58 pub last_accessed: DateTime<Utc>,
60 pub expires_at: DateTime<Utc>,
62 pub user_id: Option<i64>,
64 pub data: HashMap<String, serde_json::Value>,
66 pub flash_messages: Vec<FlashMessage>,
68}
69
70impl SessionData {
71 #[must_use]
73 pub fn new() -> Self {
74 let now = Utc::now();
75 Self {
76 created_at: now,
77 last_accessed: now,
78 expires_at: now + Duration::hours(24),
79 user_id: None,
80 data: HashMap::new(),
81 flash_messages: Vec::new(),
82 }
83 }
84
85 #[must_use]
87 pub fn with_expiration(duration: Duration) -> Self {
88 let now = Utc::now();
89 Self {
90 created_at: now,
91 last_accessed: now,
92 expires_at: now + duration,
93 user_id: None,
94 data: HashMap::new(),
95 flash_messages: Vec::new(),
96 }
97 }
98
99 #[must_use]
101 pub fn is_expired(&self) -> bool {
102 Utc::now() > self.expires_at
103 }
104
105 pub fn touch(&mut self, extend_by: Duration) {
107 self.last_accessed = Utc::now();
108 self.expires_at = self.last_accessed + extend_by;
109 }
110
111 pub fn validate_and_touch(&mut self, extend_by: Duration) -> bool {
135 if self.is_expired() {
136 false
137 } else {
138 self.touch(extend_by);
139 true
140 }
141 }
142
143 #[must_use]
145 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
146 self.data
147 .get(key)
148 .and_then(|v| serde_json::from_value(v.clone()).ok())
149 }
150
151 pub fn set<T: Serialize>(&mut self, key: String, value: T) -> Result<(), SessionError> {
157 let json_value = serde_json::to_value(value)?;
158 self.data.insert(key, json_value);
159 Ok(())
160 }
161
162 pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
164 self.data.remove(key)
165 }
166
167 pub fn clear(&mut self) {
169 self.data.clear();
170 self.flash_messages.clear();
171 self.user_id = None;
172 }
173}
174
175impl Default for SessionData {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct FlashMessage {
184 pub level: FlashLevel,
186 pub message: String,
188 pub title: Option<String>,
190}
191
192impl FlashMessage {
193 #[must_use]
195 pub fn success(message: impl Into<String>) -> Self {
196 Self {
197 level: FlashLevel::Success,
198 message: message.into(),
199 title: None,
200 }
201 }
202
203 #[must_use]
205 pub fn info(message: impl Into<String>) -> Self {
206 Self {
207 level: FlashLevel::Info,
208 message: message.into(),
209 title: None,
210 }
211 }
212
213 #[must_use]
215 pub fn warning(message: impl Into<String>) -> Self {
216 Self {
217 level: FlashLevel::Warning,
218 message: message.into(),
219 title: None,
220 }
221 }
222
223 #[must_use]
225 pub fn error(message: impl Into<String>) -> Self {
226 Self {
227 level: FlashLevel::Error,
228 message: message.into(),
229 title: None,
230 }
231 }
232
233 #[must_use]
235 pub fn with_title(mut self, title: impl Into<String>) -> Self {
236 self.title = Some(title.into());
237 self
238 }
239
240 #[must_use]
242 pub const fn css_class(&self) -> &'static str {
243 self.level.css_class()
244 }
245}
246
247#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
249#[serde(rename_all = "lowercase")]
250pub enum FlashLevel {
251 Success,
253 Info,
255 Warning,
257 Error,
259}
260
261impl FlashLevel {
262 #[must_use]
264 pub const fn css_class(self) -> &'static str {
265 match self {
266 Self::Success => "flash-success",
267 Self::Info => "flash-info",
268 Self::Warning => "flash-warning",
269 Self::Error => "flash-error",
270 }
271 }
272}
273
274#[derive(Debug, thiserror::Error)]
276pub enum SessionError {
277 #[error("Invalid session ID")]
279 InvalidSessionId,
280
281 #[error("Session not found")]
283 NotFound,
284
285 #[error("Session expired")]
287 Expired,
288
289 #[error("Serialization error: {0}")]
291 Serialization(#[from] serde_json::Error),
292
293 #[cfg(feature = "redis")]
295 #[error("Redis error: {0}")]
296 Redis(String),
297
298 #[error("Agent error: {0}")]
300 Agent(String),
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_session_id_generate() {
309 let id1 = SessionId::generate();
310 let id2 = SessionId::generate();
311 assert_ne!(id1, id2);
312 }
313
314 #[test]
315 fn test_session_id_from_string() {
316 let uuid_str = "550e8400-e29b-41d4-a716-446655440000";
317 let result = SessionId::try_from_string(uuid_str.to_string());
318 assert!(result.is_ok());
319 }
320
321 #[test]
322 fn test_session_id_invalid() {
323 let result = SessionId::try_from_string("not-a-uuid".to_string());
324 assert!(result.is_err());
325 }
326
327 #[test]
328 fn test_session_data_new() {
329 let data = SessionData::new();
330 assert!(!data.is_expired());
331 assert!(data.user_id.is_none());
332 assert!(data.data.is_empty());
333 }
334
335 #[test]
336 fn test_session_data_expiration() {
337 let data = SessionData::with_expiration(Duration::seconds(-1));
338 assert!(data.is_expired());
339 }
340
341 #[test]
342 fn test_session_data_touch() {
343 let mut data = SessionData::new();
344 let original_expiry = data.expires_at;
345 std::thread::sleep(std::time::Duration::from_millis(10));
346 data.touch(Duration::hours(24));
347 assert!(data.expires_at > original_expiry);
348 }
349
350 #[test]
351 fn test_session_data_validate_and_touch_valid() {
352 let mut data = SessionData::new();
353 let original_expiry = data.expires_at;
354 std::thread::sleep(std::time::Duration::from_millis(10));
355
356 assert!(data.validate_and_touch(Duration::hours(24)));
358 assert!(data.expires_at > original_expiry);
359 }
360
361 #[test]
362 fn test_session_data_validate_and_touch_expired() {
363 let mut data = SessionData::with_expiration(Duration::seconds(-1));
364 let original_expiry = data.expires_at;
365
366 assert!(!data.validate_and_touch(Duration::hours(24)));
368 assert_eq!(data.expires_at, original_expiry);
369 }
370
371 #[test]
372 fn test_session_data_get_set() {
373 let mut data = SessionData::new();
374 data.set("key".to_string(), "value").unwrap();
375 let value: Option<String> = data.get("key");
376 assert_eq!(value, Some("value".to_string()));
377 }
378
379 #[test]
380 fn test_session_data_remove() {
381 let mut data = SessionData::new();
382 data.set("key".to_string(), "value").unwrap();
383 let removed = data.remove("key");
384 assert!(removed.is_some());
385 let value: Option<String> = data.get("key");
386 assert!(value.is_none());
387 }
388
389 #[test]
390 fn test_flash_message_creation() {
391 let flash = FlashMessage::success("Test").with_title("Success");
392 assert_eq!(flash.level, FlashLevel::Success);
393 assert_eq!(flash.message, "Test");
394 assert_eq!(flash.title, Some("Success".to_string()));
395 }
396
397 #[test]
398 fn test_flash_level_css_class() {
399 assert_eq!(FlashLevel::Success.css_class(), "flash-success");
400 assert_eq!(FlashLevel::Info.css_class(), "flash-info");
401 assert_eq!(FlashLevel::Warning.css_class(), "flash-warning");
402 assert_eq!(FlashLevel::Error.css_class(), "flash-error");
403 }
404}