1use crate::error::{AuthError, Result};
4use crate::models::Session;
5use chrono::{Duration, Utc};
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12pub struct SessionManager {
13 sessions: Arc<RwLock<HashMap<Uuid, Session>>>,
14 config: SessionConfig,
15}
16
17#[derive(Clone)]
18pub struct SessionConfig {
19 pub idle_timeout: Duration,
20 pub absolute_timeout: Duration,
21 pub max_concurrent_sessions: u32,
22 pub device_binding: bool,
23 pub ip_binding: bool,
24}
25
26impl SessionManager {
27 pub fn new(config: SessionConfig) -> Self {
28 Self {
29 sessions: Arc::new(RwLock::new(HashMap::new())),
30 config,
31 }
32 }
33
34 pub async fn create_session(
35 &self,
36 user_id: Uuid,
37 access_token: String,
38 refresh_token: String,
39 access_token_ttl: Duration,
40 refresh_token_ttl: Duration,
41 device_id: Option<String>,
42 ip_address: Option<IpAddr>,
43 user_agent: Option<String>,
44 scopes: Vec<String>,
45 ) -> Result<Session> {
46 self.enforce_concurrent_limit(&user_id).await?;
48
49 let now = Utc::now();
50 let session = Session {
51 id: Uuid::new_v4(),
52 user_id,
53 access_token,
54 refresh_token,
55 token_type: "Bearer".to_string(),
56 expires_at: now + access_token_ttl,
57 refresh_expires_at: now + refresh_token_ttl,
58 device_id,
59 ip_address,
60 user_agent,
61 created_at: now,
62 last_active_at: now,
63 scopes,
64 };
65
66 let mut sessions = self.sessions.write().await;
67 sessions.insert(session.id, session.clone());
68
69 tracing::info!(
70 session_id = %session.id,
71 user_id = %user_id,
72 "Created new session"
73 );
74
75 Ok(session)
76 }
77
78 async fn enforce_concurrent_limit(&self, user_id: &Uuid) -> Result<()> {
79 let sessions = self.sessions.read().await;
80
81 let user_sessions: Vec<_> = sessions
82 .values()
83 .filter(|s| s.user_id == *user_id)
84 .collect();
85
86 if user_sessions.len() >= self.config.max_concurrent_sessions as usize {
87 if let Some(oldest) = user_sessions
89 .iter()
90 .min_by_key(|s| s.created_at)
91 {
92 let oldest_id = oldest.id;
93 drop(sessions); self.delete_session(&oldest_id).await?;
95 }
96 }
97
98 Ok(())
99 } pub async fn get_session(&self, session_id: &Uuid) -> Result<Session> {
100 let sessions = self.sessions.read().await;
101 sessions
102 .get(session_id)
103 .cloned()
104 .ok_or(AuthError::SessionNotFound)
105 }
106
107 pub async fn validate_session(
108 &self,
109 session_id: &Uuid,
110 ip_address: Option<IpAddr>,
111 device_id: Option<&str>,
112 ) -> Result<Session> {
113 let sessions = self.sessions.read().await;
114 let session = sessions
115 .get(session_id)
116 .ok_or(AuthError::SessionNotFound)?;
117
118 let now = Utc::now();
119
120 if now > session.expires_at {
122 drop(sessions);
123 self.delete_session(session_id).await?;
124 return Err(AuthError::SessionExpired);
125 }
126
127 let idle_time = now - session.last_active_at;
129 if idle_time > self.config.idle_timeout {
130 drop(sessions);
131 self.delete_session(session_id).await?;
132 return Err(AuthError::SessionExpired);
133 }
134
135 let session_age = now - session.created_at;
137 if session_age > self.config.absolute_timeout {
138 drop(sessions);
139 self.delete_session(session_id).await?;
140 return Err(AuthError::SessionExpired);
141 }
142
143 if self.config.device_binding {
145 if let (Some(session_device), Some(req_device)) = (&session.device_id, device_id) {
146 if session_device != req_device {
147 return Err(AuthError::InvalidToken("Device mismatch".to_string()));
148 }
149 }
150 }
151
152 if self.config.ip_binding {
154 if let (Some(session_ip), Some(req_ip)) = (session.ip_address, ip_address) {
155 if session_ip != req_ip {
156 return Err(AuthError::InvalidToken("IP address mismatch".to_string()));
157 }
158 }
159 }
160
161 Ok(session.clone())
162 }
163
164 pub async fn update_activity(&self, session_id: &Uuid) -> Result<()> {
165 let mut sessions = self.sessions.write().await;
166
167 if let Some(session) = sessions.get_mut(session_id) {
168 session.last_active_at = Utc::now();
169 Ok(())
170 } else {
171 Err(AuthError::SessionNotFound)
172 }
173 }
174
175 pub async fn refresh_session(
176 &self,
177 session_id: &Uuid,
178 new_access_token: String,
179 access_token_ttl: Duration,
180 ) -> Result<Session> {
181 let mut sessions = self.sessions.write().await;
182
183 let session = sessions
184 .get_mut(session_id)
185 .ok_or(AuthError::SessionNotFound)?;
186
187 if Utc::now() > session.refresh_expires_at {
189 return Err(AuthError::SessionExpired);
190 }
191
192 session.access_token = new_access_token;
193 session.expires_at = Utc::now() + access_token_ttl;
194 session.last_active_at = Utc::now();
195
196 Ok(session.clone())
197 }
198
199 pub async fn delete_session(&self, session_id: &Uuid) -> Result<()> {
200 let mut sessions = self.sessions.write().await;
201 sessions.remove(session_id);
202
203 tracing::info!(session_id = %session_id, "Deleted session");
204 Ok(())
205 }
206
207 pub async fn delete_user_sessions(&self, user_id: &Uuid) -> Result<usize> {
208 let mut sessions = self.sessions.write().await;
209
210 let session_ids: Vec<_> = sessions
211 .iter()
212 .filter(|(_, s)| s.user_id == *user_id)
213 .map(|(id, _)| *id)
214 .collect();
215
216 let count = session_ids.len();
217
218 for id in session_ids {
219 sessions.remove(&id);
220 }
221
222 tracing::info!(user_id = %user_id, count, "Deleted user sessions");
223 Ok(count)
224 }
225
226 pub async fn list_user_sessions(&self, user_id: &Uuid) -> Vec<Session> {
227 let sessions = self.sessions.read().await;
228 sessions
229 .values()
230 .filter(|s| s.user_id == *user_id)
231 .cloned()
232 .collect()
233 }
234
235 pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
236 let mut sessions = self.sessions.write().await;
237 let now = Utc::now();
238
239 let initial_count = sessions.len();
240
241 sessions.retain(|_, session| {
242 let expired = now > session.expires_at
243 || (now - session.last_active_at) > self.config.idle_timeout;
244 !expired
245 });
246
247 let removed = initial_count - sessions.len();
248
249 if removed > 0 {
250 tracing::info!("Cleaned up {} expired sessions", removed);
251 }
252
253 Ok(removed)
254 }
255
256 pub async fn get_stats(&self) -> SessionStats {
257 let sessions = self.sessions.read().await;
258
259 let total = sessions.len();
260 let unique_users = sessions
261 .values()
262 .map(|s| s.user_id)
263 .collect::<std::collections::HashSet<_>>()
264 .len();
265
266 let now = Utc::now();
267 let active_last_hour = sessions
268 .values()
269 .filter(|s| (now - s.last_active_at) < Duration::hours(1))
270 .count();
271
272 SessionStats {
273 total_sessions: total,
274 unique_users,
275 active_last_hour,
276 }
277 }
278}
279
280#[derive(Debug, Clone, serde::Serialize)]
281pub struct SessionStats {
282 pub total_sessions: usize,
283 pub unique_users: usize,
284 pub active_last_hour: usize,
285}