1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{anyhow, Result};
8use futures_util::{SinkExt, StreamExt};
9use serde_json::{json, Value};
10use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
11use tokio_tungstenite::tungstenite::Message;
12
13pub mod protocol;
14use protocol::{CdpError, Request, Response};
15
16const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
17const EVENT_CHANNEL_CAPACITY: usize = 256;
18
19type PendingMap = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
20
21#[derive(Debug, Clone)]
22pub struct CdpEvent {
23 pub method: String,
24 pub params: Value,
25 pub session_id: Option<String>,
26}
27
28pub struct CdpClient {
29 next_id: Mutex<u64>,
30 pending: Arc<Mutex<PendingMap>>,
31 events_tx: broadcast::Sender<CdpEvent>,
32 write_tx: mpsc::UnboundedSender<String>,
33 reader_handle: tokio::task::JoinHandle<()>,
34 writer_handle: tokio::task::JoinHandle<()>,
35}
36
37impl CdpClient {
38 pub async fn connect(ws_url: &str) -> Result<Self> {
40 let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url).await?;
41 let (mut ws_sink, mut ws_stream) = ws_stream.split();
42
43 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
44 let (events_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
45 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
46
47 let writer_handle = tokio::spawn(async move {
48 while let Some(text) = write_rx.recv().await {
49 if ws_sink.send(Message::Text(text)).await.is_err() {
50 break;
51 }
52 }
53 let _ = ws_sink.close().await;
54 });
55
56 let pending_r = pending.clone();
57 let events_r = events_tx.clone();
58 let reader_handle = tokio::spawn(async move {
59 while let Some(msg) = ws_stream.next().await {
60 let text = match msg {
61 Ok(Message::Text(t)) => t,
62 Ok(Message::Binary(b)) => match String::from_utf8(b) {
63 Ok(s) => s,
64 Err(_) => continue,
65 },
66 Ok(Message::Close(_)) | Err(_) => break,
67 Ok(_) => continue,
68 };
69 let resp: Response = match serde_json::from_str(&text) {
70 Ok(r) => r,
71 Err(_) => continue,
72 };
73 if let Some(id) = resp.id {
74 let mut p = pending_r.lock().await;
75 if let Some(tx) = p.remove(&id) {
76 let res = if let Some(err) = resp.error {
77 Err(err)
78 } else {
79 Ok(resp.result)
80 };
81 let _ = tx.send(res);
82 }
83 } else if let Some(method) = resp.method {
84 let _ = events_r.send(CdpEvent {
85 method,
86 params: resp.params,
87 session_id: resp.session_id,
88 });
89 }
90 }
91 let mut p = pending_r.lock().await;
93 for (_, tx) in p.drain() {
94 let _ = tx.send(Err(CdpError {
95 code: -1,
96 message: "connection closed".into(),
97 }));
98 }
99 });
100
101 Ok(Self {
102 next_id: Mutex::new(1),
103 pending,
104 events_tx,
105 write_tx,
106 reader_handle,
107 writer_handle,
108 })
109 }
110
111 pub async fn connect_http(base_url: &str) -> Result<Self> {
113 let base = base_url.trim_end_matches('/');
114 let url = format!("{base}/json/version");
115 let resp: Value = reqwest::get(&url).await?.json().await?;
116 let ws_url = resp
117 .get("webSocketDebuggerUrl")
118 .and_then(|v| v.as_str())
119 .ok_or_else(|| anyhow!("webSocketDebuggerUrl missing from {url}"))?
120 .to_string();
121 Self::connect(&ws_url).await
122 }
123
124 pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
126 self.send_with_session(method, params, None).await
127 }
128
129 pub async fn send_with_session(
130 &self,
131 method: &str,
132 params: Value,
133 session_id: Option<&str>,
134 ) -> Result<Value> {
135 let id = {
136 let mut n = self.next_id.lock().await;
137 let id = *n;
138 *n += 1;
139 id
140 };
141
142 let req = Request {
143 id,
144 method,
145 params,
146 session_id: session_id.map(|s| s.to_string()),
147 };
148 let text = serde_json::to_string(&req)?;
149
150 let (tx, rx) = oneshot::channel();
151 {
152 let mut p = self.pending.lock().await;
153 p.insert(id, tx);
154 }
155
156 if self.write_tx.send(text).is_err() {
157 let mut p = self.pending.lock().await;
158 p.remove(&id);
159 return Err(anyhow!("writer task closed"));
160 }
161
162 match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
163 Ok(Ok(Ok(v))) => Ok(v),
164 Ok(Ok(Err(e))) => Err(anyhow!(e)),
165 Ok(Err(_)) => Err(anyhow!("response channel dropped")),
166 Err(_) => {
167 let mut p = self.pending.lock().await;
168 p.remove(&id);
169 Err(anyhow!("CDP request timed out after {:?}", REQUEST_TIMEOUT))
170 }
171 }
172 }
173
174 pub fn subscribe(&self) -> broadcast::Receiver<CdpEvent> {
176 self.events_tx.subscribe()
177 }
178
179 pub async fn attach_to_target(&self, target_id: &str) -> Result<String> {
181 let v = self
182 .send(
183 "Target.attachToTarget",
184 json!({ "targetId": target_id, "flatten": true }),
185 )
186 .await?;
187 v.get("sessionId")
188 .and_then(|v| v.as_str())
189 .map(|s| s.to_string())
190 .ok_or_else(|| anyhow!("sessionId missing from Target.attachToTarget response"))
191 }
192
193 pub async fn list_targets(&self) -> Result<Vec<Value>> {
195 let v = self.send("Target.getTargets", Value::Null).await?;
196 match v.get("targetInfos") {
197 Some(Value::Array(a)) => Ok(a.clone()),
198 _ => Ok(vec![]),
199 }
200 }
201
202 pub async fn close(self) {
204 drop(self.write_tx);
205 let _ = self.writer_handle.await;
206 self.reader_handle.abort();
207 let _ = self.reader_handle.await;
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use futures_util::{SinkExt, StreamExt};
215 use tokio_tungstenite::tungstenite::Message;
216
217 #[tokio::test]
218 async fn round_trip_request_response() {
219 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
220 let addr = listener.local_addr().unwrap();
221 tokio::spawn(async move {
222 let (stream, _) = listener.accept().await.unwrap();
223 let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
224 while let Some(Ok(msg)) = ws.next().await {
225 if let Message::Text(t) = msg {
226 let req: Value = serde_json::from_str(&t).unwrap();
227 let id = req["id"].as_u64().unwrap();
228 let resp = json!({"id": id, "result": {"ok": true, "echo": req["method"]}});
229 ws.send(Message::Text(resp.to_string())).await.unwrap();
230 }
231 }
232 });
233 let url = format!("ws://{}", addr);
234 let client = CdpClient::connect(&url).await.unwrap();
235 let v = client
236 .send("Page.navigate", json!({"url": "about:blank"}))
237 .await
238 .unwrap();
239 assert_eq!(v["ok"], true);
240 assert_eq!(v["echo"], "Page.navigate");
241 client.close().await;
242 }
243
244 #[tokio::test]
245 async fn broadcast_event_to_subscriber() {
246 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
247 let addr = listener.local_addr().unwrap();
248 let (ready_tx, ready_rx) = oneshot::channel::<()>();
249 tokio::spawn(async move {
250 let (stream, _) = listener.accept().await.unwrap();
251 let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
252 let _ = ready_rx.await;
254 let evt = json!({
255 "method": "Target.targetCreated",
256 "params": {"targetInfo": {"targetId": "abc"}},
257 "sessionId": "S1"
258 });
259 ws.send(Message::Text(evt.to_string())).await.unwrap();
260 while let Some(Ok(_)) = ws.next().await {}
262 });
263
264 let url = format!("ws://{}", addr);
265 let client = CdpClient::connect(&url).await.unwrap();
266 let mut rx = client.subscribe();
267 ready_tx.send(()).unwrap();
268
269 let evt = tokio::time::timeout(Duration::from_secs(5), rx.recv())
270 .await
271 .expect("event timeout")
272 .expect("event recv");
273 assert_eq!(evt.method, "Target.targetCreated");
274 assert_eq!(evt.session_id.as_deref(), Some("S1"));
275 assert_eq!(evt.params["targetInfo"]["targetId"], "abc");
276 client.close().await;
277 }
278}