use crate::common::error::{FlareError, Result};
use crate::server::connection::r#trait::{ConnectionManagerTrait, ConnectionStats as TraitConnectionStats};
use crate::transport::connection::Connection;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub connection_id: String,
pub user_id: Option<String>,
pub created_at: Instant,
pub last_active: Instant,
pub metadata: HashMap<String, String>,
pub device_info: Option<crate::common::device::DeviceInfo>,
pub serialization_format: crate::common::protocol::SerializationFormat,
pub compression: crate::common::compression::CompressionAlgorithm,
pub authenticated: bool,
pub authenticated_at: Option<u64>,
}
impl ConnectionInfo {
pub fn new(connection_id: String, requires_auth: bool) -> Self {
let now = Instant::now();
let authenticated = !requires_auth; Self {
connection_id,
user_id: None,
created_at: now,
last_active: now,
metadata: HashMap::new(),
device_info: None,
serialization_format: crate::common::protocol::SerializationFormat::Json,
compression: crate::common::compression::CompressionAlgorithm::None,
authenticated,
authenticated_at: if authenticated {
Some(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs())
} else {
None
},
}
}
pub fn set_authenticated(&mut self, user_id: Option<String>) {
self.authenticated = true;
self.authenticated_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
);
if let Some(uid) = user_id {
self.user_id = Some(uid);
}
}
pub fn is_authenticated(&self) -> bool {
self.authenticated
}
pub fn with_device_info(mut self, device_info: crate::common::device::DeviceInfo) -> Self {
self.device_info = Some(device_info);
self
}
pub fn with_serialization_format(
mut self,
format: crate::common::protocol::SerializationFormat,
) -> Self {
self.serialization_format = format;
self
}
pub fn with_compression(
mut self,
compression: crate::common::compression::CompressionAlgorithm,
) -> Self {
self.compression = compression;
self
}
pub fn is_timeout(&self, timeout: Duration) -> bool {
self.last_active.elapsed() > timeout
}
pub fn update_active(&mut self) {
self.last_active = Instant::now();
}
}
pub struct ConnectionManager {
connections: Arc<RwLock<HashMap<String, (Arc<Mutex<Box<dyn Connection>>>, ConnectionInfo)>>>,
user_connections: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl ConnectionManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
user_connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_connection(
&self,
connection_id: String,
connection: Box<dyn Connection>,
user_id: Option<String>,
requires_auth: bool,
) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
if connections.contains_key(&connection_id) {
return Err(FlareError::protocol_error(format!(
"Connection {} already exists",
connection_id
)));
}
let mut info = ConnectionInfo::new(connection_id.clone(), requires_auth);
info.user_id = user_id.clone();
connections.insert(connection_id.clone(), (Arc::new(Mutex::new(connection)), info));
if let Some(user_id) = user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
user_connections
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id);
}
Ok(())
}
pub fn remove_connection(&self, connection_id: &str) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
let (_, info) = connections.remove(connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
if let Some(user_id) = info.user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
if let Some(conn_ids) = user_connections.get_mut(&user_id) {
conn_ids.retain(|id| id != connection_id);
if conn_ids.is_empty() {
user_connections.remove(&user_id);
}
}
}
Ok(())
}
pub fn get_connection(
&self,
connection_id: &str,
) -> Option<(Arc<Mutex<Box<dyn Connection>>>, ConnectionInfo)> {
self.connections.read()
.ok()
.and_then(|connections| {
connections.get(connection_id).map(|(conn, info)| {
(Arc::clone(conn), info.clone())
})
})
}
pub fn get_user_connections(&self, user_id: &str) -> Vec<String> {
self.user_connections.read()
.ok()
.and_then(|user_connections| {
user_connections.get(user_id).cloned()
})
.unwrap_or_default()
}
pub fn bind_user(&self, connection_id: &str, user_id: String) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
let (_, info) = connections.get_mut(connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
if let Some(old_user_id) = &info.user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
if let Some(conn_ids) = user_connections.get_mut(old_user_id) {
conn_ids.retain(|id| id != connection_id);
if conn_ids.is_empty() {
user_connections.remove(old_user_id);
}
}
}
info.user_id = Some(user_id.clone());
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
user_connections
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.to_string());
Ok(())
}
pub fn update_connection_active(&self, connection_id: &str) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
let (_, info) = connections.get_mut(connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
info.update_active();
Ok(())
}
pub fn set_connection_authenticated(&self, connection_id: &str, user_id: Option<String>) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
let (_, info) = connections.get_mut(connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
info.set_authenticated(user_id.clone());
if let Some(user_id) = user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
if let Some(old_user_id) = &info.user_id {
if old_user_id != &user_id {
if let Some(conn_ids) = user_connections.get_mut(old_user_id) {
conn_ids.retain(|id| id != connection_id);
if conn_ids.is_empty() {
user_connections.remove(old_user_id);
}
}
}
}
user_connections
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.to_string());
}
Ok(())
}
pub fn update_connection_negotiation(
&self,
connection_id: &str,
device_info: Option<crate::common::device::DeviceInfo>,
serialization_format: crate::common::protocol::SerializationFormat,
compression: crate::common::compression::CompressionAlgorithm,
user_id: Option<String>,
) -> Result<()> {
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
let (_, info) = connections.get_mut(connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
info.device_info = device_info;
info.serialization_format = serialization_format;
info.compression = compression;
if let Some(user_id) = user_id {
if let Some(old_user_id) = &info.user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
if let Some(conn_ids) = user_connections.get_mut(old_user_id) {
conn_ids.retain(|id| id != connection_id);
if conn_ids.is_empty() {
user_connections.remove(old_user_id);
}
}
}
info.user_id = Some(user_id.clone());
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
user_connections
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.to_string());
}
Ok(())
}
pub fn list_connections(&self) -> Vec<String> {
self.connections.read()
.ok()
.map(|connections| connections.keys().cloned().collect())
.unwrap_or_default()
}
pub fn connection_count(&self) -> usize {
self.connections.read()
.ok()
.map(|connections| connections.len())
.unwrap_or(0)
}
pub fn cleanup_timeout_connections(&self, timeout: Duration) -> Vec<String> {
let timeout_connections: Vec<String> = {
let connections = self.connections.read().ok();
if let Some(connections) = connections {
connections
.iter()
.filter(|(_, (_, info))| info.is_timeout(timeout))
.map(|(id, _)| id.clone())
.collect()
} else {
Vec::new()
}
};
for connection_id in &timeout_connections {
let _ = self.remove_connection(connection_id);
}
timeout_connections
}
pub fn stats(&self) -> TraitConnectionStats {
let connections = self.connections.read().ok();
let user_connections = self.user_connections.read().ok();
let total_connections = connections.as_ref().map(|c| c.len()).unwrap_or(0);
let total_users = user_connections.as_ref().map(|u| u.len()).unwrap_or(0);
TraitConnectionStats {
total_connections,
total_users,
}
}
}
impl Default for ConnectionManager {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ConnectionManagerTrait for ConnectionManager {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn add_connection(
&self,
connection_id: String,
connection: Arc<Mutex<Box<dyn Connection>>>,
user_id: Option<String>,
) -> Result<()> {
let requires_auth = true;
let mut connections = self.connections.write()
.map_err(|_| FlareError::general_error("Failed to lock connections"))?;
if connections.contains_key(&connection_id) {
return Err(FlareError::protocol_error(format!(
"Connection {} already exists",
connection_id
)));
}
let mut info = ConnectionInfo::new(connection_id.clone(), requires_auth);
info.user_id = user_id.clone();
connections.insert(connection_id.clone(), (Arc::clone(&connection), info));
if let Some(user_id) = user_id {
let mut user_connections = self.user_connections.write()
.map_err(|_| FlareError::general_error("Failed to lock user_connections"))?;
user_connections
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id);
}
Ok(())
}
async fn remove_connection(&self, connection_id: &str) -> Result<()> {
ConnectionManager::remove_connection(self, connection_id)
}
async fn get_connection(
&self,
connection_id: &str,
) -> Option<(Arc<Mutex<Box<dyn Connection>>>, crate::server::connection::r#trait::ConnectionInfo)> {
ConnectionManager::get_connection(self, connection_id).map(|(conn, info)| {
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs();
let created_at_secs = now.saturating_sub(info.created_at.elapsed().as_secs());
let last_active_secs = now.saturating_sub(info.last_active.elapsed().as_secs());
let trait_info = crate::server::connection::r#trait::ConnectionInfo {
connection_id: info.connection_id,
user_id: info.user_id,
created_at: created_at_secs,
last_active: last_active_secs,
metadata: info.metadata,
device_info: info.device_info.clone(),
serialization_format: info.serialization_format,
compression: info.compression,
authenticated: info.authenticated,
authenticated_at: info.authenticated_at,
};
(conn, trait_info)
})
}
async fn get_user_connections(&self, user_id: &str) -> Vec<String> {
ConnectionManager::get_user_connections(self, user_id)
}
async fn bind_user(&self, connection_id: &str, user_id: String) -> Result<()> {
ConnectionManager::bind_user(self, connection_id, user_id)
}
async fn update_connection_active(&self, connection_id: &str) -> Result<()> {
ConnectionManager::update_connection_active(self, connection_id)
}
async fn set_connection_authenticated(&self, connection_id: &str, user_id: Option<String>) -> Result<()> {
ConnectionManager::set_connection_authenticated(self, connection_id, user_id)
}
async fn list_connections(&self) -> Vec<String> {
ConnectionManager::list_connections(self)
}
async fn connection_count(&self) -> usize {
ConnectionManager::connection_count(self)
}
async fn cleanup_timeout_connections(&self, timeout: Duration) -> Vec<String> {
ConnectionManager::cleanup_timeout_connections(self, timeout)
}
async fn send_to_connection(&self, connection_id: &str, data: &[u8]) -> Result<()> {
let (connection, _) = ConnectionManager::get_connection(self, connection_id)
.ok_or_else(|| FlareError::protocol_error(format!("Connection {} not found", connection_id)))?;
let mut conn = connection.lock().await;
conn.send(data).await
}
async fn send_to_user(&self, user_id: &str, data: &[u8]) -> Result<()> {
let connection_ids = ConnectionManager::get_user_connections(self, user_id);
for connection_id in connection_ids {
if let Err(e) = self.send_to_connection(&connection_id, data).await {
tracing::warn!("Failed to send to connection {}: {:?}", connection_id, e);
}
}
Ok(())
}
async fn broadcast(&self, data: &[u8]) -> Result<()> {
let connection_ids = ConnectionManager::list_connections(self);
for connection_id in connection_ids {
if let Err(e) = self.send_to_connection(&connection_id, data).await {
tracing::warn!("Failed to broadcast to connection {}: {:?}", connection_id, e);
}
}
Ok(())
}
async fn broadcast_except(&self, data: &[u8], exclude_connection_id: &str) -> Result<()> {
let connection_ids: Vec<String> = ConnectionManager::list_connections(self)
.into_iter()
.filter(|id| id != exclude_connection_id)
.collect();
for connection_id in connection_ids {
if let Err(e) = self.send_to_connection(&connection_id, data).await {
tracing::warn!("Failed to broadcast to connection {}: {:?}", connection_id, e);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::connection::Connection;
use crate::transport::events::ArcObserver;
use async_trait::async_trait;
use std::sync::Mutex;
struct MockConnection {
last_active: Mutex<Instant>,
}
impl MockConnection {
fn new() -> Self {
Self {
last_active: Mutex::new(Instant::now()),
}
}
}
#[async_trait]
impl Connection for MockConnection {
fn add_observer(&mut self, _observer: ArcObserver) {}
fn remove_observer(&mut self, _observer: ArcObserver) {}
async fn send(&mut self, _data: &[u8]) -> Result<()> {
Ok(())
}
async fn close(&mut self) -> Result<()> {
Ok(())
}
fn last_active_time(&self) -> Instant {
*self.last_active.lock().unwrap()
}
fn update_active_time(&mut self) {
*self.last_active.lock().unwrap() = Instant::now();
}
}
#[test]
fn test_add_and_get_connection() {
let manager = ConnectionManager::new();
let connection = Box::new(MockConnection::new());
manager.add_connection("conn1".to_string(), connection, None, false).unwrap();
let (_, info) = manager.get_connection("conn1").unwrap();
assert_eq!(info.connection_id, "conn1");
}
#[test]
fn test_remove_connection() {
let manager = ConnectionManager::new();
let connection = Box::new(MockConnection::new());
manager.add_connection("conn1".to_string(), connection, None, false).unwrap();
assert_eq!(manager.connection_count(), 1);
manager.remove_connection("conn1").unwrap();
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_user_binding() {
let manager = ConnectionManager::new();
let connection = Box::new(MockConnection::new());
manager.add_connection("conn1".to_string(), connection, None, false).unwrap();
manager.bind_user("conn1", "user1".to_string()).unwrap();
let connections = manager.get_user_connections("user1");
assert_eq!(connections, vec!["conn1"]);
}
#[test]
fn test_cleanup_timeout() {
let manager = ConnectionManager::new();
let connection = Box::new(MockConnection::new());
manager.add_connection("conn1".to_string(), connection, None, false).unwrap();
std::thread::sleep(Duration::from_millis(10));
let cleaned = manager.cleanup_timeout_connections(Duration::from_millis(5));
assert!(cleaned.contains(&"conn1".to_string()));
}
}