better_auth/plugins/
session_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3
4use crate::core::{AuthPlugin, AuthRoute, AuthContext, SessionManager};
5use crate::types::{AuthRequest, AuthResponse, HttpMethod, User, Session};
6use crate::error::{AuthError, AuthResult};
7
8/// Session management plugin for handling session operations
9pub struct SessionManagementPlugin {
10    config: SessionManagementConfig,
11}
12
13#[derive(Debug, Clone)]
14pub struct SessionManagementConfig {
15    pub enable_session_listing: bool,
16    pub enable_session_revocation: bool,
17    pub require_authentication: bool,
18}
19
20// Request structures for session endpoints
21#[derive(Debug, Deserialize)]
22struct RevokeSessionRequest {
23    token: String,
24}
25
26// Response structures
27#[derive(Debug, Serialize, Deserialize)]
28struct GetSessionResponse {
29    session: Session,
30    user: User,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34struct SignOutResponse {
35    success: bool,
36}
37
38#[derive(Debug, Serialize, Deserialize)]
39struct StatusResponse {
40    status: bool,
41}
42
43impl SessionManagementPlugin {
44    pub fn new() -> Self {
45        Self {
46            config: SessionManagementConfig::default(),
47        }
48    }
49    
50    pub fn with_config(config: SessionManagementConfig) -> Self {
51        Self { config }
52    }
53    
54    pub fn enable_session_listing(mut self, enable: bool) -> Self {
55        self.config.enable_session_listing = enable;
56        self
57    }
58    
59    pub fn enable_session_revocation(mut self, enable: bool) -> Self {
60        self.config.enable_session_revocation = enable;
61        self
62    }
63    
64    pub fn require_authentication(mut self, require: bool) -> Self {
65        self.config.require_authentication = require;
66        self
67    }
68}
69
70impl Default for SessionManagementPlugin {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Default for SessionManagementConfig {
77    fn default() -> Self {
78        Self {
79            enable_session_listing: true,
80            enable_session_revocation: true,
81            require_authentication: true,
82        }
83    }
84}
85
86#[async_trait]
87impl AuthPlugin for SessionManagementPlugin {
88    fn name(&self) -> &'static str {
89        "session-management"
90    }
91    
92    fn routes(&self) -> Vec<AuthRoute> {
93        vec![
94            AuthRoute::get("/get-session", "get_session"),
95            AuthRoute::post("/sign-out", "sign_out"),
96            AuthRoute::get("/list-sessions", "list_sessions"),
97            AuthRoute::post("/revoke-session", "revoke_session"),
98            AuthRoute::post("/revoke-sessions", "revoke_sessions"),
99        ]
100    }
101    
102    async fn on_request(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<AuthResponse>> {
103        match (req.method(), req.path()) {
104            (HttpMethod::Get, "/get-session") => {
105                Ok(Some(self.handle_get_session(req, ctx).await?))
106            },
107            (HttpMethod::Post, "/sign-out") => {
108                Ok(Some(self.handle_sign_out(req, ctx).await?))
109            },
110            (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
111                Ok(Some(self.handle_list_sessions(req, ctx).await?))
112            },
113            (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
114                Ok(Some(self.handle_revoke_session(req, ctx).await?))
115            },
116            (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
117                Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
118            },
119            _ => Ok(None),
120        }
121    }
122}
123
124// Implementation methods outside the trait
125impl SessionManagementPlugin {
126    async fn handle_get_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
127        // Get current user and session
128        let (user, session) = match self.get_current_user_and_session(req, ctx).await? {
129            Some((user, session)) => (user, session),
130            None => {
131                return Ok(AuthResponse::json(401, &serde_json::json!({
132                    "error": "Unauthorized",
133                    "message": "No valid session found"
134                }))?);
135            }
136        };
137        
138        let response = GetSessionResponse { session, user };
139        Ok(AuthResponse::json(200, &response)?)
140    }
141    
142    async fn handle_sign_out(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
143        // Get current session
144        let (_user, current_session) = match self.get_current_user_and_session(req, ctx).await? {
145            Some((user, session)) => (user, session),
146            None => {
147                return Ok(AuthResponse::json(401, &serde_json::json!({
148                    "error": "Unauthorized",
149                    "message": "No valid session found"
150                }))?);
151            }
152        };
153        
154        // Delete the current session
155        ctx.database.delete_session(&current_session.token).await?;
156        
157        let response = SignOutResponse { success: true };
158        // Clear session cookie
159        let clear_cookie_header = self.create_clear_session_cookie(ctx);
160        
161        Ok(AuthResponse::json(200, &response)?
162            .with_header("Set-Cookie", clear_cookie_header))
163    }
164    
165    fn create_clear_session_cookie(&self, ctx: &AuthContext) -> String {
166        let session_config = &ctx.config.session;
167        let secure = if session_config.cookie_secure { "; Secure" } else { "" };
168        let http_only = if session_config.cookie_http_only { "; HttpOnly" } else { "" };
169        let same_site = match session_config.cookie_same_site {
170            crate::core::config::SameSite::Strict => "; SameSite=Strict",
171            crate::core::config::SameSite::Lax => "; SameSite=Lax", 
172            crate::core::config::SameSite::None => "; SameSite=None",
173        };
174        
175        format!("{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
176                session_config.cookie_name,
177                secure,
178                http_only,
179                same_site)
180    }
181    
182    async fn handle_list_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
183        // Get current user
184        let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
185            Some((user, session)) => (user, session),
186            None => {
187                return Ok(AuthResponse::json(401, &serde_json::json!({
188                    "error": "Unauthorized",
189                    "message": "No valid session found"
190                }))?);
191            }
192        };
193        
194        // Get all user sessions from database
195        let sessions = self.get_user_sessions(&user.id, ctx).await?;
196        
197        // Return sessions as an array directly (matching OpenAPI spec)
198        Ok(AuthResponse::json(200, &sessions)?)
199    }
200    
201    async fn handle_revoke_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
202        // Get current user to ensure they're authenticated
203        let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
204            Some((user, session)) => (user, session),
205            None => {
206                return Ok(AuthResponse::json(401, &serde_json::json!({
207                    "error": "Unauthorized",
208                    "message": "No valid session found"
209                }))?);
210            }
211        };
212        
213        let revoke_req: RevokeSessionRequest = match req.body_as_json() {
214            Ok(req) => req,
215            Err(e) => {
216                return Ok(AuthResponse::json(400, &serde_json::json!({
217                    "error": "Invalid request",
218                    "message": format!("Invalid JSON: {}", e)
219                }))?);
220            }
221        };
222        
223        // Get the session token to revoke
224        let session_token = &revoke_req.token;
225        
226        // Verify the session belongs to the current user before revoking
227        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
228        if let Some(session_to_revoke) = session_manager.get_session(session_token).await? {
229            if session_to_revoke.user_id != user.id {
230                return Ok(AuthResponse::json(403, &serde_json::json!({
231                    "error": "Forbidden",
232                    "message": "Cannot revoke session that belongs to another user"
233                }))?);
234            }
235        }
236        
237        // Revoke the session
238        ctx.database.delete_session(session_token).await?;
239        
240        let response = StatusResponse { status: true };
241        Ok(AuthResponse::json(200, &response)?)
242    }
243    
244    async fn handle_revoke_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
245        // Get current user to ensure they're authenticated
246        let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
247            Some((user, session)) => (user, session),
248            None => {
249                return Ok(AuthResponse::json(401, &serde_json::json!({
250                    "error": "Unauthorized",
251                    "message": "No valid session found"
252                }))?);
253            }
254        };
255        
256        // Revoke all sessions for the user
257        ctx.database.delete_user_sessions(&user.id).await?;
258        
259        let response = StatusResponse { status: true };
260        Ok(AuthResponse::json(200, &response)?)
261    }
262    
263    async fn handle_revoke_other_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
264        // Get current user and session
265        let (user, current_session) = match self.get_current_user_and_session(req, ctx).await? {
266            Some((user, session)) => (user, session),
267            None => {
268                return Ok(AuthResponse::json(401, &serde_json::json!({
269                    "error": "Unauthorized",
270                    "message": "No valid session found"
271                }))?);
272            }
273        };
274        
275        // Get all sessions for the user
276        let all_sessions = self.get_user_sessions(&user.id, ctx).await?;
277        
278        // Revoke all sessions except the current one
279        for session in all_sessions {
280            if session.token != current_session.token {
281                ctx.database.delete_session(&session.token).await?;
282            }
283        }
284        
285        let response = StatusResponse { status: true };
286        Ok(AuthResponse::json(200, &response)?)
287    }
288    
289    async fn get_current_user_and_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<(User, Session)>> {
290        // Extract session token from Authorization header
291        let token = if let Some(auth_header) = req.headers.get("authorization") {
292            if auth_header.starts_with("Bearer ") {
293                Some(&auth_header[7..])
294            } else {
295                None
296            }
297        } else {
298            None
299        };
300        
301        if let Some(token) = token {
302            let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
303            if let Some(session) = session_manager.get_session(token).await? {
304                if let Some(user) = ctx.database.get_user_by_id(&session.user_id).await? {
305                    return Ok(Some((user, session)));
306                }
307            }
308        }
309        
310        Ok(None)
311    }
312    
313    async fn get_user_sessions(&self, user_id: &str, ctx: &AuthContext) -> AuthResult<Vec<Session>> {
314        ctx.database.get_user_sessions(user_id).await
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::core::config::AuthConfig;
322    use crate::adapters::{MemoryDatabaseAdapter, DatabaseAdapter};
323    use crate::types::{CreateUser, CreateSession};
324    use chrono::{Utc, Duration};
325    use std::collections::HashMap;
326    use std::sync::Arc;
327    
328    async fn create_test_context_with_user() -> (AuthContext, User, Session) {
329        let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
330        let database = Arc::new(MemoryDatabaseAdapter::new());
331        let ctx = AuthContext::new(config.clone(), database.clone());
332        
333        // Create test user
334        let create_user = CreateUser::new()
335            .with_email("test@example.com")
336            .with_name("Test User");
337        let user = database.create_user(create_user).await.unwrap();
338        
339        // Create test session
340        let create_session = CreateSession {
341            user_id: user.id.clone(),
342            expires_at: Utc::now() + Duration::hours(24),
343            ip_address: Some("127.0.0.1".to_string()),
344            user_agent: Some("test-agent".to_string()),
345            impersonated_by: None,
346            active_organization_id: None,
347        };
348        let session = database.create_session(create_session).await.unwrap();
349        
350        (ctx, user, session)
351    }
352    
353    fn create_auth_request(method: HttpMethod, path: &str, token: Option<&str>, body: Option<Vec<u8>>) -> AuthRequest {
354        let mut headers = HashMap::new();
355        if let Some(token) = token {
356            headers.insert("authorization".to_string(), format!("Bearer {}", token));
357        }
358        
359        AuthRequest {
360            method,
361            path: path.to_string(),
362            headers,
363            body,
364            query: HashMap::new(),
365        }
366    }
367    
368    #[tokio::test]
369    async fn test_get_session_success() {
370        let plugin = SessionManagementPlugin::new();
371        let (ctx, _user, session) = create_test_context_with_user().await;
372        
373        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
374        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
375        
376        assert_eq!(response.status, 200);
377        
378        let body_str = String::from_utf8(response.body).unwrap();
379        let response_data: GetSessionResponse = serde_json::from_str(&body_str).unwrap();
380        assert_eq!(response_data.session.token, session.token);
381        assert_eq!(response_data.user.email, Some("test@example.com".to_string()));
382    }
383    
384    #[tokio::test]
385    async fn test_get_session_unauthorized() {
386        let plugin = SessionManagementPlugin::new();
387        let (ctx, _user, _session) = create_test_context_with_user().await;
388        
389        let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
390        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
391        
392        assert_eq!(response.status, 401);
393    }
394    
395    #[tokio::test]
396    async fn test_sign_out_success() {
397        let plugin = SessionManagementPlugin::new();
398        let (ctx, _user, session) = create_test_context_with_user().await;
399        
400        let req = create_auth_request(HttpMethod::Post, "/sign-out", Some(&session.token), Some(b"{}".to_vec()));
401        let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
402        
403        assert_eq!(response.status, 200);
404        
405        let body_str = String::from_utf8(response.body).unwrap();
406        let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
407        assert!(response_data.success);
408        
409        // Verify session was deleted
410        let session_check = ctx.database.get_session(&session.token).await.unwrap();
411        assert!(session_check.is_none());
412    }
413    
414    #[tokio::test]
415    async fn test_list_sessions_success() {
416        let plugin = SessionManagementPlugin::new();
417        let (ctx, user, session) = create_test_context_with_user().await;
418        
419        // Create additional session for the same user
420        let create_session2 = CreateSession {
421            user_id: user.id.clone(),
422            expires_at: Utc::now() + Duration::hours(24),
423            ip_address: Some("192.168.1.1".to_string()),
424            user_agent: Some("another-agent".to_string()),
425            impersonated_by: None,
426            active_organization_id: None,
427        };
428        ctx.database.create_session(create_session2).await.unwrap();
429        
430        let req = create_auth_request(HttpMethod::Get, "/list-sessions", Some(&session.token), None);
431        let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
432        
433        assert_eq!(response.status, 200);
434        
435        let body_str = String::from_utf8(response.body).unwrap();
436        let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
437        assert_eq!(sessions.len(), 2);
438    }
439    
440    #[tokio::test]
441    async fn test_revoke_session_success() {
442        let plugin = SessionManagementPlugin::new();
443        let (ctx, user, session) = create_test_context_with_user().await;
444        
445        // Create another session to revoke
446        let create_session2 = CreateSession {
447            user_id: user.id.clone(),
448            expires_at: Utc::now() + Duration::hours(24),
449            ip_address: Some("192.168.1.1".to_string()),
450            user_agent: Some("another-agent".to_string()),
451            impersonated_by: None,
452            active_organization_id: None,
453        };
454        let session2 = ctx.database.create_session(create_session2).await.unwrap();
455        
456        let body = serde_json::json!({ "token": session2.token });
457        let req = create_auth_request(
458            HttpMethod::Post, 
459            "/revoke-session", 
460            Some(&session.token), 
461            Some(body.to_string().into_bytes())
462        );
463        
464        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
465        assert_eq!(response.status, 200);
466        
467        let body_str = String::from_utf8(response.body).unwrap();
468        let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
469        assert!(response_data.status);
470        
471        // Verify session2 was deleted but session1 still exists
472        let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
473        assert!(session2_check.is_none());
474        
475        let session1_check = ctx.database.get_session(&session.token).await.unwrap();
476        assert!(session1_check.is_some());
477    }
478    
479    #[tokio::test]
480    async fn test_revoke_session_forbidden_different_user() {
481        let plugin = SessionManagementPlugin::new();
482        let (ctx, _user1, session1) = create_test_context_with_user().await;
483        
484        // Create another user and session
485        let create_user2 = CreateUser::new()
486            .with_email("user2@example.com")
487            .with_name("User Two");
488        let user2 = ctx.database.create_user(create_user2).await.unwrap();
489        
490        let create_session2 = CreateSession {
491            user_id: user2.id,
492            expires_at: Utc::now() + Duration::hours(24),
493            ip_address: Some("192.168.1.1".to_string()),
494            user_agent: Some("another-agent".to_string()),
495            impersonated_by: None,
496            active_organization_id: None,
497        };
498        let session2 = ctx.database.create_session(create_session2).await.unwrap();
499        
500        // Try to revoke user2's session using user1's session
501        let body = serde_json::json!({ "token": session2.token });
502        let req = create_auth_request(
503            HttpMethod::Post, 
504            "/revoke-session", 
505            Some(&session1.token), 
506            Some(body.to_string().into_bytes())
507        );
508        
509        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
510        assert_eq!(response.status, 403);
511    }
512    
513    #[tokio::test]
514    async fn test_revoke_sessions_success() {
515        let plugin = SessionManagementPlugin::new();
516        let (ctx, user, session1) = create_test_context_with_user().await;
517        
518        // Create additional sessions for the same user
519        let create_session2 = CreateSession {
520            user_id: user.id.clone(),
521            expires_at: Utc::now() + Duration::hours(24),
522            ip_address: Some("192.168.1.1".to_string()),
523            user_agent: Some("another-agent".to_string()),
524            impersonated_by: None,
525            active_organization_id: None,
526        };
527        ctx.database.create_session(create_session2).await.unwrap();
528        
529        let req = create_auth_request(HttpMethod::Post, "/revoke-sessions", Some(&session1.token), Some(b"{}".to_vec()));
530        let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
531        
532        assert_eq!(response.status, 200);
533        
534        let body_str = String::from_utf8(response.body).unwrap();
535        let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
536        assert!(response_data.status);
537        
538        // Verify all sessions for the user were deleted
539        let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
540        assert_eq!(user_sessions.len(), 0);
541    }
542    
543    #[tokio::test]
544    async fn test_plugin_routes() {
545        let plugin = SessionManagementPlugin::new();
546        let routes = plugin.routes();
547        
548        assert_eq!(routes.len(), 5);
549        assert!(routes.iter().any(|r| r.path == "/get-session" && r.method == HttpMethod::Get));
550        assert!(routes.iter().any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post));
551        assert!(routes.iter().any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get));
552        assert!(routes.iter().any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post));
553        assert!(routes.iter().any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post));
554    }
555    
556    #[tokio::test]
557    async fn test_plugin_on_request_routing() {
558        let plugin = SessionManagementPlugin::new();
559        let (ctx, _user, session) = create_test_context_with_user().await;
560        
561        // Test valid route
562        let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
563        let response = plugin.on_request(&req, &ctx).await.unwrap();
564        assert!(response.is_some());
565        assert_eq!(response.unwrap().status, 200);
566        
567        // Test invalid route
568        let req = create_auth_request(HttpMethod::Get, "/invalid-route", Some(&session.token), None);
569        let response = plugin.on_request(&req, &ctx).await.unwrap();
570        assert!(response.is_none());
571    }
572    
573    #[tokio::test]
574    async fn test_configuration() {
575        let plugin = SessionManagementPlugin::new()
576            .enable_session_listing(false)
577            .enable_session_revocation(false)
578            .require_authentication(false);
579        
580        assert!(!plugin.config.enable_session_listing);
581        assert!(!plugin.config.enable_session_revocation);
582        assert!(!plugin.config.require_authentication);
583        
584        let (ctx, _user, session) = create_test_context_with_user().await;
585        
586        // Test that disabled features return None
587        let req = create_auth_request(HttpMethod::Get, "/list-sessions", Some(&session.token), None);
588        let response = plugin.on_request(&req, &ctx).await.unwrap();
589        assert!(response.is_none());
590        
591        let req = create_auth_request(HttpMethod::Post, "/revoke-session", Some(&session.token), Some(b"{}".to_vec()));
592        let response = plugin.on_request(&req, &ctx).await.unwrap();
593        assert!(response.is_none());
594    }
595}