1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use serde::Serialize;
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::types::{Email, Session, SessionId, SessionToken, TokenHash, UserId};
12
13pub struct SessionConfig {
15 pub ttl: Duration,
17 pub cookie_name: &'static str,
19 pub secure: bool,
22}
23
24impl Default for SessionConfig {
25 fn default() -> Self {
26 Self {
27 ttl: Duration::hours(24),
28 cookie_name: "allowthem_session",
29 secure: true,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
37pub struct SessionListEntry {
38 pub id: SessionId,
39 pub user_id: UserId,
40 pub user_email: Email,
41 pub ip_address: Option<String>,
42 pub user_agent: Option<String>,
43 pub expires_at: DateTime<Utc>,
44 pub created_at: DateTime<Utc>,
45}
46
47pub struct ListSessionsParams {
49 pub user_id: Option<UserId>,
50 pub limit: u32,
51 pub offset: u32,
52}
53
54pub struct ListSessionsResult {
56 pub sessions: Vec<SessionListEntry>,
57 pub total: u32,
58}
59
60pub fn generate_token() -> SessionToken {
65 let mut bytes = [0u8; 32];
66 OsRng
67 .try_fill_bytes(&mut bytes)
68 .expect("OS RNG unavailable");
69 SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
70}
71
72pub fn hash_token(token: &SessionToken) -> TokenHash {
77 let digest = Sha256::digest(token.as_str().as_bytes());
78 TokenHash::new_unchecked(format!("{digest:x}"))
79}
80
81impl Db {
82 pub async fn create_session(
87 &self,
88 user_id: UserId,
89 token_hash: TokenHash,
90 ip_address: Option<&str>,
91 user_agent: Option<&str>,
92 expires_at: DateTime<Utc>,
93 ) -> Result<Session, AuthError> {
94 let id = SessionId::new();
95 let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
96 sqlx::query_as::<_, Session>(
97 "INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
98 VALUES (?, ?, ?, ?, ?, ?)
99 RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
100 )
101 .bind(id)
102 .bind(token_hash)
103 .bind(user_id)
104 .bind(ip_address)
105 .bind(user_agent)
106 .bind(expires_at_str)
107 .fetch_one(self.pool())
108 .await
109 .map_err(AuthError::Database)
110 }
111
112 pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
118 let hash = hash_token(token);
119 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
120 sqlx::query_as::<_, Session>(
121 "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
122 FROM allowthem_sessions
123 WHERE token_hash = ? AND expires_at > ?",
124 )
125 .bind(hash)
126 .bind(now)
127 .fetch_optional(self.pool())
128 .await
129 .map_err(AuthError::Database)
130 }
131
132 pub async fn validate_session(
139 &self,
140 token: &SessionToken,
141 ttl: Duration,
142 ) -> Result<Option<Session>, AuthError> {
143 let session = match self.lookup_session(token).await? {
144 Some(s) => s,
145 None => return Ok(None),
146 };
147
148 let now = Utc::now();
149 let halfway = session.expires_at - ttl / 2;
150
151 if now > halfway {
152 let new_expires_at = now + ttl;
153 let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
154 let hash = hash_token(token);
155 sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
156 .bind(&new_expires_str)
157 .bind(hash)
158 .execute(self.pool())
159 .await
160 .map_err(AuthError::Database)?;
161
162 return Ok(Some(Session {
163 expires_at: new_expires_at,
164 ..session
165 }));
166 }
167
168 Ok(Some(session))
169 }
170
171 pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
176 let hash = hash_token(token);
177 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
178 .bind(hash)
179 .execute(self.pool())
180 .await
181 .map_err(AuthError::Database)?;
182 Ok(result.rows_affected() > 0)
183 }
184
185 pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
189 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
190 .bind(*user_id)
191 .execute(self.pool())
192 .await
193 .map_err(AuthError::Database)?;
194 Ok(result.rows_affected())
195 }
196
197 pub async fn list_user_sessions(&self, user_id: UserId) -> Result<Vec<Session>, AuthError> {
201 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
202 sqlx::query_as::<_, Session>(
203 "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at \
204 FROM allowthem_sessions \
205 WHERE user_id = ? AND expires_at > ? \
206 ORDER BY created_at DESC",
207 )
208 .bind(user_id)
209 .bind(now)
210 .fetch_all(self.pool())
211 .await
212 .map_err(AuthError::Database)
213 }
214
215 pub async fn list_all_sessions(
220 &self,
221 params: ListSessionsParams,
222 ) -> Result<ListSessionsResult, AuthError> {
223 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
224
225 if let Some(user_id) = params.user_id {
226 let total = sqlx::query_scalar::<_, i64>(
227 "SELECT COUNT(*) FROM allowthem_sessions s \
228 JOIN allowthem_users u ON s.user_id = u.id \
229 WHERE s.expires_at > ? AND s.user_id = ?",
230 )
231 .bind(&now)
232 .bind(user_id)
233 .fetch_one(self.pool())
234 .await
235 .map_err(AuthError::Database)? as u32;
236
237 let sessions = sqlx::query_as::<_, SessionListEntry>(
238 "SELECT s.id, s.user_id, u.email AS user_email, \
239 s.ip_address, s.user_agent, s.expires_at, s.created_at \
240 FROM allowthem_sessions s \
241 JOIN allowthem_users u ON s.user_id = u.id \
242 WHERE s.expires_at > ? AND s.user_id = ? \
243 ORDER BY s.created_at DESC \
244 LIMIT ? OFFSET ?",
245 )
246 .bind(&now)
247 .bind(user_id)
248 .bind(params.limit)
249 .bind(params.offset)
250 .fetch_all(self.pool())
251 .await
252 .map_err(AuthError::Database)?;
253
254 Ok(ListSessionsResult { sessions, total })
255 } else {
256 let total = sqlx::query_scalar::<_, i64>(
257 "SELECT COUNT(*) FROM allowthem_sessions s \
258 JOIN allowthem_users u ON s.user_id = u.id \
259 WHERE s.expires_at > ?",
260 )
261 .bind(&now)
262 .fetch_one(self.pool())
263 .await
264 .map_err(AuthError::Database)? as u32;
265
266 let sessions = sqlx::query_as::<_, SessionListEntry>(
267 "SELECT s.id, s.user_id, u.email AS user_email, \
268 s.ip_address, s.user_agent, s.expires_at, s.created_at \
269 FROM allowthem_sessions s \
270 JOIN allowthem_users u ON s.user_id = u.id \
271 WHERE s.expires_at > ? \
272 ORDER BY s.created_at DESC \
273 LIMIT ? OFFSET ?",
274 )
275 .bind(&now)
276 .bind(params.limit)
277 .bind(params.offset)
278 .fetch_all(self.pool())
279 .await
280 .map_err(AuthError::Database)?;
281
282 Ok(ListSessionsResult { sessions, total })
283 }
284 }
285
286 pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
291 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE id = ?")
292 .bind(id)
293 .execute(self.pool())
294 .await
295 .map_err(AuthError::Database)?;
296 Ok(result.rows_affected() > 0)
297 }
298}
299
300pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
306 let max_age = config.ttl.num_seconds();
307 let mut cookie = format!(
308 "{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
309 config.cookie_name,
310 token.as_str(),
311 max_age,
312 );
313 if !domain.is_empty() {
314 cookie.push_str("; Domain=");
315 cookie.push_str(domain);
316 }
317 if config.secure {
318 cookie.push_str("; Secure");
319 }
320 cookie
321}
322
323pub fn clear_session_cookie(config: &SessionConfig, domain: &str) -> String {
328 let mut cookie = format!(
329 "{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
330 config.cookie_name,
331 );
332 if !domain.is_empty() {
333 cookie.push_str("; Domain=");
334 cookie.push_str(domain);
335 }
336 if config.secure {
337 cookie.push_str("; Secure");
338 }
339 cookie
340}
341
342pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
347 for pair in cookie_header.split("; ") {
348 if let Some((name, value)) = pair.split_once('=')
349 && name.trim() == cookie_name
350 {
351 return Some(SessionToken::from_encoded(value.trim().to_string()));
352 }
353 }
354 None
355}