use beet_core::exports::async_channel;
use beet_core::prelude::*;
use beet_net::prelude::*;
use beet_net::sockets::Message;
use beet_net::sockets::*;
use bevy::tasks::IoTaskPool;
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
struct SessionInner {
driver_url: String,
session_id: String,
socket_url: String,
next_id: AtomicUsize,
pending: Mutex<HashMap<u64, async_channel::Sender<Value>>>,
cmd_tx: async_channel::Sender<String>,
_cmd_rx: async_channel::Receiver<String>,
writer: Mutex<Option<SocketWrite>>,
events_tx: async_channel::Sender<Value>,
events_rx: async_channel::Receiver<Value>,
}
#[derive(Debug, Clone)]
pub struct Session {
inner: Arc<SessionInner>,
}
impl std::fmt::Debug for SessionInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionInner")
.field("session_id", &self.session_id)
.field("socket_url", &self.socket_url)
.finish()
}
}
impl Session {
pub async fn kill(self) -> Result<()> {
let url = format!(
"{}/session/{}",
self.inner.driver_url, self.inner.session_id
);
Request::delete(&url).send().await?.into_result().await?;
if let Some(writer) = self.inner.writer.lock().unwrap().take() {
let _ = writer.close(None).await;
}
Ok(())
}
pub fn id(&self) -> &str { &self.inner.session_id }
pub fn try_event(&self) -> Option<Value> {
self.inner.events_rx.try_recv().ok()
}
pub async fn next_event(&self) -> Result<Value, async_channel::RecvError> {
self.inner.events_rx.recv().await
}
pub async fn command(&self, method: &str, params: Value) -> Result<Value> {
let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst) as u64;
let (tx, rx) = async_channel::bounded(1);
{
let mut pending = self.inner.pending.lock().unwrap();
pending.insert(id, tx);
}
let payload = json!({
"id": id,
"method": method,
"params": params
});
let raw = serde_json::to_string(&payload)
.map_err(|e| bevyhow!("Failed to serialize command: {}", e))?;
self.inner
.cmd_tx
.send(raw)
.await
.map_err(|_| bevyhow!("Command channel closed"))?;
let resp = rx
.recv()
.await
.map_err(|_| bevyhow!("Response channel closed"))?;
if let Some(err_obj) = resp.get("error") {
return Err(bevyhow!(
"BiDi error for method '{}': {}",
method,
err_obj
));
}
Ok(resp)
}
pub async fn ping(&self) -> Result<()> {
let _ = self
.command("browsingContext.getTree", json!({"maxDepth": 0}))
.await?;
Ok(())
}
pub async fn connect(
driver_url: &str,
session_id: &str,
socket_url: &str,
) -> Result<Self> {
let socket = Socket::connect(socket_url).await?;
let (send, recv) = socket.split();
let (cmd_tx, cmd_rx) = async_channel::unbounded::<String>();
let (events_tx, events_rx) = async_channel::unbounded::<Value>();
let inner = Arc::new(SessionInner {
driver_url: driver_url.to_string(),
session_id: session_id.to_string(),
socket_url: socket_url.to_string(),
next_id: AtomicUsize::new(1),
pending: Mutex::new(HashMap::new()),
cmd_tx,
_cmd_rx: cmd_rx.clone(),
writer: Mutex::new(Some(send)),
events_tx,
events_rx,
});
Self::spawn_writer(inner.clone(), cmd_rx);
Self::spawn_reader(inner.clone(), recv);
Ok(Self { inner })
}
fn spawn_writer(
inner: Arc<SessionInner>,
cmd_rx: async_channel::Receiver<String>,
) {
IoTaskPool::get()
.spawn_local(async move {
while let Ok(raw) = cmd_rx.recv().await {
let send_result = {
let mut guard = inner.writer.lock().unwrap();
if let Some(writer) = guard.as_mut() {
writer.send(Message::text(raw)).await
} else {
Ok(())
}
};
if send_result.is_err() {
break;
}
}
})
.detach();
}
fn spawn_reader(inner: Arc<SessionInner>, mut read: SocketRead) {
IoTaskPool::get()
.spawn(async move {
while let Some(item) = read.next().await {
let Ok(Message::Text(text)) = item else {
continue;
};
let Ok(val) = serde_json::from_str::<Value>(&text) else {
continue;
};
if let Some(id) = val.get("id").and_then(|v| v.as_u64()) {
let pending = {
let mut pending_map = inner.pending.lock().unwrap();
pending_map.remove(&id)
};
if let Some(tx) = pending {
let _ = tx.send(val).await;
}
continue;
}
if val.get("method").is_some() {
let _ = inner.events_tx.try_send(val);
}
}
})
.detach();
}
}
#[cfg(test)]
mod test {
use crate::prelude::*;
use beet_core::prelude::*;
#[beet_core::test]
async fn works() {
App::default()
.run_io_task_local(async move {
let client = ClientProcess::new().unwrap();
let session = client.new_session().await.unwrap();
session.ping().await.unwrap();
session.kill().await.unwrap();
client.kill().unwrap();
})
.await;
}
}