1use futures_util::SinkExt;
2use serde_json::{json, Value};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::{Arc, Mutex as StdMutex};
6use tokio::sync::{broadcast, mpsc, oneshot};
7use tokio::time::{timeout, Duration};
8use tokio_tungstenite::tungstenite::Message;
9use tracing::Instrument;
10
11use crate::error::{BrowserError, Result};
12
13#[derive(Debug, Clone)]
15pub struct CDPRequest {
16 pub id: u32,
18 pub method: String,
20 pub params: Option<Value>,
22 pub session_id: Option<String>,
24}
25
26impl CDPRequest {
27 pub fn new(id: u32, method: String, params: Option<Value>) -> Self {
29 Self {
30 id,
31 method,
32 params,
33 session_id: None,
34 }
35 }
36
37 pub fn with_session(
39 id: u32,
40 method: String,
41 params: Option<Value>,
42 session_id: String,
43 ) -> Self {
44 Self {
45 id,
46 method,
47 params,
48 session_id: Some(session_id),
49 }
50 }
51
52 pub fn to_json(&self) -> Value {
54 let mut obj = json!({
55 "id": self.id,
56 "method": self.method,
57 });
58
59 if let Some(session_id) = &self.session_id {
60 obj["sessionId"] = json!(session_id);
61 }
62
63 if let Some(params) = &self.params {
64 obj["params"] = params.clone();
65 }
66
67 obj
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct CDPMessage {
74 pub id: Option<u32>,
76 pub method: Option<String>,
78 pub params: Option<Value>,
80 pub result: Option<Value>,
82 pub error: Option<Value>,
84 pub session_id: Option<String>,
87}
88
89impl CDPMessage {
90 pub fn from_json(value: Value) -> Result<Self> {
92 Ok(CDPMessage {
93 id: value.get("id").and_then(|v| v.as_u64()).map(|v| v as u32),
94 method: value
95 .get("method")
96 .and_then(|v| v.as_str())
97 .map(|s| s.to_string()),
98 params: value.get("params").cloned(),
99 result: value.get("result").cloned(),
100 error: value.get("error").cloned(),
101 session_id: value
103 .get("sessionId")
104 .and_then(|v| v.as_str())
105 .map(|s| s.to_string()),
106 })
107 }
108}
109
110pub type WebSocketSink = futures_util::stream::SplitSink<
112 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
113 Message,
114>;
115
116pub struct CDPClient {
126 ws_url: String,
127 message_id_counter: Arc<AtomicU32>,
128 pending_responses: Arc<StdMutex<HashMap<u32, oneshot::Sender<Value>>>>,
132 event_broadcast: broadcast::Sender<CDPMessage>,
135 ws_tx: Arc<StdMutex<Option<mpsc::UnboundedSender<Message>>>>,
138}
139
140impl CDPClient {
141 pub fn new(ws_url: String) -> Self {
143 let (event_broadcast, _) = broadcast::channel(1024);
144 Self {
145 ws_url,
146 message_id_counter: Arc::new(AtomicU32::new(1)),
147 pending_responses: Arc::new(StdMutex::new(HashMap::new())),
148 event_broadcast,
149 ws_tx: Arc::new(StdMutex::new(None)),
150 }
151 }
152
153 pub fn set_writer(&self, tx: mpsc::UnboundedSender<Message>) {
156 *self.ws_tx.lock().expect("ws_tx mutex poisoned") = Some(tx);
157 }
158
159 pub fn next_id(&self) -> u32 {
161 self.message_id_counter.fetch_add(1, Ordering::SeqCst)
162 }
163
164 pub async fn connect(
166 &self,
167 ) -> Result<
168 tokio_tungstenite::WebSocketStream<
169 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
170 >,
171 > {
172 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.ws_url)
173 .await
174 .map_err(|e| BrowserError::connection_failed(&self.ws_url, e.to_string()))?;
175
176 Ok(ws_stream)
177 }
178
179 pub fn send_raw(&self, msg: String) -> Result<()> {
183 let tx_guard = self.ws_tx.lock().expect("ws_tx mutex poisoned");
184 let tx = tx_guard.as_ref().ok_or_else(|| {
185 BrowserError::websocket("send_raw", "WebSocket writer not initialised")
186 })?;
187 tx.send(Message::Text(msg))
188 .map_err(|_| BrowserError::websocket("send_raw", "WebSocket writer task ended"))
189 }
190
191 pub fn subscribe_events(&self) -> broadcast::Receiver<CDPMessage> {
200 self.event_broadcast.subscribe()
201 }
202
203 #[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id))]
208 pub async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
209 let id = self.next_id();
210 tracing::Span::current().record("id", id);
211 let request = CDPRequest::new(id, method.clone(), params);
212
213 let (tx, rx) = oneshot::channel();
215 self.register_response_handler(id, tx);
216 let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
217 let bytes = json_str.len();
218 tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
219 const TIMEOUT_SECS: u64 = 30;
222 let wait = async {
223 match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
224 Ok(Ok(value)) => Ok(value),
225 Ok(Err(_)) => Err(BrowserError::command_failed(
226 &method,
227 "response channel closed unexpectedly",
228 )),
229 Err(_) => {
230 self.pending_responses
231 .lock()
232 .expect("pending_responses mutex poisoned")
233 .remove(&id);
234 Err(BrowserError::timeout(
235 format!("waiting for response to '{method}'"),
236 TIMEOUT_SECS,
237 ))
238 }
239 }
240 };
241 wait.instrument(tracing::info_span!("await_response")).await
242 }
243
244 #[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id, session_id = %session_id))]
248 pub async fn send_command_with_session(
249 &self,
250 session_id: &str,
251 method: String,
252 params: Option<Value>,
253 ) -> Result<Value> {
254 let id = self.next_id();
255 tracing::Span::current().record("id", id);
256 let request = CDPRequest::with_session(id, method.clone(), params, session_id.to_string());
257
258 let (tx, rx) = oneshot::channel();
260 self.register_response_handler(id, tx);
261 let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
262 let bytes = json_str.len();
263 tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
264 const TIMEOUT_SECS: u64 = 30;
267 let wait = async {
268 match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
269 Ok(Ok(value)) => Ok(value),
270 Ok(Err(_)) => Err(BrowserError::command_failed(
271 &method,
272 "response channel closed unexpectedly",
273 )),
274 Err(_) => {
275 self.pending_responses
276 .lock()
277 .expect("pending_responses mutex poisoned")
278 .remove(&id);
279 Err(BrowserError::timeout(
280 format!("waiting for response to '{method}'"),
281 TIMEOUT_SECS,
282 ))
283 }
284 }
285 };
286 wait.instrument(tracing::info_span!("await_response")).await
287 }
288
289 pub fn register_response_handler(&self, id: u32, tx: oneshot::Sender<Value>) {
292 self.pending_responses
293 .lock()
294 .expect("pending_responses mutex poisoned")
295 .insert(id, tx);
296 }
297
298 pub fn fail_all_pending(&self, reason: &str) {
304 let mut pending = self
305 .pending_responses
306 .lock()
307 .expect("pending_responses mutex poisoned");
308 let count = pending.len();
309 pending.clear(); drop(pending);
311 if count > 0 {
312 tracing::warn!(
313 pending_count = count,
314 reason = reason,
315 "WebSocket terminated; failing in-flight CDP requests"
316 );
317 }
318 }
319
320 #[tracing::instrument(level = "debug", skip_all, fields(method = ?msg.method, id = ?msg.id))]
324 pub fn handle_message(&self, msg: CDPMessage) -> Result<()> {
325 if let Some(id) = msg.id {
326 let tx = self
328 .pending_responses
329 .lock()
330 .expect("pending_responses mutex poisoned")
331 .remove(&id);
332 if let Some(tx) = tx {
333 if let Some(error) = msg.error {
334 let _ = tx.send(json!({ "error": error }));
335 } else if let Some(result) = msg.result {
336 let _ = tx.send(result);
337 } else {
338 let _ = tx.send(json!({}));
339 }
340 }
341 } else if msg.method.is_some() {
342 let _ = self.event_broadcast.send(msg);
345 }
346 Ok(())
347 }
348}
349
350pub fn spawn_writer_task(
356 mut sink: WebSocketSink,
357 mut rx: mpsc::UnboundedReceiver<Message>,
358 cdp: Arc<CDPClient>,
359) -> tokio::task::JoinHandle<()> {
360 tokio::spawn(async move {
361 while let Some(msg) = rx.recv().await {
362 if let Err(e) = sink.send(msg).await {
363 tracing::error!(error = %e, "WebSocket write error; terminating writer");
364 cdp.fail_all_pending(&format!("write error: {e}"));
365 return;
366 }
367 }
368 tracing::debug!("WebSocket writer task exiting (channel closed)");
369 })
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_cdp_request_creation() {
378 let req = CDPRequest::new(
379 1,
380 "Page.navigate".to_string(),
381 Some(json!({"url": "https://example.com"})),
382 );
383 assert_eq!(req.id, 1);
384 assert_eq!(req.method, "Page.navigate");
385 assert_eq!(req.params.as_ref().unwrap()["url"], "https://example.com");
386 }
387
388 #[test]
389 fn test_cdp_request_to_json() {
390 let req = CDPRequest::new(
391 1,
392 "Page.navigate".to_string(),
393 Some(json!({"url": "https://example.com"})),
394 );
395 let json = req.to_json();
396 assert_eq!(json["id"], 1);
397 assert_eq!(json["method"], "Page.navigate");
398 assert_eq!(json["params"]["url"], "https://example.com");
399 }
400
401 #[test]
402 fn test_cdp_message_from_json() {
403 let json_val = json!({
404 "id": 1,
405 "result": {"url": "https://example.com"},
406 "sessionId": "SES001"
407 });
408 let msg = CDPMessage::from_json(json_val).unwrap();
409 assert_eq!(msg.id, Some(1));
410 assert_eq!(msg.result.as_ref().unwrap()["url"], "https://example.com");
411 assert_eq!(msg.session_id.as_deref(), Some("SES001"));
412 }
413
414 #[test]
415 fn test_cdp_message_session_id_parsed() {
416 let event = json!({
417 "method": "Page.loadEventFired",
418 "params": {},
419 "sessionId": "ABC123"
420 });
421 let msg = CDPMessage::from_json(event).unwrap();
422 assert_eq!(msg.method.as_deref(), Some("Page.loadEventFired"));
423 assert_eq!(msg.session_id.as_deref(), Some("ABC123"));
424 }
425
426 #[test]
427 fn test_cdp_request_with_session() {
428 let req = CDPRequest::with_session(
429 2,
430 "Runtime.evaluate".to_string(),
431 Some(json!({"expression": "1+1"})),
432 "SES001".to_string(),
433 );
434 let json = req.to_json();
435 assert_eq!(json["sessionId"], "SES001");
436 assert_eq!(json["method"], "Runtime.evaluate");
437 }
438}