1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8
9use better_auth_core::utils::cookie_utils::create_clear_session_cookie;
10use better_auth_core::{AuthError, AuthResult};
11use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
12
13use super::StatusResponse;
14use better_auth_core::SuccessResponse;
15
16pub struct SessionManagementPlugin {
18 config: SessionManagementConfig,
19}
20
21#[derive(Debug, Clone, better_auth_core::PluginConfig)]
22#[plugin(name = "SessionManagementPlugin")]
23pub struct SessionManagementConfig {
24 #[config(default = true)]
25 pub enable_session_listing: bool,
26 #[config(default = true)]
27 pub enable_session_revocation: bool,
28 #[config(default = true)]
29 pub require_authentication: bool,
30}
31
32#[derive(Debug, Deserialize, Validate)]
34struct RevokeSessionRequest {
35 #[validate(length(min = 1, message = "Token is required"))]
36 token: String,
37}
38
39#[derive(Debug, Serialize)]
40struct GetSessionResponse<S: Serialize, U: Serialize> {
41 session: S,
42 user: U,
43}
44
45#[async_trait]
46impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
47 fn name(&self) -> &'static str {
48 "session-management"
49 }
50
51 fn routes(&self) -> Vec<AuthRoute> {
52 vec![
53 AuthRoute::get("/get-session", "get_session"),
54 AuthRoute::post("/get-session", "get_session_post"),
55 AuthRoute::post("/sign-out", "sign_out"),
56 AuthRoute::get("/list-sessions", "list_sessions"),
57 AuthRoute::post("/revoke-session", "revoke_session"),
58 AuthRoute::post("/revoke-sessions", "revoke_sessions"),
59 AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
60 ]
61 }
62
63 async fn on_request(
64 &self,
65 req: &AuthRequest,
66 ctx: &AuthContext<DB>,
67 ) -> AuthResult<Option<AuthResponse>> {
68 match (req.method(), req.path()) {
69 (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
70 Ok(Some(self.handle_get_session(req, ctx).await?))
71 }
72 (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
73 (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
74 Ok(Some(self.handle_list_sessions(req, ctx).await?))
75 }
76 (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
77 Ok(Some(self.handle_revoke_session(req, ctx).await?))
78 }
79 (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
80 Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
81 }
82 (HttpMethod::Post, "/revoke-other-sessions")
83 if self.config.enable_session_revocation =>
84 {
85 Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
86 }
87 _ => Ok(None),
88 }
89 }
90}
91
92pub(crate) async fn sign_out_core<DB: DatabaseAdapter>(
97 session: &DB::Session,
98 ctx: &AuthContext<DB>,
99) -> AuthResult<SuccessResponse> {
100 ctx.database.delete_session(session.token()).await?;
101 Ok(SuccessResponse { success: true })
102}
103
104pub(crate) async fn list_sessions_core<DB: DatabaseAdapter>(
105 user_id: &str,
106 ctx: &AuthContext<DB>,
107) -> AuthResult<Vec<DB::Session>> {
108 ctx.database.get_user_sessions(user_id).await
109}
110
111pub(crate) async fn revoke_session_core<DB: DatabaseAdapter>(
112 user: &DB::User,
113 token: &str,
114 ctx: &AuthContext<DB>,
115) -> AuthResult<StatusResponse> {
116 let session_manager = ctx.session_manager();
118 if let Some(session_to_revoke) = session_manager.get_session(token).await?
119 && session_to_revoke.user_id() != user.id()
120 {
121 return Err(AuthError::forbidden(
122 "Cannot revoke session that belongs to another user",
123 ));
124 }
125
126 ctx.database.delete_session(token).await?;
127 Ok(StatusResponse { status: true })
128}
129
130pub(crate) async fn revoke_sessions_core<DB: DatabaseAdapter>(
131 user_id: &str,
132 ctx: &AuthContext<DB>,
133) -> AuthResult<StatusResponse> {
134 ctx.database.delete_user_sessions(user_id).await?;
135 Ok(StatusResponse { status: true })
136}
137
138pub(crate) async fn revoke_other_sessions_core<DB: DatabaseAdapter>(
139 user_id: &str,
140 current_session: &DB::Session,
141 ctx: &AuthContext<DB>,
142) -> AuthResult<StatusResponse> {
143 let all_sessions: Vec<DB::Session> = ctx.database.get_user_sessions(user_id).await?;
144 for session in all_sessions {
145 if session.token() != current_session.token() {
146 ctx.database.delete_session(session.token()).await?;
147 }
148 }
149 Ok(StatusResponse { status: true })
150}
151
152impl SessionManagementPlugin {
157 async fn handle_get_session<DB: DatabaseAdapter>(
158 &self,
159 req: &AuthRequest,
160 ctx: &AuthContext<DB>,
161 ) -> AuthResult<AuthResponse> {
162 let (user, session) = ctx.require_session(req).await?;
163 let response = GetSessionResponse { session, user };
164 Ok(AuthResponse::json(200, &response)?)
165 }
166
167 async fn handle_sign_out<DB: DatabaseAdapter>(
168 &self,
169 req: &AuthRequest,
170 ctx: &AuthContext<DB>,
171 ) -> AuthResult<AuthResponse> {
172 let (_user, session) = ctx.require_session(req).await?;
173 let response = sign_out_core(&session, ctx).await?;
174 let clear_cookie = create_clear_session_cookie(&ctx.config);
175 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie))
176 }
177
178 async fn handle_list_sessions<DB: DatabaseAdapter>(
179 &self,
180 req: &AuthRequest,
181 ctx: &AuthContext<DB>,
182 ) -> AuthResult<AuthResponse> {
183 let (user, _) = ctx.require_session(req).await?;
184 let sessions = list_sessions_core(user.id(), ctx).await?;
185 Ok(AuthResponse::json(200, &sessions)?)
186 }
187
188 async fn handle_revoke_session<DB: DatabaseAdapter>(
189 &self,
190 req: &AuthRequest,
191 ctx: &AuthContext<DB>,
192 ) -> AuthResult<AuthResponse> {
193 let (user, _) = ctx.require_session(req).await?;
194
195 let revoke_req: RevokeSessionRequest = match better_auth_core::validate_request_body(req) {
196 Ok(v) => v,
197 Err(resp) => return Ok(resp),
198 };
199
200 let response = revoke_session_core(&user, &revoke_req.token, ctx).await?;
201 Ok(AuthResponse::json(200, &response)?)
202 }
203
204 async fn handle_revoke_sessions<DB: DatabaseAdapter>(
205 &self,
206 req: &AuthRequest,
207 ctx: &AuthContext<DB>,
208 ) -> AuthResult<AuthResponse> {
209 let (user, _) = ctx.require_session(req).await?;
210 let response = revoke_sessions_core(user.id(), ctx).await?;
211 Ok(AuthResponse::json(200, &response)?)
212 }
213
214 async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
215 &self,
216 req: &AuthRequest,
217 ctx: &AuthContext<DB>,
218 ) -> AuthResult<AuthResponse> {
219 let (user, current_session) = ctx.require_session(req).await?;
220 let response = revoke_other_sessions_core(user.id(), ¤t_session, ctx).await?;
221 Ok(AuthResponse::json(200, &response)?)
222 }
223}
224
225#[cfg(feature = "axum")]
226mod axum_impl {
227 use super::*;
228 use std::sync::Arc;
229
230 use axum::Json;
231 use axum::extract::{Extension, State};
232 use axum::http::header;
233 use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
234
235 #[derive(Clone)]
236 struct PluginState {
237 config: SessionManagementConfig,
238 }
239
240 async fn handle_get_session<DB: DatabaseAdapter>(
242 CurrentSession { user, session }: CurrentSession<DB>,
243 ) -> Result<Json<GetSessionResponse<DB::Session, DB::User>>, AuthError> {
244 Ok(Json(GetSessionResponse { session, user }))
245 }
246
247 async fn handle_sign_out<DB: DatabaseAdapter>(
248 State(state): State<AuthState<DB>>,
249 CurrentSession { session, .. }: CurrentSession<DB>,
250 ) -> Result<([(header::HeaderName, String); 1], Json<SuccessResponse>), AuthError> {
251 let ctx = state.to_context();
252 let response = sign_out_core(&session, &ctx).await?;
253 let cookie = state.clear_session_cookie();
254 Ok(([(header::SET_COOKIE, cookie)], Json(response)))
255 }
256
257 async fn handle_list_sessions<DB: DatabaseAdapter>(
258 State(state): State<AuthState<DB>>,
259 Extension(ps): Extension<Arc<PluginState>>,
260 CurrentSession { user, .. }: CurrentSession<DB>,
261 ) -> Result<Json<Vec<DB::Session>>, AuthError> {
262 if !ps.config.enable_session_listing {
263 return Err(AuthError::not_found("Not found"));
264 }
265 let ctx = state.to_context();
266 let sessions = list_sessions_core(user.id(), &ctx).await?;
267 Ok(Json(sessions))
268 }
269
270 async fn handle_revoke_session<DB: DatabaseAdapter>(
271 State(state): State<AuthState<DB>>,
272 Extension(ps): Extension<Arc<PluginState>>,
273 CurrentSession { user, .. }: CurrentSession<DB>,
274 ValidatedJson(body): ValidatedJson<RevokeSessionRequest>,
275 ) -> Result<Json<StatusResponse>, AuthError> {
276 if !ps.config.enable_session_revocation {
277 return Err(AuthError::not_found("Not found"));
278 }
279 let ctx = state.to_context();
280 let response = revoke_session_core(&user, &body.token, &ctx).await?;
281 Ok(Json(response))
282 }
283
284 async fn handle_revoke_sessions<DB: DatabaseAdapter>(
285 State(state): State<AuthState<DB>>,
286 Extension(ps): Extension<Arc<PluginState>>,
287 CurrentSession { user, .. }: CurrentSession<DB>,
288 ) -> Result<Json<StatusResponse>, AuthError> {
289 if !ps.config.enable_session_revocation {
290 return Err(AuthError::not_found("Not found"));
291 }
292 let ctx = state.to_context();
293 let response = revoke_sessions_core(user.id(), &ctx).await?;
294 Ok(Json(response))
295 }
296
297 async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
298 State(state): State<AuthState<DB>>,
299 Extension(ps): Extension<Arc<PluginState>>,
300 CurrentSession { user, session }: CurrentSession<DB>,
301 ) -> Result<Json<StatusResponse>, AuthError> {
302 if !ps.config.enable_session_revocation {
303 return Err(AuthError::not_found("Not found"));
304 }
305 let ctx = state.to_context();
306 let response = revoke_other_sessions_core(user.id(), &session, &ctx).await?;
307 Ok(Json(response))
308 }
309
310 impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for SessionManagementPlugin {
311 fn name(&self) -> &'static str {
312 "session-management"
313 }
314
315 fn router(&self) -> axum::Router<AuthState<DB>> {
316 use axum::routing::{get, post};
317
318 let plugin_state = Arc::new(PluginState {
319 config: self.config.clone(),
320 });
321 axum::Router::new()
322 .route(
323 "/get-session",
324 get(handle_get_session::<DB>).post(handle_get_session::<DB>),
325 )
326 .route("/sign-out", post(handle_sign_out::<DB>))
327 .route("/list-sessions", get(handle_list_sessions::<DB>))
328 .route("/revoke-session", post(handle_revoke_session::<DB>))
329 .route("/revoke-sessions", post(handle_revoke_sessions::<DB>))
330 .route(
331 "/revoke-other-sessions",
332 post(handle_revoke_other_sessions::<DB>),
333 )
334 .layer(Extension(plugin_state))
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::plugins::test_helpers;
343 use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
344 use better_auth_core::{CreateSession, CreateUser, Session};
345 use chrono::{Duration, Utc};
346
347 #[tokio::test]
348 async fn test_get_session_success() {
349 let plugin = SessionManagementPlugin::new();
350 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
351 CreateUser::new()
352 .with_email("test@example.com")
353 .with_name("Test User"),
354 Duration::hours(24),
355 )
356 .await;
357
358 let req = test_helpers::create_auth_request_no_query(
359 HttpMethod::Get,
360 "/get-session",
361 Some(&session.token),
362 None,
363 );
364 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
365
366 assert_eq!(response.status, 200);
367
368 let body_str = String::from_utf8(response.body).unwrap();
369 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
370 assert_eq!(
371 response_data["session"]["token"].as_str().unwrap(),
372 session.token
373 );
374 assert_eq!(
375 response_data["user"]["email"]
376 .as_str()
377 .map(|s| s.to_string()),
378 Some("test@example.com".to_string())
379 );
380 }
381
382 #[tokio::test]
383 async fn test_get_session_unauthorized() {
384 let plugin = SessionManagementPlugin::new();
385 let (ctx, _user, _session) = test_helpers::create_test_context_with_user(
386 CreateUser::new()
387 .with_email("test@example.com")
388 .with_name("Test User"),
389 Duration::hours(24),
390 )
391 .await;
392
393 let req =
394 test_helpers::create_auth_request_no_query(HttpMethod::Get, "/get-session", None, None);
395 let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
396 assert_eq!(err.status_code(), 401);
397 }
398
399 #[tokio::test]
400 async fn test_sign_out_success() {
401 let plugin = SessionManagementPlugin::new();
402 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
403 CreateUser::new()
404 .with_email("test@example.com")
405 .with_name("Test User"),
406 Duration::hours(24),
407 )
408 .await;
409
410 let req = test_helpers::create_auth_request_no_query(
411 HttpMethod::Post,
412 "/sign-out",
413 Some(&session.token),
414 Some(b"{}".to_vec()),
415 );
416 let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
417
418 assert_eq!(response.status, 200);
419
420 let body_str = String::from_utf8(response.body).unwrap();
421 let response_data: SuccessResponse = serde_json::from_str(&body_str).unwrap();
422 assert!(response_data.success);
423
424 let session_check = ctx.database.get_session(&session.token).await.unwrap();
425 assert!(session_check.is_none());
426 }
427
428 #[tokio::test]
429 async fn test_list_sessions_success() {
430 let plugin = SessionManagementPlugin::new();
431 let (ctx, user, session) = test_helpers::create_test_context_with_user(
432 CreateUser::new()
433 .with_email("test@example.com")
434 .with_name("Test User"),
435 Duration::hours(24),
436 )
437 .await;
438
439 let create_session2 = CreateSession {
440 user_id: user.id.clone(),
441 expires_at: Utc::now() + Duration::hours(24),
442 ip_address: Some("192.168.1.1".to_string()),
443 user_agent: Some("another-agent".to_string()),
444 impersonated_by: None,
445 active_organization_id: None,
446 };
447 ctx.database.create_session(create_session2).await.unwrap();
448
449 let req = test_helpers::create_auth_request_no_query(
450 HttpMethod::Get,
451 "/list-sessions",
452 Some(&session.token),
453 None,
454 );
455 let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
456
457 assert_eq!(response.status, 200);
458
459 let body_str = String::from_utf8(response.body).unwrap();
460 let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
461 assert_eq!(sessions.len(), 2);
462 }
463
464 #[tokio::test]
465 async fn test_revoke_session_success() {
466 let plugin = SessionManagementPlugin::new();
467 let (ctx, user, session) = test_helpers::create_test_context_with_user(
468 CreateUser::new()
469 .with_email("test@example.com")
470 .with_name("Test User"),
471 Duration::hours(24),
472 )
473 .await;
474
475 let create_session2 = CreateSession {
476 user_id: user.id.clone(),
477 expires_at: Utc::now() + Duration::hours(24),
478 ip_address: Some("192.168.1.1".to_string()),
479 user_agent: Some("another-agent".to_string()),
480 impersonated_by: None,
481 active_organization_id: None,
482 };
483 let session2 = ctx.database.create_session(create_session2).await.unwrap();
484
485 let body = serde_json::json!({ "token": session2.token });
486 let req = test_helpers::create_auth_request_no_query(
487 HttpMethod::Post,
488 "/revoke-session",
489 Some(&session.token),
490 Some(body.to_string().into_bytes()),
491 );
492
493 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
494 assert_eq!(response.status, 200);
495
496 let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
497 assert!(session2_check.is_none());
498
499 let session1_check = ctx.database.get_session(&session.token).await.unwrap();
500 assert!(session1_check.is_some());
501 }
502
503 #[tokio::test]
504 async fn test_revoke_session_forbidden_different_user() {
505 let plugin = SessionManagementPlugin::new();
506 let (ctx, _user1, session1) = test_helpers::create_test_context_with_user(
507 CreateUser::new()
508 .with_email("test@example.com")
509 .with_name("Test User"),
510 Duration::hours(24),
511 )
512 .await;
513
514 let create_user2 = CreateUser::new()
515 .with_email("user2@example.com")
516 .with_name("User Two");
517 let user2 = ctx.database.create_user(create_user2).await.unwrap();
518
519 let create_session2 = CreateSession {
520 user_id: user2.id,
521 expires_at: Utc::now() + Duration::hours(24),
522 ip_address: Some("192.168.1.1".to_string()),
523 user_agent: Some("another-agent".to_string()),
524 impersonated_by: None,
525 active_organization_id: None,
526 };
527 let session2 = ctx.database.create_session(create_session2).await.unwrap();
528
529 let body = serde_json::json!({ "token": session2.token });
530 let req = test_helpers::create_auth_request_no_query(
531 HttpMethod::Post,
532 "/revoke-session",
533 Some(&session1.token),
534 Some(body.to_string().into_bytes()),
535 );
536
537 let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
538 assert_eq!(err.status_code(), 403);
539 }
540
541 #[tokio::test]
542 async fn test_revoke_sessions_success() {
543 let plugin = SessionManagementPlugin::new();
544 let (ctx, user, session1) = test_helpers::create_test_context_with_user(
545 CreateUser::new()
546 .with_email("test@example.com")
547 .with_name("Test User"),
548 Duration::hours(24),
549 )
550 .await;
551
552 let create_session2 = CreateSession {
553 user_id: user.id.clone(),
554 expires_at: Utc::now() + Duration::hours(24),
555 ip_address: Some("192.168.1.1".to_string()),
556 user_agent: Some("another-agent".to_string()),
557 impersonated_by: None,
558 active_organization_id: None,
559 };
560 ctx.database.create_session(create_session2).await.unwrap();
561
562 let req = test_helpers::create_auth_request_no_query(
563 HttpMethod::Post,
564 "/revoke-sessions",
565 Some(&session1.token),
566 Some(b"{}".to_vec()),
567 );
568 let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
569
570 assert_eq!(response.status, 200);
571
572 let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
573 assert_eq!(user_sessions.len(), 0);
574 }
575
576 #[tokio::test]
577 async fn test_plugin_routes() {
578 let plugin = SessionManagementPlugin::new();
579 let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
580
581 assert_eq!(routes.len(), 7);
582 assert!(
583 routes
584 .iter()
585 .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
586 );
587 assert!(
588 routes
589 .iter()
590 .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
591 );
592 assert!(
593 routes
594 .iter()
595 .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
596 );
597 assert!(
598 routes
599 .iter()
600 .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
601 );
602 assert!(
603 routes
604 .iter()
605 .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
606 );
607 }
608
609 #[tokio::test]
610 async fn test_plugin_on_request_routing() {
611 let plugin = SessionManagementPlugin::new();
612 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
613 CreateUser::new()
614 .with_email("test@example.com")
615 .with_name("Test User"),
616 Duration::hours(24),
617 )
618 .await;
619
620 let req = test_helpers::create_auth_request_no_query(
622 HttpMethod::Get,
623 "/get-session",
624 Some(&session.token),
625 None,
626 );
627 let response = plugin.on_request(&req, &ctx).await.unwrap();
628 assert!(response.is_some());
629 assert_eq!(response.unwrap().status, 200);
630
631 let req = test_helpers::create_auth_request_no_query(
633 HttpMethod::Get,
634 "/invalid-route",
635 Some(&session.token),
636 None,
637 );
638 let response = plugin.on_request(&req, &ctx).await.unwrap();
639 assert!(response.is_none());
640 }
641
642 #[tokio::test]
643 async fn test_configuration() {
644 let plugin = SessionManagementPlugin::new()
645 .enable_session_listing(false)
646 .enable_session_revocation(false)
647 .require_authentication(false);
648
649 assert!(!plugin.config.enable_session_listing);
650 assert!(!plugin.config.enable_session_revocation);
651 assert!(!plugin.config.require_authentication);
652
653 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
654 CreateUser::new()
655 .with_email("test@example.com")
656 .with_name("Test User"),
657 Duration::hours(24),
658 )
659 .await;
660
661 let req = test_helpers::create_auth_request_no_query(
662 HttpMethod::Get,
663 "/list-sessions",
664 Some(&session.token),
665 None,
666 );
667 let response = plugin.on_request(&req, &ctx).await.unwrap();
668 assert!(response.is_none());
669
670 let req = test_helpers::create_auth_request_no_query(
671 HttpMethod::Post,
672 "/revoke-session",
673 Some(&session.token),
674 Some(b"{}".to_vec()),
675 );
676 let response = plugin.on_request(&req, &ctx).await.unwrap();
677 assert!(response.is_none());
678 }
679}