Skip to main content

ferrous_browser/
cdp.rs

1use futures_util::SinkExt;
2use serde_json::{json, Value};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::{Arc, Mutex as StdMutex};
6use tokio::sync::{broadcast, mpsc, oneshot};
7use tokio::time::{timeout, Duration};
8use tokio_tungstenite::tungstenite::Message;
9use tracing::Instrument;
10
11use crate::error::{BrowserError, Result};
12
13/// Represents a CDP command request
14#[derive(Debug, Clone)]
15pub struct CDPRequest {
16    /// Unique request ID
17    pub id: u32,
18    /// CDP method name
19    pub method: String,
20    /// Optional parameters for the method
21    pub params: Option<Value>,
22    /// Optional session ID for targeting specific pages
23    pub session_id: Option<String>,
24}
25
26impl CDPRequest {
27    /// Create a new CDP request
28    pub fn new(id: u32, method: String, params: Option<Value>) -> Self {
29        Self {
30            id,
31            method,
32            params,
33            session_id: None,
34        }
35    }
36
37    /// Create a CDP request with session ID
38    pub fn with_session(
39        id: u32,
40        method: String,
41        params: Option<Value>,
42        session_id: String,
43    ) -> Self {
44        Self {
45            id,
46            method,
47            params,
48            session_id: Some(session_id),
49        }
50    }
51
52    /// Convert to JSON value for sending
53    pub fn to_json(&self) -> Value {
54        let mut obj = json!({
55            "id": self.id,
56            "method": self.method,
57        });
58
59        if let Some(session_id) = &self.session_id {
60            obj["sessionId"] = json!(session_id);
61        }
62
63        if let Some(params) = &self.params {
64            obj["params"] = params.clone();
65        }
66
67        obj
68    }
69}
70
71/// Represents a CDP event or response
72#[derive(Debug, Clone)]
73pub struct CDPMessage {
74    /// Response ID (if this is a response)
75    pub id: Option<u32>,
76    /// Event method name (if this is an event)
77    pub method: Option<String>,
78    /// Event parameters
79    pub params: Option<Value>,
80    /// Command result (if successful)
81    pub result: Option<Value>,
82    /// Error object (if failed)
83    pub error: Option<Value>,
84    /// Session ID — identifies which page/target this message belongs to.
85    /// This is the critical field for multi-page session isolation.
86    pub session_id: Option<String>,
87}
88
89impl CDPMessage {
90    /// Parse a CDP message from JSON value
91    pub fn from_json(value: Value) -> Result<Self> {
92        Ok(CDPMessage {
93            id: value.get("id").and_then(|v| v.as_u64()).map(|v| v as u32),
94            method: value
95                .get("method")
96                .and_then(|v| v.as_str())
97                .map(|s| s.to_string()),
98            params: value.get("params").cloned(),
99            result: value.get("result").cloned(),
100            error: value.get("error").cloned(),
101            // Chrome always includes sessionId in session-scoped messages
102            session_id: value
103                .get("sessionId")
104                .and_then(|v| v.as_str())
105                .map(|s| s.to_string()),
106        })
107    }
108}
109
110/// Type for WebSocket sink
111pub type WebSocketSink = futures_util::stream::SplitSink<
112    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
113    Message,
114>;
115
116/// Core CDP client that manages WebSocket connection and message routing.
117///
118/// **Concurrency model:**
119/// - Outgoing writes go through an unbounded mpsc channel feeding a dedicated
120///   writer task (`spawn_writer_task`). Many callers can send concurrently
121///   without contending on a lock.
122/// - Incoming reads are dispatched by `Connection::run` calling
123///   `handle_message`, which only briefly holds a sync mutex on
124///   `pending_responses`, with no awaits while the lock is held.
125pub struct CDPClient {
126    ws_url: String,
127    message_id_counter: Arc<AtomicU32>,
128    /// Pending command responses, keyed by request id. Guarded by a *sync*
129    /// mutex; we never `.await` while holding it, so a `std::sync::Mutex`
130    /// is correct and skips `tokio::sync::RwLock`'s wait-queue overhead.
131    pending_responses: Arc<StdMutex<HashMap<u32, oneshot::Sender<Value>>>>,
132    /// Broadcast channel carrying ALL CDP events (`method.is_some()`).
133    /// Subscribers filter by method name and session_id themselves.
134    event_broadcast: broadcast::Sender<CDPMessage>,
135    /// Sender side of the writer-task mailbox. `None` until `set_writer` is
136    /// called from `Browser::connect_internal`.
137    ws_tx: Arc<StdMutex<Option<mpsc::UnboundedSender<Message>>>>,
138}
139
140impl CDPClient {
141    /// Create a new CDP client
142    pub fn new(ws_url: String) -> Self {
143        let (event_broadcast, _) = broadcast::channel(1024);
144        Self {
145            ws_url,
146            message_id_counter: Arc::new(AtomicU32::new(1)),
147            pending_responses: Arc::new(StdMutex::new(HashMap::new())),
148            event_broadcast,
149            ws_tx: Arc::new(StdMutex::new(None)),
150        }
151    }
152
153    /// Install the writer-task mailbox. Called once by
154    /// `Browser::connect_internal` right after the writer task is spawned.
155    pub fn set_writer(&self, tx: mpsc::UnboundedSender<Message>) {
156        *self.ws_tx.lock().expect("ws_tx mutex poisoned") = Some(tx);
157    }
158
159    /// Generate the next message ID
160    pub fn next_id(&self) -> u32 {
161        self.message_id_counter.fetch_add(1, Ordering::SeqCst)
162    }
163
164    /// Connect to the Chrome DevTools Protocol WebSocket
165    pub async fn connect(
166        &self,
167    ) -> Result<
168        tokio_tungstenite::WebSocketStream<
169            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
170        >,
171    > {
172        let (ws_stream, _) = tokio_tungstenite::connect_async(&self.ws_url)
173            .await
174            .map_err(|e| BrowserError::connection_failed(&self.ws_url, e.to_string()))?;
175
176        Ok(ws_stream)
177    }
178
179    /// Send raw message through the writer task. Synchronous and (modulo one
180    /// brief sync-mutex acquire) lock-free, because the mpsc channel fan-in
181    /// is the contention point now, not a per-call lock.
182    pub fn send_raw(&self, msg: String) -> Result<()> {
183        let tx_guard = self.ws_tx.lock().expect("ws_tx mutex poisoned");
184        let tx = tx_guard.as_ref().ok_or_else(|| {
185            BrowserError::websocket("send_raw", "WebSocket writer not initialised")
186        })?;
187        tx.send(Message::Text(msg))
188            .map_err(|_| BrowserError::websocket("send_raw", "WebSocket writer task ended"))
189    }
190
191    /// Subscribe to all CDP events (unfiltered broadcast receiver).
192    ///
193    /// Callers are responsible for filtering by `msg.method` and
194    /// `msg.session_id` as needed.
195    ///
196    /// **IMPORTANT:** Subscribe *before* sending the CDP command that
197    /// triggers the event to avoid the race where Chrome replies before the
198    /// receiver is registered.
199    pub fn subscribe_events(&self) -> broadcast::Receiver<CDPMessage> {
200        self.event_broadcast.subscribe()
201    }
202
203    /// Send a command and wait for response with timeout.
204    ///
205    /// The response handler is registered **before** the message is sent so
206    /// that fast Chrome replies are never dropped.
207    #[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id))]
208    pub async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
209        let id = self.next_id();
210        tracing::Span::current().record("id", id);
211        let request = CDPRequest::new(id, method.clone(), params);
212
213        // ── Register handler BEFORE sending ──────────────────────────────────
214        let (tx, rx) = oneshot::channel();
215        self.register_response_handler(id, tx);
216        let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
217        let bytes = json_str.len();
218        tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
219        // ─────────────────────────────────────────────────────────────────────
220
221        const TIMEOUT_SECS: u64 = 30;
222        let wait = async {
223            match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
224                Ok(Ok(value)) => Ok(value),
225                Ok(Err(_)) => Err(BrowserError::command_failed(
226                    &method,
227                    "response channel closed unexpectedly",
228                )),
229                Err(_) => {
230                    self.pending_responses
231                        .lock()
232                        .expect("pending_responses mutex poisoned")
233                        .remove(&id);
234                    Err(BrowserError::timeout(
235                        format!("waiting for response to '{method}'"),
236                        TIMEOUT_SECS,
237                    ))
238                }
239            }
240        };
241        wait.instrument(tracing::info_span!("await_response")).await
242    }
243
244    /// Send a command to a specific page session.
245    ///
246    /// The response handler is registered **before** the message is sent.
247    #[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id, session_id = %session_id))]
248    pub async fn send_command_with_session(
249        &self,
250        session_id: &str,
251        method: String,
252        params: Option<Value>,
253    ) -> Result<Value> {
254        let id = self.next_id();
255        tracing::Span::current().record("id", id);
256        let request = CDPRequest::with_session(id, method.clone(), params, session_id.to_string());
257
258        // ── Register handler BEFORE sending ──────────────────────────────────
259        let (tx, rx) = oneshot::channel();
260        self.register_response_handler(id, tx);
261        let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
262        let bytes = json_str.len();
263        tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
264        // ─────────────────────────────────────────────────────────────────────
265
266        const TIMEOUT_SECS: u64 = 30;
267        let wait = async {
268            match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
269                Ok(Ok(value)) => Ok(value),
270                Ok(Err(_)) => Err(BrowserError::command_failed(
271                    &method,
272                    "response channel closed unexpectedly",
273                )),
274                Err(_) => {
275                    self.pending_responses
276                        .lock()
277                        .expect("pending_responses mutex poisoned")
278                        .remove(&id);
279                    Err(BrowserError::timeout(
280                        format!("waiting for response to '{method}'"),
281                        TIMEOUT_SECS,
282                    ))
283                }
284            }
285        };
286        wait.instrument(tracing::info_span!("await_response")).await
287    }
288
289    /// Register a pending response handler. Synchronous: the sync mutex
290    /// only protects the HashMap insert.
291    pub fn register_response_handler(&self, id: u32, tx: oneshot::Sender<Value>) {
292        self.pending_responses
293            .lock()
294            .expect("pending_responses mutex poisoned")
295            .insert(id, tx);
296    }
297
298    /// Drop every pending response sender. Any `send_command` currently
299    /// awaiting one of these will see its oneshot close immediately and
300    /// return `BrowserError::command_failed("…", "response channel closed…")`,
301    /// instead of waiting out the 30-second timeout. Call this when the
302    /// underlying WebSocket dies.
303    pub fn fail_all_pending(&self, reason: &str) {
304        let mut pending = self
305            .pending_responses
306            .lock()
307            .expect("pending_responses mutex poisoned");
308        let count = pending.len();
309        pending.clear(); // dropping the senders signals the receivers
310        drop(pending);
311        if count > 0 {
312            tracing::warn!(
313                pending_count = count,
314                reason = reason,
315                "WebSocket terminated; failing in-flight CDP requests"
316            );
317        }
318    }
319
320    /// Handle an incoming CDP message; called by `Connection::run`.
321    /// Synchronous: no `.await` happens while the pending-responses mutex
322    /// is held.
323    #[tracing::instrument(level = "debug", skip_all, fields(method = ?msg.method, id = ?msg.id))]
324    pub fn handle_message(&self, msg: CDPMessage) -> Result<()> {
325        if let Some(id) = msg.id {
326            // It's a response to one of our commands
327            let tx = self
328                .pending_responses
329                .lock()
330                .expect("pending_responses mutex poisoned")
331                .remove(&id);
332            if let Some(tx) = tx {
333                if let Some(error) = msg.error {
334                    let _ = tx.send(json!({ "error": error }));
335                } else if let Some(result) = msg.result {
336                    let _ = tx.send(result);
337                } else {
338                    let _ = tx.send(json!({}));
339                }
340            }
341        } else if msg.method.is_some() {
342            // It's an event — broadcast to all subscribers.
343            // Subscribers filter by method + session_id.
344            let _ = self.event_broadcast.send(msg);
345        }
346        Ok(())
347    }
348}
349
350/// Spawn the dedicated writer task that drains the mpsc and writes to the
351/// WebSocket sink. The task ends when the channel closes (all senders
352/// dropped) or when a write fails. On a write failure it also fails every
353/// in-flight CDP request via `fail_all_pending` so callers see an immediate
354/// error instead of the 30 s timeout.
355pub fn spawn_writer_task(
356    mut sink: WebSocketSink,
357    mut rx: mpsc::UnboundedReceiver<Message>,
358    cdp: Arc<CDPClient>,
359) -> tokio::task::JoinHandle<()> {
360    tokio::spawn(async move {
361        while let Some(msg) = rx.recv().await {
362            if let Err(e) = sink.send(msg).await {
363                tracing::error!(error = %e, "WebSocket write error; terminating writer");
364                cdp.fail_all_pending(&format!("write error: {e}"));
365                return;
366            }
367        }
368        tracing::debug!("WebSocket writer task exiting (channel closed)");
369    })
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_cdp_request_creation() {
378        let req = CDPRequest::new(
379            1,
380            "Page.navigate".to_string(),
381            Some(json!({"url": "https://example.com"})),
382        );
383        assert_eq!(req.id, 1);
384        assert_eq!(req.method, "Page.navigate");
385        assert_eq!(req.params.as_ref().unwrap()["url"], "https://example.com");
386    }
387
388    #[test]
389    fn test_cdp_request_to_json() {
390        let req = CDPRequest::new(
391            1,
392            "Page.navigate".to_string(),
393            Some(json!({"url": "https://example.com"})),
394        );
395        let json = req.to_json();
396        assert_eq!(json["id"], 1);
397        assert_eq!(json["method"], "Page.navigate");
398        assert_eq!(json["params"]["url"], "https://example.com");
399    }
400
401    #[test]
402    fn test_cdp_message_from_json() {
403        let json_val = json!({
404            "id": 1,
405            "result": {"url": "https://example.com"},
406            "sessionId": "SES001"
407        });
408        let msg = CDPMessage::from_json(json_val).unwrap();
409        assert_eq!(msg.id, Some(1));
410        assert_eq!(msg.result.as_ref().unwrap()["url"], "https://example.com");
411        assert_eq!(msg.session_id.as_deref(), Some("SES001"));
412    }
413
414    #[test]
415    fn test_cdp_message_session_id_parsed() {
416        let event = json!({
417            "method": "Page.loadEventFired",
418            "params": {},
419            "sessionId": "ABC123"
420        });
421        let msg = CDPMessage::from_json(event).unwrap();
422        assert_eq!(msg.method.as_deref(), Some("Page.loadEventFired"));
423        assert_eq!(msg.session_id.as_deref(), Some("ABC123"));
424    }
425
426    #[test]
427    fn test_cdp_request_with_session() {
428        let req = CDPRequest::with_session(
429            2,
430            "Runtime.evaluate".to_string(),
431            Some(json!({"expression": "1+1"})),
432            "SES001".to_string(),
433        );
434        let json = req.to_json();
435        assert_eq!(json["sessionId"], "SES001");
436        assert_eq!(json["method"], "Runtime.evaluate");
437    }
438}