use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use std::time::Duration;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use crate::codec::{FrameDecoder, encode_frame};
use crate::protocol::message::{IncomingMessage, MessageKind, OutgoingMessage};
use crate::{Error, Result};
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const EVENT_CHANNEL_CAPACITY: usize = 4096;
#[derive(Debug, Clone)]
pub struct Event {
pub method: String,
pub params: Value,
pub session_id: Option<String>,
}
type ReplyResult = std::result::Result<Value, String>;
type PendingMap = Arc<Mutex<HashMap<i64, oneshot::Sender<ReplyResult>>>>;
#[derive(Clone)]
pub struct Connection {
inner: Arc<Inner>,
}
struct Inner {
next_id: AtomicI64,
pending: PendingMap,
cmd_tx: mpsc::UnboundedSender<Vec<u8>>,
event_tx: broadcast::Sender<Event>,
}
impl Connection {
pub fn from_pipe<W, R>(writer: W, reader: R) -> Self
where
W: AsyncWrite + Unpin + Send + 'static,
R: AsyncRead + Unpin + Send + 'static,
{
let (inner, cmd_rx, pending, event_tx) = Self::scaffold();
tokio::spawn(pipe_write_loop(writer, cmd_rx));
tokio::spawn(pipe_read_loop(reader, pending, event_tx));
Self { inner }
}
pub fn from_ws<S>(ws: WebSocketStream<S>) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (inner, cmd_rx, pending, event_tx) = Self::scaffold();
let (sink, stream) = ws.split();
tokio::spawn(ws_write_loop(sink, cmd_rx));
tokio::spawn(ws_read_loop(stream, pending, event_tx));
Self { inner }
}
fn scaffold() -> (
Arc<Inner>,
mpsc::UnboundedReceiver<Vec<u8>>,
PendingMap,
broadcast::Sender<Event>,
) {
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let inner = Arc::new(Inner {
next_id: AtomicI64::new(1),
pending: pending.clone(),
cmd_tx,
event_tx: event_tx.clone(),
});
(inner, cmd_rx, pending, event_tx)
}
pub fn subscribe(&self) -> broadcast::Receiver<Event> {
self.inner.event_tx.subscribe()
}
pub async fn send(
&self,
method: impl Into<String>,
params: Value,
session_id: Option<&str>,
) -> Result<Value> {
self.send_timeout(method, params, session_id, DEFAULT_TIMEOUT)
.await
}
pub async fn send_timeout(
&self,
method: impl Into<String>,
params: Value,
session_id: Option<&str>,
timeout: Duration,
) -> Result<Value> {
let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
self.inner.pending.lock().await.insert(id, tx);
let msg = OutgoingMessage::new(id, method, params, session_id.map(str::to_string));
let json = msg.to_json_bytes()?;
if self.inner.cmd_tx.send(json).is_err() {
self.inner.pending.lock().await.remove(&id);
return Err(Error::Transport(
"命令写通道已关闭(子进程可能已退出)".into(),
));
}
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(result))) => Ok(result),
Ok(Ok(Err(message))) => Err(Error::Protocol(message)),
Ok(Err(_)) => Err(Error::Transport("连接已关闭,响应通道被丢弃".into())),
Err(_) => {
self.inner.pending.lock().await.remove(&id);
Err(Error::Timeout(timeout))
}
}
}
pub fn fire(&self, id: i64, method: impl Into<String>, params: Value) -> Result<()> {
let msg = OutgoingMessage::new(id, method, params, None);
let json = msg.to_json_bytes()?;
self.inner
.cmd_tx
.send(json)
.map_err(|_| Error::Transport("命令写通道已关闭".into()))
}
pub fn fire_session(
&self,
method: impl Into<String>,
params: Value,
session_id: Option<&str>,
) -> Result<()> {
let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
let msg = OutgoingMessage::new(id, method, params, session_id.map(str::to_string));
let json = msg.to_json_bytes()?;
self.inner
.cmd_tx
.send(json)
.map_err(|_| Error::Transport("命令写通道已关闭".into()))
}
}
async fn pipe_write_loop<W>(mut writer: W, mut rx: mpsc::UnboundedReceiver<Vec<u8>>)
where
W: AsyncWrite + Unpin,
{
while let Some(json) = rx.recv().await {
let frame = encode_frame(&json);
if let Err(e) = writer.write_all(&frame).await {
tracing::error!(error = %e, "写入命令管道失败,写任务退出");
break;
}
}
tracing::debug!("命令写任务(pipe)结束");
}
async fn pipe_read_loop<R>(mut reader: R, pending: PendingMap, event_tx: broadcast::Sender<Event>)
where
R: AsyncRead + Unpin,
{
let mut decoder = FrameDecoder::new();
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = match reader.read(&mut buf).await {
Ok(0) => {
tracing::debug!("响应管道 EOF,子进程已退出");
break;
}
Ok(n) => n,
Err(e) => {
tracing::error!(error = %e, "读取响应管道失败,读任务退出");
break;
}
};
decoder.push(&buf[..n]);
while let Some(frame) = decoder.next_frame() {
dispatch_json(&frame, &pending, &event_tx).await;
}
}
fail_all_pending(&pending).await;
}
async fn ws_write_loop<S>(
mut sink: SplitSink<WebSocketStream<S>, Message>,
mut rx: mpsc::UnboundedReceiver<Vec<u8>>,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
while let Some(json) = rx.recv().await {
let text = match String::from_utf8(json) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "命令 JSON 非 UTF-8,已跳过");
continue;
}
};
if let Err(e) = sink.send(Message::text(text)).await {
tracing::error!(error = %e, "写入 ws 失败,写任务退出");
break;
}
}
let _ = sink.close().await;
tracing::debug!("命令写任务(ws)结束");
}
async fn ws_read_loop<S>(
mut stream: SplitStream<WebSocketStream<S>>,
pending: PendingMap,
event_tx: broadcast::Sender<Event>,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
while let Some(item) = stream.next().await {
match item {
Ok(Message::Text(t)) => dispatch_json(t.as_bytes(), &pending, &event_tx).await,
Ok(Message::Binary(b)) => dispatch_json(&b, &pending, &event_tx).await,
Ok(Message::Close(_)) => {
tracing::debug!("ws 收到 Close,读任务退出");
break;
}
Ok(_) => {}
Err(e) => {
tracing::error!(error = %e, "读取 ws 失败,读任务退出");
break;
}
}
}
fail_all_pending(&pending).await;
}
async fn dispatch_json(bytes: &[u8], pending: &PendingMap, event_tx: &broadcast::Sender<Event>) {
match IncomingMessage::from_json_bytes(bytes) {
Ok(msg) => dispatch(msg, pending, event_tx).await,
Err(e) => tracing::warn!(error = %e, "解析入站帧失败,已跳过"),
}
}
async fn fail_all_pending(pending: &PendingMap) {
let mut map = pending.lock().await;
for (_, tx) in map.drain() {
let _ = tx.send(Err("连接已关闭".to_string()));
}
}
async fn dispatch(msg: IncomingMessage, pending: &PendingMap, event_tx: &broadcast::Sender<Event>) {
match msg.kind() {
MessageKind::Response { id, result } => {
if let Some(tx) = pending.lock().await.remove(&id) {
let _ = tx.send(Ok(result));
}
}
MessageKind::Error { id, message } => {
if let Some(tx) = pending.lock().await.remove(&id) {
let _ = tx.send(Err(message));
}
}
MessageKind::Event {
method,
params,
session_id,
} => {
let _ = event_tx.send(Event {
method,
params,
session_id,
});
}
MessageKind::Unknown => {}
}
}