use futures::{SinkExt, StreamExt};
use rustc_hash::FxHashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, trace, warn};
use crate::backend::json_scan;
use crate::error::{FerriError, Result};
#[derive(Debug, Clone)]
pub(crate) struct BidiError {
pub error: String,
pub message: String,
}
impl std::fmt::Display for BidiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BiDi error '{}': {}", self.error, self.message)
}
}
type BidiResult = std::result::Result<serde_json::Value, BidiError>;
#[derive(Debug, Clone)]
pub(crate) struct BidiEvent {
pub method: String,
pub params: serde_json::Value,
}
type PendingMap = FxHashMap<u64, oneshot::Sender<BidiResult>>;
pub(crate) struct BidiTransport {
next_id: AtomicU64,
pending: Arc<std::sync::Mutex<PendingMap>>,
write_tx: mpsc::Sender<Message>,
event_tx: broadcast::Sender<BidiEvent>,
}
impl BidiTransport {
pub async fn connect(ws_url: &str) -> Result<Self> {
debug!("BiDi connecting to {ws_url}");
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| FerriError::Backend(format!("BiDi WebSocket connect to {ws_url}: {e}")))?;
let (write, read) = ws_stream.split();
let pending: Arc<std::sync::Mutex<PendingMap>> = Arc::new(std::sync::Mutex::new(FxHashMap::default()));
let (write_tx, mut write_rx) = mpsc::channel::<Message>(128);
tokio::spawn(async move {
let mut writer = write;
while let Some(msg) = write_rx.recv().await {
if writer.send(msg).await.is_err() {
break;
}
}
});
let (event_tx, _) = broadcast::channel::<BidiEvent>(4096);
let event_tx2 = event_tx.clone();
let pending2 = pending.clone();
tokio::spawn(async move {
let mut read = read;
while let Some(result) = read.next().await {
let msg = match result {
Ok(m) => m,
Err(e) => {
warn!("BiDi WebSocket error: {e:?}");
break;
},
};
let text = match msg {
Message::Text(t) => t,
Message::Close(frame) => {
debug!("BiDi WebSocket close frame: {frame:?}");
break;
},
_ => continue,
};
let bytes = text.as_bytes();
let type_field = json_scan::json_string(json_scan::json_field(bytes, b"type"));
if type_field == b"success" || type_field == b"error" {
handle_command_response(bytes, type_field, &pending2);
} else if type_field == b"event" {
let method_bytes = json_scan::json_string(json_scan::json_field(bytes, b"method"));
if method_bytes.is_empty() {
continue;
}
let method = String::from_utf8_lossy(method_bytes).to_string();
match serde_json::from_slice::<serde_json::Value>(bytes) {
Ok(parsed) => {
let params = parsed.get("params").cloned().unwrap_or(serde_json::Value::Null);
trace!("BiDi event: {method}");
let _ = event_tx2.send(BidiEvent { method, params });
},
Err(e) => {
warn!("BiDi event parse error: {e}");
},
}
}
}
let mut map = pending2.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
for (_, tx) in map.drain() {
let _ = tx.send(Err(BidiError {
error: "target closed".into(),
message: "BiDi transport closed (browser exited)".into(),
}));
}
debug!("BiDi reader task ended");
});
debug!("BiDi transport connected");
Ok(Self {
next_id: AtomicU64::new(0),
pending,
write_tx,
event_tx,
})
}
pub async fn send_command(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
{
let mut map = self.pending.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
map.insert(id, tx);
}
let params_str = serde_json::to_string(¶ms).unwrap_or_else(|_| "{}".to_string());
let cmd = format!(r#"{{"id":{id},"method":"{method}","params":{params_str}}}"#);
trace!("BiDi send id={id}: {method}");
if self.write_tx.send(Message::Text(cmd.into())).await.is_err() {
let mut map = self.pending.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
map.remove(&id);
return Err(FerriError::backend("BiDi WebSocket connection closed"));
}
match tokio::time::timeout(std::time::Duration::from_secs(60), rx).await {
Ok(Ok(result)) => result.map_err(|e| FerriError::protocol(method, e.to_string())),
Ok(Err(_)) => Err(FerriError::backend("BiDi command response channel dropped")),
Err(_) => {
let mut map = self.pending.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
map.remove(&id);
Err(FerriError::timeout(format!("BiDi command '{method}'"), 60_000))
},
}
}
#[allow(dead_code)]
pub async fn send_batch(&self, commands: &[(&str, serde_json::Value)]) -> Vec<Result<serde_json::Value>> {
let mut receivers = Vec::with_capacity(commands.len());
for (method, params) in commands {
let id = self.next_id.fetch_add(1, Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
{
let mut map = self.pending.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
map.insert(id, tx);
}
let params_str = serde_json::to_string(params).unwrap_or_else(|_| "{}".to_string());
let cmd = format!(r#"{{"id":{id},"method":"{method}","params":{params_str}}}"#);
trace!("BiDi batch send id={id}: {method}");
if self.write_tx.send(Message::Text(cmd.into())).await.is_err() {
receivers.push(Err(FerriError::backend("BiDi WebSocket connection closed")));
continue;
}
receivers.push(Ok((method.to_string(), rx)));
}
let mut results = Vec::with_capacity(receivers.len());
for recv in receivers {
match recv {
Ok((m, rx)) => match tokio::time::timeout(std::time::Duration::from_secs(60), rx).await {
Ok(Ok(result)) => results.push(result.map_err(|e| FerriError::protocol(m.clone(), e.to_string()))),
Ok(Err(_)) => results.push(Err(FerriError::backend("BiDi batch response channel dropped"))),
Err(_) => results.push(Err(FerriError::timeout(format!("BiDi batch command '{m}'"), 60_000))),
},
Err(e) => results.push(Err(e)),
}
}
results
}
pub fn subscribe_events(&self) -> broadcast::Receiver<BidiEvent> {
self.event_tx.subscribe()
}
}
fn handle_command_response(bytes: &[u8], type_field: &[u8], pending: &Arc<std::sync::Mutex<PendingMap>>) {
let id = json_scan::json_id(bytes);
if id == 0 {
warn!("BiDi response missing id");
return;
}
let tx = {
let mut map = pending.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
map.remove(&id)
};
let Some(tx) = tx else {
trace!("BiDi response for unknown id={id}");
return;
};
if type_field == b"error" {
let error_str = json_scan::json_string(json_scan::json_field(bytes, b"error"));
let message_str = json_scan::json_string(json_scan::json_field(bytes, b"message"));
let error = String::from_utf8_lossy(error_str).to_string();
let message = String::from_utf8_lossy(message_str).to_string();
trace!("BiDi error id={id}: {error} - {message}");
let _ = tx.send(Err(BidiError { error, message }));
} else {
match serde_json::from_slice::<serde_json::Value>(bytes) {
Ok(parsed) => {
let result = parsed.get("result").cloned().unwrap_or(serde_json::Value::Null);
trace!("BiDi response id={id}");
let _ = tx.send(Ok(result));
},
Err(e) => {
warn!("BiDi parse error id={id}: {e}");
let _ = tx.send(Err(BidiError {
error: "parse_error".into(),
message: e.to_string(),
}));
},
}
}
}