pjson_rs/domain/services/
connection_manager.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::RwLock;
5
6use crate::domain::DomainError;
7use crate::domain::value_objects::{SessionId, StreamId};
8
9/// Connection state tracking
10#[derive(Debug, Clone)]
11pub struct ConnectionState {
12    pub session_id: SessionId,
13    pub stream_id: Option<StreamId>,
14    pub connected_at: Instant,
15    pub last_activity: Instant,
16    pub bytes_sent: usize,
17    pub bytes_received: usize,
18    pub is_active: bool,
19}
20
21/// Connection lifecycle events
22#[derive(Debug, Clone)]
23pub enum ConnectionEvent {
24    Connected(SessionId),
25    Disconnected(SessionId),
26    Timeout(SessionId),
27    Error(SessionId, String),
28}
29
30/// Connection manager service
31pub struct ConnectionManager {
32    connections: Arc<RwLock<HashMap<SessionId, ConnectionState>>>,
33    timeout_duration: Duration,
34    max_connections: usize,
35}
36
37impl ConnectionManager {
38    pub fn new(timeout_duration: Duration, max_connections: usize) -> Self {
39        Self {
40            connections: Arc::new(RwLock::new(HashMap::new())),
41            timeout_duration,
42            max_connections,
43        }
44    }
45
46    /// Register a new connection
47    pub async fn register_connection(&self, session_id: SessionId) -> Result<(), DomainError> {
48        let mut connections = self.connections.write().await;
49
50        if connections.len() >= self.max_connections {
51            return Err(DomainError::ValidationError(
52                "Maximum connections reached".to_string(),
53            ));
54        }
55
56        let state = ConnectionState {
57            session_id,
58            stream_id: None,
59            connected_at: Instant::now(),
60            last_activity: Instant::now(),
61            bytes_sent: 0,
62            bytes_received: 0,
63            is_active: true,
64        };
65
66        connections.insert(session_id, state);
67        Ok(())
68    }
69
70    /// Update connection activity
71    pub async fn update_activity(&self, session_id: &SessionId) -> Result<(), DomainError> {
72        let mut connections = self.connections.write().await;
73
74        match connections.get_mut(session_id) {
75            Some(state) => {
76                state.last_activity = Instant::now();
77                Ok(())
78            }
79            None => Err(DomainError::ValidationError(format!(
80                "Connection not found: {session_id}"
81            ))),
82        }
83    }
84
85    /// Update connection metrics
86    pub async fn update_metrics(
87        &self,
88        session_id: &SessionId,
89        bytes_sent: usize,
90        bytes_received: usize,
91    ) -> Result<(), DomainError> {
92        let mut connections = self.connections.write().await;
93
94        match connections.get_mut(session_id) {
95            Some(state) => {
96                state.bytes_sent += bytes_sent;
97                state.bytes_received += bytes_received;
98                state.last_activity = Instant::now();
99                Ok(())
100            }
101            None => Err(DomainError::ValidationError(format!(
102                "Connection not found: {session_id}"
103            ))),
104        }
105    }
106
107    /// Associate stream with connection
108    pub async fn set_stream(
109        &self,
110        session_id: &SessionId,
111        stream_id: StreamId,
112    ) -> Result<(), DomainError> {
113        let mut connections = self.connections.write().await;
114
115        match connections.get_mut(session_id) {
116            Some(state) => {
117                state.stream_id = Some(stream_id);
118                state.last_activity = Instant::now();
119                Ok(())
120            }
121            None => Err(DomainError::ValidationError(format!(
122                "Connection not found: {session_id}"
123            ))),
124        }
125    }
126
127    /// Close connection
128    pub async fn close_connection(&self, session_id: &SessionId) -> Result<(), DomainError> {
129        let mut connections = self.connections.write().await;
130
131        match connections.get_mut(session_id) {
132            Some(state) => {
133                state.is_active = false;
134                Ok(())
135            }
136            None => Err(DomainError::ValidationError(format!(
137                "Connection not found: {session_id}"
138            ))),
139        }
140    }
141
142    /// Remove connection completely
143    pub async fn remove_connection(&self, session_id: &SessionId) -> Result<(), DomainError> {
144        let mut connections = self.connections.write().await;
145
146        match connections.remove(session_id) {
147            Some(_) => Ok(()),
148            None => Err(DomainError::ValidationError(format!(
149                "Connection not found: {session_id}"
150            ))),
151        }
152    }
153
154    /// Get connection state
155    pub async fn get_connection(&self, session_id: &SessionId) -> Option<ConnectionState> {
156        let connections = self.connections.read().await;
157        connections.get(session_id).cloned()
158    }
159
160    /// Get all active connections
161    pub async fn get_active_connections(&self) -> Vec<ConnectionState> {
162        let connections = self.connections.read().await;
163        connections
164            .values()
165            .filter(|state| state.is_active)
166            .cloned()
167            .collect()
168    }
169
170    /// Check for timed out connections
171    pub async fn check_timeouts(&self) -> Vec<SessionId> {
172        let now = Instant::now();
173        let connections = self.connections.read().await;
174
175        connections
176            .values()
177            .filter(|state| {
178                state.is_active && now.duration_since(state.last_activity) > self.timeout_duration
179            })
180            .map(|state| state.session_id)
181            .collect()
182    }
183
184    /// Process timeout check iteration (to be called by infrastructure layer)
185    pub async fn process_timeouts(&self) {
186        let timed_out = self.check_timeouts().await;
187        for session_id in timed_out {
188            if let Err(e) = self.close_connection(&session_id).await {
189                tracing::warn!("Failed to close timed out connection: {e}");
190            }
191        }
192    }
193
194    /// Get connection statistics
195    pub async fn get_statistics(&self) -> ConnectionStatistics {
196        let connections = self.connections.read().await;
197
198        let active_count = connections.values().filter(|s| s.is_active).count();
199        let total_bytes_sent: usize = connections.values().map(|s| s.bytes_sent).sum();
200        let total_bytes_received: usize = connections.values().map(|s| s.bytes_received).sum();
201
202        ConnectionStatistics {
203            total_connections: connections.len(),
204            active_connections: active_count,
205            inactive_connections: connections.len() - active_count,
206            total_bytes_sent,
207            total_bytes_received,
208        }
209    }
210}
211
212/// Connection statistics
213#[derive(Debug, Clone)]
214pub struct ConnectionStatistics {
215    pub total_connections: usize,
216    pub active_connections: usize,
217    pub inactive_connections: usize,
218    pub total_bytes_sent: usize,
219    pub total_bytes_received: usize,
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[tokio::test]
227    async fn test_connection_lifecycle() {
228        let manager = ConnectionManager::new(Duration::from_secs(60), 100);
229        let session_id = SessionId::new();
230
231        // Register connection
232        assert!(manager.register_connection(session_id).await.is_ok());
233
234        // Get connection state
235        let state = manager.get_connection(&session_id).await;
236        assert!(state.is_some());
237        assert!(state.unwrap().is_active);
238
239        // Update activity
240        assert!(manager.update_activity(&session_id).await.is_ok());
241
242        // Update metrics
243        assert!(manager.update_metrics(&session_id, 100, 50).await.is_ok());
244
245        // Close connection
246        assert!(manager.close_connection(&session_id).await.is_ok());
247
248        // Verify closed
249        let state = manager.get_connection(&session_id).await;
250        assert!(state.is_some());
251        assert!(!state.unwrap().is_active);
252
253        // Remove connection
254        assert!(manager.remove_connection(&session_id).await.is_ok());
255
256        // Verify removed
257        let state = manager.get_connection(&session_id).await;
258        assert!(state.is_none());
259    }
260
261    #[tokio::test]
262    async fn test_max_connections() {
263        let manager = ConnectionManager::new(Duration::from_secs(60), 2);
264
265        // Register max connections
266        let session1 = SessionId::new();
267        let session2 = SessionId::new();
268        let session3 = SessionId::new();
269
270        assert!(manager.register_connection(session1).await.is_ok());
271        assert!(manager.register_connection(session2).await.is_ok());
272
273        // Should fail - max reached
274        assert!(manager.register_connection(session3).await.is_err());
275    }
276
277    #[tokio::test]
278    async fn test_timeout_detection() {
279        let manager = ConnectionManager::new(Duration::from_millis(100), 10);
280        let session_id = SessionId::new();
281
282        assert!(manager.register_connection(session_id).await.is_ok());
283
284        // Wait for timeout
285        tokio::time::sleep(Duration::from_millis(150)).await;
286
287        let timed_out = manager.check_timeouts().await;
288        assert_eq!(timed_out.len(), 1);
289        assert_eq!(timed_out[0], session_id);
290    }
291}