use crate::client::connection::ConnectionStateManager;
use crate::client::heartbeat::HeartbeatManager;
use crate::client::router::MessageRouter;
use crate::common::MessageParser;
use crate::client::config::ClientConfig;
use crate::common::protocol::{Frame, connect, frame_with_system_command, Reliability};
use crate::common::protocol::flare::core::commands::command::Type;
use crate::common::error::Result;
use crate::transport::events::{ArcObserver, ConnectionEvent};
use crate::transport::connection::Connection;
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::Mutex;
use std::collections::HashMap;
pub struct ClientCore {
pub state_manager: Arc<ConnectionStateManager>,
pub parser: Arc<tokio::sync::Mutex<MessageParser>>,
heartbeat_manager: Option<Arc<tokio::sync::Mutex<HeartbeatManager>>>,
message_router: Option<MessageRouter>,
pub observers: Arc<StdMutex<Vec<ArcObserver>>>,
pub config: ClientConfig,
event_handler: Option<Arc<dyn crate::client::events::handler::ClientEventHandler>>,
client_connection: Option<Arc<Mutex<Box<dyn Connection>>>>,
}
impl ClientCore {
pub fn new(config: &ClientConfig) -> Self {
let format = if config.is_force_format() {
config.get_serialization_format()
} else {
crate::common::protocol::SerializationFormat::Json
};
let compression = if config.is_force_format() {
config.get_compression()
} else {
crate::common::compression::CompressionAlgorithm::None
};
let parser = MessageParser::new(format, compression);
let message_router = if config.enable_router {
Some(MessageRouter::new())
} else {
None
};
Self {
state_manager: Arc::new(ConnectionStateManager::new()),
parser: Arc::new(tokio::sync::Mutex::new(parser)),
heartbeat_manager: None,
message_router,
observers: Arc::new(StdMutex::new(Vec::new())),
config: config.clone(),
event_handler: None,
client_connection: None,
}
}
pub async fn update_parser(&self, format: crate::common::protocol::SerializationFormat, compression: crate::common::compression::CompressionAlgorithm) {
let mut parser = self.parser.lock().await;
*parser = MessageParser::new(format, compression);
tracing::debug!(
"[ClientCore] 协商完成,解析器已更新: 最终序列化方式={:?}, 最终压缩方式={:?}",
format,
compression
);
}
pub fn set_client_connection(&mut self, connection: Arc<Mutex<Box<dyn Connection>>>) {
self.client_connection = Some(connection);
}
pub fn set_event_handler(&mut self, handler: Option<Arc<dyn crate::client::events::handler::ClientEventHandler>>) {
self.event_handler = handler;
}
pub async fn start_heartbeat(
&mut self,
connection: Arc<Mutex<Box<dyn Connection>>>,
) {
self.client_connection = Some(Arc::clone(&connection));
if !self.config.heartbeat.enabled {
return;
}
let mut heartbeat = HeartbeatManager::new(
self.config.heartbeat.interval,
self.config.heartbeat.timeout,
);
let parser = self.parser.lock().await.clone();
heartbeat.start(connection, parser);
self.heartbeat_manager = Some(Arc::new(tokio::sync::Mutex::new(heartbeat)));
}
pub fn stop_heartbeat(&mut self) {
if let Some(ref heartbeat) = self.heartbeat_manager {
if let Some(mut hb) = self.heartbeat_manager.take() {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::try_current()
.map(|handle| handle.block_on(async {
let mut hb_guard = hb.lock().await;
hb_guard.stop();
}))
.unwrap_or_else(|_| {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut hb_guard = hb.lock().await;
hb_guard.stop();
})
})
});
}
}
}
pub async fn handle_message(&self, data: Vec<u8>) {
let parser = self.parser.lock().await;
let frame = match parser.parse(&data) {
Ok(frame) => frame,
Err(e) => {
tracing::warn!("Failed to parse message: {}", e);
return;
}
};
drop(parser);
if let Some(cmd) = &frame.command {
if let Some(crate::common::protocol::flare::core::commands::command::Type::System(sys_cmd)) = &cmd.r#type {
if sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::ConnectAck as i32 {
if let Some(ref handler) = self.event_handler {
let _ = handler.handle_system_command(
crate::common::protocol::flare::core::commands::system_command::Type::ConnectAck,
&frame
).await;
}
match self.handle_connect_ack(&frame) {
Ok((format, compression)) => {
tracing::info!(
"[ClientCore] ✅ 收到 CONNECT_ACK: 服务端确定的序列化方式={:?}, 压缩方式={:?}",
format,
compression
);
if !self.config.is_force_format() {
self.update_parser(format, compression).await;
tracing::info!(
"[ClientCore] ✅ 解析器已更新为协商后的格式: {:?}, 压缩: {:?}",
format,
compression
);
} else {
tracing::info!(
"[ClientCore] ℹ️ 强制模式:继续使用客户端强制指定的格式: {:?}, 压缩: {:?}",
self.config.get_serialization_format(),
self.config.get_compression()
);
}
}
Err(e) => {
tracing::warn!("Failed to handle CONNECT_ACK: {}", e);
}
}
return; }
if sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Kicked as i32 {
let reason = sys_cmd.message.clone();
tracing::warn!(
"[ClientCore] ⚠️ 收到被踢消息: {}",
reason
);
let mut kick_reason = reason.clone();
if let Some(reason_bytes) = sys_cmd.metadata.get("reason") {
if let Ok(reason_str) = String::from_utf8(reason_bytes.clone()) {
if reason_str == "device_conflict" {
kick_reason = format!("设备冲突:{}", reason);
}
}
}
if let Some(ref handler) = self.event_handler {
if let Err(e) = handler.handle_system_command(
crate::common::protocol::flare::core::commands::system_command::Type::Kicked,
&frame
).await {
tracing::warn!("[ClientCore] 事件处理器处理 KICKED 失败: {}", e);
}
}
self.state_manager.set_disconnected();
if let Some(ref client_conn) = self.client_connection {
let mut conn = client_conn.lock().await;
if let Err(e) = conn.close().await {
tracing::error!("[ClientCore] 断开连接失败: {}", e);
} else {
tracing::info!("[ClientCore] ✅ 已主动断开连接(被踢)");
}
} else {
tracing::warn!("[ClientCore] ⚠️ 客户端连接未设置,无法主动断开");
}
if let Ok(observers) = self.observers.lock() {
for observer in observers.iter() {
observer.on_event(&crate::transport::events::ConnectionEvent::Disconnected(
kick_reason.clone()
));
}
}
tracing::info!(
"[ClientCore] 连接已断开(被踢): {}",
kick_reason
);
return; }
if sys_cmd.r#type == crate::common::protocol::flare::core::commands::system_command::Type::Pong as i32 {
if let Some(ref handler) = self.event_handler {
let _ = handler.handle_system_command(
crate::common::protocol::flare::core::commands::system_command::Type::Pong,
&frame
).await;
}
self.record_pong();
return; }
}
}
if let Some(ref handler) = self.event_handler {
if let Some(cmd) = &frame.command {
match &cmd.r#type {
Some(Type::Message(msg_cmd)) => {
if let Ok(cmd_type) = crate::common::protocol::flare::core::commands::message_command::Type::try_from(msg_cmd.r#type) {
let _ = handler.handle_message_command(cmd_type, &frame).await;
}
}
Some(Type::Notification(notif_cmd)) => {
if let Ok(cmd_type) = crate::common::protocol::flare::core::commands::notification_command::Type::try_from(notif_cmd.r#type) {
let _ = handler.handle_notification_command(cmd_type, &frame).await;
}
}
_ => {}
}
}
}
if let Some(ref router) = self.message_router {
match router.route(&frame).await {
Ok(replies) => {
tracing::debug!("Router generated {} replies", replies.len());
}
Err(e) => {
tracing::warn!("Router error: {}", e);
}
}
}
self.notify_observers(&ConnectionEvent::Message(data));
}
pub fn handle_connection_event(&self, event: &ConnectionEvent) {
if let Some(ref handler) = self.event_handler {
let handler_clone = Arc::clone(handler);
let event_clone = event.clone();
tokio::spawn(async move {
let _ = handler_clone.handle_connection_event(&event_clone).await;
});
}
match event {
ConnectionEvent::Connected => {
self.state_manager.set_connected();
}
ConnectionEvent::Disconnected(_) => {
self.state_manager.set_disconnected();
}
ConnectionEvent::Error(_) => {
self.state_manager.set_failed();
}
ConnectionEvent::Message(_) => {
}
}
self.notify_observers(event);
}
pub fn add_observer(&self, observer: ArcObserver) {
if let Ok(mut observers) = self.observers.lock() {
observers.push(observer);
}
}
pub fn remove_observer(&self, observer: ArcObserver) {
if let Ok(mut observers) = self.observers.lock() {
observers.retain(|o| !Arc::ptr_eq(o, &observer));
}
}
fn notify_observers(&self, event: &ConnectionEvent) {
if let Ok(observers) = self.observers.lock() {
for observer in observers.iter() {
observer.on_event(event);
}
}
}
pub fn router_mut(&mut self) -> Option<&mut MessageRouter> {
self.message_router.as_mut()
}
pub fn router(&self) -> Option<&MessageRouter> {
self.message_router.as_ref()
}
pub fn state(&self) -> crate::client::connection::ConnectionState {
self.state_manager.get_state()
}
pub fn can_send(&self) -> bool {
self.state_manager.get_state().can_send()
}
pub fn can_connect(&self) -> bool {
self.state_manager.get_state().can_connect()
}
pub fn record_pong(&self) {
if let Some(ref heartbeat) = self.heartbeat_manager {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::try_current()
.map(|handle| {
handle.block_on(async {
let hb_guard = heartbeat.lock().await;
hb_guard.record_pong();
})
})
.unwrap_or_else(|_| {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let hb_guard = heartbeat.lock().await;
hb_guard.record_pong();
})
})
});
}
}
pub async fn send_connect_message(&self, connection: Arc<Mutex<Box<dyn Connection>>>) -> Result<()> {
let mut metadata = HashMap::new();
let (format, compression) = if self.config.is_force_format() {
(self.config.get_serialization_format(), self.config.get_compression())
} else {
(self.config.serialization_format, self.config.compression)
};
tracing::debug!(
"[ClientCore] 发送 CONNECT 消息: 请求序列化方式={:?}, 请求压缩方式={:?}, 强制模式={}",
format,
compression,
self.config.is_force_format()
);
let format_str = match format {
crate::common::protocol::SerializationFormat::Protobuf => "protobuf",
crate::common::protocol::SerializationFormat::Json => "json",
};
metadata.insert("format".to_string(), format_str.as_bytes().to_vec());
metadata.insert("compression".to_string(), compression.as_str().as_bytes().to_vec());
if self.config.is_force_format() {
metadata.insert("force_format".to_string(), "true".as_bytes().to_vec());
}
if let Some(ref device_info) = self.config.device_info {
metadata.insert("device_id".to_string(), device_info.device_id.as_bytes().to_vec());
metadata.insert("platform".to_string(), device_info.platform.as_str().as_bytes().to_vec());
if let Some(ref model) = device_info.model {
metadata.insert("model".to_string(), model.as_bytes().to_vec());
}
if let Some(ref app_version) = device_info.app_version {
metadata.insert("app_version".to_string(), app_version.as_bytes().to_vec());
}
if let Some(ref system_version) = device_info.system_version {
metadata.insert("system_version".to_string(), system_version.as_bytes().to_vec());
}
for (key, value) in &device_info.metadata {
metadata.insert(key.clone(), value.as_bytes().to_vec());
}
}
if let Some(ref user_id) = self.config.user_id {
metadata.insert("user_id".to_string(), user_id.as_bytes().to_vec());
}
if let Some(ref token) = self.config.token {
metadata.insert("token".to_string(), token.as_bytes().to_vec());
tracing::debug!("[ClientCore] 已添加 token 到 CONNECT 消息元数据");
}
for (key, value) in &self.config.metadata {
metadata.insert(key.clone(), value.as_bytes().to_vec());
}
let connect_cmd = connect(format, metadata);
let connect_frame = frame_with_system_command(connect_cmd, Reliability::AtLeastOnce);
let data = if self.config.is_force_format() {
let parser = self.parser.lock().await;
parser.serialize(&connect_frame)?
} else {
MessageParser::json().serialize(&connect_frame)?
};
let mut conn = connection.lock().await;
conn.send(&data).await?;
if self.config.is_force_format() {
tracing::debug!("[ClientCore] CONNECT 消息已发送(强制模式: format={:?}, compression={:?})", format, compression);
} else {
tracing::debug!("[ClientCore] CONNECT 消息已发送(协商模式: 首选 format={:?}, compression={:?})", format, compression);
}
Ok(())
}
pub fn handle_connect_ack(&self, frame: &Frame) -> Result<(crate::common::protocol::SerializationFormat, crate::common::compression::CompressionAlgorithm)> {
if let Some(cmd) = &frame.command {
if let Some(crate::common::protocol::flare::core::commands::command::Type::System(sys_cmd)) = &cmd.r#type {
use prost::Enumeration;
let cmd_type = crate::common::protocol::flare::core::commands::system_command::Type::from_i32(sys_cmd.r#type)
.ok_or_else(|| crate::common::error::FlareError::protocol_error("Invalid system command type".to_string()))?;
if cmd_type == crate::common::protocol::flare::core::commands::system_command::Type::ConnectAck {
use prost::Enumeration;
let format = crate::common::protocol::SerializationFormat::from_i32(sys_cmd.format)
.unwrap_or(crate::common::protocol::SerializationFormat::Json);
let compression = if let Some(compression_bytes) = sys_cmd.metadata.get("compression") {
if let Ok(compression_str) = String::from_utf8(compression_bytes.clone()) {
crate::common::compression::CompressionAlgorithm::from_str(&compression_str)
.unwrap_or(crate::common::compression::CompressionAlgorithm::None)
} else {
crate::common::compression::CompressionAlgorithm::None
}
} else {
crate::common::compression::CompressionAlgorithm::None
};
tracing::debug!("[ClientCore] 收到 CONNECT_ACK,协商结果: format={:?}, compression={:?}", format, compression);
if let Some(conflicts_bytes) = sys_cmd.metadata.get("conflict_connections") {
if let Ok(conflicts_json) = String::from_utf8(conflicts_bytes.clone()) {
if let Ok(conflict_connections) = serde_json::from_str::<Vec<String>>(&conflicts_json) {
if !conflict_connections.is_empty() {
tracing::warn!("[ClientCore] 检测到设备冲突,以下连接被踢掉: {:?}", conflict_connections);
}
}
}
}
return Ok((format, compression));
}
}
}
Err(crate::common::error::FlareError::protocol_error(
"Not a CONNECT_ACK message".to_string()
))
}
}
impl Clone for ClientCore {
fn clone(&self) -> Self {
Self {
state_manager: Arc::clone(&self.state_manager),
parser: Arc::clone(&self.parser), heartbeat_manager: None, message_router: self.message_router.as_ref().map(|_| MessageRouter::new()), observers: Arc::clone(&self.observers),
config: self.config.clone(),
event_handler: self.event_handler.clone(), client_connection: None, }
}
}