use crate::common::error::{FlareError, Result};
use crate::common::protocol::{frame_with_system_command, pong, Reliability};
use crate::transport::connection::Connection;
use crate::transport::events::{ArcObserver, ConnectionEvent};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::{SplitSink, SplitStream, StreamExt};
use futures_util::SinkExt;
use prost::Message as ProstMessage;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tracing::debug;
enum WebSocketSink {
Tls(Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>),
Plain(Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>),
}
pub struct WebSocketTransport {
sink: WebSocketSink,
observers: Arc<std::sync::Mutex<Vec<ArcObserver>>>,
last_active: Arc<std::sync::Mutex<std::time::Instant>>,
}
impl WebSocketTransport {
pub fn new(stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
Self::from_stream(stream)
}
pub fn from_tcp_stream(stream: WebSocketStream<TcpStream>) -> Self {
debug!("[DEBUG WebSocketTransport] from_tcp_stream 开始");
let (sink_plain, receiver_plain) = stream.split();
debug!("[DEBUG WebSocketTransport] from_tcp_stream: stream 已 split");
let observers = Arc::new(std::sync::Mutex::new(Vec::new()));
let sink_arc = Arc::new(Mutex::new(sink_plain));
let last_active = Arc::new(std::sync::Mutex::new(std::time::Instant::now()));
debug!("[DEBUG WebSocketTransport] from_tcp_stream: Arc 已创建");
let task_observers = Arc::clone(&observers);
let task_sink = Arc::clone(&sink_arc);
let task_last_active = Arc::clone(&last_active);
debug!("[DEBUG WebSocketTransport] from_tcp_stream: 准备 spawn receiver_task");
tokio::spawn(async move {
debug!("[DEBUG WebSocketTransport] receiver_task 开始 (Plain)");
Self::receiver_task_plain(receiver_plain, task_observers, task_sink, task_last_active).await;
debug!("[DEBUG WebSocketTransport] receiver_task 结束 (Plain)");
});
debug!("[DEBUG WebSocketTransport] from_tcp_stream: receiver_task 已 spawn");
let result = Self {
sink: WebSocketSink::Plain(sink_arc),
observers,
last_active,
};
debug!("[DEBUG WebSocketTransport] from_tcp_stream 完成");
result
}
fn from_stream(stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
let (sink, receiver) = stream.split();
let observers = Arc::new(std::sync::Mutex::new(Vec::new()));
let sink_arc = Arc::new(Mutex::new(sink));
let last_active = Arc::new(std::sync::Mutex::new(std::time::Instant::now()));
let task_observers = Arc::clone(&observers);
let task_sink = Arc::clone(&sink_arc);
let task_last_active = Arc::clone(&last_active);
tokio::spawn(Self::receiver_task(receiver, task_observers, task_sink, task_last_active));
Self {
sink: WebSocketSink::Tls(sink_arc),
observers,
last_active,
}
}
async fn receiver_task(
mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
observers_arc: Arc<std::sync::Mutex<Vec<ArcObserver>>>,
sink_arc: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
last_active: Arc<std::sync::Mutex<std::time::Instant>>,
) {
debug!("[DEBUG WebSocketTransport] receiver_task: 进入循环 (TLS)");
while let Some(message) = receiver.next().await {
debug!("[DEBUG WebSocketTransport] receiver_task: 收到消息 (TLS)");
if let Ok(mut active) = last_active.lock() {
*active = std::time::Instant::now();
}
let event = match message {
Ok(msg) => match msg {
Message::Text(text) => Some(ConnectionEvent::Message(text.as_bytes().to_vec())),
Message::Binary(data) => Some(ConnectionEvent::Message(data.to_vec())),
Message::Close(frame) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed by peer".to_string());
Some(ConnectionEvent::Disconnected(reason))
}
Message::Ping(data) => {
if let Err(e) = Self::send_pong_response_tls(&sink_arc, &data).await {
Some(ConnectionEvent::Error(e))
} else if let Err(e) = Self::send_pong_frame_tls(&sink_arc).await {
Some(ConnectionEvent::Error(e))
} else {
None }
}
Message::Pong(_) => {
match Self::build_pong_frame() {
Ok(pong_data) => Some(ConnectionEvent::Message(pong_data)),
Err(e) => Some(ConnectionEvent::Error(e)),
}
}
_ => None,
},
Err(e) => Some(ConnectionEvent::Error(
FlareError::connection_failed(e.to_string())
)),
};
if let Some(event) = event {
let is_terminal =
matches!(event, ConnectionEvent::Disconnected(_) | ConnectionEvent::Error(_));
Self::_notify_observers(&observers_arc, &event);
if is_terminal {
break;
}
}
}
}
async fn receiver_task_plain(
mut receiver: SplitStream<WebSocketStream<TcpStream>>,
observers_arc: Arc<std::sync::Mutex<Vec<ArcObserver>>>,
sink_arc: Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>,
last_active: Arc<std::sync::Mutex<std::time::Instant>>,
) {
debug!("[DEBUG WebSocketTransport] receiver_task: 进入循环 (Plain)");
while let Some(message) = receiver.next().await {
debug!("[DEBUG WebSocketTransport] receiver_task: 收到消息 (Plain)");
if let Ok(mut active) = last_active.lock() {
*active = std::time::Instant::now();
}
let event = match message {
Ok(msg) => match msg {
Message::Text(text) => Some(ConnectionEvent::Message(text.as_bytes().to_vec())),
Message::Binary(data) => Some(ConnectionEvent::Message(data.to_vec())),
Message::Close(frame) => {
let reason = frame
.map(|f| f.reason.to_string())
.unwrap_or_else(|| "Connection closed by peer".to_string());
Some(ConnectionEvent::Disconnected(reason))
}
Message::Ping(data) => {
if let Err(e) = Self::send_pong_response_plain(&sink_arc, &data).await {
Some(ConnectionEvent::Error(e))
} else if let Err(e) = Self::send_pong_frame_plain(&sink_arc).await {
Some(ConnectionEvent::Error(e))
} else {
None }
}
Message::Pong(_) => {
match Self::build_pong_frame() {
Ok(pong_data) => Some(ConnectionEvent::Message(pong_data)),
Err(e) => Some(ConnectionEvent::Error(e)),
}
}
_ => None,
},
Err(e) => Some(ConnectionEvent::Error(
FlareError::connection_failed(e.to_string())
)),
};
if let Some(event) = event {
let is_terminal =
matches!(event, ConnectionEvent::Disconnected(_) | ConnectionEvent::Error(_));
Self::_notify_observers(&observers_arc, &event);
if is_terminal {
break;
}
}
}
}
fn _notify_observers(observers_arc: &Arc<std::sync::Mutex<Vec<ArcObserver>>>, event: &ConnectionEvent) {
let observers = observers_arc.lock().unwrap();
for observer in observers.iter() {
observer.on_event(event);
}
}
fn notify_observers(&self, event: &ConnectionEvent) {
Self::_notify_observers(&self.observers, event);
}
async fn send_pong_response_tls(
sink: &Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
data: &[u8],
) -> Result<()> {
let mut sink = sink.lock().await;
sink.send(Message::Pong(Bytes::from(data.to_vec())))
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
Ok(())
}
async fn send_pong_response_plain(
sink: &Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>,
data: &[u8],
) -> Result<()> {
let mut sink = sink.lock().await;
sink.send(Message::Pong(Bytes::from(data.to_vec())))
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
Ok(())
}
fn build_pong_frame() -> Result<Vec<u8>> {
let pong_frame = frame_with_system_command(pong(), Reliability::BestEffort);
let mut buf = Vec::new();
pong_frame
.encode(&mut buf)
.map_err(|e| FlareError::encoding_error(e.to_string()))?;
Ok(buf)
}
async fn send_pong_frame_tls(
sink: &Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
) -> Result<()> {
let pong_data = Self::build_pong_frame()?;
let mut sink = sink.lock().await;
sink.send(Message::Binary(Bytes::from(pong_data)))
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
Ok(())
}
async fn send_pong_frame_plain(
sink: &Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>,
) -> Result<()> {
let pong_data = Self::build_pong_frame()?;
let mut sink = sink.lock().await;
sink.send(Message::Binary(Bytes::from(pong_data)))
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
Ok(())
}
}
#[async_trait]
impl Connection for WebSocketTransport {
fn add_observer(&mut self, observer: ArcObserver) {
observer.on_event(&ConnectionEvent::Connected);
self.observers.lock().unwrap().push(observer);
}
fn remove_observer(&mut self, observer: ArcObserver) {
self.observers
.lock()
.unwrap()
.retain(|o| !Arc::ptr_eq(o, &observer));
}
async fn send(&mut self, data: &[u8]) -> Result<()> {
debug!("[DEBUG WebSocketTransport] send: 开始发送数据");
if let Ok(mut active) = self.last_active.lock() {
*active = std::time::Instant::now();
}
let message = Message::Binary(Bytes::from(data.to_vec()));
debug!("[DEBUG WebSocketTransport] send: 消息已创建,准备发送");
match &mut self.sink {
WebSocketSink::Tls(sink) => {
debug!("[DEBUG WebSocketTransport] send: 使用 TLS sink");
let mut s = sink.lock().await;
debug!("[DEBUG WebSocketTransport] send: TLS sink 锁已获取");
s.send(message)
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
debug!("[DEBUG WebSocketTransport] send: TLS 发送成功");
}
WebSocketSink::Plain(sink) => {
debug!("[DEBUG WebSocketTransport] send: 使用 Plain sink");
let mut s = sink.lock().await;
debug!("[DEBUG WebSocketTransport] send: Plain sink 锁已获取");
s.send(message)
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
debug!("[DEBUG WebSocketTransport] send: Plain 发送成功");
}
}
debug!("[DEBUG WebSocketTransport] send: 完成");
Ok(())
}
async fn close(&mut self) -> Result<()> {
match &mut self.sink {
WebSocketSink::Tls(sink) => {
let mut s = sink.lock().await;
s.close()
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
}
WebSocketSink::Plain(sink) => {
let mut s = sink.lock().await;
s.close()
.await
.map_err(|e| FlareError::connection_failed(e.to_string()))?;
}
}
self.notify_observers(&ConnectionEvent::Disconnected("Closed by client".to_string()));
Ok(())
}
fn last_active_time(&self) -> std::time::Instant {
self.last_active
.lock()
.map(|guard| *guard)
.unwrap_or_else(|_| {
std::time::Instant::now() - std::time::Duration::from_secs(3600)
})
}
fn update_active_time(&mut self) {
if let Ok(mut active) = self.last_active.lock() {
*active = std::time::Instant::now();
}
}
}