1use std::collections::HashMap;
23use std::fmt;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
26use std::time::Duration;
27
28use async_tungstenite::tokio::{ConnectStream, connect_async};
29use async_tungstenite::tungstenite::Message;
30use async_tungstenite::tungstenite::error::Error as WsError;
31use async_tungstenite::{WebSocketReceiver, WebSocketSender};
32use futures::StreamExt as _;
33use serde::de::DeserializeOwned;
34use serde::{Deserialize, Serialize};
35use serde_json::Value;
36use tokio::sync::{Mutex as AsyncMutex, broadcast, oneshot};
37use tokio::task::JoinHandle;
38
39type Sink = WebSocketSender<ConnectStream>;
40type Stream = WebSocketReceiver<ConnectStream>;
41
42const EVENT_BUFFER: usize = 256;
46
47#[derive(Debug, thiserror::Error)]
49#[non_exhaustive]
50pub enum CdpError {
51 #[error("websocket: {0}")]
53 WebSocket(String),
54
55 #[error("CDP {code}: {message}")]
58 Remote {
59 code: i64,
61 message: String,
63 },
64
65 #[error("decode response: {0}")]
68 Decode(String),
69
70 #[error("CDP {what} timed out after {elapsed:?}")]
72 Timeout {
73 elapsed: Duration,
75 what: &'static str,
77 },
78
79 #[error("CDP client is closed")]
82 Closed,
83}
84
85impl CdpError {
86 fn ws(e: &WsError) -> Self {
87 Self::WebSocket(e.to_string())
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct CdpEvent {
94 pub method: String,
96 pub params: Value,
98 pub session_id: Option<String>,
101}
102
103#[derive(Serialize)]
104struct Request<'a, P> {
105 id: u64,
106 method: &'a str,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 params: Option<P>,
109 #[serde(rename = "sessionId", skip_serializing_if = "Option::is_none")]
110 session_id: Option<&'a str>,
111}
112
113#[derive(Deserialize)]
114struct RemoteError {
115 code: i64,
116 message: String,
117}
118
119#[derive(Deserialize)]
122struct Frame {
123 id: Option<u64>,
124 method: Option<String>,
125 params: Option<Value>,
126 result: Option<Value>,
127 error: Option<RemoteError>,
128 #[serde(rename = "sessionId")]
129 session_id: Option<String>,
130}
131
132type PendingMap = std::sync::Mutex<HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>>;
133
134struct Inner {
135 sink: AsyncMutex<Sink>,
136 pending: PendingMap,
137 next_id: AtomicU64,
138 events: broadcast::Sender<CdpEvent>,
139 closed: AtomicBool,
140}
141
142impl Inner {
143 fn mark_closed(&self) {
144 if self.closed.swap(true, Ordering::AcqRel) {
145 return;
146 }
147 let drained: Vec<_> = {
149 let mut g = self
150 .pending
151 .lock()
152 .unwrap_or_else(std::sync::PoisonError::into_inner);
153 g.drain().collect()
154 };
155 for (_, tx) in drained {
156 let _ = tx.send(Err(CdpError::Closed));
157 }
158 }
159}
160
161pub struct CdpClient {
167 inner: Arc<Inner>,
168 read_loop: JoinHandle<()>,
169}
170
171impl fmt::Debug for CdpClient {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("CdpClient")
174 .field("closed", &self.inner.closed.load(Ordering::Acquire))
175 .field(
176 "pending",
177 &self
178 .inner
179 .pending
180 .lock()
181 .map(|g| g.len())
182 .unwrap_or_default(),
183 )
184 .field("read_loop_finished", &self.read_loop.is_finished())
185 .finish()
186 }
187}
188
189impl Drop for CdpClient {
190 fn drop(&mut self) {
191 self.inner.mark_closed();
192 self.read_loop.abort();
193 }
194}
195
196impl CdpClient {
197 pub async fn connect(url: &str) -> Result<Self, CdpError> {
203 let (ws, _resp) = connect_async(url).await.map_err(|e| CdpError::ws(&e))?;
204 let (sink, stream) = ws.split();
205 let (events_tx, _) = broadcast::channel(EVENT_BUFFER);
206 let inner = Arc::new(Inner {
207 sink: AsyncMutex::new(sink),
208 pending: std::sync::Mutex::new(HashMap::new()),
209 next_id: AtomicU64::new(1),
210 events: events_tx,
211 closed: AtomicBool::new(false),
212 });
213 let read_loop = tokio::spawn(read_loop(Arc::clone(&inner), stream));
214 Ok(Self { inner, read_loop })
215 }
216
217 pub async fn execute<P, R>(
231 &self,
232 method: &'static str,
233 params: P,
234 session_id: Option<&str>,
235 timeout: Duration,
236 ) -> Result<R, CdpError>
237 where
238 P: Serialize,
239 R: DeserializeOwned,
240 {
241 if self.inner.closed.load(Ordering::Acquire) {
242 return Err(CdpError::Closed);
243 }
244 let id = self.inner.next_id.fetch_add(1, Ordering::AcqRel);
245 let req = Request {
246 id,
247 method,
248 params: Some(params),
249 session_id,
250 };
251 let json = serde_json::to_string(&req).map_err(|e| CdpError::Decode(e.to_string()))?;
252
253 let (tx, rx) = oneshot::channel();
254 {
257 let mut g = self
258 .inner
259 .pending
260 .lock()
261 .map_err(|_| CdpError::WebSocket("pending mutex poisoned".into()))?;
262 g.insert(id, tx);
263 }
264
265 let send = {
267 let mut sink = self.inner.sink.lock().await;
268 sink.send(Message::Text(json.into())).await
269 };
270 if let Err(e) = send {
271 let _ = self
273 .inner
274 .pending
275 .lock()
276 .map(|mut g| g.remove(&id))
277 .unwrap_or_default();
278 return Err(CdpError::ws(&e));
279 }
280
281 let wait = async {
282 rx.await.map_err(|_| CdpError::Closed)?.and_then(|value| {
283 serde_json::from_value::<R>(value).map_err(|e| CdpError::Decode(e.to_string()))
284 })
285 };
286
287 tokio::time::timeout(timeout, wait).await.map_err(|_| {
288 let _ = self
291 .inner
292 .pending
293 .lock()
294 .map(|mut g| g.remove(&id))
295 .unwrap_or_default();
296 CdpError::Timeout {
297 elapsed: timeout,
298 what: method,
299 }
300 })?
301 }
302
303 #[must_use]
310 pub fn subscribe_events(&self) -> broadcast::Receiver<CdpEvent> {
311 self.inner.events.subscribe()
312 }
313
314 pub async fn wait_for_event<F>(
328 &self,
329 predicate: F,
330 timeout: Duration,
331 what: &'static str,
332 ) -> Result<CdpEvent, CdpError>
333 where
334 F: Fn(&CdpEvent) -> bool + Send + Sync,
335 {
336 let mut rx = self.subscribe_events();
337 Self::wait_for_event_on(&mut rx, predicate, timeout, what).await
338 }
339
340 pub async fn wait_for_event_on<F>(
348 rx: &mut broadcast::Receiver<CdpEvent>,
349 predicate: F,
350 timeout: Duration,
351 what: &'static str,
352 ) -> Result<CdpEvent, CdpError>
353 where
354 F: Fn(&CdpEvent) -> bool + Send + Sync,
355 {
356 let wait = async {
357 loop {
358 match rx.recv().await {
359 Ok(evt) if predicate(&evt) => return Ok::<CdpEvent, CdpError>(evt),
360 Ok(_) | Err(broadcast::error::RecvError::Lagged(_)) => {}
362 Err(broadcast::error::RecvError::Closed) => return Err(CdpError::Closed),
363 }
364 }
365 };
366 tokio::time::timeout(timeout, wait)
367 .await
368 .map_err(|_| CdpError::Timeout {
369 elapsed: timeout,
370 what,
371 })?
372 }
373
374 pub async fn close(self) {
378 self.inner.mark_closed();
379 let _ = self.inner.sink.lock().await.close(None).await;
380 self.read_loop.abort();
381 }
382}
383
384async fn read_loop(inner: Arc<Inner>, mut stream: Stream) {
388 while let Some(msg) = stream.next().await {
389 if inner.closed.load(Ordering::Acquire) {
390 break;
391 }
392 let text = match msg {
393 Ok(Message::Text(t)) => t,
394 Ok(Message::Binary(b)) => {
395 let Ok(decoded) = String::from_utf8(b.into()) else {
396 tracing::warn!("CDP: non-UTF8 binary frame, dropped");
397 continue;
398 };
399 decoded.into()
400 }
401 Ok(Message::Close(_)) => {
402 tracing::debug!("CDP: peer closed");
403 break;
404 }
405 Ok(_) => continue, Err(e) => {
407 tracing::warn!(error = %e, "CDP: stream error, closing");
408 break;
409 }
410 };
411
412 let frame: Frame = match serde_json::from_str(&text) {
413 Ok(f) => f,
414 Err(e) => {
415 tracing::warn!(error = %e, "CDP: malformed frame, dropped");
416 continue;
417 }
418 };
419
420 match (frame.id, frame.method) {
421 (Some(id), _) => {
422 let tx = inner.pending.lock().ok().and_then(|mut g| g.remove(&id));
424 if let Some(tx) = tx {
425 let result = if let Some(err) = frame.error {
426 Err(CdpError::Remote {
427 code: err.code,
428 message: err.message,
429 })
430 } else {
431 Ok(frame.result.unwrap_or(Value::Null))
432 };
433 let _ = tx.send(result);
434 } else {
435 tracing::debug!(id, "CDP: response for unknown / cancelled id");
436 }
437 }
438 (None, Some(method)) => {
439 let evt = CdpEvent {
440 method,
441 params: frame.params.unwrap_or(Value::Null),
442 session_id: frame.session_id,
443 };
444 let _ = inner.events.send(evt);
446 }
447 (None, None) => {
448 tracing::warn!("CDP: frame has neither id nor method, dropped");
449 }
450 }
451 }
452 inner.mark_closed();
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn request_serialises_with_optional_fields() {
461 let r: Request<'_, Value> = Request {
462 id: 42,
463 method: "Page.enable",
464 params: None,
465 session_id: None,
466 };
467 let s = serde_json::to_value(&r).unwrap();
468 assert_eq!(s, serde_json::json!({ "id": 42, "method": "Page.enable" }));
470 }
471
472 #[test]
473 fn request_serialises_with_session_id() {
474 let r = Request {
475 id: 7,
476 method: "Page.navigate",
477 params: Some(serde_json::json!({ "url": "https://example.com" })),
478 session_id: Some("abc-123"),
479 };
480 let s = serde_json::to_value(&r).unwrap();
481 assert_eq!(
482 s,
483 serde_json::json!({
484 "id": 7,
485 "method": "Page.navigate",
486 "params": {"url": "https://example.com"},
487 "sessionId": "abc-123",
488 })
489 );
490 }
491
492 #[test]
493 fn frame_parses_a_response() {
494 let txt = r#"{"id": 1, "result": {"targetId": "T1"}}"#;
495 let f: Frame = serde_json::from_str(txt).unwrap();
496 assert_eq!(f.id, Some(1));
497 assert!(f.method.is_none());
498 assert_eq!(f.result.unwrap(), serde_json::json!({"targetId": "T1"}));
499 }
500
501 #[test]
502 fn frame_parses_a_remote_error() {
503 let txt = r#"{"id": 9, "error": {"code": -32601, "message": "Method not found"}}"#;
504 let f: Frame = serde_json::from_str(txt).unwrap();
505 let err = f.error.unwrap();
506 assert_eq!(err.code, -32601);
507 assert_eq!(err.message, "Method not found");
508 }
509
510 #[test]
511 fn frame_parses_an_event_with_session_id() {
512 let txt =
513 r#"{"method": "Page.loadEventFired", "params": {"timestamp": 1.0}, "sessionId": "S1"}"#;
514 let f: Frame = serde_json::from_str(txt).unwrap();
515 assert!(f.id.is_none());
516 assert_eq!(f.method.as_deref(), Some("Page.loadEventFired"));
517 assert_eq!(f.session_id.as_deref(), Some("S1"));
518 assert!(f.params.is_some());
519 }
520}