Skip to main content

cdp_use_rs/
client.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use serde_json::Value;
10use tokio::net::TcpStream;
11use tokio::sync::{oneshot, Mutex as AsyncMutex};
12use tokio::task::JoinHandle;
13use tokio_tungstenite::tungstenite::client::IntoClientRequest;
14use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
15use tokio_tungstenite::tungstenite::Message;
16use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
17
18use crate::CdpError;
19
20/// Default timeout for CDP commands (30 seconds).
21const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
22
23/// Default maximum WebSocket message size (100 MiB, matching Python cdp-use).
24const DEFAULT_MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024;
25
26/// Configuration for a CDP client connection.
27///
28/// Use `Default::default()` for sensible defaults matching the Python cdp-use client.
29#[derive(Debug, Clone)]
30pub struct CdpClientConfig {
31    /// Maximum WebSocket message size in bytes. Default: 100 MiB.
32    pub max_message_size: Option<usize>,
33    /// Maximum WebSocket frame size in bytes. Default: tungstenite default (16 MiB).
34    pub max_frame_size: Option<usize>,
35    /// Additional HTTP headers to send during the WebSocket handshake.
36    pub additional_headers: HashMap<String, String>,
37    /// Timeout for CDP commands. Default: 30 seconds.
38    pub command_timeout: Duration,
39}
40
41impl Default for CdpClientConfig {
42    fn default() -> Self {
43        Self {
44            max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
45            max_frame_size: None, // use tungstenite default (16 MiB)
46            additional_headers: HashMap::new(),
47            command_timeout: DEFAULT_COMMAND_TIMEOUT,
48        }
49    }
50}
51
52type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
53type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
54type WsSource = futures_util::stream::SplitStream<WsStream>;
55type PendingRequests = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
56
57/// Type-erased async event handler.
58pub type EventHandler = Arc<
59    dyn Fn(Value, Option<String>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
60>;
61
62/// Central registry for CDP event callbacks.
63///
64/// One handler per method (replacement semantics, matching the Python version).
65pub struct EventRegistry {
66    handlers: std::sync::Mutex<HashMap<String, EventHandler>>,
67}
68
69impl EventRegistry {
70    pub fn new() -> Self {
71        Self {
72            handlers: std::sync::Mutex::new(HashMap::new()),
73        }
74    }
75
76    /// Register a handler for a CDP event method. Replaces any existing handler.
77    pub fn register(&self, method: &str, handler: EventHandler) {
78        self.handlers
79            .lock()
80            .unwrap()
81            .insert(method.to_string(), handler);
82    }
83
84    /// Remove the handler for a CDP event method.
85    pub fn unregister(&self, method: &str) {
86        self.handlers.lock().unwrap().remove(method);
87    }
88
89    /// Dispatch an event to its registered handler. Returns true if handled.
90    ///
91    /// The handler is cloned and the lock is dropped before awaiting, so
92    /// handlers may safely call `register`/`unregister` without deadlocking.
93    pub async fn handle_event(
94        &self,
95        method: &str,
96        params: Value,
97        session_id: Option<String>,
98    ) -> bool {
99        let handler = {
100            let handlers = self.handlers.lock().unwrap();
101            handlers.get(method).cloned()
102        };
103
104        if let Some(handler) = handler {
105            handler(params, session_id).await;
106            true
107        } else {
108            false
109        }
110    }
111
112    /// Remove all registered handlers.
113    pub fn clear(&self) {
114        self.handlers.lock().unwrap().clear();
115    }
116}
117
118impl Default for EventRegistry {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124/// Chrome DevTools Protocol client.
125///
126/// Connects to a browser via WebSocket and provides access to CDP
127/// commands and events.
128///
129/// ```no_run
130/// # async fn example() -> Result<(), cdp_use::CdpError> {
131/// let cdp = cdp_use::CdpClient::connect("ws://127.0.0.1:9222/devtools/browser/...").await?;
132/// let result = cdp.send_raw("Target.getTargets", serde_json::json!({}), None).await?;
133/// cdp.close().await?;
134/// # Ok(())
135/// # }
136/// ```
137#[derive(Clone)]
138pub struct CdpClient {
139    inner: Arc<ClientInner>,
140}
141
142struct ClientInner {
143    sink: AsyncMutex<WsSink>,
144    next_id: AtomicU64,
145    pending: Arc<AsyncMutex<PendingRequests>>,
146    event_registry: Arc<EventRegistry>,
147    closed: AtomicBool,
148    command_timeout: Duration,
149    message_loop_handle: std::sync::Mutex<Option<JoinHandle<()>>>,
150}
151
152impl Drop for ClientInner {
153    fn drop(&mut self) {
154        if let Some(handle) = self.message_loop_handle.get_mut().unwrap().take() {
155            handle.abort();
156        }
157    }
158}
159
160impl CdpClient {
161    /// Connect to a CDP endpoint via WebSocket with default configuration.
162    pub async fn connect(url: &str) -> Result<Self, CdpError> {
163        Self::connect_with_config(url, CdpClientConfig::default()).await
164    }
165
166    /// Connect to a CDP endpoint via WebSocket with custom configuration.
167    ///
168    /// ```no_run
169    /// # async fn example() -> Result<(), cdp_use::CdpError> {
170    /// use cdp_use::{CdpClient, CdpClientConfig};
171    /// use std::time::Duration;
172    ///
173    /// let config = CdpClientConfig {
174    ///     command_timeout: Duration::from_secs(60),
175    ///     ..Default::default()
176    /// };
177    /// let cdp = CdpClient::connect_with_config("ws://127.0.0.1:9222/devtools/browser/...", config).await?;
178    /// # Ok(())
179    /// # }
180    /// ```
181    pub async fn connect_with_config(
182        url: &str,
183        config: CdpClientConfig,
184    ) -> Result<Self, CdpError> {
185        let mut request = url.into_client_request()?;
186
187        // Add custom headers to the WebSocket handshake request
188        for (key, value) in &config.additional_headers {
189            request.headers_mut().insert(
190                key.parse::<tokio_tungstenite::tungstenite::http::HeaderName>()
191                    .map_err(|e| CdpError::Protocol {
192                        code: -1,
193                        message: format!("Invalid header name '{key}': {e}"),
194                        data: None,
195                    })?,
196                value
197                    .parse()
198                    .map_err(|e| CdpError::Protocol {
199                        code: -1,
200                        message: format!("Invalid header value for '{key}': {e}"),
201                        data: None,
202                    })?,
203            );
204        }
205
206        let mut ws_config = WebSocketConfig::default();
207        ws_config.max_message_size = config.max_message_size;
208        ws_config.max_frame_size = config.max_frame_size;
209
210        let (ws_stream, _) =
211            connect_async_with_config(request, Some(ws_config), false).await?;
212        let (sink, stream) = ws_stream.split();
213
214        let pending = Arc::new(AsyncMutex::new(HashMap::new()));
215        let event_registry = Arc::new(EventRegistry::new());
216        let closed = Arc::new(AtomicBool::new(false));
217
218        let handle = tokio::spawn({
219            let pending = pending.clone();
220            let registry = event_registry.clone();
221            let closed = closed.clone();
222            async move {
223                message_loop(stream, pending, registry, closed).await;
224            }
225        });
226
227        Ok(Self {
228            inner: Arc::new(ClientInner {
229                sink: AsyncMutex::new(sink),
230                next_id: AtomicU64::new(0),
231                pending,
232                event_registry,
233                closed: AtomicBool::new(false),
234                command_timeout: config.command_timeout,
235                message_loop_handle: std::sync::Mutex::new(Some(handle)),
236            }),
237        })
238    }
239
240    /// Send a raw CDP command and await the response.
241    ///
242    /// Returns `CdpError::ConnectionClosed` if the connection is already closed,
243    /// or `CdpError::Timeout` if the browser does not respond within the timeout.
244    pub async fn send_raw(
245        &self,
246        method: &str,
247        params: Value,
248        session_id: Option<&str>,
249    ) -> Result<Value, CdpError> {
250        if self.inner.closed.load(Ordering::Acquire) {
251            return Err(CdpError::ConnectionClosed);
252        }
253
254        let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed) + 1;
255
256        let (tx, rx) = oneshot::channel();
257        self.inner.pending.lock().await.insert(id, tx);
258
259        let mut msg = serde_json::json!({
260            "id": id,
261            "method": method,
262            "params": params,
263        });
264        if let Some(sid) = session_id {
265            msg["sessionId"] = Value::String(sid.to_string());
266        }
267
268        let send_result = self
269            .inner
270            .sink
271            .lock()
272            .await
273            .send(Message::Text(msg.to_string().into()))
274            .await;
275
276        if let Err(e) = send_result {
277            // Clean up the pending entry if the WebSocket send failed
278            self.inner.pending.lock().await.remove(&id);
279            return Err(e.into());
280        }
281
282        // Await response with timeout
283        match tokio::time::timeout(self.inner.command_timeout, rx).await {
284            Ok(Ok(result)) => result,
285            Ok(Err(_)) => {
286                // Sender dropped (connection closed)
287                Err(CdpError::ConnectionClosed)
288            }
289            Err(_elapsed) => {
290                // Timeout — clean up the pending entry
291                self.inner.pending.lock().await.remove(&id);
292                Err(CdpError::Timeout)
293            }
294        }
295    }
296
297    /// Emit a synthetic event through the event registry.
298    ///
299    /// Useful for custom domains where events are produced by application
300    /// code rather than the browser.
301    pub async fn emit_event(
302        &self,
303        method: &str,
304        params: Value,
305        session_id: Option<&str>,
306    ) -> bool {
307        self.inner
308            .event_registry
309            .handle_event(method, params, session_id.map(String::from))
310            .await
311    }
312
313    /// Get a reference to the event registry.
314    // Used by generated code in generated.rs
315    pub(crate) fn event_registry(&self) -> &Arc<EventRegistry> {
316        &self.inner.event_registry
317    }
318
319    /// Gracefully close the WebSocket connection.
320    pub async fn close(&self) -> Result<(), CdpError> {
321        // Mark as closed first so new send_raw calls fail fast
322        self.inner.closed.store(true, Ordering::Release);
323
324        // Fail all pending requests before aborting the message loop,
325        // so in-flight callers get ConnectionClosed instead of hanging.
326        {
327            let mut pending = self.inner.pending.lock().await;
328            for (_, tx) in pending.drain() {
329                let _ = tx.send(Err(CdpError::ConnectionClosed));
330            }
331        }
332
333        // Abort the message handler
334        if let Some(handle) = self.inner.message_loop_handle.lock().unwrap().take() {
335            handle.abort();
336            let _ = handle.await;
337        }
338
339        // Send close frame
340        self.inner.sink.lock().await.close().await?;
341        Ok(())
342    }
343}
344
345async fn message_loop(
346    mut stream: WsSource,
347    pending: Arc<AsyncMutex<PendingRequests>>,
348    event_registry: Arc<EventRegistry>,
349    closed: Arc<AtomicBool>,
350) {
351    while let Some(msg_result) = stream.next().await {
352        match msg_result {
353            Ok(Message::Text(text)) => {
354                let data: Value = match serde_json::from_str(&text) {
355                    Ok(v) => v,
356                    Err(e) => {
357                        tracing::warn!("Failed to parse CDP message: {e}");
358                        continue;
359                    }
360                };
361
362                if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
363                    // Response message — match to pending request
364                    let mut pending = pending.lock().await;
365                    if let Some(tx) = pending.remove(&id) {
366                        let result = if let Some(error) = data.get("error") {
367                            let code = error.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
368                            let message = error
369                                .get("message")
370                                .and_then(|v| v.as_str())
371                                .unwrap_or("Unknown error")
372                                .to_string();
373                            let err_data = error.get("data").map(|v| v.to_string());
374                            Err(CdpError::Protocol {
375                                code,
376                                message,
377                                data: err_data,
378                            })
379                        } else {
380                            Ok(data
381                                .get("result")
382                                .cloned()
383                                .unwrap_or(Value::Object(Default::default())))
384                        };
385                        let _ = tx.send(result);
386                    }
387                } else if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
388                    // Event message — spawn handler to avoid blocking the message loop
389                    let params = data.get("params").cloned().unwrap_or_default();
390                    let session_id = data
391                        .get("sessionId")
392                        .and_then(|v| v.as_str())
393                        .map(String::from);
394                    let registry = event_registry.clone();
395                    let method = method.to_string();
396                    tokio::spawn(async move {
397                        registry.handle_event(&method, params, session_id).await;
398                    });
399                }
400            }
401            Ok(Message::Close(_)) | Err(_) => {
402                // Connection closed or error — mark closed and fail all pending requests
403                closed.store(true, Ordering::Release);
404                let mut pending = pending.lock().await;
405                for (_, tx) in pending.drain() {
406                    let _ = tx.send(Err(CdpError::ConnectionClosed));
407                }
408                break;
409            }
410            _ => {} // Ping/pong handled automatically by tokio-tungstenite
411        }
412    }
413}