use crate::common::error::Result;
use crate::common::protocol::Frame;
use crate::common::MessageParser;
use crate::transport::connection::Connection;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub connection_id: String,
pub user_id: Option<String>,
pub created_at: u64,
pub last_active: u64,
pub metadata: std::collections::HashMap<String, String>,
pub device_info: Option<crate::common::device::DeviceInfo>,
pub serialization_format: crate::common::protocol::SerializationFormat,
pub compression: crate::common::compression::CompressionAlgorithm,
pub authenticated: bool,
pub authenticated_at: Option<u64>,
}
#[async_trait]
pub trait ConnectionManagerTrait: Send + Sync + std::any::Any {
fn as_any(&self) -> &dyn std::any::Any;
async fn add_connection(
&self,
connection_id: String,
connection: Arc<Mutex<Box<dyn Connection>>>,
user_id: Option<String>,
) -> Result<()>;
async fn remove_connection(&self, connection_id: &str) -> Result<()>;
async fn get_connection(
&self,
connection_id: &str,
) -> Option<(Arc<Mutex<Box<dyn Connection>>>, ConnectionInfo)>;
async fn get_user_connections(&self, user_id: &str) -> Vec<String>;
async fn bind_user(&self, connection_id: &str, user_id: String) -> Result<()>;
async fn update_connection_active(&self, connection_id: &str) -> Result<()>;
async fn set_connection_authenticated(&self, connection_id: &str, user_id: Option<String>) -> Result<()>;
async fn list_connections(&self) -> Vec<String>;
async fn connection_count(&self) -> usize;
async fn cleanup_timeout_connections(&self, timeout: Duration) -> Vec<String>;
async fn send_to_connection(&self, connection_id: &str, data: &[u8]) -> Result<()>;
async fn send_to_user(&self, user_id: &str, data: &[u8]) -> Result<()>;
async fn broadcast(&self, data: &[u8]) -> Result<()>;
async fn broadcast_except(&self, data: &[u8], exclude_connection_id: &str) -> Result<()>;
async fn send_frame_to(
&self,
connection_id: &str,
frame: &Frame,
parser: Option<&MessageParser>,
) -> Result<()> {
if let Some((_, info)) = self.get_connection(connection_id).await {
if !info.authenticated {
let is_system_command = frame.command.as_ref().and_then(|cmd| {
if let Some(crate::common::protocol::flare::core::commands::command::Type::System(sys_cmd)) = &cmd.r#type {
Some(sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::ConnectAck as i32
|| sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Ping as i32
|| sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Pong as i32
|| sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Error as i32
|| sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Close as i32)
} else {
None
}
}).unwrap_or(false);
if !is_system_command {
return Err(crate::common::error::FlareError::authentication_failed(
format!("连接 {} 未验证,无法发送消息", connection_id)
));
}
}
} else {
return Err(crate::common::error::FlareError::connection_failed(
format!("连接 {} 不存在", connection_id)
));
}
let data = if let Some(p) = parser {
p.serialize(frame)?
} else {
if let Some((_, info)) = self.get_connection(connection_id).await {
let connection_parser = MessageParser::new(
info.serialization_format,
info.compression,
);
connection_parser.serialize(frame)?
} else {
MessageParser::json().serialize(frame)?
}
};
self.send_to_connection(connection_id, &data).await?;
self.update_connection_active(connection_id).await?;
Ok(())
}
async fn send_frame_to_user(
&self,
user_id: &str,
frame: &Frame,
parser: Option<&MessageParser>,
) -> Result<()> {
let connection_ids = self.get_user_connections(user_id).await;
for conn_id in connection_ids {
let _ = self.send_frame_to(&conn_id, frame, parser).await;
}
Ok(())
}
async fn broadcast_frame(
&self,
frame: &Frame,
parser: Option<&MessageParser>,
) -> Result<()> {
let connection_ids = self.list_connections().await;
for conn_id in connection_ids {
let _ = self.send_frame_to(&conn_id, frame, parser).await;
}
Ok(())
}
async fn broadcast_frame_except(
&self,
frame: &Frame,
exclude_connection_id: &str,
parser: Option<&MessageParser>,
) -> Result<()> {
let connection_ids = self.list_connections().await;
for conn_id in connection_ids {
if conn_id != exclude_connection_id {
let _ = self.send_frame_to(&conn_id, frame, parser).await;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ConnectionStats {
pub total_connections: usize,
pub total_users: usize,
}