1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11pub struct SessionManagementPlugin {
13 config: SessionManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct SessionManagementConfig {
18 pub enable_session_listing: bool,
19 pub enable_session_revocation: bool,
20 pub require_authentication: bool,
21}
22
23#[derive(Debug, Deserialize, Validate)]
25struct RevokeSessionRequest {
26 #[validate(length(min = 1, message = "Token is required"))]
27 token: String,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
32struct SignOutResponse {
33 success: bool,
34}
35
36#[derive(Debug, Serialize)]
37struct GetSessionResponse<S: Serialize, U: Serialize> {
38 session: S,
39 user: U,
40}
41
42#[derive(Debug, Serialize, Deserialize)]
43struct StatusResponse {
44 status: bool,
45}
46
47impl SessionManagementPlugin {
48 pub fn new() -> Self {
49 Self {
50 config: SessionManagementConfig::default(),
51 }
52 }
53
54 pub fn with_config(config: SessionManagementConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn enable_session_listing(mut self, enable: bool) -> Self {
59 self.config.enable_session_listing = enable;
60 self
61 }
62
63 pub fn enable_session_revocation(mut self, enable: bool) -> Self {
64 self.config.enable_session_revocation = enable;
65 self
66 }
67
68 pub fn require_authentication(mut self, require: bool) -> Self {
69 self.config.require_authentication = require;
70 self
71 }
72}
73
74impl Default for SessionManagementPlugin {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Default for SessionManagementConfig {
81 fn default() -> Self {
82 Self {
83 enable_session_listing: true,
84 enable_session_revocation: true,
85 require_authentication: true,
86 }
87 }
88}
89
90#[async_trait]
91impl<DB: DatabaseAdapter> AuthPlugin<DB> for SessionManagementPlugin {
92 fn name(&self) -> &'static str {
93 "session-management"
94 }
95
96 fn routes(&self) -> Vec<AuthRoute> {
97 vec![
98 AuthRoute::get("/get-session", "get_session"),
99 AuthRoute::post("/get-session", "get_session_post"),
100 AuthRoute::post("/sign-out", "sign_out"),
101 AuthRoute::get("/list-sessions", "list_sessions"),
102 AuthRoute::post("/revoke-session", "revoke_session"),
103 AuthRoute::post("/revoke-sessions", "revoke_sessions"),
104 AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
105 ]
106 }
107
108 async fn on_request(
109 &self,
110 req: &AuthRequest,
111 ctx: &AuthContext<DB>,
112 ) -> AuthResult<Option<AuthResponse>> {
113 match (req.method(), req.path()) {
114 (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
115 Ok(Some(self.handle_get_session(req, ctx).await?))
116 }
117 (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
118 (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
119 Ok(Some(self.handle_list_sessions(req, ctx).await?))
120 }
121 (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
122 Ok(Some(self.handle_revoke_session(req, ctx).await?))
123 }
124 (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
125 Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
126 }
127 (HttpMethod::Post, "/revoke-other-sessions")
128 if self.config.enable_session_revocation =>
129 {
130 Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
131 }
132 _ => Ok(None),
133 }
134 }
135}
136
137impl SessionManagementPlugin {
139 async fn handle_get_session<DB: DatabaseAdapter>(
140 &self,
141 req: &AuthRequest,
142 ctx: &AuthContext<DB>,
143 ) -> AuthResult<AuthResponse> {
144 let (user, session) = self.require_session(req, ctx).await?;
145 let response = GetSessionResponse { session, user };
146 Ok(AuthResponse::json(200, &response)?)
147 }
148
149 async fn handle_sign_out<DB: DatabaseAdapter>(
150 &self,
151 req: &AuthRequest,
152 ctx: &AuthContext<DB>,
153 ) -> AuthResult<AuthResponse> {
154 let (_user, current_session) = self.require_session(req, ctx).await?;
155
156 ctx.database.delete_session(current_session.token()).await?;
157
158 let response = SignOutResponse { success: true };
159 let clear_cookie_header = self.create_clear_session_cookie(ctx);
160
161 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
162 }
163
164 fn create_clear_session_cookie<DB: DatabaseAdapter>(&self, ctx: &AuthContext<DB>) -> String {
165 let session_config = &ctx.config.session;
166 let secure = if session_config.cookie_secure {
167 "; Secure"
168 } else {
169 ""
170 };
171 let http_only = if session_config.cookie_http_only {
172 "; HttpOnly"
173 } else {
174 ""
175 };
176 let same_site = match session_config.cookie_same_site {
177 better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
178 better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
179 better_auth_core::config::SameSite::None => "; SameSite=None",
180 };
181
182 format!(
183 "{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
184 session_config.cookie_name, secure, http_only, same_site
185 )
186 }
187
188 async fn handle_list_sessions<DB: DatabaseAdapter>(
189 &self,
190 req: &AuthRequest,
191 ctx: &AuthContext<DB>,
192 ) -> AuthResult<AuthResponse> {
193 let (user, _) = self.require_session(req, ctx).await?;
194 let sessions = self.get_user_sessions(user.id(), ctx).await?;
195 Ok(AuthResponse::json(200, &sessions)?)
196 }
197
198 async fn handle_revoke_session<DB: DatabaseAdapter>(
199 &self,
200 req: &AuthRequest,
201 ctx: &AuthContext<DB>,
202 ) -> AuthResult<AuthResponse> {
203 let (user, _) = self.require_session(req, ctx).await?;
204
205 let revoke_req: RevokeSessionRequest = req
206 .body_as_json()
207 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
208
209 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
211 if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
212 && session_to_revoke.user_id() != user.id()
213 {
214 return Err(AuthError::forbidden(
215 "Cannot revoke session that belongs to another user",
216 ));
217 }
218
219 ctx.database.delete_session(&revoke_req.token).await?;
220
221 let response = StatusResponse { status: true };
222 Ok(AuthResponse::json(200, &response)?)
223 }
224
225 async fn handle_revoke_sessions<DB: DatabaseAdapter>(
226 &self,
227 req: &AuthRequest,
228 ctx: &AuthContext<DB>,
229 ) -> AuthResult<AuthResponse> {
230 let (user, _) = self.require_session(req, ctx).await?;
231 ctx.database.delete_user_sessions(user.id()).await?;
232
233 let response = StatusResponse { status: true };
234 Ok(AuthResponse::json(200, &response)?)
235 }
236
237 async fn handle_revoke_other_sessions<DB: DatabaseAdapter>(
238 &self,
239 req: &AuthRequest,
240 ctx: &AuthContext<DB>,
241 ) -> AuthResult<AuthResponse> {
242 let (user, current_session) = self.require_session(req, ctx).await?;
243
244 let all_sessions = self.get_user_sessions(user.id(), ctx).await?;
245 for session in all_sessions {
246 if session.token() != current_session.token() {
247 ctx.database.delete_session(session.token()).await?;
248 }
249 }
250
251 let response = StatusResponse { status: true };
252 Ok(AuthResponse::json(200, &response)?)
253 }
254
255 async fn require_session<DB: DatabaseAdapter>(
258 &self,
259 req: &AuthRequest,
260 ctx: &AuthContext<DB>,
261 ) -> AuthResult<(DB::User, DB::Session)> {
262 self.get_current_user_and_session(req, ctx)
263 .await?
264 .ok_or(AuthError::Unauthenticated)
265 }
266
267 async fn get_current_user_and_session<DB: DatabaseAdapter>(
268 &self,
269 req: &AuthRequest,
270 ctx: &AuthContext<DB>,
271 ) -> AuthResult<Option<(DB::User, DB::Session)>> {
272 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
273
274 if let Some(token) = session_manager.extract_session_token(req)
275 && let Some(session) = session_manager.get_session(&token).await?
276 && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
277 {
278 return Ok(Some((user, session)));
279 }
280
281 Ok(None)
282 }
283
284 async fn get_user_sessions<DB: DatabaseAdapter>(
285 &self,
286 user_id: &str,
287 ctx: &AuthContext<DB>,
288 ) -> AuthResult<Vec<DB::Session>> {
289 ctx.database.get_user_sessions(user_id).await
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use better_auth_core::adapters::{DatabaseAdapter, MemoryDatabaseAdapter};
297 use better_auth_core::config::AuthConfig;
298 use better_auth_core::{CreateSession, CreateUser, Session, User};
299 use chrono::{Duration, Utc};
300 use std::collections::HashMap;
301 use std::sync::Arc;
302
303 async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
304 {
305 let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
306 let database = Arc::new(MemoryDatabaseAdapter::new());
307 let ctx = AuthContext::new(config.clone(), database.clone());
308
309 let create_user = CreateUser::new()
310 .with_email("test@example.com")
311 .with_name("Test User");
312 let user = database.create_user(create_user).await.unwrap();
313
314 let create_session = CreateSession {
315 user_id: user.id.clone(),
316 expires_at: Utc::now() + Duration::hours(24),
317 ip_address: Some("127.0.0.1".to_string()),
318 user_agent: Some("test-agent".to_string()),
319 impersonated_by: None,
320 active_organization_id: None,
321 };
322 let session = database.create_session(create_session).await.unwrap();
323
324 (ctx, user, session)
325 }
326
327 fn create_auth_request(
328 method: HttpMethod,
329 path: &str,
330 token: Option<&str>,
331 body: Option<Vec<u8>>,
332 ) -> AuthRequest {
333 let mut headers = HashMap::new();
334 if let Some(token) = token {
335 headers.insert("authorization".to_string(), format!("Bearer {}", token));
336 }
337
338 AuthRequest {
339 method,
340 path: path.to_string(),
341 headers,
342 body,
343 query: HashMap::new(),
344 }
345 }
346
347 #[tokio::test]
348 async fn test_get_session_success() {
349 let plugin = SessionManagementPlugin::new();
350 let (ctx, _user, session) = create_test_context_with_user().await;
351
352 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
353 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
354
355 assert_eq!(response.status, 200);
356
357 let body_str = String::from_utf8(response.body).unwrap();
358 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
359 assert_eq!(
360 response_data["session"]["token"].as_str().unwrap(),
361 session.token
362 );
363 assert_eq!(
364 response_data["user"]["email"]
365 .as_str()
366 .map(|s| s.to_string()),
367 Some("test@example.com".to_string())
368 );
369 }
370
371 #[tokio::test]
372 async fn test_get_session_unauthorized() {
373 let plugin = SessionManagementPlugin::new();
374 let (ctx, _user, _session) = create_test_context_with_user().await;
375
376 let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
377 let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
378 assert_eq!(err.status_code(), 401);
379 }
380
381 #[tokio::test]
382 async fn test_sign_out_success() {
383 let plugin = SessionManagementPlugin::new();
384 let (ctx, _user, session) = create_test_context_with_user().await;
385
386 let req = create_auth_request(
387 HttpMethod::Post,
388 "/sign-out",
389 Some(&session.token),
390 Some(b"{}".to_vec()),
391 );
392 let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
393
394 assert_eq!(response.status, 200);
395
396 let body_str = String::from_utf8(response.body).unwrap();
397 let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
398 assert!(response_data.success);
399
400 let session_check = ctx.database.get_session(&session.token).await.unwrap();
401 assert!(session_check.is_none());
402 }
403
404 #[tokio::test]
405 async fn test_list_sessions_success() {
406 let plugin = SessionManagementPlugin::new();
407 let (ctx, user, session) = create_test_context_with_user().await;
408
409 let create_session2 = CreateSession {
410 user_id: user.id.clone(),
411 expires_at: Utc::now() + Duration::hours(24),
412 ip_address: Some("192.168.1.1".to_string()),
413 user_agent: Some("another-agent".to_string()),
414 impersonated_by: None,
415 active_organization_id: None,
416 };
417 ctx.database.create_session(create_session2).await.unwrap();
418
419 let req = create_auth_request(
420 HttpMethod::Get,
421 "/list-sessions",
422 Some(&session.token),
423 None,
424 );
425 let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
426
427 assert_eq!(response.status, 200);
428
429 let body_str = String::from_utf8(response.body).unwrap();
430 let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
431 assert_eq!(sessions.len(), 2);
432 }
433
434 #[tokio::test]
435 async fn test_revoke_session_success() {
436 let plugin = SessionManagementPlugin::new();
437 let (ctx, user, session) = create_test_context_with_user().await;
438
439 let create_session2 = CreateSession {
440 user_id: user.id.clone(),
441 expires_at: Utc::now() + Duration::hours(24),
442 ip_address: Some("192.168.1.1".to_string()),
443 user_agent: Some("another-agent".to_string()),
444 impersonated_by: None,
445 active_organization_id: None,
446 };
447 let session2 = ctx.database.create_session(create_session2).await.unwrap();
448
449 let body = serde_json::json!({ "token": session2.token });
450 let req = create_auth_request(
451 HttpMethod::Post,
452 "/revoke-session",
453 Some(&session.token),
454 Some(body.to_string().into_bytes()),
455 );
456
457 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
458 assert_eq!(response.status, 200);
459
460 let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
461 assert!(session2_check.is_none());
462
463 let session1_check = ctx.database.get_session(&session.token).await.unwrap();
464 assert!(session1_check.is_some());
465 }
466
467 #[tokio::test]
468 async fn test_revoke_session_forbidden_different_user() {
469 let plugin = SessionManagementPlugin::new();
470 let (ctx, _user1, session1) = create_test_context_with_user().await;
471
472 let create_user2 = CreateUser::new()
473 .with_email("user2@example.com")
474 .with_name("User Two");
475 let user2 = ctx.database.create_user(create_user2).await.unwrap();
476
477 let create_session2 = CreateSession {
478 user_id: user2.id,
479 expires_at: Utc::now() + Duration::hours(24),
480 ip_address: Some("192.168.1.1".to_string()),
481 user_agent: Some("another-agent".to_string()),
482 impersonated_by: None,
483 active_organization_id: None,
484 };
485 let session2 = ctx.database.create_session(create_session2).await.unwrap();
486
487 let body = serde_json::json!({ "token": session2.token });
488 let req = create_auth_request(
489 HttpMethod::Post,
490 "/revoke-session",
491 Some(&session1.token),
492 Some(body.to_string().into_bytes()),
493 );
494
495 let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
496 assert_eq!(err.status_code(), 403);
497 }
498
499 #[tokio::test]
500 async fn test_revoke_sessions_success() {
501 let plugin = SessionManagementPlugin::new();
502 let (ctx, user, session1) = create_test_context_with_user().await;
503
504 let create_session2 = CreateSession {
505 user_id: user.id.clone(),
506 expires_at: Utc::now() + Duration::hours(24),
507 ip_address: Some("192.168.1.1".to_string()),
508 user_agent: Some("another-agent".to_string()),
509 impersonated_by: None,
510 active_organization_id: None,
511 };
512 ctx.database.create_session(create_session2).await.unwrap();
513
514 let req = create_auth_request(
515 HttpMethod::Post,
516 "/revoke-sessions",
517 Some(&session1.token),
518 Some(b"{}".to_vec()),
519 );
520 let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
521
522 assert_eq!(response.status, 200);
523
524 let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
525 assert_eq!(user_sessions.len(), 0);
526 }
527
528 #[tokio::test]
529 async fn test_plugin_routes() {
530 let plugin = SessionManagementPlugin::new();
531 let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
532
533 assert_eq!(routes.len(), 7);
534 assert!(
535 routes
536 .iter()
537 .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
538 );
539 assert!(
540 routes
541 .iter()
542 .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
543 );
544 assert!(
545 routes
546 .iter()
547 .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
548 );
549 assert!(
550 routes
551 .iter()
552 .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
553 );
554 assert!(
555 routes
556 .iter()
557 .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
558 );
559 }
560
561 #[tokio::test]
562 async fn test_plugin_on_request_routing() {
563 let plugin = SessionManagementPlugin::new();
564 let (ctx, _user, session) = create_test_context_with_user().await;
565
566 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
568 let response = plugin.on_request(&req, &ctx).await.unwrap();
569 assert!(response.is_some());
570 assert_eq!(response.unwrap().status, 200);
571
572 let req = create_auth_request(
574 HttpMethod::Get,
575 "/invalid-route",
576 Some(&session.token),
577 None,
578 );
579 let response = plugin.on_request(&req, &ctx).await.unwrap();
580 assert!(response.is_none());
581 }
582
583 #[tokio::test]
584 async fn test_configuration() {
585 let plugin = SessionManagementPlugin::new()
586 .enable_session_listing(false)
587 .enable_session_revocation(false)
588 .require_authentication(false);
589
590 assert!(!plugin.config.enable_session_listing);
591 assert!(!plugin.config.enable_session_revocation);
592 assert!(!plugin.config.require_authentication);
593
594 let (ctx, _user, session) = create_test_context_with_user().await;
595
596 let req = create_auth_request(
597 HttpMethod::Get,
598 "/list-sessions",
599 Some(&session.token),
600 None,
601 );
602 let response = plugin.on_request(&req, &ctx).await.unwrap();
603 assert!(response.is_none());
604
605 let req = create_auth_request(
606 HttpMethod::Post,
607 "/revoke-session",
608 Some(&session.token),
609 Some(b"{}".to_vec()),
610 );
611 let response = plugin.on_request(&req, &ctx).await.unwrap();
612 assert!(response.is_none());
613 }
614}