1use crate::config::SessionConfig;
2use crate::entity::session::{self, ActiveModel, Column, Entity};
3use crate::meta::SessionMeta;
4use crate::types::{SessionData, SessionId, SessionToken};
5use chrono::{DateTime, Utc};
6use modo::Error;
7use modo::cookies::CookieConfig;
8use modo_db::DbPool;
9use modo_db::sea_orm::{
10 ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
11 QuerySelect, Set,
12};
13
14#[derive(Clone)]
23pub struct SessionStore {
24 db: DbPool,
25 config: SessionConfig,
26 cookie_config: CookieConfig,
27}
28
29impl SessionStore {
30 pub fn new(db: &DbPool, config: SessionConfig, cookie_config: CookieConfig) -> Self {
32 Self {
33 db: db.clone(),
34 config,
35 cookie_config,
36 }
37 }
38
39 pub fn config(&self) -> &SessionConfig {
41 &self.config
42 }
43
44 pub fn cookie_config(&self) -> &CookieConfig {
46 &self.cookie_config
47 }
48
49 pub async fn create(
55 &self,
56 meta: &SessionMeta,
57 user_id: &str,
58 data: Option<serde_json::Value>,
59 ) -> Result<(SessionData, SessionToken), Error> {
60 let token = SessionToken::generate();
61 let token_hash = token.hash();
62 let now = Utc::now();
63 let expires_at = now + chrono::Duration::seconds(self.config.session_ttl_secs as i64);
64 let data_json = data.unwrap_or(serde_json::json!({}));
65
66 let model = ActiveModel {
67 id: Set(SessionId::new().to_string()),
68 token_hash: Set(token_hash),
69 user_id: Set(user_id.to_string()),
70 ip_address: Set(meta.ip_address.clone()),
71 user_agent: Set(meta.user_agent.clone()),
72 device_name: Set(meta.device_name.clone()),
73 device_type: Set(meta.device_type.clone()),
74 fingerprint: Set(meta.fingerprint.clone()),
75 data: Set(serde_json::to_string(&data_json)
76 .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
77 created_at: Set(now),
78 last_active_at: Set(now),
79 expires_at: Set(expires_at),
80 };
81
82 let result = model
83 .insert(self.db.connection())
84 .await
85 .map_err(|e| Error::internal(format!("insert session: {e}")))?;
86
87 self.enforce_session_limit(user_id).await?;
88
89 Ok((model_to_session_data(&result)?, token))
90 }
91
92 pub async fn read(&self, id: &SessionId) -> Result<Option<SessionData>, Error> {
96 let model = Entity::find_by_id(id.as_str())
97 .one(self.db.connection())
98 .await
99 .map_err(|e| Error::internal(format!("read session: {e}")))?;
100
101 match model {
102 Some(m) => Ok(Some(model_to_session_data(&m)?)),
103 None => Ok(None),
104 }
105 }
106
107 pub async fn read_by_token(&self, token: &SessionToken) -> Result<Option<SessionData>, Error> {
111 let hash = token.hash();
112 let model = Entity::find()
113 .filter(Column::TokenHash.eq(&hash))
114 .filter(Column::ExpiresAt.gt(Utc::now()))
115 .one(self.db.connection())
116 .await
117 .map_err(|e| Error::internal(format!("read session by token: {e}")))?;
118
119 match model {
120 Some(m) => Ok(Some(model_to_session_data(&m)?)),
121 None => Ok(None),
122 }
123 }
124
125 pub async fn destroy(&self, id: &SessionId) -> Result<(), Error> {
127 Entity::delete_by_id(id.as_str())
128 .exec(self.db.connection())
129 .await
130 .map_err(|e| Error::internal(format!("destroy session: {e}")))?;
131 Ok(())
132 }
133
134 pub async fn rotate_token(&self, id: &SessionId) -> Result<SessionToken, Error> {
137 let new_token = SessionToken::generate();
138 let new_hash = new_token.hash();
139
140 let model = ActiveModel {
141 id: Set(id.as_str().to_string()),
142 token_hash: Set(new_hash),
143 ..Default::default()
144 };
145
146 model
147 .update(self.db.connection())
148 .await
149 .map_err(|e| Error::internal(format!("rotate token: {e}")))?;
150
151 Ok(new_token)
152 }
153
154 pub async fn touch(&self, id: &SessionId, new_expires_at: DateTime<Utc>) -> Result<(), Error> {
156 let model = ActiveModel {
157 id: Set(id.as_str().to_string()),
158 last_active_at: Set(Utc::now()),
159 expires_at: Set(new_expires_at),
160 ..Default::default()
161 };
162
163 model
164 .update(self.db.connection())
165 .await
166 .map_err(|e| Error::internal(format!("touch session: {e}")))?;
167
168 Ok(())
169 }
170
171 pub async fn update_data(&self, id: &SessionId, data: serde_json::Value) -> Result<(), Error> {
173 let model = ActiveModel {
174 id: Set(id.as_str().to_string()),
175 data: Set(serde_json::to_string(&data)
176 .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
177 ..Default::default()
178 };
179
180 model
181 .update(self.db.connection())
182 .await
183 .map_err(|e| Error::internal(format!("update session data: {e}")))?;
184
185 Ok(())
186 }
187
188 pub async fn destroy_all_for_user(&self, user_id: &str) -> Result<(), Error> {
190 Entity::delete_many()
191 .filter(Column::UserId.eq(user_id))
192 .exec(self.db.connection())
193 .await
194 .map_err(|e| Error::internal(format!("destroy all sessions for user: {e}")))?;
195 Ok(())
196 }
197
198 pub async fn destroy_all_except(&self, user_id: &str, keep: &SessionId) -> Result<(), Error> {
201 Entity::delete_many()
202 .filter(Column::UserId.eq(user_id))
203 .filter(Column::Id.ne(keep.as_str()))
204 .exec(self.db.connection())
205 .await
206 .map_err(|e| Error::internal(format!("destroy all except: {e}")))?;
207 Ok(())
208 }
209
210 pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<SessionData>, Error> {
213 let models = Entity::find()
214 .filter(Column::UserId.eq(user_id))
215 .filter(Column::ExpiresAt.gt(Utc::now()))
216 .order_by_desc(Column::LastActiveAt)
217 .all(self.db.connection())
218 .await
219 .map_err(|e| Error::internal(format!("list sessions: {e}")))?;
220
221 models.iter().map(model_to_session_data).collect()
222 }
223
224 pub async fn cleanup_expired(&self) -> Result<u64, Error> {
229 let result = Entity::delete_many()
230 .filter(Column::ExpiresAt.lt(Utc::now()))
231 .exec(self.db.connection())
232 .await
233 .map_err(|e| Error::internal(format!("cleanup expired sessions: {e}")))?;
234 Ok(result.rows_affected)
235 }
236
237 async fn enforce_session_limit(&self, user_id: &str) -> Result<(), Error> {
238 let now = Utc::now();
239
240 let count = Entity::find()
241 .filter(Column::UserId.eq(user_id))
242 .filter(Column::ExpiresAt.gt(now))
243 .count(self.db.connection())
244 .await
245 .map_err(|e| Error::internal(format!("count sessions: {e}")))?;
246
247 if count as usize <= self.config.max_sessions_per_user {
248 return Ok(());
249 }
250
251 let excess = count as usize - self.config.max_sessions_per_user;
252
253 let oldest = Entity::find()
255 .filter(Column::UserId.eq(user_id))
256 .filter(Column::ExpiresAt.gt(now))
257 .order_by_asc(Column::LastActiveAt)
258 .limit(excess as u64)
259 .all(self.db.connection())
260 .await
261 .map_err(|e| Error::internal(format!("find oldest sessions: {e}")))?;
262
263 let ids: Vec<String> = oldest.into_iter().map(|m| m.id).collect();
264 if !ids.is_empty() {
265 Entity::delete_many()
266 .filter(Column::Id.is_in(ids))
267 .exec(self.db.connection())
268 .await
269 .map_err(|e| Error::internal(format!("evict sessions: {e}")))?;
270 }
271
272 Ok(())
273 }
274}
275
276fn model_to_session_data(model: &session::Model) -> Result<SessionData, Error> {
277 let data: serde_json::Value = serde_json::from_str(&model.data)
278 .map_err(|e| Error::internal(format!("deserialize session data: {e}")))?;
279
280 Ok(SessionData {
281 id: SessionId::from_raw(&model.id),
282 token_hash: model.token_hash.clone(),
283 user_id: model.user_id.clone(),
284 ip_address: model.ip_address.clone(),
285 user_agent: model.user_agent.clone(),
286 device_name: model.device_name.clone(),
287 device_type: model.device_type.clone(),
288 fingerprint: model.fingerprint.clone(),
289 data,
290 created_at: model.created_at,
291 last_active_at: model.last_active_at,
292 expires_at: model.expires_at,
293 })
294}