browser_control/bidi/
mod.rs1pub mod protocol;
4
5use anyhow::{anyhow, Result};
6use futures_util::{SinkExt, StreamExt};
7use protocol::*;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
13use tokio_tungstenite::tungstenite::Message;
14
15const SEND_TIMEOUT: Duration = Duration::from_secs(30);
16const EVENT_CHANNEL_CAPACITY: usize = 256;
17
18type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, BidiError>>>>>;
19
20#[derive(Debug, Clone)]
21pub struct BidiEvent {
22 pub method: String,
23 pub params: Value,
24}
25
26pub struct BidiClient {
27 next_id: Mutex<u64>,
28 pending: PendingMap,
29 events_tx: broadcast::Sender<BidiEvent>,
30 write_tx: mpsc::UnboundedSender<String>,
31 session_id: Mutex<Option<String>>,
32}
33
34impl BidiClient {
35 pub async fn connect(ws_url: &str) -> Result<Self> {
36 let (ws, _resp) = tokio_tungstenite::connect_async(ws_url).await?;
37 let (mut sink, mut stream) = ws.split();
38
39 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
40 let (events_tx, _) = broadcast::channel::<BidiEvent>(EVENT_CHANNEL_CAPACITY);
41 let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
42
43 tokio::spawn(async move {
45 while let Some(msg) = write_rx.recv().await {
46 if sink.send(Message::Text(msg)).await.is_err() {
47 break;
48 }
49 }
50 let _ = sink.close().await;
51 });
52
53 let pending_reader = pending.clone();
55 let events_reader = events_tx.clone();
56 tokio::spawn(async move {
57 while let Some(Ok(msg)) = stream.next().await {
58 let text = match msg {
59 Message::Text(t) => t,
60 Message::Binary(b) => match String::from_utf8(b) {
61 Ok(s) => s,
62 Err(_) => continue,
63 },
64 Message::Close(_) => break,
65 _ => continue,
66 };
67 let parsed: Result<IncomingMessage, _> = serde_json::from_str(&text);
68 match parsed {
69 Ok(IncomingMessage::Success { id, result }) => {
70 if let Some(tx) = pending_reader.lock().await.remove(&id) {
71 let _ = tx.send(Ok(result));
72 }
73 }
74 Ok(IncomingMessage::Error { id, error, message }) => {
75 if let Some(id) = id {
76 if let Some(tx) = pending_reader.lock().await.remove(&id) {
77 let _ = tx.send(Err(BidiError {
78 code: error,
79 message,
80 }));
81 }
82 }
83 }
84 Ok(IncomingMessage::Event { method, params }) => {
85 let _ = events_reader.send(BidiEvent { method, params });
86 }
87 Err(_) => continue,
88 }
89 }
90 pending_reader.lock().await.clear();
91 });
92
93 Ok(Self {
94 next_id: Mutex::new(1),
95 pending,
96 events_tx,
97 write_tx,
98 session_id: Mutex::new(None),
99 })
100 }
101
102 pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
103 let id = {
104 let mut guard = self.next_id.lock().await;
105 let id = *guard;
106 *guard += 1;
107 id
108 };
109 let cmd = Command { id, method, params };
110 let text = serde_json::to_string(&cmd)?;
111
112 let (tx, rx) = oneshot::channel();
113 self.pending.lock().await.insert(id, tx);
114
115 self.write_tx
116 .send(text)
117 .map_err(|_| anyhow!("BiDi connection closed"))?;
118
119 match tokio::time::timeout(SEND_TIMEOUT, rx).await {
120 Ok(Ok(Ok(v))) => Ok(v),
121 Ok(Ok(Err(e))) => Err(e.into()),
122 Ok(Err(_)) => Err(anyhow!("BiDi response channel cancelled")),
123 Err(_) => {
124 self.pending.lock().await.remove(&id);
125 Err(anyhow!("BiDi send timed out after {:?}", SEND_TIMEOUT))
126 }
127 }
128 }
129
130 pub fn subscribe(&self) -> broadcast::Receiver<BidiEvent> {
131 self.events_tx.subscribe()
132 }
133
134 pub async fn session_new(&self) -> Result<String> {
135 let v = self
136 .send("session.new", json!({"capabilities": {}}))
137 .await?;
138 let sid = v["sessionId"]
139 .as_str()
140 .ok_or_else(|| anyhow!("no sessionId"))?
141 .to_string();
142 *self.session_id.lock().await = Some(sid.clone());
143 Ok(sid)
144 }
145
146 pub async fn session_end(&self) -> Result<()> {
147 let _ = self.send("session.end", json!({})).await;
149 Ok(())
150 }
151
152 pub async fn browsing_context_navigate(&self, context: &str, url: &str) -> Result<Value> {
153 self.send(
154 "browsingContext.navigate",
155 json!({"context": context, "url": url, "wait": "complete"}),
156 )
157 .await
158 }
159
160 pub async fn script_evaluate(&self, context: &str, expression: &str) -> Result<Value> {
161 self.send(
162 "script.evaluate",
163 json!({
164 "expression": expression,
165 "target": {"context": context},
166 "awaitPromise": true,
167 "resultOwnership": "none"
168 }),
169 )
170 .await
171 }
172
173 pub async fn browsing_context_capture_screenshot(&self, context: &str) -> Result<String> {
174 let v = self
175 .send(
176 "browsingContext.captureScreenshot",
177 json!({"context": context}),
178 )
179 .await?;
180 Ok(v["data"]
181 .as_str()
182 .ok_or_else(|| anyhow!("no data"))?
183 .to_string())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use futures_util::{SinkExt, StreamExt};
191 use tokio::net::TcpListener;
192 use tokio_tungstenite::accept_async;
193
194 async fn spawn_echo_server() -> String {
195 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
196 let addr = listener.local_addr().unwrap();
197 tokio::spawn(async move {
198 if let Ok((stream, _)) = listener.accept().await {
199 let mut ws = accept_async(stream).await.unwrap();
200 while let Some(Ok(msg)) = ws.next().await {
201 if let Message::Text(text) = msg {
202 let v: Value = serde_json::from_str(&text).unwrap();
203 let id = v["id"].as_u64().unwrap();
204 let method = v["method"].as_str().unwrap().to_string();
205 let reply = json!({
206 "id": id,
207 "type": "success",
208 "result": {"echoed": method}
209 });
210 ws.send(Message::Text(reply.to_string())).await.unwrap();
211 }
212 }
213 }
214 });
215 format!("ws://{}", addr)
216 }
217
218 #[tokio::test]
219 async fn send_receives_success_result() {
220 let url = spawn_echo_server().await;
221 let client = BidiClient::connect(&url).await.unwrap();
222 let result = client.send("session.status", json!({})).await.unwrap();
223 assert_eq!(result["echoed"], "session.status");
224 }
225
226 #[tokio::test]
227 async fn subscriber_receives_event() {
228 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
229 let addr = listener.local_addr().unwrap();
230 tokio::spawn(async move {
231 let (stream, _) = listener.accept().await.unwrap();
232 let mut ws = accept_async(stream).await.unwrap();
233 let event = json!({
234 "type": "event",
235 "method": "log.entryAdded",
236 "params": {"text": "hello"}
237 });
238 ws.send(Message::Text(event.to_string())).await.unwrap();
239 while ws.next().await.is_some() {}
240 });
241 let url = format!("ws://{}", addr);
242 let client = BidiClient::connect(&url).await.unwrap();
243 let mut rx = client.subscribe();
244 let evt = tokio::time::timeout(Duration::from_secs(5), rx.recv())
245 .await
246 .unwrap()
247 .unwrap();
248 assert_eq!(evt.method, "log.entryAdded");
249 assert_eq!(evt.params["text"], "hello");
250 }
251}