Skip to main content

ferro_rs/session/driver/
database.rs

1//! Database-backed session storage driver
2
3use async_trait::async_trait;
4use sea_orm::entity::prelude::*;
5use sea_orm::{QueryFilter, Set};
6use std::collections::HashMap;
7use std::time::Duration;
8
9use crate::database::DB;
10use crate::error::FrameworkError;
11use crate::session::store::{SessionData, SessionStore};
12
13/// Database session driver using SeaORM
14///
15/// Stores sessions in a `sessions` table with the following schema:
16/// - id: VARCHAR (primary key) - session ID
17/// - user_id: BIGINT (nullable) - authenticated user ID
18/// - payload: TEXT - JSON serialized session data
19/// - csrf_token: VARCHAR - CSRF protection token
20/// - last_activity: TIMESTAMP - last access time
21pub struct DatabaseSessionDriver {
22    lifetime: Duration,
23}
24
25impl DatabaseSessionDriver {
26    /// Create a new database session driver
27    pub fn new(lifetime: Duration) -> Self {
28        Self { lifetime }
29    }
30}
31
32#[async_trait]
33impl SessionStore for DatabaseSessionDriver {
34    async fn read(&self, id: &str) -> Result<Option<SessionData>, FrameworkError> {
35        let db = DB::connection()?;
36
37        let result = sessions::Entity::find_by_id(id)
38            .one(db.inner())
39            .await
40            .map_err(|e| FrameworkError::database(e.to_string()))?;
41
42        if let Some(session) = result {
43            // Check if expired
44            let now = chrono::Utc::now().naive_utc();
45            let expiry =
46                session.last_activity + chrono::Duration::seconds(self.lifetime.as_secs() as i64);
47
48            if now > expiry {
49                // Session expired, clean it up
50                let _ = self.destroy(id).await;
51                return Ok(None);
52            }
53
54            // Parse the payload
55            let data: HashMap<String, serde_json::Value> =
56                serde_json::from_str(&session.payload).unwrap_or_default();
57
58            Ok(Some(SessionData {
59                id: session.id,
60                data,
61                user_id: session.user_id,
62                csrf_token: session.csrf_token,
63                dirty: false,
64            }))
65        } else {
66            Ok(None)
67        }
68    }
69
70    async fn write(&self, session: &SessionData) -> Result<(), FrameworkError> {
71        let db = DB::connection()?;
72
73        let payload = serde_json::to_string(&session.data)
74            .map_err(|e| FrameworkError::internal(format!("Session serialize error: {e}")))?;
75
76        let now = chrono::Utc::now().naive_utc();
77
78        // Check if session exists
79        let existing = sessions::Entity::find_by_id(&session.id)
80            .one(db.inner())
81            .await
82            .map_err(|e| FrameworkError::database(e.to_string()))?;
83
84        if existing.is_some() {
85            // Update existing session
86            let update = sessions::ActiveModel {
87                id: Set(session.id.clone()),
88                user_id: Set(session.user_id),
89                payload: Set(payload),
90                csrf_token: Set(session.csrf_token.clone()),
91                last_activity: Set(now),
92            };
93
94            sessions::Entity::update(update)
95                .exec(db.inner())
96                .await
97                .map_err(|e| FrameworkError::database(e.to_string()))?;
98        } else {
99            // Insert new session
100            let model = sessions::ActiveModel {
101                id: Set(session.id.clone()),
102                user_id: Set(session.user_id),
103                payload: Set(payload),
104                csrf_token: Set(session.csrf_token.clone()),
105                last_activity: Set(now),
106            };
107
108            sessions::Entity::insert(model)
109                .exec(db.inner())
110                .await
111                .map_err(|e| FrameworkError::database(e.to_string()))?;
112        }
113
114        Ok(())
115    }
116
117    async fn destroy(&self, id: &str) -> Result<(), FrameworkError> {
118        let db = DB::connection()?;
119
120        sessions::Entity::delete_by_id(id)
121            .exec(db.inner())
122            .await
123            .map_err(|e| FrameworkError::database(e.to_string()))?;
124
125        Ok(())
126    }
127
128    async fn gc(&self) -> Result<u64, FrameworkError> {
129        let db = DB::connection()?;
130
131        let threshold = chrono::Utc::now().naive_utc()
132            - chrono::Duration::seconds(self.lifetime.as_secs() as i64);
133
134        let result = sessions::Entity::delete_many()
135            .filter(sessions::Column::LastActivity.lt(threshold))
136            .exec(db.inner())
137            .await
138            .map_err(|e| FrameworkError::database(e.to_string()))?;
139
140        Ok(result.rows_affected)
141    }
142}
143
144/// Sessions table entity for SeaORM
145pub mod sessions {
146    use sea_orm::entity::prelude::*;
147
148    #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
149    #[sea_orm(table_name = "sessions")]
150    pub struct Model {
151        #[sea_orm(primary_key, auto_increment = false)]
152        pub id: String,
153        pub user_id: Option<i64>,
154        #[sea_orm(column_type = "Text")]
155        pub payload: String,
156        pub csrf_token: String,
157        pub last_activity: chrono::NaiveDateTime,
158    }
159
160    #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
161    pub enum Relation {}
162
163    impl ActiveModelBehavior for ActiveModel {}
164}