better_auth_core/
session.rs1use chrono::Utc;
2use std::sync::Arc;
3
4use crate::adapters::DatabaseAdapter;
5use crate::config::AuthConfig;
6use crate::entity::{AuthSession, AuthUser};
7use crate::error::AuthResult;
8use crate::types::CreateSession;
9
10pub struct SessionManager<DB: DatabaseAdapter> {
12 config: Arc<AuthConfig>,
13 database: Arc<DB>,
14}
15
16impl<DB: DatabaseAdapter> Clone for SessionManager<DB> {
17 fn clone(&self) -> Self {
18 Self {
19 config: self.config.clone(),
20 database: self.database.clone(),
21 }
22 }
23}
24
25impl<DB: DatabaseAdapter> SessionManager<DB> {
26 pub fn new(config: Arc<AuthConfig>, database: Arc<DB>) -> Self {
27 Self { config, database }
28 }
29
30 pub async fn create_session(
32 &self,
33 user: &impl AuthUser,
34 ip_address: Option<String>,
35 user_agent: Option<String>,
36 ) -> AuthResult<DB::Session> {
37 let expires_at = Utc::now() + self.config.session.expires_in;
38
39 let create_session = CreateSession {
40 user_id: user.id().to_string(),
41 expires_at,
42 ip_address,
43 user_agent,
44 impersonated_by: None,
45 active_organization_id: None,
46 };
47
48 let session = self.database.create_session(create_session).await?;
49 Ok(session)
50 }
51
52 pub async fn get_session(&self, token: &str) -> AuthResult<Option<DB::Session>> {
54 let mut session = self.database.get_session(token).await?;
55
56 let should_refresh = if let Some(ref s) = session {
57 let now = Utc::now();
58
59 if s.expires_at() < now || !s.active() {
60 if let Err(err) = self.database.delete_session(token).await {
65 tracing::warn!(
66 error = %err,
67 "Failed to delete expired session; will be retried later"
68 );
69 }
70 return Ok(None);
71 }
72
73 if !self.config.session.disable_session_refresh {
74 match self.config.session.update_age {
75 Some(age) => {
76 let updated = s.updated_at();
79 Utc::now() - updated >= age
80 }
81 None => true,
83 }
84 } else {
85 false
86 }
87 } else {
88 false
89 };
90
91 if should_refresh {
92 let new_expires_at = Utc::now() + self.config.session.expires_in;
93 match self
94 .database
95 .update_session_expiry(token, new_expires_at)
96 .await
97 {
98 Ok(()) => {
99 match self.database.get_session(token).await {
105 Ok(Some(refreshed)) => session = Some(refreshed),
106 Ok(None) => {
107 tracing::warn!(
108 "Session re-read after refresh returned None (concurrent revoke?); returning pre-refresh value"
109 );
110 }
111 Err(err) => {
112 tracing::warn!(
113 error = %err,
114 "Session re-read after refresh failed; returning pre-refresh value"
115 );
116 }
117 }
118 }
119 Err(err) => {
120 tracing::warn!(
125 error = %err,
126 "Failed to refresh session expiry; returning pre-refresh session"
127 );
128 }
129 }
130 }
131
132 Ok(session)
133 }
134
135 pub async fn delete_session(&self, token: &str) -> AuthResult<()> {
137 self.database.delete_session(token).await?;
138 Ok(())
139 }
140
141 pub async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
143 self.database.delete_user_sessions(user_id).await?;
144 Ok(())
145 }
146
147 pub async fn list_user_sessions(&self, user_id: &str) -> AuthResult<Vec<DB::Session>> {
149 let sessions = self.database.get_user_sessions(user_id).await?;
150 let now = Utc::now();
151
152 let active_sessions: Vec<DB::Session> = sessions
154 .into_iter()
155 .filter(|session| session.expires_at() > now && session.active())
156 .collect();
157
158 Ok(active_sessions)
159 }
160
161 pub async fn revoke_session(&self, token: &str) -> AuthResult<bool> {
163 let session_exists = self.get_session(token).await?.is_some();
165
166 if session_exists {
167 self.delete_session(token).await?;
168 Ok(true)
169 } else {
170 Ok(false)
171 }
172 }
173
174 pub async fn revoke_all_user_sessions(&self, user_id: &str) -> AuthResult<usize> {
176 let sessions = self.list_user_sessions(user_id).await?;
178 let count = sessions.len();
179
180 self.delete_user_sessions(user_id).await?;
181 Ok(count)
182 }
183
184 pub async fn revoke_other_user_sessions(
186 &self,
187 user_id: &str,
188 current_token: &str,
189 ) -> AuthResult<usize> {
190 let sessions = self.list_user_sessions(user_id).await?;
191 let mut count = 0;
192
193 for session in sessions {
194 if session.token() != current_token {
195 self.delete_session(session.token()).await?;
196 count += 1;
197 }
198 }
199
200 Ok(count)
201 }
202
203 pub async fn cleanup_expired_sessions(&self) -> AuthResult<usize> {
205 let count = self.database.delete_expired_sessions().await?;
206 Ok(count)
207 }
208
209 pub fn is_session_fresh(&self, session: &impl AuthSession) -> bool {
216 match self.config.session.fresh_age {
217 Some(fresh_age) => session.created_at() + fresh_age > Utc::now(),
218 None => false,
219 }
220 }
221
222 pub fn validate_token_format(&self, token: &str) -> bool {
224 token.starts_with("session_") && token.len() > 40
225 }
226
227 pub fn extract_session_token(&self, req: &crate::types::AuthRequest) -> Option<String> {
232 if let Some(auth_header) = req.headers.get("authorization")
234 && let Some(token) = auth_header.strip_prefix("Bearer ")
235 {
236 return Some(token.to_string());
237 }
238
239 if let Some(cookie_header) = req.headers.get("cookie") {
241 let cookie_name = &self.config.session.cookie_name;
242 for c in cookie::Cookie::split_parse(cookie_header).flatten() {
243 if c.name() == cookie_name && !c.value().is_empty() {
244 return Some(c.value().to_string());
245 }
246 }
247 }
248
249 None
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
257 use crate::config::SessionConfig;
258 use crate::types::{CreateUser, User};
259 use chrono::Duration;
260
261 fn test_config(session: SessionConfig) -> Arc<AuthConfig> {
262 Arc::new(AuthConfig {
263 session,
264 ..AuthConfig::default()
265 })
266 }
267
268 async fn setup() -> (Arc<MemoryDatabaseAdapter>, User) {
269 let db = Arc::new(MemoryDatabaseAdapter::new());
270 let user = db
271 .create_user(CreateUser {
272 email: Some("test@example.com".into()),
273 name: Some("Test User".into()),
274 ..Default::default()
275 })
276 .await
277 .unwrap();
278 (db, user)
279 }
280
281 #[tokio::test]
282 async fn refresh_updates_returned_session_expires_at() {
283 let (db, user) = setup().await;
284 let config = test_config(SessionConfig {
285 expires_in: Duration::hours(1),
286 update_age: None,
287 ..SessionConfig::default()
288 });
289 let mgr = SessionManager::new(config, db.clone());
290
291 let initial = mgr.create_session(&user, None, None).await.unwrap();
292 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
293
294 let refreshed = mgr.get_session(initial.token()).await.unwrap().unwrap();
295 assert!(refreshed.expires_at() > initial.expires_at());
296 }
297
298 #[tokio::test]
299 async fn refresh_is_throttled_by_update_age() {
300 let (db, user) = setup().await;
301 let config = test_config(SessionConfig {
302 expires_in: Duration::hours(1),
303 update_age: Some(Duration::hours(1)),
304 ..SessionConfig::default()
305 });
306 let mgr = SessionManager::new(config, db.clone());
307
308 let initial = mgr.create_session(&user, None, None).await.unwrap();
309 let observed = mgr.get_session(initial.token()).await.unwrap().unwrap();
310 assert_eq!(observed.expires_at(), initial.expires_at());
311 }
312
313 #[tokio::test]
314 async fn refresh_skipped_when_disabled() {
315 let (db, user) = setup().await;
316 let config = test_config(SessionConfig {
317 expires_in: Duration::hours(1),
318 update_age: None,
319 disable_session_refresh: true,
320 ..SessionConfig::default()
321 });
322 let mgr = SessionManager::new(config, db.clone());
323
324 let initial = mgr.create_session(&user, None, None).await.unwrap();
325 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
326
327 let observed = mgr.get_session(initial.token()).await.unwrap().unwrap();
328 assert_eq!(observed.expires_at(), initial.expires_at());
329 }
330
331 #[tokio::test]
332 async fn expired_session_is_removed_and_returns_none() {
333 let (db, user) = setup().await;
334 let config = test_config(SessionConfig::default());
335 let mgr = SessionManager::new(config, db.clone());
336
337 let created = mgr.create_session(&user, None, None).await.unwrap();
338 db.update_session_expiry(created.token(), Utc::now() - Duration::seconds(1))
339 .await
340 .unwrap();
341
342 let result = mgr.get_session(created.token()).await.unwrap();
343 assert!(result.is_none());
344 let still_there = db.get_session(created.token()).await.unwrap();
345 assert!(still_there.is_none());
346 }
347}