Skip to main content

better_auth_api/plugins/
session_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8
9use better_auth_core::utils::cookie_utils::create_clear_session_cookie;
10use better_auth_core::{AuthError, AuthResult};
11use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
12
13use super::StatusResponse;
14use better_auth_core::SuccessResponse;
15
16/// Session management plugin for handling session operations
17pub struct SessionManagementPlugin {
18    config: SessionManagementConfig,
19}
20
21#[derive(Debug, Clone, better_auth_core::PluginConfig)]
22#[plugin(name = "SessionManagementPlugin")]
23pub struct SessionManagementConfig {
24    #[config(default = true)]
25    pub enable_session_listing: bool,
26    #[config(default = true)]
27    pub enable_session_revocation: bool,
28    #[config(default = true)]
29    pub require_authentication: bool,
30}
31
32// Request structures for session endpoints
33#[derive(Debug, Deserialize, Validate)]
34struct RevokeSessionRequest {
35    #[validate(length(min = 1, message = "Token is required"))]
36    token: String,
37}
38
39#[derive(Debug, Serialize)]
40struct GetSessionResponse<S: Serialize, U: Serialize> {
41    session: S,
42    user: U,
43}
44
45#[async_trait]
46impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
47    fn name(&self) -> &'static str {
48        "session-management"
49    }
50
51    fn routes(&self) -> Vec<AuthRoute> {
52        vec![
53            AuthRoute::get("/get-session", "get_session"),
54            AuthRoute::post("/get-session", "get_session_post"),
55            AuthRoute::post("/sign-out", "sign_out"),
56            AuthRoute::get("/list-sessions", "list_sessions"),
57            AuthRoute::post("/revoke-session", "revoke_session"),
58            AuthRoute::post("/revoke-sessions", "revoke_sessions"),
59            AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
60        ]
61    }
62
63    async fn on_request(
64        &self,
65        req: &AuthRequest,
66        ctx: &AuthContext<DB>,
67    ) -> AuthResult<Option<AuthResponse>> {
68        match (req.method(), req.path()) {
69            (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
70                Ok(Some(self.handle_get_session(req, ctx).await?))
71            }
72            (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
73            (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
74                Ok(Some(self.handle_list_sessions(req, ctx).await?))
75            }
76            (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
77                Ok(Some(self.handle_revoke_session(req, ctx).await?))
78            }
79            (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
80                Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
81            }
82            (HttpMethod::Post, "/revoke-other-sessions")
83                if self.config.enable_session_revocation =>
84            {
85                Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
86            }
87            _ => Ok(None),
88        }
89    }
90}
91
92// ---------------------------------------------------------------------------
93// Core functions — framework-agnostic business logic
94// ---------------------------------------------------------------------------
95
96pub(crate) async fn sign_out_core<DB: DatabaseAdapter>(
97    session: &DB::Session,
98    ctx: &AuthContext<DB>,
99) -> AuthResult<SuccessResponse> {
100    ctx.database.delete_session(session.token()).await?;
101    Ok(SuccessResponse { success: true })
102}
103
104pub(crate) async fn list_sessions_core<DB: DatabaseAdapter>(
105    user_id: &str,
106    ctx: &AuthContext<DB>,
107) -> AuthResult<Vec<DB::Session>> {
108    ctx.database.get_user_sessions(user_id).await
109}
110
111pub(crate) async fn revoke_session_core<DB: DatabaseAdapter>(
112    user: &DB::User,
113    token: &str,
114    ctx: &AuthContext<DB>,
115) -> AuthResult<StatusResponse> {
116    // Verify the session belongs to the current user before revoking
117    let session_manager = ctx.session_manager();
118    if let Some(session_to_revoke) = session_manager.get_session(token).await?
119        && session_to_revoke.user_id() != user.id()
120    {
121        return Err(AuthError::forbidden(
122            "Cannot revoke session that belongs to another user",
123        ));
124    }
125
126    ctx.database.delete_session(token).await?;
127    Ok(StatusResponse { status: true })
128}
129
130pub(crate) async fn revoke_sessions_core<DB: DatabaseAdapter>(
131    user_id: &str,
132    ctx: &AuthContext<DB>,
133) -> AuthResult<StatusResponse> {
134    ctx.database.delete_user_sessions(user_id).await?;
135    Ok(StatusResponse { status: true })
136}
137
138pub(crate) async fn revoke_other_sessions_core<DB: DatabaseAdapter>(
139    user_id: &str,
140    current_session: &DB::Session,
141    ctx: &AuthContext<DB>,
142) -> AuthResult<StatusResponse> {
143    let all_sessions: Vec<DB::Session> = ctx.database.get_user_sessions(user_id).await?;
144    for session in all_sessions {
145        if session.token() != current_session.token() {
146            ctx.database.delete_session(session.token()).await?;
147        }
148    }
149    Ok(StatusResponse { status: true })
150}
151
152// ---------------------------------------------------------------------------
153// Old handler methods — delegate to core functions
154// ---------------------------------------------------------------------------
155
156impl SessionManagementPlugin {
157    async fn handle_get_session<DB: DatabaseAdapter>(
158        &self,
159        req: &AuthRequest,
160        ctx: &AuthContext<DB>,
161    ) -> AuthResult<AuthResponse> {
162        let (user, session) = ctx.require_session(req).await?;
163        let response = GetSessionResponse { session, user };
164        Ok(AuthResponse::json(200, &response)?)
165    }
166
167    async fn handle_sign_out<DB: DatabaseAdapter>(
168        &self,
169        req: &AuthRequest,
170        ctx: &AuthContext<DB>,
171    ) -> AuthResult<AuthResponse> {
172        let (_user, session) = ctx.require_session(req).await?;
173        let response = sign_out_core(&session, ctx).await?;
174        let clear_cookie = create_clear_session_cookie(&ctx.config);
175        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie))
176    }
177
178    async fn handle_list_sessions<DB: DatabaseAdapter>(
179        &self,
180        req: &AuthRequest,
181        ctx: &AuthContext<DB>,
182    ) -> AuthResult<AuthResponse> {
183        let (user, _) = ctx.require_session(req).await?;
184        let sessions = list_sessions_core(user.id(), ctx).await?;
185        Ok(AuthResponse::json(200, &sessions)?)
186    }
187
188    async fn handle_revoke_session<DB: DatabaseAdapter>(
189        &self,
190        req: &AuthRequest,
191        ctx: &AuthContext<DB>,
192    ) -> AuthResult<AuthResponse> {
193        let (user, _) = ctx.require_session(req).await?;
194
195        let revoke_req: RevokeSessionRequest = match better_auth_core::validate_request_body(req) {
196            Ok(v) => v,
197            Err(resp) => return Ok(resp),
198        };
199
200        let response = revoke_session_core(&user, &revoke_req.token, ctx).await?;
201        Ok(AuthResponse::json(200, &response)?)
202    }
203
204    async fn handle_revoke_sessions<DB: DatabaseAdapter>(
205        &self,
206        req: &AuthRequest,
207        ctx: &AuthContext<DB>,
208    ) -> AuthResult<AuthResponse> {
209        let (user, _) = ctx.require_session(req).await?;
210        let response = revoke_sessions_core(user.id(), ctx).await?;
211        Ok(AuthResponse::json(200, &response)?)
212    }
213
214    async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
215        &self,
216        req: &AuthRequest,
217        ctx: &AuthContext<DB>,
218    ) -> AuthResult<AuthResponse> {
219        let (user, current_session) = ctx.require_session(req).await?;
220        let response = revoke_other_sessions_core(user.id(), &current_session, ctx).await?;
221        Ok(AuthResponse::json(200, &response)?)
222    }
223}
224
225#[cfg(feature = "axum")]
226mod axum_impl {
227    use super::*;
228    use std::sync::Arc;
229
230    use axum::Json;
231    use axum::extract::{Extension, State};
232    use axum::http::header;
233    use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
234
235    #[derive(Clone)]
236    struct PluginState {
237        config: SessionManagementConfig,
238    }
239
240    // get_session is trivially simple: just construct the response directly.
241    async fn handle_get_session<DB: DatabaseAdapter>(
242        CurrentSession { user, session }: CurrentSession<DB>,
243    ) -> Result<Json<GetSessionResponse<DB::Session, DB::User>>, AuthError> {
244        Ok(Json(GetSessionResponse { session, user }))
245    }
246
247    async fn handle_sign_out<DB: DatabaseAdapter>(
248        State(state): State<AuthState<DB>>,
249        CurrentSession { session, .. }: CurrentSession<DB>,
250    ) -> Result<([(header::HeaderName, String); 1], Json<SuccessResponse>), AuthError> {
251        let ctx = state.to_context();
252        let response = sign_out_core(&session, &ctx).await?;
253        let cookie = state.clear_session_cookie();
254        Ok(([(header::SET_COOKIE, cookie)], Json(response)))
255    }
256
257    async fn handle_list_sessions<DB: DatabaseAdapter>(
258        State(state): State<AuthState<DB>>,
259        Extension(ps): Extension<Arc<PluginState>>,
260        CurrentSession { user, .. }: CurrentSession<DB>,
261    ) -> Result<Json<Vec<DB::Session>>, AuthError> {
262        if !ps.config.enable_session_listing {
263            return Err(AuthError::not_found("Not found"));
264        }
265        let ctx = state.to_context();
266        let sessions = list_sessions_core(user.id(), &ctx).await?;
267        Ok(Json(sessions))
268    }
269
270    async fn handle_revoke_session<DB: DatabaseAdapter>(
271        State(state): State<AuthState<DB>>,
272        Extension(ps): Extension<Arc<PluginState>>,
273        CurrentSession { user, .. }: CurrentSession<DB>,
274        ValidatedJson(body): ValidatedJson<RevokeSessionRequest>,
275    ) -> Result<Json<StatusResponse>, AuthError> {
276        if !ps.config.enable_session_revocation {
277            return Err(AuthError::not_found("Not found"));
278        }
279        let ctx = state.to_context();
280        let response = revoke_session_core(&user, &body.token, &ctx).await?;
281        Ok(Json(response))
282    }
283
284    async fn handle_revoke_sessions<DB: DatabaseAdapter>(
285        State(state): State<AuthState<DB>>,
286        Extension(ps): Extension<Arc<PluginState>>,
287        CurrentSession { user, .. }: CurrentSession<DB>,
288    ) -> Result<Json<StatusResponse>, AuthError> {
289        if !ps.config.enable_session_revocation {
290            return Err(AuthError::not_found("Not found"));
291        }
292        let ctx = state.to_context();
293        let response = revoke_sessions_core(user.id(), &ctx).await?;
294        Ok(Json(response))
295    }
296
297    async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
298        State(state): State<AuthState<DB>>,
299        Extension(ps): Extension<Arc<PluginState>>,
300        CurrentSession { user, session }: CurrentSession<DB>,
301    ) -> Result<Json<StatusResponse>, AuthError> {
302        if !ps.config.enable_session_revocation {
303            return Err(AuthError::not_found("Not found"));
304        }
305        let ctx = state.to_context();
306        let response = revoke_other_sessions_core(user.id(), &session, &ctx).await?;
307        Ok(Json(response))
308    }
309
310    impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for SessionManagementPlugin {
311        fn name(&self) -> &'static str {
312            "session-management"
313        }
314
315        fn router(&self) -> axum::Router<AuthState<DB>> {
316            use axum::routing::{get, post};
317
318            let plugin_state = Arc::new(PluginState {
319                config: self.config.clone(),
320            });
321            axum::Router::new()
322                .route(
323                    "/get-session",
324                    get(handle_get_session::<DB>).post(handle_get_session::<DB>),
325                )
326                .route("/sign-out", post(handle_sign_out::<DB>))
327                .route("/list-sessions", get(handle_list_sessions::<DB>))
328                .route("/revoke-session", post(handle_revoke_session::<DB>))
329                .route("/revoke-sessions", post(handle_revoke_sessions::<DB>))
330                .route(
331                    "/revoke-other-sessions",
332                    post(handle_revoke_other_sessions::<DB>),
333                )
334                .layer(Extension(plugin_state))
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::plugins::test_helpers;
343    use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
344    use better_auth_core::{CreateSession, CreateUser, Session};
345    use chrono::{Duration, Utc};
346
347    #[tokio::test]
348    async fn test_get_session_success() {
349        let plugin = SessionManagementPlugin::new();
350        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
351            CreateUser::new()
352                .with_email("test@example.com")
353                .with_name("Test User"),
354            Duration::hours(24),
355        )
356        .await;
357
358        let req = test_helpers::create_auth_request_no_query(
359            HttpMethod::Get,
360            "/get-session",
361            Some(&session.token),
362            None,
363        );
364        let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
365
366        assert_eq!(response.status, 200);
367
368        let body_str = String::from_utf8(response.body).unwrap();
369        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
370        assert_eq!(
371            response_data["session"]["token"].as_str().unwrap(),
372            session.token
373        );
374        assert_eq!(
375            response_data["user"]["email"]
376                .as_str()
377                .map(|s| s.to_string()),
378            Some("test@example.com".to_string())
379        );
380    }
381
382    #[tokio::test]
383    async fn test_get_session_unauthorized() {
384        let plugin = SessionManagementPlugin::new();
385        let (ctx, _user, _session) = test_helpers::create_test_context_with_user(
386            CreateUser::new()
387                .with_email("test@example.com")
388                .with_name("Test User"),
389            Duration::hours(24),
390        )
391        .await;
392
393        let req =
394            test_helpers::create_auth_request_no_query(HttpMethod::Get, "/get-session", None, None);
395        let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
396        assert_eq!(err.status_code(), 401);
397    }
398
399    #[tokio::test]
400    async fn test_sign_out_success() {
401        let plugin = SessionManagementPlugin::new();
402        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
403            CreateUser::new()
404                .with_email("test@example.com")
405                .with_name("Test User"),
406            Duration::hours(24),
407        )
408        .await;
409
410        let req = test_helpers::create_auth_request_no_query(
411            HttpMethod::Post,
412            "/sign-out",
413            Some(&session.token),
414            Some(b"{}".to_vec()),
415        );
416        let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
417
418        assert_eq!(response.status, 200);
419
420        let body_str = String::from_utf8(response.body).unwrap();
421        let response_data: SuccessResponse = serde_json::from_str(&body_str).unwrap();
422        assert!(response_data.success);
423
424        let session_check = ctx.database.get_session(&session.token).await.unwrap();
425        assert!(session_check.is_none());
426    }
427
428    #[tokio::test]
429    async fn test_list_sessions_success() {
430        let plugin = SessionManagementPlugin::new();
431        let (ctx, user, session) = test_helpers::create_test_context_with_user(
432            CreateUser::new()
433                .with_email("test@example.com")
434                .with_name("Test User"),
435            Duration::hours(24),
436        )
437        .await;
438
439        let create_session2 = CreateSession {
440            user_id: user.id.clone(),
441            expires_at: Utc::now() + Duration::hours(24),
442            ip_address: Some("192.168.1.1".to_string()),
443            user_agent: Some("another-agent".to_string()),
444            impersonated_by: None,
445            active_organization_id: None,
446        };
447        ctx.database.create_session(create_session2).await.unwrap();
448
449        let req = test_helpers::create_auth_request_no_query(
450            HttpMethod::Get,
451            "/list-sessions",
452            Some(&session.token),
453            None,
454        );
455        let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
456
457        assert_eq!(response.status, 200);
458
459        let body_str = String::from_utf8(response.body).unwrap();
460        let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
461        assert_eq!(sessions.len(), 2);
462    }
463
464    #[tokio::test]
465    async fn test_revoke_session_success() {
466        let plugin = SessionManagementPlugin::new();
467        let (ctx, user, session) = test_helpers::create_test_context_with_user(
468            CreateUser::new()
469                .with_email("test@example.com")
470                .with_name("Test User"),
471            Duration::hours(24),
472        )
473        .await;
474
475        let create_session2 = CreateSession {
476            user_id: user.id.clone(),
477            expires_at: Utc::now() + Duration::hours(24),
478            ip_address: Some("192.168.1.1".to_string()),
479            user_agent: Some("another-agent".to_string()),
480            impersonated_by: None,
481            active_organization_id: None,
482        };
483        let session2 = ctx.database.create_session(create_session2).await.unwrap();
484
485        let body = serde_json::json!({ "token": session2.token });
486        let req = test_helpers::create_auth_request_no_query(
487            HttpMethod::Post,
488            "/revoke-session",
489            Some(&session.token),
490            Some(body.to_string().into_bytes()),
491        );
492
493        let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
494        assert_eq!(response.status, 200);
495
496        let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
497        assert!(session2_check.is_none());
498
499        let session1_check = ctx.database.get_session(&session.token).await.unwrap();
500        assert!(session1_check.is_some());
501    }
502
503    #[tokio::test]
504    async fn test_revoke_session_forbidden_different_user() {
505        let plugin = SessionManagementPlugin::new();
506        let (ctx, _user1, session1) = test_helpers::create_test_context_with_user(
507            CreateUser::new()
508                .with_email("test@example.com")
509                .with_name("Test User"),
510            Duration::hours(24),
511        )
512        .await;
513
514        let create_user2 = CreateUser::new()
515            .with_email("user2@example.com")
516            .with_name("User Two");
517        let user2 = ctx.database.create_user(create_user2).await.unwrap();
518
519        let create_session2 = CreateSession {
520            user_id: user2.id,
521            expires_at: Utc::now() + Duration::hours(24),
522            ip_address: Some("192.168.1.1".to_string()),
523            user_agent: Some("another-agent".to_string()),
524            impersonated_by: None,
525            active_organization_id: None,
526        };
527        let session2 = ctx.database.create_session(create_session2).await.unwrap();
528
529        let body = serde_json::json!({ "token": session2.token });
530        let req = test_helpers::create_auth_request_no_query(
531            HttpMethod::Post,
532            "/revoke-session",
533            Some(&session1.token),
534            Some(body.to_string().into_bytes()),
535        );
536
537        let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
538        assert_eq!(err.status_code(), 403);
539    }
540
541    #[tokio::test]
542    async fn test_revoke_sessions_success() {
543        let plugin = SessionManagementPlugin::new();
544        let (ctx, user, session1) = test_helpers::create_test_context_with_user(
545            CreateUser::new()
546                .with_email("test@example.com")
547                .with_name("Test User"),
548            Duration::hours(24),
549        )
550        .await;
551
552        let create_session2 = CreateSession {
553            user_id: user.id.clone(),
554            expires_at: Utc::now() + Duration::hours(24),
555            ip_address: Some("192.168.1.1".to_string()),
556            user_agent: Some("another-agent".to_string()),
557            impersonated_by: None,
558            active_organization_id: None,
559        };
560        ctx.database.create_session(create_session2).await.unwrap();
561
562        let req = test_helpers::create_auth_request_no_query(
563            HttpMethod::Post,
564            "/revoke-sessions",
565            Some(&session1.token),
566            Some(b"{}".to_vec()),
567        );
568        let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
569
570        assert_eq!(response.status, 200);
571
572        let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
573        assert_eq!(user_sessions.len(), 0);
574    }
575
576    #[tokio::test]
577    async fn test_plugin_routes() {
578        let plugin = SessionManagementPlugin::new();
579        let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
580
581        assert_eq!(routes.len(), 7);
582        assert!(
583            routes
584                .iter()
585                .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
586        );
587        assert!(
588            routes
589                .iter()
590                .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
591        );
592        assert!(
593            routes
594                .iter()
595                .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
596        );
597        assert!(
598            routes
599                .iter()
600                .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
601        );
602        assert!(
603            routes
604                .iter()
605                .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
606        );
607    }
608
609    #[tokio::test]
610    async fn test_plugin_on_request_routing() {
611        let plugin = SessionManagementPlugin::new();
612        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
613            CreateUser::new()
614                .with_email("test@example.com")
615                .with_name("Test User"),
616            Duration::hours(24),
617        )
618        .await;
619
620        // Test valid route
621        let req = test_helpers::create_auth_request_no_query(
622            HttpMethod::Get,
623            "/get-session",
624            Some(&session.token),
625            None,
626        );
627        let response = plugin.on_request(&req, &ctx).await.unwrap();
628        assert!(response.is_some());
629        assert_eq!(response.unwrap().status, 200);
630
631        // Test invalid route
632        let req = test_helpers::create_auth_request_no_query(
633            HttpMethod::Get,
634            "/invalid-route",
635            Some(&session.token),
636            None,
637        );
638        let response = plugin.on_request(&req, &ctx).await.unwrap();
639        assert!(response.is_none());
640    }
641
642    #[tokio::test]
643    async fn test_configuration() {
644        let plugin = SessionManagementPlugin::new()
645            .enable_session_listing(false)
646            .enable_session_revocation(false)
647            .require_authentication(false);
648
649        assert!(!plugin.config.enable_session_listing);
650        assert!(!plugin.config.enable_session_revocation);
651        assert!(!plugin.config.require_authentication);
652
653        let (ctx, _user, session) = test_helpers::create_test_context_with_user(
654            CreateUser::new()
655                .with_email("test@example.com")
656                .with_name("Test User"),
657            Duration::hours(24),
658        )
659        .await;
660
661        let req = test_helpers::create_auth_request_no_query(
662            HttpMethod::Get,
663            "/list-sessions",
664            Some(&session.token),
665            None,
666        );
667        let response = plugin.on_request(&req, &ctx).await.unwrap();
668        assert!(response.is_none());
669
670        let req = test_helpers::create_auth_request_no_query(
671            HttpMethod::Post,
672            "/revoke-session",
673            Some(&session.token),
674            Some(b"{}".to_vec()),
675        );
676        let response = plugin.on_request(&req, &ctx).await.unwrap();
677        assert!(response.is_none());
678    }
679}