1use crate::config::SessionConfig;
2use crate::entity::session::{self, ActiveModel, Column, Entity};
3use crate::meta::SessionMeta;
4use crate::types::{SessionData, SessionId, SessionToken};
5use chrono::{DateTime, Utc};
6use modo::Error;
7use modo::cookies::CookieConfig;
8use modo_db::DbPool;
9use modo_db::sea_orm::{
10 ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
11 QuerySelect, Set, TransactionTrait,
12};
13
14#[derive(Clone)]
23pub struct SessionStore {
24 db: DbPool,
25 config: SessionConfig,
26 cookie_config: CookieConfig,
27}
28
29impl SessionStore {
30 pub fn new(db: &DbPool, config: SessionConfig, cookie_config: CookieConfig) -> Self {
32 Self {
33 db: db.clone(),
34 config,
35 cookie_config,
36 }
37 }
38
39 pub fn config(&self) -> &SessionConfig {
41 &self.config
42 }
43
44 pub fn cookie_config(&self) -> &CookieConfig {
46 &self.cookie_config
47 }
48
49 pub async fn create(
55 &self,
56 meta: &SessionMeta,
57 user_id: &str,
58 data: Option<serde_json::Value>,
59 ) -> Result<(SessionData, SessionToken), Error> {
60 let token = SessionToken::generate();
61 let token_hash = token.hash();
62 let now = Utc::now();
63 let expires_at = now + chrono::Duration::seconds(self.config.session_ttl_secs as i64);
64 let data_json = data.unwrap_or(serde_json::json!({}));
65
66 let model = ActiveModel {
67 id: Set(SessionId::new().to_string()),
68 token_hash: Set(token_hash),
69 user_id: Set(user_id.to_string()),
70 ip_address: Set(meta.ip_address.clone()),
71 user_agent: Set(meta.user_agent.clone()),
72 device_name: Set(meta.device_name.clone()),
73 device_type: Set(meta.device_type.clone()),
74 fingerprint: Set(meta.fingerprint.clone()),
75 data: Set(serde_json::to_string(&data_json)
76 .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
77 created_at: Set(now),
78 last_active_at: Set(now),
79 expires_at: Set(expires_at),
80 };
81
82 let txn = self
85 .db
86 .connection()
87 .begin()
88 .await
89 .map_err(|e| Error::internal(format!("begin transaction: {e}")))?;
90
91 let result = model
92 .insert(&txn)
93 .await
94 .map_err(|e| Error::internal(format!("insert session: {e}")))?;
95
96 self.enforce_session_limit_txn(user_id, &txn).await?;
97
98 txn.commit()
99 .await
100 .map_err(|e| Error::internal(format!("commit transaction: {e}")))?;
101
102 Ok((model_to_session_data(&result)?, token))
103 }
104
105 pub async fn read(&self, id: &SessionId) -> Result<Option<SessionData>, Error> {
109 let model = Entity::find_by_id(id.as_str())
110 .one(self.db.connection())
111 .await
112 .map_err(|e| Error::internal(format!("read session: {e}")))?;
113
114 match model {
115 Some(m) => Ok(Some(model_to_session_data(&m)?)),
116 None => Ok(None),
117 }
118 }
119
120 pub async fn read_by_token(&self, token: &SessionToken) -> Result<Option<SessionData>, Error> {
124 let hash = token.hash();
125 let model = Entity::find()
126 .filter(Column::TokenHash.eq(&hash))
127 .filter(Column::ExpiresAt.gt(Utc::now()))
128 .one(self.db.connection())
129 .await
130 .map_err(|e| Error::internal(format!("read session by token: {e}")))?;
131
132 match model {
133 Some(m) => Ok(Some(model_to_session_data(&m)?)),
134 None => Ok(None),
135 }
136 }
137
138 pub async fn destroy(&self, id: &SessionId) -> Result<(), Error> {
140 Entity::delete_by_id(id.as_str())
141 .exec(self.db.connection())
142 .await
143 .map_err(|e| Error::internal(format!("destroy session: {e}")))?;
144 Ok(())
145 }
146
147 pub async fn rotate_token(&self, id: &SessionId) -> Result<SessionToken, Error> {
150 let new_token = SessionToken::generate();
151 let new_hash = new_token.hash();
152
153 let model = ActiveModel {
154 id: Set(id.as_str().to_string()),
155 token_hash: Set(new_hash),
156 ..Default::default()
157 };
158
159 model
160 .update(self.db.connection())
161 .await
162 .map_err(|e| Error::internal(format!("rotate token: {e}")))?;
163
164 Ok(new_token)
165 }
166
167 pub async fn touch(&self, id: &SessionId, new_expires_at: DateTime<Utc>) -> Result<(), Error> {
169 let model = ActiveModel {
170 id: Set(id.as_str().to_string()),
171 last_active_at: Set(Utc::now()),
172 expires_at: Set(new_expires_at),
173 ..Default::default()
174 };
175
176 model
177 .update(self.db.connection())
178 .await
179 .map_err(|e| Error::internal(format!("touch session: {e}")))?;
180
181 Ok(())
182 }
183
184 pub async fn update_data(&self, id: &SessionId, data: serde_json::Value) -> Result<(), Error> {
186 let model = ActiveModel {
187 id: Set(id.as_str().to_string()),
188 data: Set(serde_json::to_string(&data)
189 .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
190 ..Default::default()
191 };
192
193 model
194 .update(self.db.connection())
195 .await
196 .map_err(|e| Error::internal(format!("update session data: {e}")))?;
197
198 Ok(())
199 }
200
201 pub async fn destroy_all_for_user(&self, user_id: &str) -> Result<(), Error> {
203 Entity::delete_many()
204 .filter(Column::UserId.eq(user_id))
205 .exec(self.db.connection())
206 .await
207 .map_err(|e| Error::internal(format!("destroy all sessions for user: {e}")))?;
208 Ok(())
209 }
210
211 pub async fn destroy_all_except(&self, user_id: &str, keep: &SessionId) -> Result<(), Error> {
214 Entity::delete_many()
215 .filter(Column::UserId.eq(user_id))
216 .filter(Column::Id.ne(keep.as_str()))
217 .exec(self.db.connection())
218 .await
219 .map_err(|e| Error::internal(format!("destroy all except: {e}")))?;
220 Ok(())
221 }
222
223 pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<SessionData>, Error> {
226 let models = Entity::find()
227 .filter(Column::UserId.eq(user_id))
228 .filter(Column::ExpiresAt.gt(Utc::now()))
229 .order_by_desc(Column::LastActiveAt)
230 .all(self.db.connection())
231 .await
232 .map_err(|e| Error::internal(format!("list sessions: {e}")))?;
233
234 models.iter().map(model_to_session_data).collect()
235 }
236
237 pub async fn cleanup_expired(&self) -> Result<u64, Error> {
242 let result = Entity::delete_many()
243 .filter(Column::ExpiresAt.lt(Utc::now()))
244 .exec(self.db.connection())
245 .await
246 .map_err(|e| Error::internal(format!("cleanup expired sessions: {e}")))?;
247 Ok(result.rows_affected)
248 }
249
250 async fn enforce_session_limit_txn(
252 &self,
253 user_id: &str,
254 txn: &modo_db::sea_orm::DatabaseTransaction,
255 ) -> Result<(), Error> {
256 let now = Utc::now();
257
258 let count = Entity::find()
259 .filter(Column::UserId.eq(user_id))
260 .filter(Column::ExpiresAt.gt(now))
261 .count(txn)
262 .await
263 .map_err(|e| Error::internal(format!("count sessions: {e}")))?;
264
265 if count as usize <= self.config.max_sessions_per_user {
266 return Ok(());
267 }
268
269 let excess = count as usize - self.config.max_sessions_per_user;
270
271 let oldest = Entity::find()
273 .filter(Column::UserId.eq(user_id))
274 .filter(Column::ExpiresAt.gt(now))
275 .order_by_asc(Column::LastActiveAt)
276 .limit(excess as u64)
277 .all(txn)
278 .await
279 .map_err(|e| Error::internal(format!("find oldest sessions: {e}")))?;
280
281 let ids: Vec<String> = oldest.into_iter().map(|m| m.id).collect();
282 if !ids.is_empty() {
283 Entity::delete_many()
284 .filter(Column::Id.is_in(ids))
285 .exec(txn)
286 .await
287 .map_err(|e| Error::internal(format!("evict sessions: {e}")))?;
288 }
289
290 Ok(())
291 }
292}
293
294fn model_to_session_data(model: &session::Model) -> Result<SessionData, Error> {
295 let data: serde_json::Value = serde_json::from_str(&model.data)
296 .map_err(|e| Error::internal(format!("deserialize session data: {e}")))?;
297
298 Ok(SessionData {
299 id: SessionId::from_raw(&model.id),
300 token_hash: model.token_hash.clone(),
301 user_id: model.user_id.clone(),
302 ip_address: model.ip_address.clone(),
303 user_agent: model.user_agent.clone(),
304 device_name: model.device_name.clone(),
305 device_type: model.device_type.clone(),
306 fingerprint: model.fingerprint.clone(),
307 data,
308 created_at: model.created_at,
309 last_active_at: model.last_active_at,
310 expires_at: model.expires_at,
311 })
312}