use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{RwLock, mpsc};
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::protocol::Message;
use tracing::{debug, error, info, warn};
use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::Request;
type Connections = Arc<RwLock<Vec<mpsc::Sender<Message>>>>;
pub struct WebSocketTransport {
host: String,
port: u16,
request_handler: Option<RequestHandler>,
connections: Connections,
}
impl fmt::Debug for WebSocketTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketTransport")
.field("host", &self.host)
.field("port", &self.port)
.field(
"request_handler",
&format!("<handler: {}>", self.request_handler.is_some()),
)
.field("connections", &"<connections>")
.finish()
}
}
impl WebSocketTransport {
pub fn new(host: String, port: u16) -> Self {
Self {
host,
port,
request_handler: None,
connections: Arc::new(RwLock::new(Vec::new())),
}
}
async fn handle_connection(
websocket: tokio_tungstenite::WebSocketStream<TcpStream>, addr: SocketAddr,
request_handler: RequestHandler, connections: Connections,
) {
info!("新的WebSocket连接: {}", addr);
let (mut ws_sender, mut ws_receiver) = websocket.split();
let (msg_tx, mut msg_rx) = mpsc::channel::<Message>(32);
{
let mut connections = connections.write().await;
connections.push(msg_tx.clone());
}
let forward_task = tokio::spawn(async move {
while let Some(msg) = msg_rx.recv().await {
if let Err(e) = ws_sender.send(msg).await {
error!("发送WebSocket消息失败: {}", e);
break;
}
}
});
while let Some(result) = ws_receiver.next().await {
match result {
Ok(msg) => {
if msg.is_text() || msg.is_binary() {
let text = msg.to_text().unwrap_or_default();
match serde_json::from_str::<Request>(text) {
Ok(request) => {
debug!("收到WebSocket请求: {}", request.tool);
let (resp_tx, mut resp_rx) = mpsc::channel(10);
let tx = (request_handler)(request.clone());
let msg_tx_clone = msg_tx.clone();
tokio::spawn(async move {
if let Err(e) = tx.send((request, resp_tx)).await {
error!("发送请求到处理器失败: {}", e);
return;
}
while let Some(message) = resp_rx.recv().await {
match message {
TransportMessage::Response(response) => {
if let Ok(json) = serde_json::to_string(&response) {
if let Err(e) =
msg_tx_clone.send(Message::Text(json)).await
{
error!("发送响应失败: {}", e);
}
}
}
TransportMessage::Error(error) => {
if let Ok(json) = serde_json::to_string(&error) {
if let Err(e) =
msg_tx_clone.send(Message::Text(json)).await
{
error!("发送错误响应失败: {}", e);
}
}
}
TransportMessage::Notification(notification) => {
if let Ok(json) =
serde_json::to_string(¬ification)
{
if let Err(e) =
msg_tx_clone.send(Message::Text(json)).await
{
error!("发送通知失败: {}", e);
}
}
}
}
}
});
}
Err(e) => {
warn!("无法解析请求: {}", e);
let error_msg = json!({
"error": "无效的请求格式",
"details": e.to_string()
});
if let Ok(json) = serde_json::to_string(&error_msg) {
if let Err(e) = msg_tx.send(Message::Text(json)).await {
error!("发送错误消息失败: {}", e);
}
}
}
}
} else if msg.is_close() {
break;
}
}
Err(e) => {
error!("WebSocket错误: {}", e);
break;
}
}
}
forward_task.abort();
info!("WebSocket连接关闭: {}", addr);
{
let mut connections = connections.write().await;
connections.retain(|tx| !tx.is_closed());
}
}
}
#[cfg(feature = "websocket")]
#[async_trait]
impl Transport for WebSocketTransport {
async fn start(&self) -> Result<()> {
info!("启动WebSocket传输,地址 {}:{}", self.host, self.port);
let request_handler = self
.request_handler
.clone()
.ok_or_else(|| Error::Transport("请求处理器未设置".to_string()))?;
let addr = format!("{}:{}", self.host, self.port);
let listener = TcpListener::bind(&addr)
.await
.map_err(|e| Error::Transport(format!("绑定WebSocket地址失败: {e}")))?;
info!("WebSocket服务器监听于 {}", addr);
while let Ok((stream, addr)) = listener.accept().await {
let request_handler = request_handler.clone();
let connections = self.connections.clone();
tokio::spawn(async move {
match accept_async(stream).await {
Ok(ws_stream) => {
Self::handle_connection(ws_stream, addr, request_handler, connections)
.await;
}
Err(e) => {
error!("WebSocket握手失败: {}", e);
}
}
});
}
Ok(())
}
fn set_request_handler(&mut self, handler: RequestHandler) {
self.request_handler = Some(handler);
}
}