Skip to main content

better_auth_api/plugins/
session_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11/// Session management plugin for handling session operations
12pub struct SessionManagementPlugin {
13    config: SessionManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct SessionManagementConfig {
18    pub enable_session_listing: bool,
19    pub enable_session_revocation: bool,
20    pub require_authentication: bool,
21}
22
23// Request structures for session endpoints
24#[derive(Debug, Deserialize, Validate)]
25struct RevokeSessionRequest {
26    #[validate(length(min = 1, message = "Token is required"))]
27    token: String,
28}
29
30// Response structures
31#[derive(Debug, Serialize, Deserialize)]
32struct SignOutResponse {
33    success: bool,
34}
35
36#[derive(Debug, Serialize)]
37struct GetSessionResponse<S: Serialize, U: Serialize> {
38    session: S,
39    user: U,
40}
41
42#[derive(Debug, Serialize, Deserialize)]
43struct StatusResponse {
44    status: bool,
45}
46
47impl SessionManagementPlugin {
48    pub fn new() -> Self {
49        Self {
50            config: SessionManagementConfig::default(),
51        }
52    }
53
54    pub fn with_config(config: SessionManagementConfig) -> Self {
55        Self { config }
56    }
57
58    pub fn enable_session_listing(mut self, enable: bool) -> Self {
59        self.config.enable_session_listing = enable;
60        self
61    }
62
63    pub fn enable_session_revocation(mut self, enable: bool) -> Self {
64        self.config.enable_session_revocation = enable;
65        self
66    }
67
68    pub fn require_authentication(mut self, require: bool) -> Self {
69        self.config.require_authentication = require;
70        self
71    }
72}
73
74impl Default for SessionManagementPlugin {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Default for SessionManagementConfig {
81    fn default() -> Self {
82        Self {
83            enable_session_listing: true,
84            enable_session_revocation: true,
85            require_authentication: true,
86        }
87    }
88}
89
90#[async_trait]
91impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
92    fn name(&self) -> &'static str {
93        "session-management"
94    }
95
96    fn routes(&self) -> Vec<AuthRoute> {
97        vec![
98            AuthRoute::get("/get-session", "get_session"),
99            AuthRoute::post("/get-session", "get_session_post"),
100            AuthRoute::post("/sign-out", "sign_out"),
101            AuthRoute::get("/list-sessions", "list_sessions"),
102            AuthRoute::post("/revoke-session", "revoke_session"),
103            AuthRoute::post("/revoke-sessions", "revoke_sessions"),
104            AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
105        ]
106    }
107
108    async fn on_request(
109        &self,
110        req: &AuthRequest,
111        ctx: &AuthContext<DB>,
112    ) -> AuthResult<Option<AuthResponse>> {
113        match (req.method(), req.path()) {
114            (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
115                Ok(Some(self.handle_get_session(req, ctx).await?))
116            }
117            (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
118            (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
119                Ok(Some(self.handle_list_sessions(req, ctx).await?))
120            }
121            (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
122                Ok(Some(self.handle_revoke_session(req, ctx).await?))
123            }
124            (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
125                Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
126            }
127            (HttpMethod::Post, "/revoke-other-sessions")
128                if self.config.enable_session_revocation =>
129            {
130                Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
131            }
132            _ => Ok(None),
133        }
134    }
135}
136
137// Implementation methods outside the trait
138impl SessionManagementPlugin {
139    async fn handle_get_session<DB: DatabaseAdapter>(
140        &self,
141        req: &AuthRequest,
142        ctx: &AuthContext<DB>,
143    ) -> AuthResult<AuthResponse> {
144        let (user, session) = self.require_session(req, ctx).await?;
145        let response = GetSessionResponse { session, user };
146        Ok(AuthResponse::json(200, &response)?)
147    }
148
149    async fn handle_sign_out<DB: DatabaseAdapter>(
150        &self,
151        req: &AuthRequest,
152        ctx: &AuthContext<DB>,
153    ) -> AuthResult<AuthResponse> {
154        let (_user, current_session) = self.require_session(req, ctx).await?;
155
156        ctx.database.delete_session(current_session.token()).await?;
157
158        let response = SignOutResponse { success: true };
159        let clear_cookie_header = self.create_clear_session_cookie(ctx);
160
161        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
162    }
163
164    fn create_clear_session_cookie<DB: DatabaseAdapter>(&self, ctx: &AuthContext<DB>) -> String {
165        let session_config = &ctx.config.session;
166        let secure = if session_config.cookie_secure {
167            "; Secure"
168        } else {
169            ""
170        };
171        let http_only = if session_config.cookie_http_only {
172            "; HttpOnly"
173        } else {
174            ""
175        };
176        let same_site = match session_config.cookie_same_site {
177            better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
178            better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
179            better_auth_core::config::SameSite::None => "; SameSite=None",
180        };
181
182        format!(
183            "{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
184            session_config.cookie_name, secure, http_only, same_site
185        )
186    }
187
188    async fn handle_list_sessions<DB: DatabaseAdapter>(
189        &self,
190        req: &AuthRequest,
191        ctx: &AuthContext<DB>,
192    ) -> AuthResult<AuthResponse> {
193        let (user, _) = self.require_session(req, ctx).await?;
194        let sessions = self.get_user_sessions(user.id(), ctx).await?;
195        Ok(AuthResponse::json(200, &sessions)?)
196    }
197
198    async fn handle_revoke_session<DB: DatabaseAdapter>(
199        &self,
200        req: &AuthRequest,
201        ctx: &AuthContext<DB>,
202    ) -> AuthResult<AuthResponse> {
203        let (user, _) = self.require_session(req, ctx).await?;
204
205        let revoke_req: RevokeSessionRequest = req
206            .body_as_json()
207            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
208
209        // Verify the session belongs to the current user before revoking
210        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
211        if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
212            && session_to_revoke.user_id() != user.id()
213        {
214            return Err(AuthError::forbidden(
215                "Cannot revoke session that belongs to another user",
216            ));
217        }
218
219        ctx.database.delete_session(&revoke_req.token).await?;
220
221        let response = StatusResponse { status: true };
222        Ok(AuthResponse::json(200, &response)?)
223    }
224
225    async fn handle_revoke_sessions<DB: DatabaseAdapter>(
226        &self,
227        req: &AuthRequest,
228        ctx: &AuthContext<DB>,
229    ) -> AuthResult<AuthResponse> {
230        let (user, _) = self.require_session(req, ctx).await?;
231        ctx.database.delete_user_sessions(user.id()).await?;
232
233        let response = StatusResponse { status: true };
234        Ok(AuthResponse::json(200, &response)?)
235    }
236
237    async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
238        &self,
239        req: &AuthRequest,
240        ctx: &AuthContext<DB>,
241    ) -> AuthResult<AuthResponse> {
242        let (user, current_session) = self.require_session(req, ctx).await?;
243
244        let all_sessions = self.get_user_sessions(user.id(), ctx).await?;
245        for session in all_sessions {
246            if session.token() != current_session.token() {
247                ctx.database.delete_session(session.token()).await?;
248            }
249        }
250
251        let response = StatusResponse { status: true };
252        Ok(AuthResponse::json(200, &response)?)
253    }
254
255    /// Extract and validate the current user + session from the request.
256    /// Returns `Err(AuthError::Unauthenticated)` if no valid session.
257    async fn require_session<DB: DatabaseAdapter>(
258        &self,
259        req: &AuthRequest,
260        ctx: &AuthContext<DB>,
261    ) -> AuthResult<(DB::User, DB::Session)> {
262        self.get_current_user_and_session(req, ctx)
263            .await?
264            .ok_or(AuthError::Unauthenticated)
265    }
266
267    async fn get_current_user_and_session<DB: DatabaseAdapter>(
268        &self,
269        req: &AuthRequest,
270        ctx: &AuthContext<DB>,
271    ) -> AuthResult<Option<(DB::User, DB::Session)>> {
272        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
273
274        if let Some(token) = session_manager.extract_session_token(req)
275            && let Some(session) = session_manager.get_session(&token).await?
276            && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
277        {
278            return Ok(Some((user, session)));
279        }
280
281        Ok(None)
282    }
283
284    async fn get_user_sessions<DB: DatabaseAdapter>(
285        &self,
286        user_id: &str,
287        ctx: &AuthContext<DB>,
288    ) -> AuthResult<Vec<DB::Session>> {
289        ctx.database.get_user_sessions(user_id).await
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use better_auth_core::adapters::{DatabaseAdapter, MemoryDatabaseAdapter};
297    use better_auth_core::config::AuthConfig;
298    use better_auth_core::{CreateSession, CreateUser, Session, User};
299    use chrono::{Duration, Utc};
300    use std::collections::HashMap;
301    use std::sync::Arc;
302
303    async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
304    {
305        let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
306        let database = Arc::new(MemoryDatabaseAdapter::new());
307        let ctx = AuthContext::new(config.clone(), database.clone());
308
309        let create_user = CreateUser::new()
310            .with_email("test@example.com")
311            .with_name("Test User");
312        let user = database.create_user(create_user).await.unwrap();
313
314        let create_session = CreateSession {
315            user_id: user.id.clone(),
316            expires_at: Utc::now() + Duration::hours(24),
317            ip_address: Some("127.0.0.1".to_string()),
318            user_agent: Some("test-agent".to_string()),
319            impersonated_by: None,
320            active_organization_id: None,
321        };
322        let session = database.create_session(create_session).await.unwrap();
323
324        (ctx, user, session)
325    }
326
327    fn create_auth_request(
328        method: HttpMethod,
329        path: &str,
330        token: Option<&str>,
331        body: Option<Vec<u8>>,
332    ) -> AuthRequest {
333        let mut headers = HashMap::new();
334        if let Some(token) = token {
335            headers.insert("authorization".to_string(), format!("Bearer {}", token));
336        }
337
338        AuthRequest {
339            method,
340            path: path.to_string(),
341            headers,
342            body,
343            query: HashMap::new(),
344        }
345    }
346
347    #[tokio::test]
348    async fn test_get_session_success() {
349        let plugin = SessionManagementPlugin::new();
350        let (ctx, _user, session) = create_test_context_with_user().await;
351
352        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
353        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
354
355        assert_eq!(response.status, 200);
356
357        let body_str = String::from_utf8(response.body).unwrap();
358        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
359        assert_eq!(
360            response_data["session"]["token"].as_str().unwrap(),
361            session.token
362        );
363        assert_eq!(
364            response_data["user"]["email"]
365                .as_str()
366                .map(|s| s.to_string()),
367            Some("test@example.com".to_string())
368        );
369    }
370
371    #[tokio::test]
372    async fn test_get_session_unauthorized() {
373        let plugin = SessionManagementPlugin::new();
374        let (ctx, _user, _session) = create_test_context_with_user().await;
375
376        let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
377        let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
378        assert_eq!(err.status_code(), 401);
379    }
380
381    #[tokio::test]
382    async fn test_sign_out_success() {
383        let plugin = SessionManagementPlugin::new();
384        let (ctx, _user, session) = create_test_context_with_user().await;
385
386        let req = create_auth_request(
387            HttpMethod::Post,
388            "/sign-out",
389            Some(&session.token),
390            Some(b"{}".to_vec()),
391        );
392        let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
393
394        assert_eq!(response.status, 200);
395
396        let body_str = String::from_utf8(response.body).unwrap();
397        let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
398        assert!(response_data.success);
399
400        let session_check = ctx.database.get_session(&session.token).await.unwrap();
401        assert!(session_check.is_none());
402    }
403
404    #[tokio::test]
405    async fn test_list_sessions_success() {
406        let plugin = SessionManagementPlugin::new();
407        let (ctx, user, session) = create_test_context_with_user().await;
408
409        let create_session2 = CreateSession {
410            user_id: user.id.clone(),
411            expires_at: Utc::now() + Duration::hours(24),
412            ip_address: Some("192.168.1.1".to_string()),
413            user_agent: Some("another-agent".to_string()),
414            impersonated_by: None,
415            active_organization_id: None,
416        };
417        ctx.database.create_session(create_session2).await.unwrap();
418
419        let req = create_auth_request(
420            HttpMethod::Get,
421            "/list-sessions",
422            Some(&session.token),
423            None,
424        );
425        let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
426
427        assert_eq!(response.status, 200);
428
429        let body_str = String::from_utf8(response.body).unwrap();
430        let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
431        assert_eq!(sessions.len(), 2);
432    }
433
434    #[tokio::test]
435    async fn test_revoke_session_success() {
436        let plugin = SessionManagementPlugin::new();
437        let (ctx, user, session) = create_test_context_with_user().await;
438
439        let create_session2 = CreateSession {
440            user_id: user.id.clone(),
441            expires_at: Utc::now() + Duration::hours(24),
442            ip_address: Some("192.168.1.1".to_string()),
443            user_agent: Some("another-agent".to_string()),
444            impersonated_by: None,
445            active_organization_id: None,
446        };
447        let session2 = ctx.database.create_session(create_session2).await.unwrap();
448
449        let body = serde_json::json!({ "token": session2.token });
450        let req = create_auth_request(
451            HttpMethod::Post,
452            "/revoke-session",
453            Some(&session.token),
454            Some(body.to_string().into_bytes()),
455        );
456
457        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
458        assert_eq!(response.status, 200);
459
460        let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
461        assert!(session2_check.is_none());
462
463        let session1_check = ctx.database.get_session(&session.token).await.unwrap();
464        assert!(session1_check.is_some());
465    }
466
467    #[tokio::test]
468    async fn test_revoke_session_forbidden_different_user() {
469        let plugin = SessionManagementPlugin::new();
470        let (ctx, _user1, session1) = create_test_context_with_user().await;
471
472        let create_user2 = CreateUser::new()
473            .with_email("user2@example.com")
474            .with_name("User Two");
475        let user2 = ctx.database.create_user(create_user2).await.unwrap();
476
477        let create_session2 = CreateSession {
478            user_id: user2.id,
479            expires_at: Utc::now() + Duration::hours(24),
480            ip_address: Some("192.168.1.1".to_string()),
481            user_agent: Some("another-agent".to_string()),
482            impersonated_by: None,
483            active_organization_id: None,
484        };
485        let session2 = ctx.database.create_session(create_session2).await.unwrap();
486
487        let body = serde_json::json!({ "token": session2.token });
488        let req = create_auth_request(
489            HttpMethod::Post,
490            "/revoke-session",
491            Some(&session1.token),
492            Some(body.to_string().into_bytes()),
493        );
494
495        let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
496        assert_eq!(err.status_code(), 403);
497    }
498
499    #[tokio::test]
500    async fn test_revoke_sessions_success() {
501        let plugin = SessionManagementPlugin::new();
502        let (ctx, user, session1) = create_test_context_with_user().await;
503
504        let create_session2 = CreateSession {
505            user_id: user.id.clone(),
506            expires_at: Utc::now() + Duration::hours(24),
507            ip_address: Some("192.168.1.1".to_string()),
508            user_agent: Some("another-agent".to_string()),
509            impersonated_by: None,
510            active_organization_id: None,
511        };
512        ctx.database.create_session(create_session2).await.unwrap();
513
514        let req = create_auth_request(
515            HttpMethod::Post,
516            "/revoke-sessions",
517            Some(&session1.token),
518            Some(b"{}".to_vec()),
519        );
520        let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
521
522        assert_eq!(response.status, 200);
523
524        let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
525        assert_eq!(user_sessions.len(), 0);
526    }
527
528    #[tokio::test]
529    async fn test_plugin_routes() {
530        let plugin = SessionManagementPlugin::new();
531        let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
532
533        assert_eq!(routes.len(), 7);
534        assert!(
535            routes
536                .iter()
537                .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
538        );
539        assert!(
540            routes
541                .iter()
542                .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
543        );
544        assert!(
545            routes
546                .iter()
547                .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
548        );
549        assert!(
550            routes
551                .iter()
552                .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
553        );
554        assert!(
555            routes
556                .iter()
557                .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
558        );
559    }
560
561    #[tokio::test]
562    async fn test_plugin_on_request_routing() {
563        let plugin = SessionManagementPlugin::new();
564        let (ctx, _user, session) = create_test_context_with_user().await;
565
566        // Test valid route
567        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
568        let response = plugin.on_request(&req, &ctx).await.unwrap();
569        assert!(response.is_some());
570        assert_eq!(response.unwrap().status, 200);
571
572        // Test invalid route
573        let req = create_auth_request(
574            HttpMethod::Get,
575            "/invalid-route",
576            Some(&session.token),
577            None,
578        );
579        let response = plugin.on_request(&req, &ctx).await.unwrap();
580        assert!(response.is_none());
581    }
582
583    #[tokio::test]
584    async fn test_configuration() {
585        let plugin = SessionManagementPlugin::new()
586            .enable_session_listing(false)
587            .enable_session_revocation(false)
588            .require_authentication(false);
589
590        assert!(!plugin.config.enable_session_listing);
591        assert!(!plugin.config.enable_session_revocation);
592        assert!(!plugin.config.require_authentication);
593
594        let (ctx, _user, session) = create_test_context_with_user().await;
595
596        let req = create_auth_request(
597            HttpMethod::Get,
598            "/list-sessions",
599            Some(&session.token),
600            None,
601        );
602        let response = plugin.on_request(&req, &ctx).await.unwrap();
603        assert!(response.is_none());
604
605        let req = create_auth_request(
606            HttpMethod::Post,
607            "/revoke-session",
608            Some(&session.token),
609            Some(b"{}".to_vec()),
610        );
611        let response = plugin.on_request(&req, &ctx).await.unwrap();
612        assert!(response.is_none());
613    }
614}