1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use serde::Serialize;
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::event_sink::AuthEvent;
12use crate::handle::AllowThem;
13use crate::types::{Email, Session, SessionId, SessionToken, TokenHash, UserId};
14
15pub struct SessionConfig {
17 pub ttl: Duration,
19 pub cookie_name: &'static str,
21 pub secure: bool,
24}
25
26impl Default for SessionConfig {
27 fn default() -> Self {
28 Self {
29 ttl: Duration::hours(24),
30 cookie_name: "allowthem_session",
31 secure: true,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
39pub struct SessionListEntry {
40 pub id: SessionId,
41 pub user_id: UserId,
42 pub user_email: Email,
43 pub ip_address: Option<String>,
44 pub user_agent: Option<String>,
45 pub expires_at: DateTime<Utc>,
46 pub created_at: DateTime<Utc>,
47}
48
49pub struct ListSessionsParams {
51 pub user_id: Option<UserId>,
52 pub limit: u32,
53 pub offset: u32,
54}
55
56pub struct ListSessionsResult {
58 pub sessions: Vec<SessionListEntry>,
59 pub total: u32,
60}
61
62pub fn generate_token() -> SessionToken {
67 let mut bytes = [0u8; 32];
68 OsRng
69 .try_fill_bytes(&mut bytes)
70 .expect("OS RNG unavailable");
71 SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
72}
73
74pub fn hash_token(token: &SessionToken) -> TokenHash {
79 let digest = Sha256::digest(token.as_str().as_bytes());
80 TokenHash::new_unchecked(format!("{digest:x}"))
81}
82
83impl Db {
84 pub async fn create_session(
89 &self,
90 user_id: UserId,
91 token_hash: TokenHash,
92 ip_address: Option<&str>,
93 user_agent: Option<&str>,
94 expires_at: DateTime<Utc>,
95 ) -> Result<Session, AuthError> {
96 let id = SessionId::new();
97 let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
98 sqlx::query_as::<_, Session>(
99 "INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
100 VALUES (?, ?, ?, ?, ?, ?)
101 RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
102 )
103 .bind(id)
104 .bind(token_hash)
105 .bind(user_id)
106 .bind(ip_address)
107 .bind(user_agent)
108 .bind(expires_at_str)
109 .fetch_one(self.pool())
110 .await
111 .map_err(AuthError::Database)
112 }
113
114 pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
120 let hash = hash_token(token);
121 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
122 sqlx::query_as::<_, Session>(
123 "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
124 FROM allowthem_sessions
125 WHERE token_hash = ? AND expires_at > ?",
126 )
127 .bind(hash)
128 .bind(now)
129 .fetch_optional(self.pool())
130 .await
131 .map_err(AuthError::Database)
132 }
133
134 pub async fn validate_session(
141 &self,
142 token: &SessionToken,
143 ttl: Duration,
144 ) -> Result<Option<Session>, AuthError> {
145 let session = match self.lookup_session(token).await? {
146 Some(s) => s,
147 None => return Ok(None),
148 };
149
150 let now = Utc::now();
151 let halfway = session.expires_at - ttl / 2;
152
153 if now > halfway {
154 let new_expires_at = now + ttl;
155 let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
156 let hash = hash_token(token);
157 sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
158 .bind(&new_expires_str)
159 .bind(hash)
160 .execute(self.pool())
161 .await
162 .map_err(AuthError::Database)?;
163
164 return Ok(Some(Session {
165 expires_at: new_expires_at,
166 ..session
167 }));
168 }
169
170 Ok(Some(session))
171 }
172
173 pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
178 let hash = hash_token(token);
179 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
180 .bind(hash)
181 .execute(self.pool())
182 .await
183 .map_err(AuthError::Database)?;
184 Ok(result.rows_affected() > 0)
185 }
186
187 pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
191 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
192 .bind(*user_id)
193 .execute(self.pool())
194 .await
195 .map_err(AuthError::Database)?;
196 Ok(result.rows_affected())
197 }
198
199 pub async fn list_user_sessions(&self, user_id: UserId) -> Result<Vec<Session>, AuthError> {
203 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
204 sqlx::query_as::<_, Session>(
205 "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at \
206 FROM allowthem_sessions \
207 WHERE user_id = ? AND expires_at > ? \
208 ORDER BY created_at DESC",
209 )
210 .bind(user_id)
211 .bind(now)
212 .fetch_all(self.pool())
213 .await
214 .map_err(AuthError::Database)
215 }
216
217 pub async fn list_all_sessions(
222 &self,
223 params: ListSessionsParams,
224 ) -> Result<ListSessionsResult, AuthError> {
225 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
226
227 if let Some(user_id) = params.user_id {
228 let total = sqlx::query_scalar::<_, i64>(
229 "SELECT COUNT(*) FROM allowthem_sessions s \
230 JOIN allowthem_users u ON s.user_id = u.id \
231 WHERE s.expires_at > ? AND s.user_id = ?",
232 )
233 .bind(&now)
234 .bind(user_id)
235 .fetch_one(self.pool())
236 .await
237 .map_err(AuthError::Database)? as u32;
238
239 let sessions = sqlx::query_as::<_, SessionListEntry>(
240 "SELECT s.id, s.user_id, u.email AS user_email, \
241 s.ip_address, s.user_agent, s.expires_at, s.created_at \
242 FROM allowthem_sessions s \
243 JOIN allowthem_users u ON s.user_id = u.id \
244 WHERE s.expires_at > ? AND s.user_id = ? \
245 ORDER BY s.created_at DESC \
246 LIMIT ? OFFSET ?",
247 )
248 .bind(&now)
249 .bind(user_id)
250 .bind(params.limit)
251 .bind(params.offset)
252 .fetch_all(self.pool())
253 .await
254 .map_err(AuthError::Database)?;
255
256 Ok(ListSessionsResult { sessions, total })
257 } else {
258 let total = sqlx::query_scalar::<_, i64>(
259 "SELECT COUNT(*) FROM allowthem_sessions s \
260 JOIN allowthem_users u ON s.user_id = u.id \
261 WHERE s.expires_at > ?",
262 )
263 .bind(&now)
264 .fetch_one(self.pool())
265 .await
266 .map_err(AuthError::Database)? as u32;
267
268 let sessions = sqlx::query_as::<_, SessionListEntry>(
269 "SELECT s.id, s.user_id, u.email AS user_email, \
270 s.ip_address, s.user_agent, s.expires_at, s.created_at \
271 FROM allowthem_sessions s \
272 JOIN allowthem_users u ON s.user_id = u.id \
273 WHERE s.expires_at > ? \
274 ORDER BY s.created_at DESC \
275 LIMIT ? OFFSET ?",
276 )
277 .bind(&now)
278 .bind(params.limit)
279 .bind(params.offset)
280 .fetch_all(self.pool())
281 .await
282 .map_err(AuthError::Database)?;
283
284 Ok(ListSessionsResult { sessions, total })
285 }
286 }
287
288 pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
293 let result = sqlx::query("DELETE FROM allowthem_sessions WHERE id = ?")
294 .bind(id)
295 .execute(self.pool())
296 .await
297 .map_err(AuthError::Database)?;
298 Ok(result.rows_affected() > 0)
299 }
300}
301
302impl AllowThem {
305 pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
310 let deleted = self.db().delete_session(token).await?;
311 if deleted {
312 self.emit_event(AuthEvent::new(
313 "session.destroyed",
314 None,
315 serde_json::json!({ "scope": "single" }),
316 ))
317 .await;
318 }
319 Ok(deleted)
320 }
321
322 pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
327 let deleted = self.db().delete_session_by_id(id).await?;
328 if deleted {
329 self.emit_event(AuthEvent::new(
330 "session.destroyed",
331 None,
332 serde_json::json!({ "scope": "single" }),
333 ))
334 .await;
335 }
336 Ok(deleted)
337 }
338
339 pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
343 let count = self.db().delete_user_sessions(user_id).await?;
344 self.emit_event(AuthEvent::new(
345 "session.destroyed",
346 Some(*user_id),
347 serde_json::json!({ "user_id": user_id, "count": count, "scope": "all" }),
348 ))
349 .await;
350 Ok(count)
351 }
352}
353
354pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
360 let max_age = config.ttl.num_seconds();
361 let mut cookie = format!(
362 "{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
363 config.cookie_name,
364 token.as_str(),
365 max_age,
366 );
367 if !domain.is_empty() {
368 cookie.push_str("; Domain=");
369 cookie.push_str(domain);
370 }
371 if config.secure {
372 cookie.push_str("; Secure");
373 }
374 cookie
375}
376
377pub fn clear_session_cookie(config: &SessionConfig, domain: &str) -> String {
382 let mut cookie = format!(
383 "{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
384 config.cookie_name,
385 );
386 if !domain.is_empty() {
387 cookie.push_str("; Domain=");
388 cookie.push_str(domain);
389 }
390 if config.secure {
391 cookie.push_str("; Secure");
392 }
393 cookie
394}
395
396pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
401 for pair in cookie_header.split("; ") {
402 if let Some((name, value)) = pair.split_once('=')
403 && name.trim() == cookie_name
404 {
405 return Some(SessionToken::from_encoded(value.trim().to_string()));
406 }
407 }
408 None
409}