1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
6use better_auth_core::{AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse, HttpMethod, Session, User};
8
9pub struct SessionManagementPlugin {
11 config: SessionManagementConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct SessionManagementConfig {
16 pub enable_session_listing: bool,
17 pub enable_session_revocation: bool,
18 pub require_authentication: bool,
19}
20
21#[derive(Debug, Deserialize, Validate)]
23struct RevokeSessionRequest {
24 #[validate(length(min = 1, message = "Token is required"))]
25 token: String,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
30struct GetSessionResponse {
31 session: Session,
32 user: User,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36struct SignOutResponse {
37 success: bool,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41struct StatusResponse {
42 status: bool,
43}
44
45impl SessionManagementPlugin {
46 pub fn new() -> Self {
47 Self {
48 config: SessionManagementConfig::default(),
49 }
50 }
51
52 pub fn with_config(config: SessionManagementConfig) -> Self {
53 Self { config }
54 }
55
56 pub fn enable_session_listing(mut self, enable: bool) -> Self {
57 self.config.enable_session_listing = enable;
58 self
59 }
60
61 pub fn enable_session_revocation(mut self, enable: bool) -> Self {
62 self.config.enable_session_revocation = enable;
63 self
64 }
65
66 pub fn require_authentication(mut self, require: bool) -> Self {
67 self.config.require_authentication = require;
68 self
69 }
70}
71
72impl Default for SessionManagementPlugin {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl Default for SessionManagementConfig {
79 fn default() -> Self {
80 Self {
81 enable_session_listing: true,
82 enable_session_revocation: true,
83 require_authentication: true,
84 }
85 }
86}
87
88#[async_trait]
89impl AuthPlugin for SessionManagementPlugin {
90 fn name(&self) -> &'static str {
91 "session-management"
92 }
93
94 fn routes(&self) -> Vec<AuthRoute> {
95 vec![
96 AuthRoute::get("/get-session", "get_session"),
97 AuthRoute::post("/get-session", "get_session_post"),
98 AuthRoute::post("/sign-out", "sign_out"),
99 AuthRoute::get("/list-sessions", "list_sessions"),
100 AuthRoute::post("/revoke-session", "revoke_session"),
101 AuthRoute::post("/revoke-sessions", "revoke_sessions"),
102 AuthRoute::post("/revoke-other-sessions", "revoke_other_sessions"),
103 ]
104 }
105
106 async fn on_request(
107 &self,
108 req: &AuthRequest,
109 ctx: &AuthContext,
110 ) -> AuthResult<Option<AuthResponse>> {
111 match (req.method(), req.path()) {
112 (HttpMethod::Get | HttpMethod::Post, "/get-session") => {
113 Ok(Some(self.handle_get_session(req, ctx).await?))
114 }
115 (HttpMethod::Post, "/sign-out") => Ok(Some(self.handle_sign_out(req, ctx).await?)),
116 (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
117 Ok(Some(self.handle_list_sessions(req, ctx).await?))
118 }
119 (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
120 Ok(Some(self.handle_revoke_session(req, ctx).await?))
121 }
122 (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
123 Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
124 }
125 (HttpMethod::Post, "/revoke-other-sessions")
126 if self.config.enable_session_revocation =>
127 {
128 Ok(Some(self.handle_revoke_other_sessions(req, ctx).await?))
129 }
130 _ => Ok(None),
131 }
132 }
133}
134
135impl SessionManagementPlugin {
137 async fn handle_get_session(
138 &self,
139 req: &AuthRequest,
140 ctx: &AuthContext,
141 ) -> AuthResult<AuthResponse> {
142 let (user, session) = self.require_session(req, ctx).await?;
143 let response = GetSessionResponse { session, user };
144 Ok(AuthResponse::json(200, &response)?)
145 }
146
147 async fn handle_sign_out(
148 &self,
149 req: &AuthRequest,
150 ctx: &AuthContext,
151 ) -> AuthResult<AuthResponse> {
152 let (_user, current_session) = self.require_session(req, ctx).await?;
153
154 ctx.database.delete_session(¤t_session.token).await?;
155
156 let response = SignOutResponse { success: true };
157 let clear_cookie_header = self.create_clear_session_cookie(ctx);
158
159 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", clear_cookie_header))
160 }
161
162 fn create_clear_session_cookie(&self, ctx: &AuthContext) -> String {
163 let session_config = &ctx.config.session;
164 let secure = if session_config.cookie_secure {
165 "; Secure"
166 } else {
167 ""
168 };
169 let http_only = if session_config.cookie_http_only {
170 "; HttpOnly"
171 } else {
172 ""
173 };
174 let same_site = match session_config.cookie_same_site {
175 better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
176 better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
177 better_auth_core::config::SameSite::None => "; SameSite=None",
178 };
179
180 format!(
181 "{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
182 session_config.cookie_name, secure, http_only, same_site
183 )
184 }
185
186 async fn handle_list_sessions(
187 &self,
188 req: &AuthRequest,
189 ctx: &AuthContext,
190 ) -> AuthResult<AuthResponse> {
191 let (user, _) = self.require_session(req, ctx).await?;
192 let sessions = self.get_user_sessions(&user.id, ctx).await?;
193 Ok(AuthResponse::json(200, &sessions)?)
194 }
195
196 async fn handle_revoke_session(
197 &self,
198 req: &AuthRequest,
199 ctx: &AuthContext,
200 ) -> AuthResult<AuthResponse> {
201 let (user, _) = self.require_session(req, ctx).await?;
202
203 let revoke_req: RevokeSessionRequest = req
204 .body_as_json()
205 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
206
207 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
209 if let Some(session_to_revoke) = session_manager.get_session(&revoke_req.token).await?
210 && session_to_revoke.user_id != user.id
211 {
212 return Err(AuthError::forbidden(
213 "Cannot revoke session that belongs to another user",
214 ));
215 }
216
217 ctx.database.delete_session(&revoke_req.token).await?;
218
219 let response = StatusResponse { status: true };
220 Ok(AuthResponse::json(200, &response)?)
221 }
222
223 async fn handle_revoke_sessions(
224 &self,
225 req: &AuthRequest,
226 ctx: &AuthContext,
227 ) -> AuthResult<AuthResponse> {
228 let (user, _) = self.require_session(req, ctx).await?;
229 ctx.database.delete_user_sessions(&user.id).await?;
230
231 let response = StatusResponse { status: true };
232 Ok(AuthResponse::json(200, &response)?)
233 }
234
235 async fn handle_revoke_other_sessions(
236 &self,
237 req: &AuthRequest,
238 ctx: &AuthContext,
239 ) -> AuthResult<AuthResponse> {
240 let (user, current_session) = self.require_session(req, ctx).await?;
241
242 let all_sessions = self.get_user_sessions(&user.id, ctx).await?;
243 for session in all_sessions {
244 if session.token != current_session.token {
245 ctx.database.delete_session(&session.token).await?;
246 }
247 }
248
249 let response = StatusResponse { status: true };
250 Ok(AuthResponse::json(200, &response)?)
251 }
252
253 async fn require_session(
256 &self,
257 req: &AuthRequest,
258 ctx: &AuthContext,
259 ) -> AuthResult<(User, Session)> {
260 self.get_current_user_and_session(req, ctx)
261 .await?
262 .ok_or(AuthError::Unauthenticated)
263 }
264
265 async fn get_current_user_and_session(
266 &self,
267 req: &AuthRequest,
268 ctx: &AuthContext,
269 ) -> AuthResult<Option<(User, Session)>> {
270 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
271
272 if let Some(token) = session_manager.extract_session_token(req)
273 && let Some(session) = session_manager.get_session(&token).await?
274 && let Some(user) = ctx.database.get_user_by_id(&session.user_id).await?
275 {
276 return Ok(Some((user, session)));
277 }
278
279 Ok(None)
280 }
281
282 async fn get_user_sessions(
283 &self,
284 user_id: &str,
285 ctx: &AuthContext,
286 ) -> AuthResult<Vec<Session>> {
287 ctx.database.get_user_sessions(user_id).await
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use better_auth_core::adapters::{DatabaseAdapter, MemoryDatabaseAdapter};
295 use better_auth_core::config::AuthConfig;
296 use better_auth_core::{CreateSession, CreateUser};
297 use chrono::{Duration, Utc};
298 use std::collections::HashMap;
299 use std::sync::Arc;
300
301 async fn create_test_context_with_user() -> (AuthContext, User, Session) {
302 let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
303 let database = Arc::new(MemoryDatabaseAdapter::new());
304 let ctx = AuthContext::new(config.clone(), database.clone());
305
306 let create_user = CreateUser::new()
307 .with_email("test@example.com")
308 .with_name("Test User");
309 let user = database.create_user(create_user).await.unwrap();
310
311 let create_session = CreateSession {
312 user_id: user.id.clone(),
313 expires_at: Utc::now() + Duration::hours(24),
314 ip_address: Some("127.0.0.1".to_string()),
315 user_agent: Some("test-agent".to_string()),
316 impersonated_by: None,
317 active_organization_id: None,
318 };
319 let session = database.create_session(create_session).await.unwrap();
320
321 (ctx, user, session)
322 }
323
324 fn create_auth_request(
325 method: HttpMethod,
326 path: &str,
327 token: Option<&str>,
328 body: Option<Vec<u8>>,
329 ) -> AuthRequest {
330 let mut headers = HashMap::new();
331 if let Some(token) = token {
332 headers.insert("authorization".to_string(), format!("Bearer {}", token));
333 }
334
335 AuthRequest {
336 method,
337 path: path.to_string(),
338 headers,
339 body,
340 query: HashMap::new(),
341 }
342 }
343
344 #[tokio::test]
345 async fn test_get_session_success() {
346 let plugin = SessionManagementPlugin::new();
347 let (ctx, _user, session) = create_test_context_with_user().await;
348
349 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
350 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
351
352 assert_eq!(response.status, 200);
353
354 let body_str = String::from_utf8(response.body).unwrap();
355 let response_data: GetSessionResponse = serde_json::from_str(&body_str).unwrap();
356 assert_eq!(response_data.session.token, session.token);
357 assert_eq!(
358 response_data.user.email,
359 Some("test@example.com".to_string())
360 );
361 }
362
363 #[tokio::test]
364 async fn test_get_session_unauthorized() {
365 let plugin = SessionManagementPlugin::new();
366 let (ctx, _user, _session) = create_test_context_with_user().await;
367
368 let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
369 let err = plugin.handle_get_session(&req, &ctx).await.unwrap_err();
370 assert_eq!(err.status_code(), 401);
371 }
372
373 #[tokio::test]
374 async fn test_sign_out_success() {
375 let plugin = SessionManagementPlugin::new();
376 let (ctx, _user, session) = create_test_context_with_user().await;
377
378 let req = create_auth_request(
379 HttpMethod::Post,
380 "/sign-out",
381 Some(&session.token),
382 Some(b"{}".to_vec()),
383 );
384 let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
385
386 assert_eq!(response.status, 200);
387
388 let body_str = String::from_utf8(response.body).unwrap();
389 let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
390 assert!(response_data.success);
391
392 let session_check = ctx.database.get_session(&session.token).await.unwrap();
393 assert!(session_check.is_none());
394 }
395
396 #[tokio::test]
397 async fn test_list_sessions_success() {
398 let plugin = SessionManagementPlugin::new();
399 let (ctx, user, session) = create_test_context_with_user().await;
400
401 let create_session2 = CreateSession {
402 user_id: user.id.clone(),
403 expires_at: Utc::now() + Duration::hours(24),
404 ip_address: Some("192.168.1.1".to_string()),
405 user_agent: Some("another-agent".to_string()),
406 impersonated_by: None,
407 active_organization_id: None,
408 };
409 ctx.database.create_session(create_session2).await.unwrap();
410
411 let req = create_auth_request(
412 HttpMethod::Get,
413 "/list-sessions",
414 Some(&session.token),
415 None,
416 );
417 let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
418
419 assert_eq!(response.status, 200);
420
421 let body_str = String::from_utf8(response.body).unwrap();
422 let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
423 assert_eq!(sessions.len(), 2);
424 }
425
426 #[tokio::test]
427 async fn test_revoke_session_success() {
428 let plugin = SessionManagementPlugin::new();
429 let (ctx, user, session) = create_test_context_with_user().await;
430
431 let create_session2 = CreateSession {
432 user_id: user.id.clone(),
433 expires_at: Utc::now() + Duration::hours(24),
434 ip_address: Some("192.168.1.1".to_string()),
435 user_agent: Some("another-agent".to_string()),
436 impersonated_by: None,
437 active_organization_id: None,
438 };
439 let session2 = ctx.database.create_session(create_session2).await.unwrap();
440
441 let body = serde_json::json!({ "token": session2.token });
442 let req = create_auth_request(
443 HttpMethod::Post,
444 "/revoke-session",
445 Some(&session.token),
446 Some(body.to_string().into_bytes()),
447 );
448
449 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
450 assert_eq!(response.status, 200);
451
452 let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
453 assert!(session2_check.is_none());
454
455 let session1_check = ctx.database.get_session(&session.token).await.unwrap();
456 assert!(session1_check.is_some());
457 }
458
459 #[tokio::test]
460 async fn test_revoke_session_forbidden_different_user() {
461 let plugin = SessionManagementPlugin::new();
462 let (ctx, _user1, session1) = create_test_context_with_user().await;
463
464 let create_user2 = CreateUser::new()
465 .with_email("user2@example.com")
466 .with_name("User Two");
467 let user2 = ctx.database.create_user(create_user2).await.unwrap();
468
469 let create_session2 = CreateSession {
470 user_id: user2.id,
471 expires_at: Utc::now() + Duration::hours(24),
472 ip_address: Some("192.168.1.1".to_string()),
473 user_agent: Some("another-agent".to_string()),
474 impersonated_by: None,
475 active_organization_id: None,
476 };
477 let session2 = ctx.database.create_session(create_session2).await.unwrap();
478
479 let body = serde_json::json!({ "token": session2.token });
480 let req = create_auth_request(
481 HttpMethod::Post,
482 "/revoke-session",
483 Some(&session1.token),
484 Some(body.to_string().into_bytes()),
485 );
486
487 let err = plugin.handle_revoke_session(&req, &ctx).await.unwrap_err();
488 assert_eq!(err.status_code(), 403);
489 }
490
491 #[tokio::test]
492 async fn test_revoke_sessions_success() {
493 let plugin = SessionManagementPlugin::new();
494 let (ctx, user, session1) = create_test_context_with_user().await;
495
496 let create_session2 = CreateSession {
497 user_id: user.id.clone(),
498 expires_at: Utc::now() + Duration::hours(24),
499 ip_address: Some("192.168.1.1".to_string()),
500 user_agent: Some("another-agent".to_string()),
501 impersonated_by: None,
502 active_organization_id: None,
503 };
504 ctx.database.create_session(create_session2).await.unwrap();
505
506 let req = create_auth_request(
507 HttpMethod::Post,
508 "/revoke-sessions",
509 Some(&session1.token),
510 Some(b"{}".to_vec()),
511 );
512 let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
513
514 assert_eq!(response.status, 200);
515
516 let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
517 assert_eq!(user_sessions.len(), 0);
518 }
519
520 #[tokio::test]
521 async fn test_plugin_routes() {
522 let plugin = SessionManagementPlugin::new();
523 let routes = plugin.routes();
524
525 assert_eq!(routes.len(), 7);
526 assert!(
527 routes
528 .iter()
529 .any(|r| r.path == "/get-session" && r.method == HttpMethod::Get)
530 );
531 assert!(
532 routes
533 .iter()
534 .any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post)
535 );
536 assert!(
537 routes
538 .iter()
539 .any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get)
540 );
541 assert!(
542 routes
543 .iter()
544 .any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post)
545 );
546 assert!(
547 routes
548 .iter()
549 .any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post)
550 );
551 }
552
553 #[tokio::test]
554 async fn test_plugin_on_request_routing() {
555 let plugin = SessionManagementPlugin::new();
556 let (ctx, _user, session) = create_test_context_with_user().await;
557
558 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
560 let response = plugin.on_request(&req, &ctx).await.unwrap();
561 assert!(response.is_some());
562 assert_eq!(response.unwrap().status, 200);
563
564 let req = create_auth_request(
566 HttpMethod::Get,
567 "/invalid-route",
568 Some(&session.token),
569 None,
570 );
571 let response = plugin.on_request(&req, &ctx).await.unwrap();
572 assert!(response.is_none());
573 }
574
575 #[tokio::test]
576 async fn test_configuration() {
577 let plugin = SessionManagementPlugin::new()
578 .enable_session_listing(false)
579 .enable_session_revocation(false)
580 .require_authentication(false);
581
582 assert!(!plugin.config.enable_session_listing);
583 assert!(!plugin.config.enable_session_revocation);
584 assert!(!plugin.config.require_authentication);
585
586 let (ctx, _user, session) = create_test_context_with_user().await;
587
588 let req = create_auth_request(
589 HttpMethod::Get,
590 "/list-sessions",
591 Some(&session.token),
592 None,
593 );
594 let response = plugin.on_request(&req, &ctx).await.unwrap();
595 assert!(response.is_none());
596
597 let req = create_auth_request(
598 HttpMethod::Post,
599 "/revoke-session",
600 Some(&session.token),
601 Some(b"{}".to_vec()),
602 );
603 let response = plugin.on_request(&req, &ctx).await.unwrap();
604 assert!(response.is_none());
605 }
606}