use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use serde::Serialize;
use serde_json::Value;
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio_tungstenite::tungstenite::Message;
use crate::bidi::command::RawEvent;
use crate::bidi::error::{BidiError, BidiErrorEnvelope};
use crate::error::{WebDriverError, WebDriverResult};
const EVENT_BUFFER: usize = 1024;
#[derive(Debug, Clone)]
pub(crate) struct BidiTransport {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
out: mpsc::UnboundedSender<String>,
pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, BidiError>>>>,
events: broadcast::Sender<RawEvent>,
next_id: AtomicU64,
subscriptions: Mutex<HashMap<&'static str, usize>>,
}
impl BidiTransport {
pub(crate) async fn connect(url: &str) -> WebDriverResult<Self> {
let (ws, _resp) = tokio_tungstenite::connect_async(url)
.await
.map_err(|e| WebDriverError::HttpError(format!("BiDi WebSocket connect: {e}")))?;
let (mut sink, mut stream) = ws.split();
let (out_tx, mut out_rx) = mpsc::unbounded_channel::<String>();
let (events_tx, _events_rx0) = broadcast::channel::<RawEvent>(EVENT_BUFFER);
let inner = Arc::new(Inner {
out: out_tx,
pending: Mutex::new(HashMap::new()),
events: events_tx,
next_id: AtomicU64::new(1),
subscriptions: Mutex::new(HashMap::new()),
});
tokio::spawn(async move {
while let Some(text) = out_rx.recv().await {
if sink.send(Message::Text(text.into())).await.is_err() {
break;
}
}
let _ = sink.close().await;
});
let inner_for_reader = Arc::clone(&inner);
tokio::spawn(async move {
while let Some(message) = stream.next().await {
let frame = match message {
Ok(Message::Text(t)) => t.as_str().to_owned(),
Ok(Message::Binary(b)) => match std::str::from_utf8(&b) {
Ok(s) => s.to_owned(),
Err(_) => continue,
},
Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_)) => continue,
Ok(Message::Close(_)) | Err(_) => break,
};
inner_for_reader.dispatch(&frame).await;
}
inner_for_reader.fail_all_pending("BiDi WebSocket closed").await;
});
Ok(Self {
inner,
})
}
pub(crate) async fn send_raw(&self, method: &str, params: Value) -> Result<Value, BidiError> {
let params = if params.is_null() {
serde_json::json!({})
} else {
params
};
let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
let (resp_tx, resp_rx) = oneshot::channel();
self.inner.pending.lock().await.insert(id, resp_tx);
let frame = OutgoingCommand {
id,
method,
params: ¶ms,
};
let serialized = serde_json::to_string(&frame).map_err(|e| BidiError {
command: method.to_string(),
error: "unknown error".to_string(),
message: format!("serialise command: {e}"),
stacktrace: None,
})?;
if self.inner.out.send(serialized).is_err() {
return Err(BidiError {
command: method.to_string(),
error: "unknown error".to_string(),
message: "BiDi WebSocket connection closed".to_string(),
stacktrace: None,
});
}
match resp_rx.await {
Ok(result) => result,
Err(_) => Err(BidiError {
command: method.to_string(),
error: "unknown error".to_string(),
message: "BiDi response channel dropped".to_string(),
stacktrace: None,
}),
}
}
pub(crate) fn subscribe_events(&self) -> broadcast::Receiver<RawEvent> {
self.inner.events.subscribe()
}
pub(crate) async fn ensure_subscribed(&self, method: &'static str) -> Result<(), BidiError> {
let mut subs = self.inner.subscriptions.lock().await;
let current = *subs.get(method).unwrap_or(&0);
if current == 0 {
self.send_raw("session.subscribe", serde_json::json!({ "events": [method] })).await?;
}
subs.insert(method, current + 1);
Ok(())
}
pub(crate) async fn release_subscription(&self, method: &'static str) {
let mut subs = self.inner.subscriptions.lock().await;
let Some(entry) = subs.get_mut(method) else {
return;
};
if *entry == 0 {
return;
}
*entry -= 1;
if *entry == 0 {
subs.remove(method);
let _ = self
.send_raw("session.unsubscribe", serde_json::json!({ "events": [method] }))
.await;
}
}
}
impl Inner {
async fn dispatch(self: &Arc<Self>, frame: &str) {
let value: Value = match serde_json::from_str(frame) {
Ok(v) => v,
Err(_) => return,
};
match value.get("type").and_then(Value::as_str) {
Some("success") => {
let Some(id) = value.get("id").and_then(Value::as_u64) else {
return;
};
if let Some(tx) = self.pending.lock().await.remove(&id) {
let result = value.get("result").cloned().unwrap_or(Value::Null);
let _ = tx.send(Ok(result));
}
}
Some("error") => {
let env: BidiErrorEnvelope = match serde_json::from_value(value.clone()) {
Ok(e) => e,
Err(_) => BidiErrorEnvelope {
error: "unknown error".to_string(),
message: "malformed BiDi error envelope".to_string(),
stacktrace: None,
},
};
if let Some(id) = value.get("id").and_then(Value::as_u64)
&& let Some(tx) = self.pending.lock().await.remove(&id)
{
let _ = tx.send(Err(env.into_error("<response>")));
}
}
Some("event") => {
if let Some(method) = value.get("method").and_then(Value::as_str) {
let params = value.get("params").cloned().unwrap_or(Value::Null);
let _ = self.events.send(RawEvent {
method: method.to_string(),
params,
});
}
}
_ => {}
}
}
async fn fail_all_pending(&self, message: &str) {
let mut pending = self.pending.lock().await;
for (_, tx) in pending.drain() {
let _ = tx.send(Err(BidiError {
command: "<connection>".to_string(),
error: "unknown error".to_string(),
message: message.to_string(),
stacktrace: None,
}));
}
}
}
#[derive(Serialize)]
struct OutgoingCommand<'a> {
id: u64,
method: &'a str,
params: &'a Value,
}