1use std::collections::HashMap;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use chrono::{DateTime, Utc};
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10use tracing::{debug, warn};
11
12use crate::{AuthError, AuthState, Result, TokenSet, UserInfo};
13
14#[derive(Clone)]
16struct RateLimitEntry {
17 count: u32,
18 reset_at: SystemTime,
19}
20
21pub struct RateLimiter {
23 max_attempts: u32,
25 window: Duration,
27 entries: RwLock<HashMap<String, RateLimitEntry>>,
29}
30
31impl RateLimiter {
32 pub fn new(max_attempts: u32, window: Duration) -> Self {
34 Self {
35 max_attempts,
36 window,
37 entries: RwLock::new(HashMap::new()),
38 }
39 }
40
41 pub fn check(&self, key: &str) -> bool {
43 let mut entries = self.entries.write();
44 let now = SystemTime::now();
45
46 entries.retain(|_, entry| entry.reset_at > now);
48
49 let entry = entries.entry(key.to_string()).or_insert_with(|| {
50 RateLimitEntry {
51 count: 0,
52 reset_at: now + self.window,
53 }
54 });
55
56 if now >= entry.reset_at {
58 entry.count = 0;
59 entry.reset_at = now + self.window;
60 }
61
62 entry.count += 1;
63 let allowed = entry.count <= self.max_attempts;
64
65 if !allowed {
66 warn!("Rate limit exceeded for key: {}", key);
67 }
68
69 allowed
70 }
71
72 pub fn reset(&self, key: &str) {
74 self.entries.write().remove(key);
75 }
76
77 pub fn cleanup(&self) {
79 let now = SystemTime::now();
80 self.entries.write().retain(|_, entry| entry.reset_at > now);
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AuthSession {
87 pub id: String,
89 pub auth_state: Option<AuthState>,
91 pub tokens: Option<TokenSet>,
93 pub user_info: Option<UserInfo>,
95 pub provider: String,
97 pub created_at: DateTime<Utc>,
99 pub expires_at: DateTime<Utc>,
101 pub last_activity: DateTime<Utc>,
103 pub vpn_session_id: Option<String>,
105 pub client_ip: Option<String>,
107 pub metadata: HashMap<String, String>,
109}
110
111impl AuthSession {
112 pub fn new(provider: &str, lifetime: Duration) -> Self {
114 let now = Utc::now();
115 let auth_state = AuthState::new(Duration::from_secs(600)); Self {
118 id: Uuid::new_v4().to_string(),
119 auth_state: Some(auth_state),
120 tokens: None,
121 user_info: None,
122 provider: provider.to_string(),
123 created_at: now,
124 expires_at: now + chrono::Duration::from_std(lifetime)
125 .unwrap_or_else(|_| chrono::Duration::seconds(86400)), last_activity: now,
127 vpn_session_id: None,
128 client_ip: None,
129 metadata: HashMap::new(),
130 }
131 }
132
133 pub fn is_expired(&self) -> bool {
135 Utc::now() > self.expires_at
136 }
137
138 pub fn is_authenticated(&self) -> bool {
140 self.tokens.is_some() && self.user_info.is_some()
141 }
142
143 pub fn needs_token_refresh(&self) -> bool {
145 if let Some(tokens) = &self.tokens {
146 tokens.expires_within(chrono::Duration::minutes(5))
147 } else {
148 false
149 }
150 }
151
152 pub fn update_tokens(&mut self, tokens: TokenSet) {
154 self.tokens = Some(tokens);
155 self.last_activity = Utc::now();
156 }
157
158 pub fn update_user_info(&mut self, user_info: UserInfo) {
160 self.user_info = Some(user_info);
161 self.last_activity = Utc::now();
162 }
163
164 pub fn complete_auth(&mut self, tokens: TokenSet, user_info: UserInfo) {
166 self.tokens = Some(tokens);
167 self.user_info = Some(user_info);
168 self.auth_state = None; self.last_activity = Utc::now();
170 }
171
172 pub fn associate_vpn_session(&mut self, vpn_session_id: &str) {
174 self.vpn_session_id = Some(vpn_session_id.to_string());
175 self.last_activity = Utc::now();
176 }
177
178 pub fn extend(&mut self, duration: Duration) {
180 self.expires_at = Utc::now() + chrono::Duration::from_std(duration).unwrap();
181 }
182
183 pub fn touch(&mut self) {
185 self.last_activity = Utc::now();
186 }
187
188 pub fn duration(&self) -> chrono::Duration {
190 Utc::now() - self.created_at
191 }
192
193 pub fn idle_time(&self) -> chrono::Duration {
195 Utc::now() - self.last_activity
196 }
197
198 pub fn state(&self) -> Option<&str> {
200 self.auth_state.as_ref().map(|s| s.state.as_str())
201 }
202
203 pub fn email(&self) -> Option<&str> {
205 self.user_info.as_ref().and_then(|u| u.email.as_deref())
206 }
207
208 pub fn display_name(&self) -> Option<&str> {
210 self.user_info.as_ref().and_then(|u| u.name.as_deref())
211 }
212}
213
214pub struct AuthSessionManager {
216 sessions: RwLock<HashMap<String, AuthSession>>,
218 sessions_by_state: RwLock<HashMap<String, String>>,
220 default_lifetime: Duration,
222 max_sessions_per_user: usize,
224 lookup_rate_limiter: RateLimiter,
226}
227
228impl AuthSessionManager {
229 pub fn new(default_lifetime: Duration, max_sessions_per_user: usize) -> Self {
231 Self {
232 sessions: RwLock::new(HashMap::new()),
233 sessions_by_state: RwLock::new(HashMap::new()),
234 default_lifetime,
235 max_sessions_per_user,
236 lookup_rate_limiter: RateLimiter::new(100, Duration::from_secs(60)), }
238 }
239
240 pub fn create_session(&self, provider: &str) -> AuthSession {
242 let session = AuthSession::new(provider, self.default_lifetime);
243
244 let mut sessions = self.sessions.write();
246 let mut by_state = self.sessions_by_state.write();
247
248 if let Some(state) = session.state() {
249 by_state.insert(state.to_string(), session.id.clone());
250 }
251 sessions.insert(session.id.clone(), session.clone());
252
253 session
254 }
255
256 pub fn get_session(&self, id: &str, client_ip: Option<&str>) -> Option<AuthSession> {
258 let rate_limit_key = client_ip.unwrap_or(id);
260 if !self.lookup_rate_limiter.check(rate_limit_key) {
261 warn!("Rate limit exceeded for session lookup: {}", rate_limit_key);
262 return None;
263 }
264
265 if Uuid::parse_str(id).map(|u| u.get_version() != Some(uuid::Version::Random))
267 .unwrap_or(true) {
268 warn!("Invalid session ID format: {}", id);
269 return None;
270 }
271
272 self.sessions.read().get(id).cloned()
273 }
274
275 pub fn get_session_by_state(&self, state: &str, client_ip: Option<&str>) -> Option<AuthSession> {
277 let rate_limit_key = client_ip.unwrap_or(state);
279 if !self.lookup_rate_limiter.check(rate_limit_key) {
280 warn!("Rate limit exceeded for state lookup: {}", rate_limit_key);
281 return None;
282 }
283
284 let session_id = self.sessions_by_state.read().get(state)?.clone();
285 self.get_session(&session_id, client_ip)
286 }
287
288 pub fn update_session(&self, session: &AuthSession) -> Result<()> {
290 let mut sessions = self.sessions.write();
291 if sessions.contains_key(&session.id) {
292 sessions.insert(session.id.clone(), session.clone());
293 Ok(())
294 } else {
295 Err(AuthError::SessionNotFound)
296 }
297 }
298
299 pub fn remove_session(&self, id: &str) -> Option<AuthSession> {
301 let mut sessions = self.sessions.write();
302 let mut by_state = self.sessions_by_state.write();
303
304 if let Some(session) = sessions.remove(id) {
305 if let Some(state) = session.state() {
306 by_state.remove(state);
307 }
308 Some(session)
309 } else {
310 None
311 }
312 }
313
314 pub fn get_user_sessions(&self, email: &str, client_ip: Option<&str>) -> Vec<AuthSession> {
316 let rate_limit_key = client_ip.unwrap_or(email);
318 if !self.lookup_rate_limiter.check(rate_limit_key) {
319 warn!("Rate limit exceeded for user session lookup: {}", rate_limit_key);
320 return Vec::new();
321 }
322
323 self.sessions
324 .read()
325 .values()
326 .filter(|s| s.email() == Some(email))
327 .cloned()
328 .collect()
329 }
330
331 pub fn remove_user_sessions(&self, email: &str) -> usize {
333 let mut sessions = self.sessions.write();
334 let mut by_state = self.sessions_by_state.write();
335
336 let to_remove: Vec<_> = sessions
337 .iter()
338 .filter(|(_, s)| s.email() == Some(email))
339 .map(|(id, s)| (id.clone(), s.state().map(String::from)))
340 .collect();
341
342 for (id, state) in &to_remove {
343 sessions.remove(id);
344 if let Some(s) = state {
345 by_state.remove(s);
346 }
347 }
348
349 to_remove.len()
350 }
351
352 pub fn cleanup_expired(&self) -> usize {
354 let mut sessions = self.sessions.write();
355 let mut by_state = self.sessions_by_state.write();
356
357 let before = sessions.len();
358
359 let expired: Vec<_> = sessions
360 .iter()
361 .filter(|(_, s)| s.is_expired())
362 .map(|(id, s)| (id.clone(), s.state().map(String::from)))
363 .collect();
364
365 for (id, state) in &expired {
366 sessions.remove(id);
367 if let Some(s) = state {
368 by_state.remove(s);
369 }
370 }
371
372 before - sessions.len()
373 }
374
375 pub fn session_count(&self) -> usize {
377 self.sessions.read().len()
378 }
379
380 pub fn active_sessions(&self) -> Vec<AuthSession> {
382 self.sessions
383 .read()
384 .values()
385 .filter(|s| !s.is_expired() && s.is_authenticated())
386 .cloned()
387 .collect()
388 }
389}
390
391impl Default for AuthSessionManager {
392 fn default() -> Self {
393 Self::new(Duration::from_secs(86400), 5) }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_auth_session() {
403 let mut session = AuthSession::new("google", Duration::from_secs(3600));
404
405 assert!(!session.is_expired());
406 assert!(!session.is_authenticated());
407 assert!(session.state().is_some());
408 }
409
410 #[test]
411 fn test_session_manager() {
412 let manager = AuthSessionManager::default();
413
414 let session = manager.create_session("google");
415 let state = session.state().unwrap().to_string();
416
417 let found = manager.get_session(&session.id, None);
419 assert!(found.is_some());
420
421 let found = manager.get_session_by_state(&state, None);
423 assert!(found.is_some());
424
425 manager.remove_session(&session.id);
427 assert!(manager.get_session(&session.id, None).is_none());
428 }
429
430 #[test]
431 fn test_session_lifecycle() {
432 let manager = AuthSessionManager::default();
433
434 let mut session = manager.create_session("google");
435
436 let tokens = TokenSet {
438 access_token: "test-token".to_string(),
439 refresh_token: Some("refresh-token".to_string()),
440 id_token: None,
441 expires_at: Utc::now() + chrono::Duration::hours(1),
442 token_type: "Bearer".to_string(),
443 scopes: vec![],
444 };
445
446 let user_info = UserInfo {
447 sub: "user123".to_string(),
448 email: Some("user@example.com".to_string()),
449 email_verified: true,
450 name: Some("Test User".to_string()),
451 given_name: None,
452 family_name: None,
453 picture: None,
454 groups: vec![],
455 provider: "google".to_string(),
456 };
457
458 session.complete_auth(tokens, user_info);
459
460 assert!(session.is_authenticated());
461 assert!(session.auth_state.is_none()); assert_eq!(session.email(), Some("user@example.com"));
463 }
464}