1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8
9use better_auth_core::utils::cookie_utils::create_clear_session_cookie;
10use better_auth_core::{AuthError, AuthResult};
11use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
12
13use super::StatusResponse;
14
15pub struct SessionManagementPlugin {
17 config: SessionManagementConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct SessionManagementConfig {
22 pub enable_session_listing: bool,
23 pub enable_session_revocation: bool,
24 pub require_authentication: bool,
25}
26
27#[derive(Debug, Deserialize, Validate)]
29struct RevokeSessionRequest {
30 #[validate(length(min = 1, message = "Token is required"))]
31 token: String,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
36struct SignOutResponse {
37 success: bool,
38}
39
40#[derive(Debug, Serialize)]
41struct GetSessionResponse<S: Serialize, U: Serialize> {
42 session: S,
43 user: U,
44}
45
46impl SessionManagementPlugin {
47 pub fn new() -> Self {
48 Self {
49 config: SessionManagementConfig::default(),
50 }
51 }
52
53 pub fn with_config(config: SessionManagementConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn enable_session_listing(mut self, enable: bool) -> Self {
58 self.config.enable_session_listing = enable;
59 self
60 }
61
62 pub fn enable_session_revocation(mut self, enable: bool) -> Self {
63 self.config.enable_session_revocation = enable;
64 self
65 }
66
67 pub fn require_authentication(mut self, require: bool) -> Self {
68 self.config.require_authentication = require;
69 self
70 }
71}
72
73impl Default for SessionManagementPlugin {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Default for SessionManagementConfig {
80 fn default() -> Self {
81 Self {
82 enable_session_listing: true,
83 enable_session_revocation: true,
84 require_authentication: true,
85 }
86 }
87}
88
89#[async_trait]
90impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
91 fn name(&self) -> &'static str {
92 "session-management"
93 }
94
95 fn routes(&self) -> Vec<AuthRoute> {
96 vec![
97 AuthRoute::get("/get-session", "get_session"),
98 AuthRoute::post("/get-session", "get_session_post"),
99 AuthRoute::post("/sign-out", "sign_out"),
100 AuthRoute::get("/list-sessions", "list_sessions"),
101 AuthRoute::post("/revoke-session", "revoke_session"),
102 AuthRoute::post("/revoke-sessions", "revoke_sessions"),
103 AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
104 ]
105 }
106
107 async fn on_request(
108 &self,
109 req: &AuthRequest,
110 ctx: &AuthContext<DB>,
111 ) -> AuthResult<Option<AuthResponse>> {
112 match (req.method(), req.path()) {
113 (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
114 Ok(Some(self.handle_get_session(req, ctx).await?))
115 }
116 (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
117 (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
118 Ok(Some(self.handle_list_sessions(req, ctx).await?))
119 }
120 (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
121 Ok(Some(self.handle_revoke_session(req, ctx).await?))
122 }
123 (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
124 Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
125 }
126 (HttpMethod::Post, "/revoke-other-sessions")
127 if self.config.enable_session_revocation =>
128 {
129 Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
130 }
131 _ => Ok(None),
132 }
133 }
134}
135
136impl SessionManagementPlugin {
138 async fn handle_get_session<DB: DatabaseAdapter>(
139 &self,
140 req: &AuthRequest,
141 ctx: &AuthContext<DB>,
142 ) -> AuthResult<AuthResponse> {
143 let (user, session) = ctx.require_session(req).await?;
144 let response = GetSessionResponse { session, user };
145 Ok(AuthResponse::json(200, &response)?)
146 }
147
148 async fn handle_sign_out<DB: DatabaseAdapter>(
149 &self,
150 req: &AuthRequest,
151 ctx: &AuthContext<DB>,
152 ) -> AuthResult<AuthResponse> {
153 let (_user, current_session) = ctx.require_session(req).await?;
154
155 ctx.database.delete_session(current_session.token()).await?;
156
157 let response = SignOutResponse { success: true };
158 let clear_cookie_header = create_clear_session_cookie(ctx);
159
160 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
161 }
162
163 async fn handle_list_sessions<DB: DatabaseAdapter>(
164 &self,
165 req: &AuthRequest,
166 ctx: &AuthContext<DB>,
167 ) -> AuthResult<AuthResponse> {
168 let (user, _) = ctx.require_session(req).await?;
169 let sessions = self.get_user_sessions(user.id(), ctx).await?;
170 Ok(AuthResponse::json(200, &sessions)?)
171 }
172
173 async fn handle_revoke_session<DB: DatabaseAdapter>(
174 &self,
175 req: &AuthRequest,
176 ctx: &AuthContext<DB>,
177 ) -> AuthResult<AuthResponse> {
178 let (user, _) = ctx.require_session(req).await?;
179
180 let revoke_req: RevokeSessionRequest = req
181 .body_as_json()
182 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
183
184 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
186 if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
187 && session_to_revoke.user_id() != user.id()
188 {
189 return Err(AuthError::forbidden(
190 "Cannot revoke session that belongs to another user",
191 ));
192 }
193
194 ctx.database.delete_session(&revoke_req.token).await?;
195
196 let response = StatusResponse { status: true };
197 Ok(AuthResponse::json(200, &response)?)
198 }
199
200 async fn handle_revoke_sessions<DB: DatabaseAdapter>(
201 &self,
202 req: &AuthRequest,
203 ctx: &AuthContext<DB>,
204 ) -> AuthResult<AuthResponse> {
205 let (user, _) = ctx.require_session(req).await?;
206 ctx.database.delete_user_sessions(user.id()).await?;
207
208 let response = StatusResponse { status: true };
209 Ok(AuthResponse::json(200, &response)?)
210 }
211
212 async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
213 &self,
214 req: &AuthRequest,
215 ctx: &AuthContext<DB>,
216 ) -> AuthResult<AuthResponse> {
217 let (user, current_session) = ctx.require_session(req).await?;
218
219 let all_sessions = self.get_user_sessions(user.id(), ctx).await?;
220 for session in all_sessions {
221 if session.token() != current_session.token() {
222 ctx.database.delete_session(session.token()).await?;
223 }
224 }
225
226 let response = StatusResponse { status: true };
227 Ok(AuthResponse::json(200, &response)?)
228 }
229
230 async fn get_user_sessions<DB: DatabaseAdapter>(
231 &self,
232 user_id: &str,
233 ctx: &AuthContext<DB>,
234 ) -> AuthResult<Vec<DB::Session>> {
235 ctx.database.get_user_sessions(user_id).await
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::plugins::test_helpers;
243 use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps};
244 use better_auth_core::{CreateSession, CreateUser, Session};
245 use chrono::{Duration, Utc};
246
247 #[tokio::test]
248 async fn test_get_session_success() {
249 let plugin = SessionManagementPlugin::new();
250 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
251 CreateUser::new()
252 .with_email("test@example.com")
253 .with_name("Test User"),
254 Duration::hours(24),
255 )
256 .await;
257
258 let req = test_helpers::create_auth_request_no_query(
259 HttpMethod::Get,
260 "/get-session",
261 Some(&session.token),
262 None,
263 );
264 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
265
266 assert_eq!(response.status, 200);
267
268 let body_str = String::from_utf8(response.body).unwrap();
269 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
270 assert_eq!(
271 response_data["session"]["token"].as_str().unwrap(),
272 session.token
273 );
274 assert_eq!(
275 response_data["user"]["email"]
276 .as_str()
277 .map(|s| s.to_string()),
278 Some("test@example.com".to_string())
279 );
280 }
281
282 #[tokio::test]
283 async fn test_get_session_unauthorized() {
284 let plugin = SessionManagementPlugin::new();
285 let (ctx, _user, _session) = test_helpers::create_test_context_with_user(
286 CreateUser::new()
287 .with_email("test@example.com")
288 .with_name("Test User"),
289 Duration::hours(24),
290 )
291 .await;
292
293 let req =
294 test_helpers::create_auth_request_no_query(HttpMethod::Get, "/get-session", None, None);
295 let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
296 assert_eq!(err.status_code(), 401);
297 }
298
299 #[tokio::test]
300 async fn test_sign_out_success() {
301 let plugin = SessionManagementPlugin::new();
302 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
303 CreateUser::new()
304 .with_email("test@example.com")
305 .with_name("Test User"),
306 Duration::hours(24),
307 )
308 .await;
309
310 let req = test_helpers::create_auth_request_no_query(
311 HttpMethod::Post,
312 "/sign-out",
313 Some(&session.token),
314 Some(b"{}".to_vec()),
315 );
316 let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
317
318 assert_eq!(response.status, 200);
319
320 let body_str = String::from_utf8(response.body).unwrap();
321 let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
322 assert!(response_data.success);
323
324 let session_check = ctx.database.get_session(&session.token).await.unwrap();
325 assert!(session_check.is_none());
326 }
327
328 #[tokio::test]
329 async fn test_list_sessions_success() {
330 let plugin = SessionManagementPlugin::new();
331 let (ctx, user, session) = test_helpers::create_test_context_with_user(
332 CreateUser::new()
333 .with_email("test@example.com")
334 .with_name("Test User"),
335 Duration::hours(24),
336 )
337 .await;
338
339 let create_session2 = CreateSession {
340 user_id: user.id.clone(),
341 expires_at: Utc::now() + Duration::hours(24),
342 ip_address: Some("192.168.1.1".to_string()),
343 user_agent: Some("another-agent".to_string()),
344 impersonated_by: None,
345 active_organization_id: None,
346 };
347 ctx.database.create_session(create_session2).await.unwrap();
348
349 let req = test_helpers::create_auth_request_no_query(
350 HttpMethod::Get,
351 "/list-sessions",
352 Some(&session.token),
353 None,
354 );
355 let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
356
357 assert_eq!(response.status, 200);
358
359 let body_str = String::from_utf8(response.body).unwrap();
360 let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
361 assert_eq!(sessions.len(), 2);
362 }
363
364 #[tokio::test]
365 async fn test_revoke_session_success() {
366 let plugin = SessionManagementPlugin::new();
367 let (ctx, user, session) = test_helpers::create_test_context_with_user(
368 CreateUser::new()
369 .with_email("test@example.com")
370 .with_name("Test User"),
371 Duration::hours(24),
372 )
373 .await;
374
375 let create_session2 = CreateSession {
376 user_id: user.id.clone(),
377 expires_at: Utc::now() + Duration::hours(24),
378 ip_address: Some("192.168.1.1".to_string()),
379 user_agent: Some("another-agent".to_string()),
380 impersonated_by: None,
381 active_organization_id: None,
382 };
383 let session2 = ctx.database.create_session(create_session2).await.unwrap();
384
385 let body = serde_json::json!({ "token": session2.token });
386 let req = test_helpers::create_auth_request_no_query(
387 HttpMethod::Post,
388 "/revoke-session",
389 Some(&session.token),
390 Some(body.to_string().into_bytes()),
391 );
392
393 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
394 assert_eq!(response.status, 200);
395
396 let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
397 assert!(session2_check.is_none());
398
399 let session1_check = ctx.database.get_session(&session.token).await.unwrap();
400 assert!(session1_check.is_some());
401 }
402
403 #[tokio::test]
404 async fn test_revoke_session_forbidden_different_user() {
405 let plugin = SessionManagementPlugin::new();
406 let (ctx, _user1, session1) = test_helpers::create_test_context_with_user(
407 CreateUser::new()
408 .with_email("test@example.com")
409 .with_name("Test User"),
410 Duration::hours(24),
411 )
412 .await;
413
414 let create_user2 = CreateUser::new()
415 .with_email("user2@example.com")
416 .with_name("User Two");
417 let user2 = ctx.database.create_user(create_user2).await.unwrap();
418
419 let create_session2 = CreateSession {
420 user_id: user2.id,
421 expires_at: Utc::now() + Duration::hours(24),
422 ip_address: Some("192.168.1.1".to_string()),
423 user_agent: Some("another-agent".to_string()),
424 impersonated_by: None,
425 active_organization_id: None,
426 };
427 let session2 = ctx.database.create_session(create_session2).await.unwrap();
428
429 let body = serde_json::json!({ "token": session2.token });
430 let req = test_helpers::create_auth_request_no_query(
431 HttpMethod::Post,
432 "/revoke-session",
433 Some(&session1.token),
434 Some(body.to_string().into_bytes()),
435 );
436
437 let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
438 assert_eq!(err.status_code(), 403);
439 }
440
441 #[tokio::test]
442 async fn test_revoke_sessions_success() {
443 let plugin = SessionManagementPlugin::new();
444 let (ctx, user, session1) = test_helpers::create_test_context_with_user(
445 CreateUser::new()
446 .with_email("test@example.com")
447 .with_name("Test User"),
448 Duration::hours(24),
449 )
450 .await;
451
452 let create_session2 = CreateSession {
453 user_id: user.id.clone(),
454 expires_at: Utc::now() + Duration::hours(24),
455 ip_address: Some("192.168.1.1".to_string()),
456 user_agent: Some("another-agent".to_string()),
457 impersonated_by: None,
458 active_organization_id: None,
459 };
460 ctx.database.create_session(create_session2).await.unwrap();
461
462 let req = test_helpers::create_auth_request_no_query(
463 HttpMethod::Post,
464 "/revoke-sessions",
465 Some(&session1.token),
466 Some(b"{}".to_vec()),
467 );
468 let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
469
470 assert_eq!(response.status, 200);
471
472 let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
473 assert_eq!(user_sessions.len(), 0);
474 }
475
476 #[tokio::test]
477 async fn test_plugin_routes() {
478 let plugin = SessionManagementPlugin::new();
479 let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
480
481 assert_eq!(routes.len(), 7);
482 assert!(
483 routes
484 .iter()
485 .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
486 );
487 assert!(
488 routes
489 .iter()
490 .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
491 );
492 assert!(
493 routes
494 .iter()
495 .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
496 );
497 assert!(
498 routes
499 .iter()
500 .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
501 );
502 assert!(
503 routes
504 .iter()
505 .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
506 );
507 }
508
509 #[tokio::test]
510 async fn test_plugin_on_request_routing() {
511 let plugin = SessionManagementPlugin::new();
512 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
513 CreateUser::new()
514 .with_email("test@example.com")
515 .with_name("Test User"),
516 Duration::hours(24),
517 )
518 .await;
519
520 let req = test_helpers::create_auth_request_no_query(
522 HttpMethod::Get,
523 "/get-session",
524 Some(&session.token),
525 None,
526 );
527 let response = plugin.on_request(&req, &ctx).await.unwrap();
528 assert!(response.is_some());
529 assert_eq!(response.unwrap().status, 200);
530
531 let req = test_helpers::create_auth_request_no_query(
533 HttpMethod::Get,
534 "/invalid-route",
535 Some(&session.token),
536 None,
537 );
538 let response = plugin.on_request(&req, &ctx).await.unwrap();
539 assert!(response.is_none());
540 }
541
542 #[tokio::test]
543 async fn test_configuration() {
544 let plugin = SessionManagementPlugin::new()
545 .enable_session_listing(false)
546 .enable_session_revocation(false)
547 .require_authentication(false);
548
549 assert!(!plugin.config.enable_session_listing);
550 assert!(!plugin.config.enable_session_revocation);
551 assert!(!plugin.config.require_authentication);
552
553 let (ctx, _user, session) = test_helpers::create_test_context_with_user(
554 CreateUser::new()
555 .with_email("test@example.com")
556 .with_name("Test User"),
557 Duration::hours(24),
558 )
559 .await;
560
561 let req = test_helpers::create_auth_request_no_query(
562 HttpMethod::Get,
563 "/list-sessions",
564 Some(&session.token),
565 None,
566 );
567 let response = plugin.on_request(&req, &ctx).await.unwrap();
568 assert!(response.is_none());
569
570 let req = test_helpers::create_auth_request_no_query(
571 HttpMethod::Post,
572 "/revoke-session",
573 Some(&session.token),
574 Some(b"{}".to_vec()),
575 );
576 let response = plugin.on_request(&req, &ctx).await.unwrap();
577 assert!(response.is_none());
578 }
579}