use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::domain::DomainError;
use crate::domain::value_objects::{SessionId, StreamId};
#[derive(Debug, Clone)]
pub struct ConnectionState {
pub session_id: SessionId,
pub stream_id: Option<StreamId>,
pub connected_at: Instant,
pub last_activity: Instant,
pub bytes_sent: usize,
pub bytes_received: usize,
pub is_active: bool,
}
#[derive(Debug, Clone)]
pub enum ConnectionEvent {
Connected(SessionId),
Disconnected(SessionId),
Timeout(SessionId),
Error(SessionId, String),
}
pub struct ConnectionManager {
connections: Arc<RwLock<HashMap<SessionId, ConnectionState>>>,
timeout_duration: Duration,
max_connections: usize,
}
impl ConnectionManager {
pub fn new(timeout_duration: Duration, max_connections: usize) -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
timeout_duration,
max_connections,
}
}
pub async fn register_connection(&self, session_id: SessionId) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
if connections.len() >= self.max_connections {
return Err(DomainError::ValidationError(
"Maximum connections reached".to_string(),
));
}
let state = ConnectionState {
session_id,
stream_id: None,
connected_at: Instant::now(),
last_activity: Instant::now(),
bytes_sent: 0,
bytes_received: 0,
is_active: true,
};
connections.insert(session_id, state);
Ok(())
}
pub async fn update_activity(&self, session_id: &SessionId) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
match connections.get_mut(session_id) {
Some(state) => {
state.last_activity = Instant::now();
Ok(())
}
None => Err(DomainError::ValidationError(format!(
"Connection not found: {session_id}"
))),
}
}
pub async fn update_metrics(
&self,
session_id: &SessionId,
bytes_sent: usize,
bytes_received: usize,
) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
match connections.get_mut(session_id) {
Some(state) => {
state.bytes_sent += bytes_sent;
state.bytes_received += bytes_received;
state.last_activity = Instant::now();
Ok(())
}
None => Err(DomainError::ValidationError(format!(
"Connection not found: {session_id}"
))),
}
}
pub async fn set_stream(
&self,
session_id: &SessionId,
stream_id: StreamId,
) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
match connections.get_mut(session_id) {
Some(state) => {
state.stream_id = Some(stream_id);
state.last_activity = Instant::now();
Ok(())
}
None => Err(DomainError::ValidationError(format!(
"Connection not found: {session_id}"
))),
}
}
pub async fn close_connection(&self, session_id: &SessionId) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
match connections.get_mut(session_id) {
Some(state) => {
state.is_active = false;
Ok(())
}
None => Err(DomainError::ValidationError(format!(
"Connection not found: {session_id}"
))),
}
}
pub async fn remove_connection(&self, session_id: &SessionId) -> Result<(), DomainError> {
let mut connections = self.connections.write().await;
match connections.remove(session_id) {
Some(_) => Ok(()),
None => Err(DomainError::ValidationError(format!(
"Connection not found: {session_id}"
))),
}
}
pub async fn get_connection(&self, session_id: &SessionId) -> Option<ConnectionState> {
let connections = self.connections.read().await;
connections.get(session_id).cloned()
}
pub async fn get_active_connections(&self) -> Vec<ConnectionState> {
let connections = self.connections.read().await;
connections
.values()
.filter(|state| state.is_active)
.cloned()
.collect()
}
pub async fn check_timeouts(&self) -> Vec<SessionId> {
let now = Instant::now();
let connections = self.connections.read().await;
connections
.values()
.filter(|state| {
state.is_active && now.duration_since(state.last_activity) > self.timeout_duration
})
.map(|state| state.session_id)
.collect()
}
pub async fn process_timeouts(&self) {
let timed_out = self.check_timeouts().await;
for session_id in timed_out {
if let Err(e) = self.close_connection(&session_id).await {
tracing::warn!("Failed to close timed out connection: {e}");
}
}
}
pub async fn get_statistics(&self) -> ConnectionStatistics {
let connections = self.connections.read().await;
let active_count = connections.values().filter(|s| s.is_active).count();
let total_bytes_sent: usize = connections.values().map(|s| s.bytes_sent).sum();
let total_bytes_received: usize = connections.values().map(|s| s.bytes_received).sum();
ConnectionStatistics {
total_connections: connections.len(),
active_connections: active_count,
inactive_connections: connections.len() - active_count,
total_bytes_sent,
total_bytes_received,
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionStatistics {
pub total_connections: usize,
pub active_connections: usize,
pub inactive_connections: usize,
pub total_bytes_sent: usize,
pub total_bytes_received: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connection_lifecycle() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
assert!(manager.register_connection(session_id).await.is_ok());
let state = manager.get_connection(&session_id).await;
assert!(state.is_some());
assert!(state.unwrap().is_active);
assert!(manager.update_activity(&session_id).await.is_ok());
assert!(manager.update_metrics(&session_id, 100, 50).await.is_ok());
assert!(manager.close_connection(&session_id).await.is_ok());
let state = manager.get_connection(&session_id).await;
assert!(state.is_some());
assert!(!state.unwrap().is_active);
assert!(manager.remove_connection(&session_id).await.is_ok());
let state = manager.get_connection(&session_id).await;
assert!(state.is_none());
}
#[tokio::test]
async fn test_max_connections() {
let manager = ConnectionManager::new(Duration::from_secs(60), 2);
let session1 = SessionId::new();
let session2 = SessionId::new();
let session3 = SessionId::new();
assert!(manager.register_connection(session1).await.is_ok());
assert!(manager.register_connection(session2).await.is_ok());
assert!(manager.register_connection(session3).await.is_err());
}
#[tokio::test]
async fn test_timeout_detection() {
let manager = ConnectionManager::new(Duration::from_millis(100), 10);
let session_id = SessionId::new();
assert!(manager.register_connection(session_id).await.is_ok());
tokio::time::sleep(Duration::from_millis(150)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out.len(), 1);
assert_eq!(timed_out[0], session_id);
}
#[tokio::test]
async fn test_set_stream_success() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let stream_id = StreamId::new();
manager.register_connection(session_id).await.unwrap();
let state = manager.get_connection(&session_id).await.unwrap();
assert!(state.stream_id.is_none());
assert!(manager.set_stream(&session_id, stream_id).await.is_ok());
let state = manager.get_connection(&session_id).await.unwrap();
assert_eq!(state.stream_id, Some(stream_id));
}
#[tokio::test]
async fn test_set_stream_connection_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let stream_id = StreamId::new();
let result = manager.set_stream(&session_id, stream_id).await;
assert!(result.is_err());
match result {
Err(DomainError::ValidationError(msg)) => {
assert!(msg.contains("Connection not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[tokio::test]
async fn test_get_active_connections() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session1 = SessionId::new();
let session2 = SessionId::new();
let session3 = SessionId::new();
manager.register_connection(session1).await.unwrap();
manager.register_connection(session2).await.unwrap();
manager.register_connection(session3).await.unwrap();
let active = manager.get_active_connections().await;
assert_eq!(active.len(), 3);
manager.close_connection(&session2).await.unwrap();
let active = manager.get_active_connections().await;
assert_eq!(active.len(), 2);
assert!(active.iter().all(|s| s.session_id != session2));
}
#[tokio::test]
async fn test_get_active_connections_empty() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let active = manager.get_active_connections().await;
assert!(active.is_empty());
}
#[tokio::test]
async fn test_process_timeouts() {
let manager = ConnectionManager::new(Duration::from_millis(50), 10);
let session1 = SessionId::new();
let session2 = SessionId::new();
manager.register_connection(session1).await.unwrap();
manager.register_connection(session2).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
manager.process_timeouts().await;
let state1 = manager.get_connection(&session1).await.unwrap();
let state2 = manager.get_connection(&session2).await.unwrap();
assert!(!state1.is_active);
assert!(!state2.is_active);
}
#[tokio::test]
async fn test_process_timeouts_no_timeouts() {
let manager = ConnectionManager::new(Duration::from_secs(60), 10);
let session_id = SessionId::new();
manager.register_connection(session_id).await.unwrap();
manager.process_timeouts().await;
let state = manager.get_connection(&session_id).await.unwrap();
assert!(state.is_active);
}
#[tokio::test]
async fn test_get_statistics() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session1 = SessionId::new();
let session2 = SessionId::new();
manager.register_connection(session1).await.unwrap();
manager.register_connection(session2).await.unwrap();
manager.update_metrics(&session1, 100, 50).await.unwrap();
manager.update_metrics(&session1, 200, 100).await.unwrap();
manager.update_metrics(&session2, 50, 25).await.unwrap();
manager.close_connection(&session2).await.unwrap();
let stats = manager.get_statistics().await;
assert_eq!(stats.total_connections, 2);
assert_eq!(stats.active_connections, 1);
assert_eq!(stats.inactive_connections, 1);
assert_eq!(stats.total_bytes_sent, 350); assert_eq!(stats.total_bytes_received, 175); }
#[tokio::test]
async fn test_get_statistics_empty() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let stats = manager.get_statistics().await;
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.inactive_connections, 0);
assert_eq!(stats.total_bytes_sent, 0);
assert_eq!(stats.total_bytes_received, 0);
}
#[tokio::test]
async fn test_update_activity_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let result = manager.update_activity(&session_id).await;
assert!(result.is_err());
match result {
Err(DomainError::ValidationError(msg)) => {
assert!(msg.contains("Connection not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[tokio::test]
async fn test_update_metrics_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let result = manager.update_metrics(&session_id, 100, 50).await;
assert!(result.is_err());
match result {
Err(DomainError::ValidationError(msg)) => {
assert!(msg.contains("Connection not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[tokio::test]
async fn test_close_connection_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let result = manager.close_connection(&session_id).await;
assert!(result.is_err());
match result {
Err(DomainError::ValidationError(msg)) => {
assert!(msg.contains("Connection not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[tokio::test]
async fn test_remove_connection_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let result = manager.remove_connection(&session_id).await;
assert!(result.is_err());
match result {
Err(DomainError::ValidationError(msg)) => {
assert!(msg.contains("Connection not found"));
}
_ => panic!("Expected ValidationError"),
}
}
#[tokio::test]
async fn test_connection_state_initial_values() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
manager.register_connection(session_id).await.unwrap();
let state = manager.get_connection(&session_id).await.unwrap();
assert_eq!(state.session_id, session_id);
assert!(state.stream_id.is_none());
assert_eq!(state.bytes_sent, 0);
assert_eq!(state.bytes_received, 0);
assert!(state.is_active);
}
#[tokio::test]
async fn test_connection_event_variants() {
let session_id = SessionId::new();
let connected = ConnectionEvent::Connected(session_id);
let disconnected = ConnectionEvent::Disconnected(session_id);
let timeout = ConnectionEvent::Timeout(session_id);
let error = ConnectionEvent::Error(session_id, "test error".to_string());
assert!(format!("{:?}", connected).contains("Connected"));
assert!(format!("{:?}", disconnected).contains("Disconnected"));
assert!(format!("{:?}", timeout).contains("Timeout"));
assert!(format!("{:?}", error).contains("Error"));
assert!(format!("{:?}", error).contains("test error"));
}
#[tokio::test]
async fn test_connection_event_clone() {
let session_id = SessionId::new();
let event = ConnectionEvent::Error(session_id, "clone test".to_string());
let cloned = event.clone();
match cloned {
ConnectionEvent::Error(id, msg) => {
assert_eq!(id, session_id);
assert_eq!(msg, "clone test");
}
_ => panic!("Expected Error variant"),
}
}
#[tokio::test]
async fn test_update_metrics_cumulative() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
manager.register_connection(session_id).await.unwrap();
manager.update_metrics(&session_id, 100, 50).await.unwrap();
manager.update_metrics(&session_id, 200, 100).await.unwrap();
manager.update_metrics(&session_id, 50, 25).await.unwrap();
let state = manager.get_connection(&session_id).await.unwrap();
assert_eq!(state.bytes_sent, 350);
assert_eq!(state.bytes_received, 175);
}
#[tokio::test]
async fn test_get_connection_not_found() {
let manager = ConnectionManager::new(Duration::from_secs(60), 100);
let session_id = SessionId::new();
let result = manager.get_connection(&session_id).await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_connection_statistics_debug() {
let stats = ConnectionStatistics {
total_connections: 10,
active_connections: 8,
inactive_connections: 2,
total_bytes_sent: 1000,
total_bytes_received: 500,
};
let debug_str = format!("{:?}", stats);
assert!(debug_str.contains("total_connections: 10"));
assert!(debug_str.contains("active_connections: 8"));
}
#[tokio::test]
async fn test_connection_statistics_clone() {
let stats = ConnectionStatistics {
total_connections: 5,
active_connections: 3,
inactive_connections: 2,
total_bytes_sent: 500,
total_bytes_received: 250,
};
let cloned = stats.clone();
assert_eq!(cloned.total_connections, 5);
assert_eq!(cloned.active_connections, 3);
assert_eq!(cloned.inactive_connections, 2);
assert_eq!(cloned.total_bytes_sent, 500);
assert_eq!(cloned.total_bytes_received, 250);
}
#[tokio::test]
async fn test_timeout_check_excludes_inactive() {
let manager = ConnectionManager::new(Duration::from_millis(50), 10);
let session1 = SessionId::new();
let session2 = SessionId::new();
manager.register_connection(session1).await.unwrap();
manager.register_connection(session2).await.unwrap();
manager.close_connection(&session1).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out.len(), 1);
assert_eq!(timed_out[0], session2);
}
#[tokio::test]
async fn test_activity_update_prevents_timeout() {
let manager = ConnectionManager::new(Duration::from_millis(100), 10);
let session_id = SessionId::new();
manager.register_connection(session_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
manager.update_activity(&session_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(60)).await;
let timed_out = manager.check_timeouts().await;
assert!(timed_out.is_empty());
tokio::time::sleep(Duration::from_millis(50)).await;
let timed_out = manager.check_timeouts().await;
assert_eq!(timed_out.len(), 1);
}
}