avl_auth/
session.rs

1//! Session management with distributed storage
2
3use crate::error::{AuthError, Result};
4use crate::models::Session;
5use chrono::{Duration, Utc};
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12pub struct SessionManager {
13    sessions: Arc<RwLock<HashMap<Uuid, Session>>>,
14    config: SessionConfig,
15}
16
17#[derive(Clone)]
18pub struct SessionConfig {
19    pub idle_timeout: Duration,
20    pub absolute_timeout: Duration,
21    pub max_concurrent_sessions: u32,
22    pub device_binding: bool,
23    pub ip_binding: bool,
24}
25
26impl SessionManager {
27    pub fn new(config: SessionConfig) -> Self {
28        Self {
29            sessions: Arc::new(RwLock::new(HashMap::new())),
30            config,
31        }
32    }
33
34    pub async fn create_session(
35        &self,
36        user_id: Uuid,
37        access_token: String,
38        refresh_token: String,
39        access_token_ttl: Duration,
40        refresh_token_ttl: Duration,
41        device_id: Option<String>,
42        ip_address: Option<IpAddr>,
43        user_agent: Option<String>,
44        scopes: Vec<String>,
45    ) -> Result<Session> {
46        // Check concurrent session limit
47        self.enforce_concurrent_limit(&user_id).await?;
48
49        let now = Utc::now();
50        let session = Session {
51            id: Uuid::new_v4(),
52            user_id,
53            access_token,
54            refresh_token,
55            token_type: "Bearer".to_string(),
56            expires_at: now + access_token_ttl,
57            refresh_expires_at: now + refresh_token_ttl,
58            device_id,
59            ip_address,
60            user_agent,
61            created_at: now,
62            last_active_at: now,
63            scopes,
64        };
65
66        let mut sessions = self.sessions.write().await;
67        sessions.insert(session.id, session.clone());
68
69        tracing::info!(
70            session_id = %session.id,
71            user_id = %user_id,
72            "Created new session"
73        );
74
75        Ok(session)
76    }
77
78    async fn enforce_concurrent_limit(&self, user_id: &Uuid) -> Result<()> {
79        let sessions = self.sessions.read().await;
80
81        let user_sessions: Vec<_> = sessions
82            .values()
83            .filter(|s| s.user_id == *user_id)
84            .collect();
85
86        if user_sessions.len() >= self.config.max_concurrent_sessions as usize {
87            // Get oldest session ID
88            if let Some(oldest) = user_sessions
89                .iter()
90                .min_by_key(|s| s.created_at)
91            {
92                let oldest_id = oldest.id;
93                drop(sessions); // Release read lock
94                self.delete_session(&oldest_id).await?;
95            }
96        }
97
98        Ok(())
99    }    pub async fn get_session(&self, session_id: &Uuid) -> Result<Session> {
100        let sessions = self.sessions.read().await;
101        sessions
102            .get(session_id)
103            .cloned()
104            .ok_or(AuthError::SessionNotFound)
105    }
106
107    pub async fn validate_session(
108        &self,
109        session_id: &Uuid,
110        ip_address: Option<IpAddr>,
111        device_id: Option<&str>,
112    ) -> Result<Session> {
113        let sessions = self.sessions.read().await;
114        let session = sessions
115            .get(session_id)
116            .ok_or(AuthError::SessionNotFound)?;
117
118        let now = Utc::now();
119
120        // Check if session expired
121        if now > session.expires_at {
122            drop(sessions);
123            self.delete_session(session_id).await?;
124            return Err(AuthError::SessionExpired);
125        }
126
127        // Check idle timeout
128        let idle_time = now - session.last_active_at;
129        if idle_time > self.config.idle_timeout {
130            drop(sessions);
131            self.delete_session(session_id).await?;
132            return Err(AuthError::SessionExpired);
133        }
134
135        // Check absolute timeout
136        let session_age = now - session.created_at;
137        if session_age > self.config.absolute_timeout {
138            drop(sessions);
139            self.delete_session(session_id).await?;
140            return Err(AuthError::SessionExpired);
141        }
142
143        // Verify device binding
144        if self.config.device_binding {
145            if let (Some(session_device), Some(req_device)) = (&session.device_id, device_id) {
146                if session_device != req_device {
147                    return Err(AuthError::InvalidToken("Device mismatch".to_string()));
148                }
149            }
150        }
151
152        // Verify IP binding
153        if self.config.ip_binding {
154            if let (Some(session_ip), Some(req_ip)) = (session.ip_address, ip_address) {
155                if session_ip != req_ip {
156                    return Err(AuthError::InvalidToken("IP address mismatch".to_string()));
157                }
158            }
159        }
160
161        Ok(session.clone())
162    }
163
164    pub async fn update_activity(&self, session_id: &Uuid) -> Result<()> {
165        let mut sessions = self.sessions.write().await;
166
167        if let Some(session) = sessions.get_mut(session_id) {
168            session.last_active_at = Utc::now();
169            Ok(())
170        } else {
171            Err(AuthError::SessionNotFound)
172        }
173    }
174
175    pub async fn refresh_session(
176        &self,
177        session_id: &Uuid,
178        new_access_token: String,
179        access_token_ttl: Duration,
180    ) -> Result<Session> {
181        let mut sessions = self.sessions.write().await;
182
183        let session = sessions
184            .get_mut(session_id)
185            .ok_or(AuthError::SessionNotFound)?;
186
187        // Check if refresh token is still valid
188        if Utc::now() > session.refresh_expires_at {
189            return Err(AuthError::SessionExpired);
190        }
191
192        session.access_token = new_access_token;
193        session.expires_at = Utc::now() + access_token_ttl;
194        session.last_active_at = Utc::now();
195
196        Ok(session.clone())
197    }
198
199    pub async fn delete_session(&self, session_id: &Uuid) -> Result<()> {
200        let mut sessions = self.sessions.write().await;
201        sessions.remove(session_id);
202
203        tracing::info!(session_id = %session_id, "Deleted session");
204        Ok(())
205    }
206
207    pub async fn delete_user_sessions(&self, user_id: &Uuid) -> Result<usize> {
208        let mut sessions = self.sessions.write().await;
209
210        let session_ids: Vec<_> = sessions
211            .iter()
212            .filter(|(_, s)| s.user_id == *user_id)
213            .map(|(id, _)| *id)
214            .collect();
215
216        let count = session_ids.len();
217
218        for id in session_ids {
219            sessions.remove(&id);
220        }
221
222        tracing::info!(user_id = %user_id, count, "Deleted user sessions");
223        Ok(count)
224    }
225
226    pub async fn list_user_sessions(&self, user_id: &Uuid) -> Vec<Session> {
227        let sessions = self.sessions.read().await;
228        sessions
229            .values()
230            .filter(|s| s.user_id == *user_id)
231            .cloned()
232            .collect()
233    }
234
235    pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
236        let mut sessions = self.sessions.write().await;
237        let now = Utc::now();
238
239        let initial_count = sessions.len();
240
241        sessions.retain(|_, session| {
242            let expired = now > session.expires_at
243                || (now - session.last_active_at) > self.config.idle_timeout;
244            !expired
245        });
246
247        let removed = initial_count - sessions.len();
248
249        if removed > 0 {
250            tracing::info!("Cleaned up {} expired sessions", removed);
251        }
252
253        Ok(removed)
254    }
255
256    pub async fn get_stats(&self) -> SessionStats {
257        let sessions = self.sessions.read().await;
258
259        let total = sessions.len();
260        let unique_users = sessions
261            .values()
262            .map(|s| s.user_id)
263            .collect::<std::collections::HashSet<_>>()
264            .len();
265
266        let now = Utc::now();
267        let active_last_hour = sessions
268            .values()
269            .filter(|s| (now - s.last_active_at) < Duration::hours(1))
270            .count();
271
272        SessionStats {
273            total_sessions: total,
274            unique_users,
275            active_last_hour,
276        }
277    }
278}
279
280#[derive(Debug, Clone, serde::Serialize)]
281pub struct SessionStats {
282    pub total_sessions: usize,
283    pub unique_users: usize,
284    pub active_last_hour: usize,
285}