Skip to main content

better_auth_api/plugins/
session_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8
9use better_auth_core::utils::cookie_utils::create_clear_session_cookie;
10use better_auth_core::{AuthError, AuthResult};
11use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
12
13use super::StatusResponse;
14
15/// Session management plugin for handling session operations
16pub struct SessionManagementPlugin {
17    config: SessionManagementConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct SessionManagementConfig {
22    pub enable_session_listing: bool,
23    pub enable_session_revocation: bool,
24    pub require_authentication: bool,
25}
26
27// Request structures for session endpoints
28#[derive(Debug, Deserialize, Validate)]
29struct RevokeSessionRequest {
30    #[validate(length(min = 1, message = "Token is required"))]
31    token: String,
32}
33
34// Response structures
35#[derive(Debug, Serialize, Deserialize)]
36struct SignOutResponse {
37    success: bool,
38}
39
40#[derive(Debug, Serialize)]
41struct GetSessionResponse<S: Serialize, U: Serialize> {
42    session: S,
43    user: U,
44}
45
46impl SessionManagementPlugin {
47    pub fn new() -> Self {
48        Self {
49            config: SessionManagementConfig::default(),
50        }
51    }
52
53    pub fn with_config(config: SessionManagementConfig) -> Self {
54        Self { config }
55    }
56
57    pub fn enable_session_listing(mut self, enable: bool) -> Self {
58        self.config.enable_session_listing = enable;
59        self
60    }
61
62    pub fn enable_session_revocation(mut self, enable: bool) -> Self {
63        self.config.enable_session_revocation = enable;
64        self
65    }
66
67    pub fn require_authentication(mut self, require: bool) -> Self {
68        self.config.require_authentication = require;
69        self
70    }
71}
72
73impl Default for SessionManagementPlugin {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Default for SessionManagementConfig {
80    fn default() -> Self {
81        Self {
82            enable_session_listing: true,
83            enable_session_revocation: true,
84            require_authentication: true,
85        }
86    }
87}
88
89#[async_trait]
90impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
91    fn name(&self) -> &'static str {
92        "session-management"
93    }
94
95    fn routes(&self) -> Vec<AuthRoute> {
96        vec![
97            AuthRoute::get("/get-session", "get_session"),
98            AuthRoute::post("/get-session", "get_session_post"),
99            AuthRoute::post("/sign-out", "sign_out"),
100            AuthRoute::get("/list-sessions", "list_sessions"),
101            AuthRoute::post("/revoke-session", "revoke_session"),
102            AuthRoute::post("/revoke-sessions", "revoke_sessions"),
103            AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
104        ]
105    }
106
107    async fn on_request(
108        &self,
109        req: &AuthRequest,
110        ctx: &AuthContext<DB>,
111    ) -> AuthResult<Option<AuthResponse>> {
112        match (req.method(), req.path()) {
113            (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
114                Ok(Some(self.handle_get_session(req, ctx).await?))
115            }
116            (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
117            (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
118                Ok(Some(self.handle_list_sessions(req, ctx).await?))
119            }
120            (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
121                Ok(Some(self.handle_revoke_session(req, ctx).await?))
122            }
123            (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
124                Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
125            }
126            (HttpMethod::Post, "/revoke-other-sessions")
127                if self.config.enable_session_revocation =>
128            {
129                Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
130            }
131            _ => Ok(None),
132        }
133    }
134}
135
136// Implementation methods outside the trait
137impl SessionManagementPlugin {
138    async fn handle_get_session<DB: DatabaseAdapter>(
139        &self,
140        req: &AuthRequest,
141        ctx: &AuthContext<DB>,
142    ) -> AuthResult<AuthResponse> {
143        let (user, session) = ctx.require_session(req).await?;
144        let response = GetSessionResponse { session, user };
145        Ok(AuthResponse::json(200, &response)?)
146    }
147
148    async fn handle_sign_out<DB: DatabaseAdapter>(
149        &self,
150        req: &AuthRequest,
151        ctx: &AuthContext<DB>,
152    ) -> AuthResult<AuthResponse> {
153        let (_user, current_session) = ctx.require_session(req).await?;
154
155        ctx.database.delete_session(current_session.token()).await?;
156
157        let response = SignOutResponse { success: true };
158        let clear_cookie_header = create_clear_session_cookie(ctx);
159
160        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
161    }
162
163    async fn handle_list_sessions<DB: DatabaseAdapter>(
164        &self,
165        req: &AuthRequest,
166        ctx: &AuthContext<DB>,
167    ) -> AuthResult<AuthResponse> {
168        let (user, _) = ctx.require_session(req).await?;
169        let sessions = self.get_user_sessions(user.id(), ctx).await?;
170        Ok(AuthResponse::json(200, &sessions)?)
171    }
172
173    async fn handle_revoke_session<DB: DatabaseAdapter>(
174        &self,
175        req: &AuthRequest,
176        ctx: &AuthContext<DB>,
177    ) -> AuthResult<AuthResponse> {
178        let (user, _) = ctx.require_session(req).await?;
179
180        let revoke_req: RevokeSessionRequest = req
181            .body_as_json()
182            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
183
184        // Verify the session belongs to the current user before revoking
185        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
186        if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
187            && session_to_revoke.user_id() != user.id()
188        {
189            return Err(AuthError::forbidden(
190                "Cannot revoke session that belongs to another user",
191            ));
192        }
193
194        ctx.database.delete_session(&revoke_req.token).await?;
195
196        let response = StatusResponse { status: true };
197        Ok(AuthResponse::json(200, &response)?)
198    }
199
200    async fn handle_revoke_sessions<DB: DatabaseAdapter>(
201        &self,
202        req: &AuthRequest,
203        ctx: &AuthContext<DB>,
204    ) -> AuthResult<AuthResponse> {
205        let (user, _) = ctx.require_session(req).await?;
206        ctx.database.delete_user_sessions(user.id()).await?;
207
208        let response = StatusResponse { status: true };
209        Ok(AuthResponse::json(200, &response)?)
210    }
211
212    async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
213        &self,
214        req: &AuthRequest,
215        ctx: &AuthContext<DB>,
216    ) -> AuthResult<AuthResponse> {
217        let (user, current_session) = ctx.require_session(req).await?;
218
219        let all_sessions = self.get_user_sessions(user.id(), ctx).await?;
220        for session in all_sessions {
221            if session.token() != current_session.token() {
222                ctx.database.delete_session(session.token()).await?;
223            }
224        }
225
226        let response = StatusResponse { status: true };
227        Ok(AuthResponse::json(200, &response)?)
228    }
229
230    async fn get_user_sessions<DB: DatabaseAdapter>(
231        &self,
232        user_id: &str,
233        ctx: &AuthContext<DB>,
234    ) -> AuthResult<Vec<DB::Session>> {
235        ctx.database.get_user_sessions(user_id).await
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::plugins::test_helpers;
243    use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
244    use better_auth_core::{CreateSession, CreateUser, Session};
245    use chrono::{Duration, Utc};
246
247    #[tokio::test]
248    async fn test_get_session_success() {
249        let plugin = SessionManagementPlugin::new();
250        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
251            CreateUser::new()
252                .with_email("test@example.com")
253                .with_name("Test User"),
254            Duration::hours(24),
255        )
256        .await;
257
258        let req = test_helpers::create_auth_request_no_query(
259            HttpMethod::Get,
260            "/get-session",
261            Some(&session.token),
262            None,
263        );
264        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
265
266        assert_eq!(response.status, 200);
267
268        let body_str = String::from_utf8(response.body).unwrap();
269        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
270        assert_eq!(
271            response_data["session"]["token"].as_str().unwrap(),
272            session.token
273        );
274        assert_eq!(
275            response_data["user"]["email"]
276                .as_str()
277                .map(|s| s.to_string()),
278            Some("test@example.com".to_string())
279        );
280    }
281
282    #[tokio::test]
283    async fn test_get_session_unauthorized() {
284        let plugin = SessionManagementPlugin::new();
285        let (ctx, _user, _session) = test_helpers::create_test_context_with_user(
286            CreateUser::new()
287                .with_email("test@example.com")
288                .with_name("Test User"),
289            Duration::hours(24),
290        )
291        .await;
292
293        let req =
294            test_helpers::create_auth_request_no_query(HttpMethod::Get, "/get-session", None, None);
295        let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
296        assert_eq!(err.status_code(), 401);
297    }
298
299    #[tokio::test]
300    async fn test_sign_out_success() {
301        let plugin = SessionManagementPlugin::new();
302        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
303            CreateUser::new()
304                .with_email("test@example.com")
305                .with_name("Test User"),
306            Duration::hours(24),
307        )
308        .await;
309
310        let req = test_helpers::create_auth_request_no_query(
311            HttpMethod::Post,
312            "/sign-out",
313            Some(&session.token),
314            Some(b"{}".to_vec()),
315        );
316        let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
317
318        assert_eq!(response.status, 200);
319
320        let body_str = String::from_utf8(response.body).unwrap();
321        let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
322        assert!(response_data.success);
323
324        let session_check = ctx.database.get_session(&session.token).await.unwrap();
325        assert!(session_check.is_none());
326    }
327
328    #[tokio::test]
329    async fn test_list_sessions_success() {
330        let plugin = SessionManagementPlugin::new();
331        let (ctx, user, session) = test_helpers::create_test_context_with_user(
332            CreateUser::new()
333                .with_email("test@example.com")
334                .with_name("Test User"),
335            Duration::hours(24),
336        )
337        .await;
338
339        let create_session2 = CreateSession {
340            user_id: user.id.clone(),
341            expires_at: Utc::now() + Duration::hours(24),
342            ip_address: Some("192.168.1.1".to_string()),
343            user_agent: Some("another-agent".to_string()),
344            impersonated_by: None,
345            active_organization_id: None,
346        };
347        ctx.database.create_session(create_session2).await.unwrap();
348
349        let req = test_helpers::create_auth_request_no_query(
350            HttpMethod::Get,
351            "/list-sessions",
352            Some(&session.token),
353            None,
354        );
355        let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
356
357        assert_eq!(response.status, 200);
358
359        let body_str = String::from_utf8(response.body).unwrap();
360        let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
361        assert_eq!(sessions.len(), 2);
362    }
363
364    #[tokio::test]
365    async fn test_revoke_session_success() {
366        let plugin = SessionManagementPlugin::new();
367        let (ctx, user, session) = test_helpers::create_test_context_with_user(
368            CreateUser::new()
369                .with_email("test@example.com")
370                .with_name("Test User"),
371            Duration::hours(24),
372        )
373        .await;
374
375        let create_session2 = CreateSession {
376            user_id: user.id.clone(),
377            expires_at: Utc::now() + Duration::hours(24),
378            ip_address: Some("192.168.1.1".to_string()),
379            user_agent: Some("another-agent".to_string()),
380            impersonated_by: None,
381            active_organization_id: None,
382        };
383        let session2 = ctx.database.create_session(create_session2).await.unwrap();
384
385        let body = serde_json::json!({ "token": session2.token });
386        let req = test_helpers::create_auth_request_no_query(
387            HttpMethod::Post,
388            "/revoke-session",
389            Some(&session.token),
390            Some(body.to_string().into_bytes()),
391        );
392
393        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
394        assert_eq!(response.status, 200);
395
396        let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
397        assert!(session2_check.is_none());
398
399        let session1_check = ctx.database.get_session(&session.token).await.unwrap();
400        assert!(session1_check.is_some());
401    }
402
403    #[tokio::test]
404    async fn test_revoke_session_forbidden_different_user() {
405        let plugin = SessionManagementPlugin::new();
406        let (ctx, _user1, session1) = test_helpers::create_test_context_with_user(
407            CreateUser::new()
408                .with_email("test@example.com")
409                .with_name("Test User"),
410            Duration::hours(24),
411        )
412        .await;
413
414        let create_user2 = CreateUser::new()
415            .with_email("user2@example.com")
416            .with_name("User Two");
417        let user2 = ctx.database.create_user(create_user2).await.unwrap();
418
419        let create_session2 = CreateSession {
420            user_id: user2.id,
421            expires_at: Utc::now() + Duration::hours(24),
422            ip_address: Some("192.168.1.1".to_string()),
423            user_agent: Some("another-agent".to_string()),
424            impersonated_by: None,
425            active_organization_id: None,
426        };
427        let session2 = ctx.database.create_session(create_session2).await.unwrap();
428
429        let body = serde_json::json!({ "token": session2.token });
430        let req = test_helpers::create_auth_request_no_query(
431            HttpMethod::Post,
432            "/revoke-session",
433            Some(&session1.token),
434            Some(body.to_string().into_bytes()),
435        );
436
437        let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
438        assert_eq!(err.status_code(), 403);
439    }
440
441    #[tokio::test]
442    async fn test_revoke_sessions_success() {
443        let plugin = SessionManagementPlugin::new();
444        let (ctx, user, session1) = test_helpers::create_test_context_with_user(
445            CreateUser::new()
446                .with_email("test@example.com")
447                .with_name("Test User"),
448            Duration::hours(24),
449        )
450        .await;
451
452        let create_session2 = CreateSession {
453            user_id: user.id.clone(),
454            expires_at: Utc::now() + Duration::hours(24),
455            ip_address: Some("192.168.1.1".to_string()),
456            user_agent: Some("another-agent".to_string()),
457            impersonated_by: None,
458            active_organization_id: None,
459        };
460        ctx.database.create_session(create_session2).await.unwrap();
461
462        let req = test_helpers::create_auth_request_no_query(
463            HttpMethod::Post,
464            "/revoke-sessions",
465            Some(&session1.token),
466            Some(b"{}".to_vec()),
467        );
468        let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
469
470        assert_eq!(response.status, 200);
471
472        let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
473        assert_eq!(user_sessions.len(), 0);
474    }
475
476    #[tokio::test]
477    async fn test_plugin_routes() {
478        let plugin = SessionManagementPlugin::new();
479        let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
480
481        assert_eq!(routes.len(), 7);
482        assert!(
483            routes
484                .iter()
485                .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
486        );
487        assert!(
488            routes
489                .iter()
490                .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
491        );
492        assert!(
493            routes
494                .iter()
495                .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
496        );
497        assert!(
498            routes
499                .iter()
500                .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
501        );
502        assert!(
503            routes
504                .iter()
505                .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
506        );
507    }
508
509    #[tokio::test]
510    async fn test_plugin_on_request_routing() {
511        let plugin = SessionManagementPlugin::new();
512        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
513            CreateUser::new()
514                .with_email("test@example.com")
515                .with_name("Test User"),
516            Duration::hours(24),
517        )
518        .await;
519
520        // Test valid route
521        let req = test_helpers::create_auth_request_no_query(
522            HttpMethod::Get,
523            "/get-session",
524            Some(&session.token),
525            None,
526        );
527        let response = plugin.on_request(&req, &ctx).await.unwrap();
528        assert!(response.is_some());
529        assert_eq!(response.unwrap().status, 200);
530
531        // Test invalid route
532        let req = test_helpers::create_auth_request_no_query(
533            HttpMethod::Get,
534            "/invalid-route",
535            Some(&session.token),
536            None,
537        );
538        let response = plugin.on_request(&req, &ctx).await.unwrap();
539        assert!(response.is_none());
540    }
541
542    #[tokio::test]
543    async fn test_configuration() {
544        let plugin = SessionManagementPlugin::new()
545            .enable_session_listing(false)
546            .enable_session_revocation(false)
547            .require_authentication(false);
548
549        assert!(!plugin.config.enable_session_listing);
550        assert!(!plugin.config.enable_session_revocation);
551        assert!(!plugin.config.require_authentication);
552
553        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
554            CreateUser::new()
555                .with_email("test@example.com")
556                .with_name("Test User"),
557            Duration::hours(24),
558        )
559        .await;
560
561        let req = test_helpers::create_auth_request_no_query(
562            HttpMethod::Get,
563            "/list-sessions",
564            Some(&session.token),
565            None,
566        );
567        let response = plugin.on_request(&req, &ctx).await.unwrap();
568        assert!(response.is_none());
569
570        let req = test_helpers::create_auth_request_no_query(
571            HttpMethod::Post,
572            "/revoke-session",
573            Some(&session.token),
574            Some(b"{}".to_vec()),
575        );
576        let response = plugin.on_request(&req, &ctx).await.unwrap();
577        assert!(response.is_none());
578    }
579}