use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::future::BoxFuture;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::{Mutex, oneshot};
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use super::{HumanLoopKind, HumanLoopProvider, HumanLoopRequest, HumanLoopResponse};
use echo_core::error::{ReactError, Result};
type PendingMap = Arc<Mutex<HashMap<String, oneshot::Sender<ClientResponse>>>>;
type ClientSenders = Arc<Mutex<Vec<tokio::sync::mpsc::UnboundedSender<String>>>>;
pub struct WebSocketHumanLoopProvider {
pending: PendingMap,
clients: ClientSenders,
timeout: Duration,
}
#[derive(Serialize)]
struct ServerMessage<'a> {
kind: &'a str,
request_id: &'a str,
prompt: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
tool_name: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
args: Option<&'a serde_json::Value>,
}
#[derive(Deserialize)]
struct ClientResponse {
request_id: String,
decision: Option<String>,
text: Option<String>,
reason: Option<String>,
}
impl WebSocketHumanLoopProvider {
pub async fn bind(port: u16) -> std::io::Result<Self> {
Self::bind_with_timeout(port, Duration::from_secs(300)).await
}
pub async fn bind_with_timeout(port: u16, timeout: Duration) -> std::io::Result<Self> {
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await?;
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let clients: ClientSenders = Arc::new(Mutex::new(Vec::new()));
let pending_bg = pending.clone();
let clients_bg = clients.clone();
tokio::spawn(async move {
info!("WebSocket 人工介入服务器已启动: ws://127.0.0.1:{port}");
loop {
match listener.accept().await {
Ok((stream, addr)) => {
debug!("新的 WebSocket 客户端连接: {addr}");
let pending = pending_bg.clone();
let clients = clients_bg.clone();
tokio::spawn(handle_connection(stream, addr, pending, clients));
}
Err(e) => {
error!("WebSocket accept 错误: {e}");
}
}
}
});
Ok(Self {
pending,
clients,
timeout,
})
}
async fn broadcast(&self, msg: &str) -> usize {
let mut clients = self.clients.lock().await;
clients.retain(|tx| tx.send(msg.to_string()).is_ok());
clients.len()
}
}
async fn handle_connection(
stream: tokio::net::TcpStream,
addr: SocketAddr,
pending: PendingMap,
clients: ClientSenders,
) {
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
warn!("WebSocket 握手失败 ({addr}): {e}");
return;
}
};
let (mut write, mut read) = ws_stream.split();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
clients.lock().await.push(tx);
let write_task = tokio::spawn(async move {
let mut heartbeat = tokio::time::interval(Duration::from_secs(30));
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(msg) => {
if let Err(e) = write.send(Message::Text(msg)).await {
warn!("WS 消息发送失败: {e}");
break;
}
}
None => break,
}
}
_ = heartbeat.tick() => {
if let Err(e) = write.send(Message::Ping(vec![])).await {
warn!("WS ping 发送失败: {e}");
break;
}
}
}
}
});
const READ_TIMEOUT: Duration = Duration::from_secs(90);
loop {
match tokio::time::timeout(READ_TIMEOUT, read.next()).await {
Ok(Some(Ok(Message::Text(text)))) => {
match serde_json::from_str::<ClientResponse>(&text) {
Ok(response) => {
let mut map = pending.lock().await;
if let Some(sender) = map.remove(&response.request_id) {
let _ = sender.send(response);
} else {
warn!("收到未知 request_id 的 WS 响应: {}", response.request_id);
}
}
Err(e) => {
warn!("WebSocket 消息解析失败: {e},原始内容: {text}");
}
}
}
Ok(Some(Ok(Message::Close(_)))) | Ok(Some(Err(_))) => break,
Ok(Some(Ok(Message::Pong(_)))) => {
debug!("收到 WebSocket pong ({addr})");
}
Ok(Some(Ok(_))) => {} Ok(None) => break,
Err(_) => {
warn!("WebSocket 读取超时 ({addr}),关闭死连接");
break;
}
}
}
write_task.abort();
info!("WebSocket 客户端断开: {addr}");
}
impl HumanLoopProvider for WebSocketHumanLoopProvider {
fn request(&self, req: HumanLoopRequest) -> BoxFuture<'_, Result<HumanLoopResponse>> {
Box::pin(async move {
let request_id = Uuid::new_v4().to_string();
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id.clone(), tx);
let kind_str = match req.kind {
HumanLoopKind::Approval => "approval",
HumanLoopKind::Input => "input",
};
let msg = serde_json::to_string(&ServerMessage {
kind: kind_str,
request_id: &request_id,
prompt: &req.prompt,
tool_name: req.tool_name.as_deref(),
args: req.args.as_ref(),
})
.map_err(|e| ReactError::Other(format!("WS 消息序列化失败: {e}")))?;
let sent = self.broadcast(&msg).await;
if sent == 0 {
self.pending.lock().await.remove(&request_id);
return Err(ReactError::Other(
"没有已连接的 WebSocket 客户端,无法发送人工介入请求".to_string(),
));
}
match tokio::time::timeout(self.timeout, rx).await {
Ok(Ok(response)) => match req.kind {
HumanLoopKind::Approval => match response.decision.as_deref() {
Some("approved") => Ok(HumanLoopResponse::Approved),
_ => Ok(HumanLoopResponse::Rejected {
reason: response.reason,
}),
},
HumanLoopKind::Input => {
Ok(HumanLoopResponse::Text(response.text.unwrap_or_default()))
}
},
Ok(Err(_)) => {
self.pending.lock().await.remove(&request_id);
Err(ReactError::Other("介入 channel 意外关闭".to_string()))
}
Err(_) => {
self.pending.lock().await.remove(&request_id);
Ok(HumanLoopResponse::Timeout)
}
}
})
}
}