Skip to main content

better_auth_api/plugins/
session_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
6use better_auth_core::{AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse, HttpMethod, Session, User};
8
9/// Session management plugin for handling session operations
10pub struct SessionManagementPlugin {
11    config: SessionManagementConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct SessionManagementConfig {
16    pub enable_session_listing: bool,
17    pub enable_session_revocation: bool,
18    pub require_authentication: bool,
19}
20
21// Request structures for session endpoints
22#[derive(Debug, Deserialize, Validate)]
23struct RevokeSessionRequest {
24    #[validate(length(min = 1, message = "Token is required"))]
25    token: String,
26}
27
28// Response structures
29#[derive(Debug, Serialize, Deserialize)]
30struct GetSessionResponse {
31    session: Session,
32    user: User,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36struct SignOutResponse {
37    success: bool,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41struct StatusResponse {
42    status: bool,
43}
44
45impl SessionManagementPlugin {
46    pub fn new() -> Self {
47        Self {
48            config: SessionManagementConfig::default(),
49        }
50    }
51
52    pub fn with_config(config: SessionManagementConfig) -> Self {
53        Self { config }
54    }
55
56    pub fn enable_session_listing(mut self, enable: bool) -> Self {
57        self.config.enable_session_listing = enable;
58        self
59    }
60
61    pub fn enable_session_revocation(mut self, enable: bool) -> Self {
62        self.config.enable_session_revocation = enable;
63        self
64    }
65
66    pub fn require_authentication(mut self, require: bool) -> Self {
67        self.config.require_authentication = require;
68        self
69    }
70}
71
72impl Default for SessionManagementPlugin {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl Default for SessionManagementConfig {
79    fn default() -> Self {
80        Self {
81            enable_session_listing: true,
82            enable_session_revocation: true,
83            require_authentication: true,
84        }
85    }
86}
87
88#[async_trait]
89impl AuthPlugin for SessionManagementPlugin {
90    fn name(&self) -> &'static str {
91        "session-management"
92    }
93
94    fn routes(&self) -> Vec<AuthRoute> {
95        vec![
96            AuthRoute::get("/get-session", "get_session"),
97            AuthRoute::post("/get-session", "get_session_post"),
98            AuthRoute::post("/sign-out", "sign_out"),
99            AuthRoute::get("/list-sessions", "list_sessions"),
100            AuthRoute::post("/revoke-session", "revoke_session"),
101            AuthRoute::post("/revoke-sessions", "revoke_sessions"),
102            AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
103        ]
104    }
105
106    async fn on_request(
107        &self,
108        req: &AuthRequest,
109        ctx: &AuthContext,
110    ) -> AuthResult<Option<AuthResponse>> {
111        match (req.method(), req.path()) {
112            (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
113                Ok(Some(self.handle_get_session(req, ctx).await?))
114            }
115            (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
116            (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
117                Ok(Some(self.handle_list_sessions(req, ctx).await?))
118            }
119            (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
120                Ok(Some(self.handle_revoke_session(req, ctx).await?))
121            }
122            (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
123                Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
124            }
125            (HttpMethod::Post, "/revoke-other-sessions")
126                if self.config.enable_session_revocation =>
127            {
128                Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
129            }
130            _ => Ok(None),
131        }
132    }
133}
134
135// Implementation methods outside the trait
136impl SessionManagementPlugin {
137    async fn handle_get_session(
138        &self,
139        req: &AuthRequest,
140        ctx: &AuthContext,
141    ) -> AuthResult<AuthResponse> {
142        let (user, session) = self.require_session(req, ctx).await?;
143        let response = GetSessionResponse { session, user };
144        Ok(AuthResponse::json(200, &response)?)
145    }
146
147    async fn handle_sign_out(
148        &self,
149        req: &AuthRequest,
150        ctx: &AuthContext,
151    ) -> AuthResult<AuthResponse> {
152        let (_user, current_session) = self.require_session(req, ctx).await?;
153
154        ctx.database.delete_session(&current_session.token).await?;
155
156        let response = SignOutResponse { success: true };
157        let clear_cookie_header = self.create_clear_session_cookie(ctx);
158
159        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
160    }
161
162    fn create_clear_session_cookie(&self, ctx: &AuthContext) -> String {
163        let session_config = &ctx.config.session;
164        let secure = if session_config.cookie_secure {
165            "; Secure"
166        } else {
167            ""
168        };
169        let http_only = if session_config.cookie_http_only {
170            "; HttpOnly"
171        } else {
172            ""
173        };
174        let same_site = match session_config.cookie_same_site {
175            better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
176            better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
177            better_auth_core::config::SameSite::None => "; SameSite=None",
178        };
179
180        format!(
181            "{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
182            session_config.cookie_name, secure, http_only, same_site
183        )
184    }
185
186    async fn handle_list_sessions(
187        &self,
188        req: &AuthRequest,
189        ctx: &AuthContext,
190    ) -> AuthResult<AuthResponse> {
191        let (user, _) = self.require_session(req, ctx).await?;
192        let sessions = self.get_user_sessions(&user.id, ctx).await?;
193        Ok(AuthResponse::json(200, &sessions)?)
194    }
195
196    async fn handle_revoke_session(
197        &self,
198        req: &AuthRequest,
199        ctx: &AuthContext,
200    ) -> AuthResult<AuthResponse> {
201        let (user, _) = self.require_session(req, ctx).await?;
202
203        let revoke_req: RevokeSessionRequest = req
204            .body_as_json()
205            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
206
207        // Verify the session belongs to the current user before revoking
208        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
209        if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
210            && session_to_revoke.user_id != user.id
211        {
212            return Err(AuthError::forbidden(
213                "Cannot revoke session that belongs to another user",
214            ));
215        }
216
217        ctx.database.delete_session(&revoke_req.token).await?;
218
219        let response = StatusResponse { status: true };
220        Ok(AuthResponse::json(200, &response)?)
221    }
222
223    async fn handle_revoke_sessions(
224        &self,
225        req: &AuthRequest,
226        ctx: &AuthContext,
227    ) -> AuthResult<AuthResponse> {
228        let (user, _) = self.require_session(req, ctx).await?;
229        ctx.database.delete_user_sessions(&user.id).await?;
230
231        let response = StatusResponse { status: true };
232        Ok(AuthResponse::json(200, &response)?)
233    }
234
235    async fn handle_revoke_other_sessions(
236        &self,
237        req: &AuthRequest,
238        ctx: &AuthContext,
239    ) -> AuthResult<AuthResponse> {
240        let (user, current_session) = self.require_session(req, ctx).await?;
241
242        let all_sessions = self.get_user_sessions(&user.id, ctx).await?;
243        for session in all_sessions {
244            if session.token != current_session.token {
245                ctx.database.delete_session(&session.token).await?;
246            }
247        }
248
249        let response = StatusResponse { status: true };
250        Ok(AuthResponse::json(200, &response)?)
251    }
252
253    /// Extract and validate the current user + session from the request.
254    /// Returns `Err(AuthError::Unauthenticated)` if no valid session.
255    async fn require_session(
256        &self,
257        req: &AuthRequest,
258        ctx: &AuthContext,
259    ) -> AuthResult<(User, Session)> {
260        self.get_current_user_and_session(req, ctx)
261            .await?
262            .ok_or(AuthError::Unauthenticated)
263    }
264
265    async fn get_current_user_and_session(
266        &self,
267        req: &AuthRequest,
268        ctx: &AuthContext,
269    ) -> AuthResult<Option<(User, Session)>> {
270        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
271
272        if let Some(token) = session_manager.extract_session_token(req)
273            && let Some(session) = session_manager.get_session(&token).await?
274            && let Some(user) = ctx.database.get_user_by_id(&session.user_id).await?
275        {
276            return Ok(Some((user, session)));
277        }
278
279        Ok(None)
280    }
281
282    async fn get_user_sessions(
283        &self,
284        user_id: &str,
285        ctx: &AuthContext,
286    ) -> AuthResult<Vec<Session>> {
287        ctx.database.get_user_sessions(user_id).await
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use better_auth_core::adapters::{DatabaseAdapter, MemoryDatabaseAdapter};
295    use better_auth_core::config::AuthConfig;
296    use better_auth_core::{CreateSession, CreateUser};
297    use chrono::{Duration, Utc};
298    use std::collections::HashMap;
299    use std::sync::Arc;
300
301    async fn create_test_context_with_user() -> (AuthContext, User, Session) {
302        let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
303        let database = Arc::new(MemoryDatabaseAdapter::new());
304        let ctx = AuthContext::new(config.clone(), database.clone());
305
306        let create_user = CreateUser::new()
307            .with_email("test@example.com")
308            .with_name("Test User");
309        let user = database.create_user(create_user).await.unwrap();
310
311        let create_session = CreateSession {
312            user_id: user.id.clone(),
313            expires_at: Utc::now() + Duration::hours(24),
314            ip_address: Some("127.0.0.1".to_string()),
315            user_agent: Some("test-agent".to_string()),
316            impersonated_by: None,
317            active_organization_id: None,
318        };
319        let session = database.create_session(create_session).await.unwrap();
320
321        (ctx, user, session)
322    }
323
324    fn create_auth_request(
325        method: HttpMethod,
326        path: &str,
327        token: Option<&str>,
328        body: Option<Vec<u8>>,
329    ) -> AuthRequest {
330        let mut headers = HashMap::new();
331        if let Some(token) = token {
332            headers.insert("authorization".to_string(), format!("Bearer {}", token));
333        }
334
335        AuthRequest {
336            method,
337            path: path.to_string(),
338            headers,
339            body,
340            query: HashMap::new(),
341        }
342    }
343
344    #[tokio::test]
345    async fn test_get_session_success() {
346        let plugin = SessionManagementPlugin::new();
347        let (ctx, _user, session) = create_test_context_with_user().await;
348
349        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
350        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
351
352        assert_eq!(response.status, 200);
353
354        let body_str = String::from_utf8(response.body).unwrap();
355        let response_data: GetSessionResponse = serde_json::from_str(&body_str).unwrap();
356        assert_eq!(response_data.session.token, session.token);
357        assert_eq!(
358            response_data.user.email,
359            Some("test@example.com".to_string())
360        );
361    }
362
363    #[tokio::test]
364    async fn test_get_session_unauthorized() {
365        let plugin = SessionManagementPlugin::new();
366        let (ctx, _user, _session) = create_test_context_with_user().await;
367
368        let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
369        let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
370        assert_eq!(err.status_code(), 401);
371    }
372
373    #[tokio::test]
374    async fn test_sign_out_success() {
375        let plugin = SessionManagementPlugin::new();
376        let (ctx, _user, session) = create_test_context_with_user().await;
377
378        let req = create_auth_request(
379            HttpMethod::Post,
380            "/sign-out",
381            Some(&session.token),
382            Some(b"{}".to_vec()),
383        );
384        let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
385
386        assert_eq!(response.status, 200);
387
388        let body_str = String::from_utf8(response.body).unwrap();
389        let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
390        assert!(response_data.success);
391
392        let session_check = ctx.database.get_session(&session.token).await.unwrap();
393        assert!(session_check.is_none());
394    }
395
396    #[tokio::test]
397    async fn test_list_sessions_success() {
398        let plugin = SessionManagementPlugin::new();
399        let (ctx, user, session) = create_test_context_with_user().await;
400
401        let create_session2 = CreateSession {
402            user_id: user.id.clone(),
403            expires_at: Utc::now() + Duration::hours(24),
404            ip_address: Some("192.168.1.1".to_string()),
405            user_agent: Some("another-agent".to_string()),
406            impersonated_by: None,
407            active_organization_id: None,
408        };
409        ctx.database.create_session(create_session2).await.unwrap();
410
411        let req = create_auth_request(
412            HttpMethod::Get,
413            "/list-sessions",
414            Some(&session.token),
415            None,
416        );
417        let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
418
419        assert_eq!(response.status, 200);
420
421        let body_str = String::from_utf8(response.body).unwrap();
422        let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
423        assert_eq!(sessions.len(), 2);
424    }
425
426    #[tokio::test]
427    async fn test_revoke_session_success() {
428        let plugin = SessionManagementPlugin::new();
429        let (ctx, user, session) = create_test_context_with_user().await;
430
431        let create_session2 = CreateSession {
432            user_id: user.id.clone(),
433            expires_at: Utc::now() + Duration::hours(24),
434            ip_address: Some("192.168.1.1".to_string()),
435            user_agent: Some("another-agent".to_string()),
436            impersonated_by: None,
437            active_organization_id: None,
438        };
439        let session2 = ctx.database.create_session(create_session2).await.unwrap();
440
441        let body = serde_json::json!({ "token": session2.token });
442        let req = create_auth_request(
443            HttpMethod::Post,
444            "/revoke-session",
445            Some(&session.token),
446            Some(body.to_string().into_bytes()),
447        );
448
449        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
450        assert_eq!(response.status, 200);
451
452        let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
453        assert!(session2_check.is_none());
454
455        let session1_check = ctx.database.get_session(&session.token).await.unwrap();
456        assert!(session1_check.is_some());
457    }
458
459    #[tokio::test]
460    async fn test_revoke_session_forbidden_different_user() {
461        let plugin = SessionManagementPlugin::new();
462        let (ctx, _user1, session1) = create_test_context_with_user().await;
463
464        let create_user2 = CreateUser::new()
465            .with_email("user2@example.com")
466            .with_name("User Two");
467        let user2 = ctx.database.create_user(create_user2).await.unwrap();
468
469        let create_session2 = CreateSession {
470            user_id: user2.id,
471            expires_at: Utc::now() + Duration::hours(24),
472            ip_address: Some("192.168.1.1".to_string()),
473            user_agent: Some("another-agent".to_string()),
474            impersonated_by: None,
475            active_organization_id: None,
476        };
477        let session2 = ctx.database.create_session(create_session2).await.unwrap();
478
479        let body = serde_json::json!({ "token": session2.token });
480        let req = create_auth_request(
481            HttpMethod::Post,
482            "/revoke-session",
483            Some(&session1.token),
484            Some(body.to_string().into_bytes()),
485        );
486
487        let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
488        assert_eq!(err.status_code(), 403);
489    }
490
491    #[tokio::test]
492    async fn test_revoke_sessions_success() {
493        let plugin = SessionManagementPlugin::new();
494        let (ctx, user, session1) = create_test_context_with_user().await;
495
496        let create_session2 = CreateSession {
497            user_id: user.id.clone(),
498            expires_at: Utc::now() + Duration::hours(24),
499            ip_address: Some("192.168.1.1".to_string()),
500            user_agent: Some("another-agent".to_string()),
501            impersonated_by: None,
502            active_organization_id: None,
503        };
504        ctx.database.create_session(create_session2).await.unwrap();
505
506        let req = create_auth_request(
507            HttpMethod::Post,
508            "/revoke-sessions",
509            Some(&session1.token),
510            Some(b"{}".to_vec()),
511        );
512        let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
513
514        assert_eq!(response.status, 200);
515
516        let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
517        assert_eq!(user_sessions.len(), 0);
518    }
519
520    #[tokio::test]
521    async fn test_plugin_routes() {
522        let plugin = SessionManagementPlugin::new();
523        let routes = plugin.routes();
524
525        assert_eq!(routes.len(), 7);
526        assert!(
527            routes
528                .iter()
529                .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
530        );
531        assert!(
532            routes
533                .iter()
534                .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
535        );
536        assert!(
537            routes
538                .iter()
539                .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
540        );
541        assert!(
542            routes
543                .iter()
544                .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
545        );
546        assert!(
547            routes
548                .iter()
549                .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
550        );
551    }
552
553    #[tokio::test]
554    async fn test_plugin_on_request_routing() {
555        let plugin = SessionManagementPlugin::new();
556        let (ctx, _user, session) = create_test_context_with_user().await;
557
558        // Test valid route
559        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
560        let response = plugin.on_request(&req, &ctx).await.unwrap();
561        assert!(response.is_some());
562        assert_eq!(response.unwrap().status, 200);
563
564        // Test invalid route
565        let req = create_auth_request(
566            HttpMethod::Get,
567            "/invalid-route",
568            Some(&session.token),
569            None,
570        );
571        let response = plugin.on_request(&req, &ctx).await.unwrap();
572        assert!(response.is_none());
573    }
574
575    #[tokio::test]
576    async fn test_configuration() {
577        let plugin = SessionManagementPlugin::new()
578            .enable_session_listing(false)
579            .enable_session_revocation(false)
580            .require_authentication(false);
581
582        assert!(!plugin.config.enable_session_listing);
583        assert!(!plugin.config.enable_session_revocation);
584        assert!(!plugin.config.require_authentication);
585
586        let (ctx, _user, session) = create_test_context_with_user().await;
587
588        let req = create_auth_request(
589            HttpMethod::Get,
590            "/list-sessions",
591            Some(&session.token),
592            None,
593        );
594        let response = plugin.on_request(&req, &ctx).await.unwrap();
595        assert!(response.is_none());
596
597        let req = create_auth_request(
598            HttpMethod::Post,
599            "/revoke-session",
600            Some(&session.token),
601            Some(b"{}".to_vec()),
602        );
603        let response = plugin.on_request(&req, &ctx).await.unwrap();
604        assert!(response.is_none());
605    }
606}