1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3
4use crate::core::{AuthPlugin, AuthRoute, AuthContext, SessionManager};
5use crate::types::{AuthRequest, AuthResponse, HttpMethod, User, Session};
6use crate::error::{AuthError, AuthResult};
7
8pub struct SessionManagementPlugin {
10 config: SessionManagementConfig,
11}
12
13#[derive(Debug, Clone)]
14pub struct SessionManagementConfig {
15 pub enable_session_listing: bool,
16 pub enable_session_revocation: bool,
17 pub require_authentication: bool,
18}
19
20#[derive(Debug, Deserialize)]
22struct RevokeSessionRequest {
23 token: String,
24}
25
26#[derive(Debug, Serialize, Deserialize)]
28struct GetSessionResponse {
29 session: Session,
30 user: User,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34struct SignOutResponse {
35 success: bool,
36}
37
38#[derive(Debug, Serialize, Deserialize)]
39struct StatusResponse {
40 status: bool,
41}
42
43impl SessionManagementPlugin {
44 pub fn new() -> Self {
45 Self {
46 config: SessionManagementConfig::default(),
47 }
48 }
49
50 pub fn with_config(config: SessionManagementConfig) -> Self {
51 Self { config }
52 }
53
54 pub fn enable_session_listing(mut self, enable: bool) -> Self {
55 self.config.enable_session_listing = enable;
56 self
57 }
58
59 pub fn enable_session_revocation(mut self, enable: bool) -> Self {
60 self.config.enable_session_revocation = enable;
61 self
62 }
63
64 pub fn require_authentication(mut self, require: bool) -> Self {
65 self.config.require_authentication = require;
66 self
67 }
68}
69
70impl Default for SessionManagementPlugin {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Default for SessionManagementConfig {
77 fn default() -> Self {
78 Self {
79 enable_session_listing: true,
80 enable_session_revocation: true,
81 require_authentication: true,
82 }
83 }
84}
85
86#[async_trait]
87impl AuthPlugin for SessionManagementPlugin {
88 fn name(&self) -> &'static str {
89 "session-management"
90 }
91
92 fn routes(&self) -> Vec<AuthRoute> {
93 vec![
94 AuthRoute::get("/get-session", "get_session"),
95 AuthRoute::post("/sign-out", "sign_out"),
96 AuthRoute::get("/list-sessions", "list_sessions"),
97 AuthRoute::post("/revoke-session", "revoke_session"),
98 AuthRoute::post("/revoke-sessions", "revoke_sessions"),
99 ]
100 }
101
102 async fn on_request(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<AuthResponse>> {
103 match (req.method(), req.path()) {
104 (HttpMethod::Get, "/get-session") => {
105 Ok(Some(self.handle_get_session(req, ctx).await?))
106 },
107 (HttpMethod::Post, "/sign-out") => {
108 Ok(Some(self.handle_sign_out(req, ctx).await?))
109 },
110 (HttpMethod::Get, "/list-sessions") if self.config.enable_session_listing => {
111 Ok(Some(self.handle_list_sessions(req, ctx).await?))
112 },
113 (HttpMethod::Post, "/revoke-session") if self.config.enable_session_revocation => {
114 Ok(Some(self.handle_revoke_session(req, ctx).await?))
115 },
116 (HttpMethod::Post, "/revoke-sessions") if self.config.enable_session_revocation => {
117 Ok(Some(self.handle_revoke_sessions(req, ctx).await?))
118 },
119 _ => Ok(None),
120 }
121 }
122}
123
124impl SessionManagementPlugin {
126 async fn handle_get_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
127 let (user, session) = match self.get_current_user_and_session(req, ctx).await? {
129 Some((user, session)) => (user, session),
130 None => {
131 return Ok(AuthResponse::json(401, &serde_json::json!({
132 "error": "Unauthorized",
133 "message": "No valid session found"
134 }))?);
135 }
136 };
137
138 let response = GetSessionResponse { session, user };
139 Ok(AuthResponse::json(200, &response)?)
140 }
141
142 async fn handle_sign_out(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
143 let (_user, current_session) = match self.get_current_user_and_session(req, ctx).await? {
145 Some((user, session)) => (user, session),
146 None => {
147 return Ok(AuthResponse::json(401, &serde_json::json!({
148 "error": "Unauthorized",
149 "message": "No valid session found"
150 }))?);
151 }
152 };
153
154 ctx.database.delete_session(¤t_session.token).await?;
156
157 let response = SignOutResponse { success: true };
158 let clear_cookie_header = self.create_clear_session_cookie(ctx);
160
161 Ok(AuthResponse::json(200, &response)?
162 .with_header("Set-Cookie", clear_cookie_header))
163 }
164
165 fn create_clear_session_cookie(&self, ctx: &AuthContext) -> String {
166 let session_config = &ctx.config.session;
167 let secure = if session_config.cookie_secure { "; Secure" } else { "" };
168 let http_only = if session_config.cookie_http_only { "; HttpOnly" } else { "" };
169 let same_site = match session_config.cookie_same_site {
170 crate::core::config::SameSite::Strict => "; SameSite=Strict",
171 crate::core::config::SameSite::Lax => "; SameSite=Lax",
172 crate::core::config::SameSite::None => "; SameSite=None",
173 };
174
175 format!("{}=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT{}{}{}",
176 session_config.cookie_name,
177 secure,
178 http_only,
179 same_site)
180 }
181
182 async fn handle_list_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
183 let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
185 Some((user, session)) => (user, session),
186 None => {
187 return Ok(AuthResponse::json(401, &serde_json::json!({
188 "error": "Unauthorized",
189 "message": "No valid session found"
190 }))?);
191 }
192 };
193
194 let sessions = self.get_user_sessions(&user.id, ctx).await?;
196
197 Ok(AuthResponse::json(200, &sessions)?)
199 }
200
201 async fn handle_revoke_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
202 let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
204 Some((user, session)) => (user, session),
205 None => {
206 return Ok(AuthResponse::json(401, &serde_json::json!({
207 "error": "Unauthorized",
208 "message": "No valid session found"
209 }))?);
210 }
211 };
212
213 let revoke_req: RevokeSessionRequest = match req.body_as_json() {
214 Ok(req) => req,
215 Err(e) => {
216 return Ok(AuthResponse::json(400, &serde_json::json!({
217 "error": "Invalid request",
218 "message": format!("Invalid JSON: {}", e)
219 }))?);
220 }
221 };
222
223 let session_token = &revoke_req.token;
225
226 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
228 if let Some(session_to_revoke) = session_manager.get_session(session_token).await? {
229 if session_to_revoke.user_id != user.id {
230 return Ok(AuthResponse::json(403, &serde_json::json!({
231 "error": "Forbidden",
232 "message": "Cannot revoke session that belongs to another user"
233 }))?);
234 }
235 }
236
237 ctx.database.delete_session(session_token).await?;
239
240 let response = StatusResponse { status: true };
241 Ok(AuthResponse::json(200, &response)?)
242 }
243
244 async fn handle_revoke_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
245 let (user, _current_session) = match self.get_current_user_and_session(req, ctx).await? {
247 Some((user, session)) => (user, session),
248 None => {
249 return Ok(AuthResponse::json(401, &serde_json::json!({
250 "error": "Unauthorized",
251 "message": "No valid session found"
252 }))?);
253 }
254 };
255
256 ctx.database.delete_user_sessions(&user.id).await?;
258
259 let response = StatusResponse { status: true };
260 Ok(AuthResponse::json(200, &response)?)
261 }
262
263 async fn handle_revoke_other_sessions(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
264 let (user, current_session) = match self.get_current_user_and_session(req, ctx).await? {
266 Some((user, session)) => (user, session),
267 None => {
268 return Ok(AuthResponse::json(401, &serde_json::json!({
269 "error": "Unauthorized",
270 "message": "No valid session found"
271 }))?);
272 }
273 };
274
275 let all_sessions = self.get_user_sessions(&user.id, ctx).await?;
277
278 for session in all_sessions {
280 if session.token != current_session.token {
281 ctx.database.delete_session(&session.token).await?;
282 }
283 }
284
285 let response = StatusResponse { status: true };
286 Ok(AuthResponse::json(200, &response)?)
287 }
288
289 async fn get_current_user_and_session(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<(User, Session)>> {
290 let token = if let Some(auth_header) = req.headers.get("authorization") {
292 if auth_header.starts_with("Bearer ") {
293 Some(&auth_header[7..])
294 } else {
295 None
296 }
297 } else {
298 None
299 };
300
301 if let Some(token) = token {
302 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
303 if let Some(session) = session_manager.get_session(token).await? {
304 if let Some(user) = ctx.database.get_user_by_id(&session.user_id).await? {
305 return Ok(Some((user, session)));
306 }
307 }
308 }
309
310 Ok(None)
311 }
312
313 async fn get_user_sessions(&self, user_id: &str, ctx: &AuthContext) -> AuthResult<Vec<Session>> {
314 ctx.database.get_user_sessions(user_id).await
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::core::config::AuthConfig;
322 use crate::adapters::{MemoryDatabaseAdapter, DatabaseAdapter};
323 use crate::types::{CreateUser, CreateSession};
324 use chrono::{Utc, Duration};
325 use std::collections::HashMap;
326 use std::sync::Arc;
327
328 async fn create_test_context_with_user() -> (AuthContext, User, Session) {
329 let config = Arc::new(AuthConfig::new("test-secret-key-at-least-32-chars-long"));
330 let database = Arc::new(MemoryDatabaseAdapter::new());
331 let ctx = AuthContext::new(config.clone(), database.clone());
332
333 let create_user = CreateUser::new()
335 .with_email("test@example.com")
336 .with_name("Test User");
337 let user = database.create_user(create_user).await.unwrap();
338
339 let create_session = CreateSession {
341 user_id: user.id.clone(),
342 expires_at: Utc::now() + Duration::hours(24),
343 ip_address: Some("127.0.0.1".to_string()),
344 user_agent: Some("test-agent".to_string()),
345 impersonated_by: None,
346 active_organization_id: None,
347 };
348 let session = database.create_session(create_session).await.unwrap();
349
350 (ctx, user, session)
351 }
352
353 fn create_auth_request(method: HttpMethod, path: &str, token: Option<&str>, body: Option<Vec<u8>>) -> AuthRequest {
354 let mut headers = HashMap::new();
355 if let Some(token) = token {
356 headers.insert("authorization".to_string(), format!("Bearer {}", token));
357 }
358
359 AuthRequest {
360 method,
361 path: path.to_string(),
362 headers,
363 body,
364 query: HashMap::new(),
365 }
366 }
367
368 #[tokio::test]
369 async fn test_get_session_success() {
370 let plugin = SessionManagementPlugin::new();
371 let (ctx, _user, session) = create_test_context_with_user().await;
372
373 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
374 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
375
376 assert_eq!(response.status, 200);
377
378 let body_str = String::from_utf8(response.body).unwrap();
379 let response_data: GetSessionResponse = serde_json::from_str(&body_str).unwrap();
380 assert_eq!(response_data.session.token, session.token);
381 assert_eq!(response_data.user.email, Some("test@example.com".to_string()));
382 }
383
384 #[tokio::test]
385 async fn test_get_session_unauthorized() {
386 let plugin = SessionManagementPlugin::new();
387 let (ctx, _user, _session) = create_test_context_with_user().await;
388
389 let req = create_auth_request(HttpMethod::Get, "/get-session", None, None);
390 let response = plugin.handle_get_session(&req, &ctx).await.unwrap();
391
392 assert_eq!(response.status, 401);
393 }
394
395 #[tokio::test]
396 async fn test_sign_out_success() {
397 let plugin = SessionManagementPlugin::new();
398 let (ctx, _user, session) = create_test_context_with_user().await;
399
400 let req = create_auth_request(HttpMethod::Post, "/sign-out", Some(&session.token), Some(b"{}".to_vec()));
401 let response = plugin.handle_sign_out(&req, &ctx).await.unwrap();
402
403 assert_eq!(response.status, 200);
404
405 let body_str = String::from_utf8(response.body).unwrap();
406 let response_data: SignOutResponse = serde_json::from_str(&body_str).unwrap();
407 assert!(response_data.success);
408
409 let session_check = ctx.database.get_session(&session.token).await.unwrap();
411 assert!(session_check.is_none());
412 }
413
414 #[tokio::test]
415 async fn test_list_sessions_success() {
416 let plugin = SessionManagementPlugin::new();
417 let (ctx, user, session) = create_test_context_with_user().await;
418
419 let create_session2 = CreateSession {
421 user_id: user.id.clone(),
422 expires_at: Utc::now() + Duration::hours(24),
423 ip_address: Some("192.168.1.1".to_string()),
424 user_agent: Some("another-agent".to_string()),
425 impersonated_by: None,
426 active_organization_id: None,
427 };
428 ctx.database.create_session(create_session2).await.unwrap();
429
430 let req = create_auth_request(HttpMethod::Get, "/list-sessions", Some(&session.token), None);
431 let response = plugin.handle_list_sessions(&req, &ctx).await.unwrap();
432
433 assert_eq!(response.status, 200);
434
435 let body_str = String::from_utf8(response.body).unwrap();
436 let sessions: Vec<Session> = serde_json::from_str(&body_str).unwrap();
437 assert_eq!(sessions.len(), 2);
438 }
439
440 #[tokio::test]
441 async fn test_revoke_session_success() {
442 let plugin = SessionManagementPlugin::new();
443 let (ctx, user, session) = create_test_context_with_user().await;
444
445 let create_session2 = CreateSession {
447 user_id: user.id.clone(),
448 expires_at: Utc::now() + Duration::hours(24),
449 ip_address: Some("192.168.1.1".to_string()),
450 user_agent: Some("another-agent".to_string()),
451 impersonated_by: None,
452 active_organization_id: None,
453 };
454 let session2 = ctx.database.create_session(create_session2).await.unwrap();
455
456 let body = serde_json::json!({ "token": session2.token });
457 let req = create_auth_request(
458 HttpMethod::Post,
459 "/revoke-session",
460 Some(&session.token),
461 Some(body.to_string().into_bytes())
462 );
463
464 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
465 assert_eq!(response.status, 200);
466
467 let body_str = String::from_utf8(response.body).unwrap();
468 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
469 assert!(response_data.status);
470
471 let session2_check = ctx.database.get_session(&session2.token).await.unwrap();
473 assert!(session2_check.is_none());
474
475 let session1_check = ctx.database.get_session(&session.token).await.unwrap();
476 assert!(session1_check.is_some());
477 }
478
479 #[tokio::test]
480 async fn test_revoke_session_forbidden_different_user() {
481 let plugin = SessionManagementPlugin::new();
482 let (ctx, _user1, session1) = create_test_context_with_user().await;
483
484 let create_user2 = CreateUser::new()
486 .with_email("user2@example.com")
487 .with_name("User Two");
488 let user2 = ctx.database.create_user(create_user2).await.unwrap();
489
490 let create_session2 = CreateSession {
491 user_id: user2.id,
492 expires_at: Utc::now() + Duration::hours(24),
493 ip_address: Some("192.168.1.1".to_string()),
494 user_agent: Some("another-agent".to_string()),
495 impersonated_by: None,
496 active_organization_id: None,
497 };
498 let session2 = ctx.database.create_session(create_session2).await.unwrap();
499
500 let body = serde_json::json!({ "token": session2.token });
502 let req = create_auth_request(
503 HttpMethod::Post,
504 "/revoke-session",
505 Some(&session1.token),
506 Some(body.to_string().into_bytes())
507 );
508
509 let response = plugin.handle_revoke_session(&req, &ctx).await.unwrap();
510 assert_eq!(response.status, 403);
511 }
512
513 #[tokio::test]
514 async fn test_revoke_sessions_success() {
515 let plugin = SessionManagementPlugin::new();
516 let (ctx, user, session1) = create_test_context_with_user().await;
517
518 let create_session2 = CreateSession {
520 user_id: user.id.clone(),
521 expires_at: Utc::now() + Duration::hours(24),
522 ip_address: Some("192.168.1.1".to_string()),
523 user_agent: Some("another-agent".to_string()),
524 impersonated_by: None,
525 active_organization_id: None,
526 };
527 ctx.database.create_session(create_session2).await.unwrap();
528
529 let req = create_auth_request(HttpMethod::Post, "/revoke-sessions", Some(&session1.token), Some(b"{}".to_vec()));
530 let response = plugin.handle_revoke_sessions(&req, &ctx).await.unwrap();
531
532 assert_eq!(response.status, 200);
533
534 let body_str = String::from_utf8(response.body).unwrap();
535 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
536 assert!(response_data.status);
537
538 let user_sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
540 assert_eq!(user_sessions.len(), 0);
541 }
542
543 #[tokio::test]
544 async fn test_plugin_routes() {
545 let plugin = SessionManagementPlugin::new();
546 let routes = plugin.routes();
547
548 assert_eq!(routes.len(), 5);
549 assert!(routes.iter().any(|r| r.path == "/get-session" && r.method == HttpMethod::Get));
550 assert!(routes.iter().any(|r| r.path == "/sign-out" && r.method == HttpMethod::Post));
551 assert!(routes.iter().any(|r| r.path == "/list-sessions" && r.method == HttpMethod::Get));
552 assert!(routes.iter().any(|r| r.path == "/revoke-session" && r.method == HttpMethod::Post));
553 assert!(routes.iter().any(|r| r.path == "/revoke-sessions" && r.method == HttpMethod::Post));
554 }
555
556 #[tokio::test]
557 async fn test_plugin_on_request_routing() {
558 let plugin = SessionManagementPlugin::new();
559 let (ctx, _user, session) = create_test_context_with_user().await;
560
561 let req = create_auth_request(HttpMethod::Get, "/get-session", Some(&session.token), None);
563 let response = plugin.on_request(&req, &ctx).await.unwrap();
564 assert!(response.is_some());
565 assert_eq!(response.unwrap().status, 200);
566
567 let req = create_auth_request(HttpMethod::Get, "/invalid-route", Some(&session.token), None);
569 let response = plugin.on_request(&req, &ctx).await.unwrap();
570 assert!(response.is_none());
571 }
572
573 #[tokio::test]
574 async fn test_configuration() {
575 let plugin = SessionManagementPlugin::new()
576 .enable_session_listing(false)
577 .enable_session_revocation(false)
578 .require_authentication(false);
579
580 assert!(!plugin.config.enable_session_listing);
581 assert!(!plugin.config.enable_session_revocation);
582 assert!(!plugin.config.require_authentication);
583
584 let (ctx, _user, session) = create_test_context_with_user().await;
585
586 let req = create_auth_request(HttpMethod::Get, "/list-sessions", Some(&session.token), None);
588 let response = plugin.on_request(&req, &ctx).await.unwrap();
589 assert!(response.is_none());
590
591 let req = create_auth_request(HttpMethod::Post, "/revoke-session", Some(&session.token), Some(b"{}".to_vec()));
592 let response = plugin.on_request(&req, &ctx).await.unwrap();
593 assert!(response.is_none());
594 }
595}