Skip to main content

browser_control/cdp/
mod.rs

1//! Minimal Chrome DevTools Protocol (CDP) WebSocket client.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{anyhow, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
11use tokio_tungstenite::tungstenite::Message;
12
13pub mod protocol;
14use protocol::{CdpError, Request, Response};
15
16const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
17const EVENT_CHANNEL_CAPACITY: usize = 256;
18
19type PendingMap = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
20
21#[derive(Debug, Clone)]
22pub struct CdpEvent {
23    pub method: String,
24    pub params: Value,
25    pub session_id: Option<String>,
26}
27
28pub struct CdpClient {
29    next_id: Mutex<u64>,
30    pending: Arc<Mutex<PendingMap>>,
31    events_tx: broadcast::Sender<CdpEvent>,
32    write_tx: mpsc::UnboundedSender<String>,
33    reader_handle: tokio::task::JoinHandle<()>,
34    writer_handle: tokio::task::JoinHandle<()>,
35}
36
37impl CdpClient {
38    /// Connect by full WebSocket URL (ws:// or wss://).
39    pub async fn connect(ws_url: &str) -> Result<Self> {
40        let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url).await?;
41        let (mut ws_sink, mut ws_stream) = ws_stream.split();
42
43        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
44        let (events_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
45        let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
46
47        let writer_handle = tokio::spawn(async move {
48            while let Some(text) = write_rx.recv().await {
49                if ws_sink.send(Message::Text(text)).await.is_err() {
50                    break;
51                }
52            }
53            let _ = ws_sink.close().await;
54        });
55
56        let pending_r = pending.clone();
57        let events_r = events_tx.clone();
58        let reader_handle = tokio::spawn(async move {
59            while let Some(msg) = ws_stream.next().await {
60                let text = match msg {
61                    Ok(Message::Text(t)) => t,
62                    Ok(Message::Binary(b)) => match String::from_utf8(b) {
63                        Ok(s) => s,
64                        Err(_) => continue,
65                    },
66                    Ok(Message::Close(_)) | Err(_) => break,
67                    Ok(_) => continue,
68                };
69                let resp: Response = match serde_json::from_str(&text) {
70                    Ok(r) => r,
71                    Err(_) => continue,
72                };
73                if let Some(id) = resp.id {
74                    let mut p = pending_r.lock().await;
75                    if let Some(tx) = p.remove(&id) {
76                        let res = if let Some(err) = resp.error {
77                            Err(err)
78                        } else {
79                            Ok(resp.result)
80                        };
81                        let _ = tx.send(res);
82                    }
83                } else if let Some(method) = resp.method {
84                    let _ = events_r.send(CdpEvent {
85                        method,
86                        params: resp.params,
87                        session_id: resp.session_id,
88                    });
89                }
90            }
91            // Reader closed: fail all pending requests.
92            let mut p = pending_r.lock().await;
93            for (_, tx) in p.drain() {
94                let _ = tx.send(Err(CdpError {
95                    code: -1,
96                    message: "connection closed".into(),
97                }));
98            }
99        });
100
101        Ok(Self {
102            next_id: Mutex::new(1),
103            pending,
104            events_tx,
105            write_tx,
106            reader_handle,
107            writer_handle,
108        })
109    }
110
111    /// Connect by HTTP base URL (e.g. http://127.0.0.1:9222). Fetches /json/version to discover the WS URL.
112    pub async fn connect_http(base_url: &str) -> Result<Self> {
113        let base = base_url.trim_end_matches('/');
114        let url = format!("{base}/json/version");
115        let resp: Value = reqwest::get(&url).await?.json().await?;
116        let ws_url = resp
117            .get("webSocketDebuggerUrl")
118            .and_then(|v| v.as_str())
119            .ok_or_else(|| anyhow!("webSocketDebuggerUrl missing from {url}"))?
120            .to_string();
121        Self::connect(&ws_url).await
122    }
123
124    /// Send a method on the root browser-level session.
125    pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
126        self.send_with_session(method, params, None).await
127    }
128
129    pub async fn send_with_session(
130        &self,
131        method: &str,
132        params: Value,
133        session_id: Option<&str>,
134    ) -> Result<Value> {
135        let id = {
136            let mut n = self.next_id.lock().await;
137            let id = *n;
138            *n += 1;
139            id
140        };
141
142        let req = Request {
143            id,
144            method,
145            params,
146            session_id: session_id.map(|s| s.to_string()),
147        };
148        let text = serde_json::to_string(&req)?;
149
150        let (tx, rx) = oneshot::channel();
151        {
152            let mut p = self.pending.lock().await;
153            p.insert(id, tx);
154        }
155
156        if self.write_tx.send(text).is_err() {
157            let mut p = self.pending.lock().await;
158            p.remove(&id);
159            return Err(anyhow!("writer task closed"));
160        }
161
162        match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
163            Ok(Ok(Ok(v))) => Ok(v),
164            Ok(Ok(Err(e))) => Err(anyhow!(e)),
165            Ok(Err(_)) => Err(anyhow!("response channel dropped")),
166            Err(_) => {
167                let mut p = self.pending.lock().await;
168                p.remove(&id);
169                Err(anyhow!("CDP request timed out after {:?}", REQUEST_TIMEOUT))
170            }
171        }
172    }
173
174    /// Subscribe to all events. Drop the receiver to unsubscribe.
175    pub fn subscribe(&self) -> broadcast::Receiver<CdpEvent> {
176        self.events_tx.subscribe()
177    }
178
179    /// Attach to a target via Target.attachToTarget(flatten=true) and return the session id.
180    pub async fn attach_to_target(&self, target_id: &str) -> Result<String> {
181        let v = self
182            .send(
183                "Target.attachToTarget",
184                json!({ "targetId": target_id, "flatten": true }),
185            )
186            .await?;
187        v.get("sessionId")
188            .and_then(|v| v.as_str())
189            .map(|s| s.to_string())
190            .ok_or_else(|| anyhow!("sessionId missing from Target.attachToTarget response"))
191    }
192
193    /// Convenience: list targets via Target.getTargets.
194    pub async fn list_targets(&self) -> Result<Vec<Value>> {
195        let v = self.send("Target.getTargets", Value::Null).await?;
196        match v.get("targetInfos") {
197            Some(Value::Array(a)) => Ok(a.clone()),
198            _ => Ok(vec![]),
199        }
200    }
201
202    /// Gracefully shut down.
203    pub async fn close(self) {
204        drop(self.write_tx);
205        let _ = self.writer_handle.await;
206        self.reader_handle.abort();
207        let _ = self.reader_handle.await;
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use futures_util::{SinkExt, StreamExt};
215    use tokio_tungstenite::tungstenite::Message;
216
217    #[tokio::test]
218    async fn round_trip_request_response() {
219        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
220        let addr = listener.local_addr().unwrap();
221        tokio::spawn(async move {
222            let (stream, _) = listener.accept().await.unwrap();
223            let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
224            while let Some(Ok(msg)) = ws.next().await {
225                if let Message::Text(t) = msg {
226                    let req: Value = serde_json::from_str(&t).unwrap();
227                    let id = req["id"].as_u64().unwrap();
228                    let resp = json!({"id": id, "result": {"ok": true, "echo": req["method"]}});
229                    ws.send(Message::Text(resp.to_string())).await.unwrap();
230                }
231            }
232        });
233        let url = format!("ws://{}", addr);
234        let client = CdpClient::connect(&url).await.unwrap();
235        let v = client
236            .send("Page.navigate", json!({"url": "about:blank"}))
237            .await
238            .unwrap();
239        assert_eq!(v["ok"], true);
240        assert_eq!(v["echo"], "Page.navigate");
241        client.close().await;
242    }
243
244    #[tokio::test]
245    async fn broadcast_event_to_subscriber() {
246        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
247        let addr = listener.local_addr().unwrap();
248        let (ready_tx, ready_rx) = oneshot::channel::<()>();
249        tokio::spawn(async move {
250            let (stream, _) = listener.accept().await.unwrap();
251            let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
252            // Wait until the test confirms it has subscribed before pushing event.
253            let _ = ready_rx.await;
254            let evt = json!({
255                "method": "Target.targetCreated",
256                "params": {"targetInfo": {"targetId": "abc"}},
257                "sessionId": "S1"
258            });
259            ws.send(Message::Text(evt.to_string())).await.unwrap();
260            // Keep socket alive briefly.
261            while let Some(Ok(_)) = ws.next().await {}
262        });
263
264        let url = format!("ws://{}", addr);
265        let client = CdpClient::connect(&url).await.unwrap();
266        let mut rx = client.subscribe();
267        ready_tx.send(()).unwrap();
268
269        let evt = tokio::time::timeout(Duration::from_secs(5), rx.recv())
270            .await
271            .expect("event timeout")
272            .expect("event recv");
273        assert_eq!(evt.method, "Target.targetCreated");
274        assert_eq!(evt.session_id.as_deref(), Some("S1"));
275        assert_eq!(evt.params["targetInfo"]["targetId"], "abc");
276        client.close().await;
277    }
278}