Skip to main content

better_auth_core/
session.rs

1use chrono::Utc;
2use std::sync::Arc;
3
4use crate::adapters::DatabaseAdapter;
5use crate::config::AuthConfig;
6use crate::entity::{AuthSession, AuthUser};
7use crate::error::AuthResult;
8use crate::types::CreateSession;
9
10/// Session manager handles session creation, validation, and cleanup
11pub struct SessionManager<DB: DatabaseAdapter> {
12    config: Arc<AuthConfig>,
13    database: Arc<DB>,
14}
15
16impl<DB: DatabaseAdapter> Clone for SessionManager<DB> {
17    fn clone(&self) -> Self {
18        Self {
19            config: self.config.clone(),
20            database: self.database.clone(),
21        }
22    }
23}
24
25impl<DB: DatabaseAdapter> SessionManager<DB> {
26    pub fn new(config: Arc<AuthConfig>, database: Arc<DB>) -> Self {
27        Self { config, database }
28    }
29
30    /// Create a new session for a user
31    pub async fn create_session(
32        &self,
33        user: &impl AuthUser,
34        ip_address: Option<String>,
35        user_agent: Option<String>,
36    ) -> AuthResult<DB::Session> {
37        let expires_at = Utc::now() + self.config.session.expires_in;
38
39        let create_session = CreateSession {
40            user_id: user.id().to_string(),
41            expires_at,
42            ip_address,
43            user_agent,
44            impersonated_by: None,
45            active_organization_id: None,
46        };
47
48        let session = self.database.create_session(create_session).await?;
49        Ok(session)
50    }
51
52    /// Get session by token
53    pub async fn get_session(&self, token: &str) -> AuthResult<Option<DB::Session>> {
54        let mut session = self.database.get_session(token).await?;
55
56        let should_refresh = if let Some(ref s) = session {
57            let now = Utc::now();
58
59            if s.expires_at() < now || !s.active() {
60                // Session expired or inactive — best-effort cleanup. A DB
61                // hiccup here shouldn't turn "your session is expired" into
62                // a 500; the row will be caught by the next access or the
63                // periodic `cleanup_expired_sessions` sweep.
64                if let Err(err) = self.database.delete_session(token).await {
65                    tracing::warn!(
66                        error = %err,
67                        "Failed to delete expired session; will be retried later"
68                    );
69                }
70                return Ok(None);
71            }
72
73            if !self.config.session.disable_session_refresh {
74                match self.config.session.update_age {
75                    Some(age) => {
76                        // Only refresh if the session was last updated more than
77                        // `update_age` ago.
78                        let updated = s.updated_at();
79                        Utc::now() - updated >= age
80                    }
81                    // No update_age set → refresh on every access.
82                    None => true,
83                }
84            } else {
85                false
86            }
87        } else {
88            false
89        };
90
91        if should_refresh {
92            let new_expires_at = Utc::now() + self.config.session.expires_in;
93            match self
94                .database
95                .update_session_expiry(token, new_expires_at)
96                .await
97            {
98                Ok(()) => {
99                    // Re-read so the returned session reflects the new expiry.
100                    // Both failure modes fall back to the pre-refresh session:
101                    // a concurrent revoke (re-read returns None) shouldn't log
102                    // the user out mid-request, and a second DB hiccup
103                    // shouldn't turn a successful refresh into a 500.
104                    match self.database.get_session(token).await {
105                        Ok(Some(refreshed)) => session = Some(refreshed),
106                        Ok(None) => {
107                            tracing::warn!(
108                                "Session re-read after refresh returned None (concurrent revoke?); returning pre-refresh value"
109                            );
110                        }
111                        Err(err) => {
112                            tracing::warn!(
113                                error = %err,
114                                "Session re-read after refresh failed; returning pre-refresh value"
115                            );
116                        }
117                    }
118                }
119                Err(err) => {
120                    // Transient write failure (connection reset, contention,
121                    // etc.) must not fail the whole request. Keep the
122                    // pre-refresh session — auth still works, the refresh
123                    // window will be retried on the next call.
124                    tracing::warn!(
125                        error = %err,
126                        "Failed to refresh session expiry; returning pre-refresh session"
127                    );
128                }
129            }
130        }
131
132        Ok(session)
133    }
134
135    /// Delete a session
136    pub async fn delete_session(&self, token: &str) -> AuthResult<()> {
137        self.database.delete_session(token).await?;
138        Ok(())
139    }
140
141    /// Delete all sessions for a user
142    pub async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
143        self.database.delete_user_sessions(user_id).await?;
144        Ok(())
145    }
146
147    /// Get all active sessions for a user
148    pub async fn list_user_sessions(&self, user_id: &str) -> AuthResult<Vec<DB::Session>> {
149        let sessions = self.database.get_user_sessions(user_id).await?;
150        let now = Utc::now();
151
152        // Filter out expired sessions
153        let active_sessions: Vec<DB::Session> = sessions
154            .into_iter()
155            .filter(|session| session.expires_at() > now && session.active())
156            .collect();
157
158        Ok(active_sessions)
159    }
160
161    /// Revoke a specific session by token
162    pub async fn revoke_session(&self, token: &str) -> AuthResult<bool> {
163        // Check if session exists before trying to delete
164        let session_exists = self.get_session(token).await?.is_some();
165
166        if session_exists {
167            self.delete_session(token).await?;
168            Ok(true)
169        } else {
170            Ok(false)
171        }
172    }
173
174    /// Revoke all sessions for a user
175    pub async fn revoke_all_user_sessions(&self, user_id: &str) -> AuthResult<usize> {
176        // Get count of sessions before deletion for return value
177        let sessions = self.list_user_sessions(user_id).await?;
178        let count = sessions.len();
179
180        self.delete_user_sessions(user_id).await?;
181        Ok(count)
182    }
183
184    /// Revoke all sessions for a user except the current one
185    pub async fn revoke_other_user_sessions(
186        &self,
187        user_id: &str,
188        current_token: &str,
189    ) -> AuthResult<usize> {
190        let sessions = self.list_user_sessions(user_id).await?;
191        let mut count = 0;
192
193        for session in sessions {
194            if session.token() != current_token {
195                self.delete_session(session.token()).await?;
196                count += 1;
197            }
198        }
199
200        Ok(count)
201    }
202
203    /// Cleanup expired sessions
204    pub async fn cleanup_expired_sessions(&self) -> AuthResult<usize> {
205        let count = self.database.delete_expired_sessions().await?;
206        Ok(count)
207    }
208
209    /// Check whether a session is "fresh" (created recently enough for
210    /// sensitive operations like password change or account deletion).
211    ///
212    /// Returns `true` when `fresh_age` is set and
213    /// `session.created_at() + fresh_age > now`.
214    /// If `fresh_age` is `None`, the session is never considered fresh.
215    pub fn is_session_fresh(&self, session: &impl AuthSession) -> bool {
216        match self.config.session.fresh_age {
217            Some(fresh_age) => session.created_at() + fresh_age > Utc::now(),
218            None => false,
219        }
220    }
221
222    /// Validate session token format
223    pub fn validate_token_format(&self, token: &str) -> bool {
224        token.starts_with("session_") && token.len() > 40
225    }
226
227    /// Extract session token from a request.
228    ///
229    /// Tries Bearer token from Authorization header first, then falls back
230    /// to parsing the configured cookie from the Cookie header.
231    pub fn extract_session_token(&self, req: &crate::types::AuthRequest) -> Option<String> {
232        // Try Bearer token first
233        if let Some(auth_header) = req.headers.get("authorization")
234            && let Some(token) = auth_header.strip_prefix("Bearer ")
235        {
236            return Some(token.to_string());
237        }
238
239        // Fall back to cookie (using the `cookie` crate for correct parsing)
240        if let Some(cookie_header) = req.headers.get("cookie") {
241            let cookie_name = &self.config.session.cookie_name;
242            for c in cookie::Cookie::split_parse(cookie_header).flatten() {
243                if c.name() == cookie_name && !c.value().is_empty() {
244                    return Some(c.value().to_string());
245                }
246            }
247        }
248
249        None
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
257    use crate::config::SessionConfig;
258    use crate::types::{CreateUser, User};
259    use chrono::Duration;
260
261    fn test_config(session: SessionConfig) -> Arc<AuthConfig> {
262        Arc::new(AuthConfig {
263            session,
264            ..AuthConfig::default()
265        })
266    }
267
268    async fn setup() -> (Arc<MemoryDatabaseAdapter>, User) {
269        let db = Arc::new(MemoryDatabaseAdapter::new());
270        let user = db
271            .create_user(CreateUser {
272                email: Some("test@example.com".into()),
273                name: Some("Test User".into()),
274                ..Default::default()
275            })
276            .await
277            .unwrap();
278        (db, user)
279    }
280
281    #[tokio::test]
282    async fn refresh_updates_returned_session_expires_at() {
283        let (db, user) = setup().await;
284        let config = test_config(SessionConfig {
285            expires_in: Duration::hours(1),
286            update_age: None,
287            ..SessionConfig::default()
288        });
289        let mgr = SessionManager::new(config, db.clone());
290
291        let initial = mgr.create_session(&user, None, None).await.unwrap();
292        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
293
294        let refreshed = mgr.get_session(initial.token()).await.unwrap().unwrap();
295        assert!(refreshed.expires_at() > initial.expires_at());
296    }
297
298    #[tokio::test]
299    async fn refresh_is_throttled_by_update_age() {
300        let (db, user) = setup().await;
301        let config = test_config(SessionConfig {
302            expires_in: Duration::hours(1),
303            update_age: Some(Duration::hours(1)),
304            ..SessionConfig::default()
305        });
306        let mgr = SessionManager::new(config, db.clone());
307
308        let initial = mgr.create_session(&user, None, None).await.unwrap();
309        let observed = mgr.get_session(initial.token()).await.unwrap().unwrap();
310        assert_eq!(observed.expires_at(), initial.expires_at());
311    }
312
313    #[tokio::test]
314    async fn refresh_skipped_when_disabled() {
315        let (db, user) = setup().await;
316        let config = test_config(SessionConfig {
317            expires_in: Duration::hours(1),
318            update_age: None,
319            disable_session_refresh: true,
320            ..SessionConfig::default()
321        });
322        let mgr = SessionManager::new(config, db.clone());
323
324        let initial = mgr.create_session(&user, None, None).await.unwrap();
325        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
326
327        let observed = mgr.get_session(initial.token()).await.unwrap().unwrap();
328        assert_eq!(observed.expires_at(), initial.expires_at());
329    }
330
331    #[tokio::test]
332    async fn expired_session_is_removed_and_returns_none() {
333        let (db, user) = setup().await;
334        let config = test_config(SessionConfig::default());
335        let mgr = SessionManager::new(config, db.clone());
336
337        let created = mgr.create_session(&user, None, None).await.unwrap();
338        db.update_session_expiry(created.token(), Utc::now() - Duration::seconds(1))
339            .await
340            .unwrap();
341
342        let result = mgr.get_session(created.token()).await.unwrap();
343        assert!(result.is_none());
344        let still_there = db.get_session(created.token()).await.unwrap();
345        assert!(still_there.is_none());
346    }
347}