Skip to main content

ferrous_browser/
cdp.rs

1use serde_json::{json, Value};
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU32, Ordering};
4use std::sync::Arc;
5use tokio::sync::{broadcast, oneshot, RwLock};
6use tokio::time::{timeout, Duration};
7use futures_util::SinkExt;
8use tokio_tungstenite::tungstenite::Message;
9
10use crate::error::{BrowserError, Result};
11
12/// Represents a CDP command request
13#[derive(Debug, Clone)]
14pub struct CDPRequest {
15    /// Unique request ID
16    pub id: u32,
17    /// CDP method name
18    pub method: String,
19    /// Optional parameters for the method
20    pub params: Option<Value>,
21    /// Optional session ID for targeting specific pages
22    pub session_id: Option<String>,
23}
24
25impl CDPRequest {
26    /// Create a new CDP request
27    pub fn new(id: u32, method: String, params: Option<Value>) -> Self {
28        Self { id, method, params, session_id: None }
29    }
30
31    /// Create a CDP request with session ID
32    pub fn with_session(id: u32, method: String, params: Option<Value>, session_id: String) -> Self {
33        Self { id, method, params, session_id: Some(session_id) }
34    }
35
36    /// Convert to JSON value for sending
37    pub fn to_json(&self) -> Value {
38        let mut obj = json!({
39            "id": self.id,
40            "method": self.method,
41        });
42
43        if let Some(session_id) = &self.session_id {
44            obj["sessionId"] = json!(session_id);
45        }
46
47        if let Some(params) = &self.params {
48            obj["params"] = params.clone();
49        }
50
51        obj
52    }
53}
54
55/// Represents a CDP event or response
56#[derive(Debug, Clone)]
57pub struct CDPMessage {
58    /// Response ID (if this is a response)
59    pub id: Option<u32>,
60    /// Event method name (if this is an event)
61    pub method: Option<String>,
62    /// Event parameters
63    pub params: Option<Value>,
64    /// Command result (if successful)
65    pub result: Option<Value>,
66    /// Error object (if failed)
67    pub error: Option<Value>,
68    /// Session ID — identifies which page/target this message belongs to.
69    /// This is the critical field for multi-page session isolation.
70    pub session_id: Option<String>,
71}
72
73impl CDPMessage {
74    /// Parse a CDP message from JSON value
75    pub fn from_json(value: Value) -> Result<Self> {
76        Ok(CDPMessage {
77            id: value.get("id").and_then(|v| v.as_u64()).map(|v| v as u32),
78            method: value.get("method").and_then(|v| v.as_str()).map(|s| s.to_string()),
79            params: value.get("params").cloned(),
80            result: value.get("result").cloned(),
81            error: value.get("error").cloned(),
82            // Chrome always includes sessionId in session-scoped messages
83            session_id: value
84                .get("sessionId")
85                .and_then(|v| v.as_str())
86                .map(|s| s.to_string()),
87        })
88    }
89}
90
91/// Type for WebSocket sink
92pub type WebSocketSink = futures_util::stream::SplitSink<
93    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
94    Message,
95>;
96
97/// Core CDP client that manages WebSocket connection and message routing
98pub struct CDPClient {
99    ws_url: String,
100    message_id_counter: Arc<AtomicU32>,
101    pending_responses: Arc<RwLock<HashMap<u32, oneshot::Sender<Value>>>>,
102    /// Broadcast channel carrying ALL CDP events (method is_some()).
103    /// Subscribers filter by method name and session_id themselves.
104    event_broadcast: broadcast::Sender<CDPMessage>,
105    ws_sink: Arc<RwLock<Option<WebSocketSink>>>,
106}
107
108impl CDPClient {
109    /// Create a new CDP client
110    pub fn new(ws_url: String) -> Self {
111        let (event_broadcast, _) = broadcast::channel(1024);
112        Self {
113            ws_url,
114            message_id_counter: Arc::new(AtomicU32::new(1)),
115            pending_responses: Arc::new(RwLock::new(HashMap::new())),
116            event_broadcast,
117            ws_sink: Arc::new(RwLock::new(None)),
118        }
119    }
120
121    /// Set the WebSocket sink (called from Connection)
122    pub async fn set_sink(&self, sink: WebSocketSink) {
123        let mut ws = self.ws_sink.write().await;
124        *ws = Some(sink);
125    }
126
127    /// Generate the next message ID
128    pub fn next_id(&self) -> u32 {
129        self.message_id_counter.fetch_add(1, Ordering::SeqCst)
130    }
131
132    /// Connect to the Chrome DevTools Protocol WebSocket
133    pub async fn connect(
134        &self,
135    ) -> Result<
136        tokio_tungstenite::WebSocketStream<
137            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
138        >,
139    > {
140        let (ws_stream, _) = tokio_tungstenite::connect_async(&self.ws_url)
141            .await
142            .map_err(|e| BrowserError::connection_failed(&self.ws_url, e.to_string()))?;
143
144        Ok(ws_stream)
145    }
146
147    /// Send raw message through WebSocket
148    pub async fn send_raw(&self, msg: String) -> Result<()> {
149        let mut ws = self.ws_sink.write().await;
150        if let Some(sink) = ws.as_mut() {
151            sink.send(Message::Text(msg))
152                .await
153                .map_err(|e| BrowserError::websocket("send_raw", e.to_string()))?;
154        } else {
155            return Err(BrowserError::websocket("send_raw", "WebSocket not connected"));
156        }
157        Ok(())
158    }
159
160    /// Subscribe to all CDP events (unfiltered broadcast receiver).
161    ///
162    /// Callers are responsible for filtering by `msg.method` and
163    /// `msg.session_id` as needed.
164    ///
165    /// **IMPORTANT:** Subscribe *before* sending the CDP command that
166    /// triggers the event to avoid the race where Chrome replies before the
167    /// receiver is registered.
168    pub fn subscribe_events(&self) -> broadcast::Receiver<CDPMessage> {
169        self.event_broadcast.subscribe()
170    }
171
172    /// Send a command and wait for response with timeout.
173    ///
174    /// The response handler is registered **before** the message is sent so
175    /// that fast Chrome replies are never dropped.
176    pub async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
177        let id = self.next_id();
178        let request = CDPRequest::new(id, method.clone(), params);
179
180        // ── Register handler BEFORE sending ──────────────────────────────────
181        let (tx, rx) = oneshot::channel();
182        self.register_response_handler(id, tx).await;
183        let json_str = request.to_json().to_string();
184        self.send_raw(json_str).await?;
185        // ─────────────────────────────────────────────────────────────────────
186
187        const TIMEOUT_SECS: u64 = 30;
188        match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
189            Ok(Ok(value)) => Ok(value),
190            Ok(Err(_)) => Err(BrowserError::command_failed(
191                &method,
192                "response channel closed unexpectedly",
193            )),
194            Err(_) => {
195                let mut pending = self.pending_responses.write().await;
196                pending.remove(&id);
197                Err(BrowserError::timeout(
198                    format!("waiting for response to '{method}'"),
199                    TIMEOUT_SECS,
200                ))
201            }
202        }
203    }
204
205    /// Send a command to a specific page session.
206    ///
207    /// The response handler is registered **before** the message is sent.
208    pub async fn send_command_with_session(
209        &self,
210        session_id: &str,
211        method: String,
212        params: Option<Value>,
213    ) -> Result<Value> {
214        let id = self.next_id();
215        let request =
216            CDPRequest::with_session(id, method.clone(), params, session_id.to_string());
217
218        // ── Register handler BEFORE sending ──────────────────────────────────
219        let (tx, rx) = oneshot::channel();
220        self.register_response_handler(id, tx).await;
221        let json_str = request.to_json().to_string();
222        self.send_raw(json_str).await?;
223        // ─────────────────────────────────────────────────────────────────────
224
225        const TIMEOUT_SECS: u64 = 30;
226        match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
227            Ok(Ok(value)) => Ok(value),
228            Ok(Err(_)) => Err(BrowserError::command_failed(
229                &method,
230                "response channel closed unexpectedly",
231            )),
232            Err(_) => {
233                let mut pending = self.pending_responses.write().await;
234                pending.remove(&id);
235                Err(BrowserError::timeout(
236                    format!("waiting for response to '{method}'"),
237                    TIMEOUT_SECS,
238                ))
239            }
240        }
241    }
242
243    /// Register a pending response handler
244    pub async fn register_response_handler(&self, id: u32, tx: oneshot::Sender<Value>) {
245        let mut pending = self.pending_responses.write().await;
246        pending.insert(id, tx);
247    }
248
249    /// Handle an incoming CDP message — called by `Connection`
250    pub async fn handle_message(&self, msg: CDPMessage) -> Result<()> {
251        if let Some(id) = msg.id {
252            // It's a response to one of our commands
253            let mut pending = self.pending_responses.write().await;
254            if let Some(tx) = pending.remove(&id) {
255                if let Some(error) = msg.error {
256                    let _ = tx.send(json!({ "error": error }));
257                } else if let Some(result) = msg.result {
258                    let _ = tx.send(result);
259                } else {
260                    let _ = tx.send(json!({}));
261                }
262            }
263        } else if msg.method.is_some() {
264            // It's an event — broadcast to all subscribers.
265            // Subscribers filter by method + session_id.
266            let _ = self.event_broadcast.send(msg);
267        }
268
269        Ok(())
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_cdp_request_creation() {
279        let req = CDPRequest::new(
280            1,
281            "Page.navigate".to_string(),
282            Some(json!({"url": "https://example.com"})),
283        );
284        assert_eq!(req.id, 1);
285        assert_eq!(req.method, "Page.navigate");
286        assert_eq!(req.params.as_ref().unwrap()["url"], "https://example.com");
287    }
288
289    #[test]
290    fn test_cdp_request_to_json() {
291        let req = CDPRequest::new(
292            1,
293            "Page.navigate".to_string(),
294            Some(json!({"url": "https://example.com"})),
295        );
296        let json = req.to_json();
297        assert_eq!(json["id"], 1);
298        assert_eq!(json["method"], "Page.navigate");
299        assert_eq!(json["params"]["url"], "https://example.com");
300    }
301
302    #[test]
303    fn test_cdp_message_from_json() {
304        let json_val = json!({
305            "id": 1,
306            "result": {"url": "https://example.com"},
307            "sessionId": "SES001"
308        });
309        let msg = CDPMessage::from_json(json_val).unwrap();
310        assert_eq!(msg.id, Some(1));
311        assert_eq!(msg.result.as_ref().unwrap()["url"], "https://example.com");
312        assert_eq!(msg.session_id.as_deref(), Some("SES001"));
313    }
314
315    #[test]
316    fn test_cdp_message_session_id_parsed() {
317        let event = json!({
318            "method": "Page.loadEventFired",
319            "params": {},
320            "sessionId": "ABC123"
321        });
322        let msg = CDPMessage::from_json(event).unwrap();
323        assert_eq!(msg.method.as_deref(), Some("Page.loadEventFired"));
324        assert_eq!(msg.session_id.as_deref(), Some("ABC123"));
325    }
326
327    #[test]
328    fn test_cdp_request_with_session() {
329        let req = CDPRequest::with_session(
330            2,
331            "Runtime.evaluate".to_string(),
332            Some(json!({"expression": "1+1"})),
333            "SES001".to_string(),
334        );
335        let json = req.to_json();
336        assert_eq!(json["sessionId"], "SES001");
337        assert_eq!(json["method"], "Runtime.evaluate");
338    }
339}