Skip to main content

browser_control/bidi/
mod.rs

1//! Minimal WebDriver BiDi WebSocket client.
2
3pub mod protocol;
4
5use anyhow::{anyhow, Result};
6use futures_util::{SinkExt, StreamExt};
7use protocol::*;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
13use tokio_tungstenite::tungstenite::Message;
14
15const SEND_TIMEOUT: Duration = Duration::from_secs(30);
16const EVENT_CHANNEL_CAPACITY: usize = 256;
17
18type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, BidiError>>>>>;
19
20#[derive(Debug, Clone)]
21pub struct BidiEvent {
22    pub method: String,
23    pub params: Value,
24}
25
26pub struct BidiClient {
27    next_id: Mutex<u64>,
28    pending: PendingMap,
29    events_tx: broadcast::Sender<BidiEvent>,
30    write_tx: mpsc::UnboundedSender<String>,
31    session_id: Mutex<Option<String>>,
32}
33
34impl BidiClient {
35    pub async fn connect(ws_url: &str) -> Result<Self> {
36        let (ws, _resp) = tokio_tungstenite::connect_async(ws_url).await?;
37        let (mut sink, mut stream) = ws.split();
38
39        let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
40        let (events_tx, _) = broadcast::channel::<BidiEvent>(EVENT_CHANNEL_CAPACITY);
41        let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
42
43        // Writer task.
44        tokio::spawn(async move {
45            while let Some(msg) = write_rx.recv().await {
46                if sink.send(Message::Text(msg)).await.is_err() {
47                    break;
48                }
49            }
50            let _ = sink.close().await;
51        });
52
53        // Reader task.
54        let pending_reader = pending.clone();
55        let events_reader = events_tx.clone();
56        tokio::spawn(async move {
57            while let Some(Ok(msg)) = stream.next().await {
58                let text = match msg {
59                    Message::Text(t) => t,
60                    Message::Binary(b) => match String::from_utf8(b) {
61                        Ok(s) => s,
62                        Err(_) => continue,
63                    },
64                    Message::Close(_) => break,
65                    _ => continue,
66                };
67                let parsed: Result<IncomingMessage, _> = serde_json::from_str(&text);
68                match parsed {
69                    Ok(IncomingMessage::Success { id, result }) => {
70                        if let Some(tx) = pending_reader.lock().await.remove(&id) {
71                            let _ = tx.send(Ok(result));
72                        }
73                    }
74                    Ok(IncomingMessage::Error { id, error, message }) => {
75                        if let Some(id) = id {
76                            if let Some(tx) = pending_reader.lock().await.remove(&id) {
77                                let _ = tx.send(Err(BidiError {
78                                    code: error,
79                                    message,
80                                }));
81                            }
82                        }
83                    }
84                    Ok(IncomingMessage::Event { method, params }) => {
85                        let _ = events_reader.send(BidiEvent { method, params });
86                    }
87                    Err(_) => continue,
88                }
89            }
90            pending_reader.lock().await.clear();
91        });
92
93        Ok(Self {
94            next_id: Mutex::new(1),
95            pending,
96            events_tx,
97            write_tx,
98            session_id: Mutex::new(None),
99        })
100    }
101
102    pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
103        let id = {
104            let mut guard = self.next_id.lock().await;
105            let id = *guard;
106            *guard += 1;
107            id
108        };
109        let cmd = Command { id, method, params };
110        let text = serde_json::to_string(&cmd)?;
111
112        let (tx, rx) = oneshot::channel();
113        self.pending.lock().await.insert(id, tx);
114
115        self.write_tx
116            .send(text)
117            .map_err(|_| anyhow!("BiDi connection closed"))?;
118
119        match tokio::time::timeout(SEND_TIMEOUT, rx).await {
120            Ok(Ok(Ok(v))) => Ok(v),
121            Ok(Ok(Err(e))) => Err(e.into()),
122            Ok(Err(_)) => Err(anyhow!("BiDi response channel cancelled")),
123            Err(_) => {
124                self.pending.lock().await.remove(&id);
125                Err(anyhow!("BiDi send timed out after {:?}", SEND_TIMEOUT))
126            }
127        }
128    }
129
130    pub fn subscribe(&self) -> broadcast::Receiver<BidiEvent> {
131        self.events_tx.subscribe()
132    }
133
134    pub async fn session_new(&self) -> Result<String> {
135        let v = self
136            .send("session.new", json!({"capabilities": {}}))
137            .await?;
138        let sid = v["sessionId"]
139            .as_str()
140            .ok_or_else(|| anyhow!("no sessionId"))?
141            .to_string();
142        *self.session_id.lock().await = Some(sid.clone());
143        Ok(sid)
144    }
145
146    pub async fn session_end(&self) -> Result<()> {
147        // Best effort: ignore errors if no session is active.
148        let _ = self.send("session.end", json!({})).await;
149        Ok(())
150    }
151
152    pub async fn browsing_context_navigate(&self, context: &str, url: &str) -> Result<Value> {
153        self.send(
154            "browsingContext.navigate",
155            json!({"context": context, "url": url, "wait": "complete"}),
156        )
157        .await
158    }
159
160    pub async fn script_evaluate(&self, context: &str, expression: &str) -> Result<Value> {
161        self.send(
162            "script.evaluate",
163            json!({
164                "expression": expression,
165                "target": {"context": context},
166                "awaitPromise": true,
167                "resultOwnership": "none"
168            }),
169        )
170        .await
171    }
172
173    pub async fn browsing_context_capture_screenshot(&self, context: &str) -> Result<String> {
174        let v = self
175            .send(
176                "browsingContext.captureScreenshot",
177                json!({"context": context}),
178            )
179            .await?;
180        Ok(v["data"]
181            .as_str()
182            .ok_or_else(|| anyhow!("no data"))?
183            .to_string())
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use futures_util::{SinkExt, StreamExt};
191    use tokio::net::TcpListener;
192    use tokio_tungstenite::accept_async;
193
194    async fn spawn_echo_server() -> String {
195        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
196        let addr = listener.local_addr().unwrap();
197        tokio::spawn(async move {
198            if let Ok((stream, _)) = listener.accept().await {
199                let mut ws = accept_async(stream).await.unwrap();
200                while let Some(Ok(msg)) = ws.next().await {
201                    if let Message::Text(text) = msg {
202                        let v: Value = serde_json::from_str(&text).unwrap();
203                        let id = v["id"].as_u64().unwrap();
204                        let method = v["method"].as_str().unwrap().to_string();
205                        let reply = json!({
206                            "id": id,
207                            "type": "success",
208                            "result": {"echoed": method}
209                        });
210                        ws.send(Message::Text(reply.to_string())).await.unwrap();
211                    }
212                }
213            }
214        });
215        format!("ws://{}", addr)
216    }
217
218    #[tokio::test]
219    async fn send_receives_success_result() {
220        let url = spawn_echo_server().await;
221        let client = BidiClient::connect(&url).await.unwrap();
222        let result = client.send("session.status", json!({})).await.unwrap();
223        assert_eq!(result["echoed"], "session.status");
224    }
225
226    #[tokio::test]
227    async fn subscriber_receives_event() {
228        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
229        let addr = listener.local_addr().unwrap();
230        tokio::spawn(async move {
231            let (stream, _) = listener.accept().await.unwrap();
232            let mut ws = accept_async(stream).await.unwrap();
233            let event = json!({
234                "type": "event",
235                "method": "log.entryAdded",
236                "params": {"text": "hello"}
237            });
238            ws.send(Message::Text(event.to_string())).await.unwrap();
239            while ws.next().await.is_some() {}
240        });
241        let url = format!("ws://{}", addr);
242        let client = BidiClient::connect(&url).await.unwrap();
243        let mut rx = client.subscribe();
244        let evt = tokio::time::timeout(Duration::from_secs(5), rx.recv())
245            .await
246            .unwrap()
247            .unwrap();
248        assert_eq!(evt.method, "log.entryAdded");
249        assert_eq!(evt.params["text"], "hello");
250    }
251}