1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::Duration;
3
4use futures_util::{SinkExt, StreamExt};
5use serde_json::{Value, json};
6use tokio::sync::Mutex;
7use tokio_tungstenite::tungstenite::Message;
8use tracing::{debug, trace};
9
10use roboticus_core::{Result, RoboticusError};
11
12type WsStream =
13 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
14
15pub struct CdpSession {
30 ws: Mutex<WsStream>,
31 command_id: AtomicU64,
32 timeout_ms: AtomicU64,
33}
34
35impl CdpSession {
36 pub async fn connect(ws_url: &str) -> Result<Self> {
38 debug!(url = ws_url, "connecting to CDP WebSocket");
39 let (ws, _response) = tokio_tungstenite::connect_async(ws_url)
40 .await
41 .map_err(|e| RoboticusError::Network(format!("CDP WebSocket connect failed: {e}")))?;
42
43 debug!("CDP WebSocket connected");
44 Ok(Self {
45 ws: Mutex::new(ws),
46 command_id: AtomicU64::new(1),
47 timeout_ms: AtomicU64::new(30_000),
48 })
49 }
50
51 pub fn set_timeout(&self, timeout: Duration) {
53 self.timeout_ms
54 .store(timeout.as_millis() as u64, Ordering::SeqCst);
55 }
56
57 fn timeout(&self) -> Duration {
58 Duration::from_millis(self.timeout_ms.load(Ordering::SeqCst))
59 }
60
61 fn next_id(&self) -> u64 {
62 self.command_id.fetch_add(1, Ordering::SeqCst)
63 }
64
65 pub async fn send_command(&self, method: &str, params: Value) -> Result<Value> {
71 let id = self.next_id();
72 let cmd = json!({
73 "id": id,
74 "method": method,
75 "params": params,
76 });
77
78 let text = serde_json::to_string(&cmd)
79 .map_err(|e| RoboticusError::Network(format!("serialize CDP command: {e}")))?;
80
81 trace!(id, method, "sending CDP command");
82
83 let mut ws = self.ws.lock().await;
84 ws.send(Message::Text(text))
85 .await
86 .map_err(|e| RoboticusError::Network(format!("CDP send failed: {e}")))?;
87
88 let timeout = self.timeout();
92 let deadline = tokio::time::Instant::now() + timeout;
93
94 loop {
95 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
96 if remaining.is_zero() {
97 return Err(RoboticusError::Network(format!(
98 "CDP command {method} (id={id}) timed out after {timeout:?}",
99 )));
100 }
101
102 let frame = tokio::time::timeout(remaining, ws.next()).await;
103
104 let msg = match frame {
105 Ok(Some(Ok(m))) => m,
106 Ok(Some(Err(e))) => {
107 return Err(RoboticusError::Network(format!("CDP read error: {e}")));
108 }
109 Ok(None) => {
110 return Err(RoboticusError::Network(
111 "CDP WebSocket closed unexpectedly".into(),
112 ));
113 }
114 Err(_) => {
115 return Err(RoboticusError::Network(format!(
116 "CDP command {method} (id={id}) timed out after {timeout:?}",
117 )));
118 }
119 };
120
121 match msg {
122 Message::Text(ref t) => {
123 let val: Value = serde_json::from_str(t).map_err(|e| {
124 RoboticusError::Network(format!("CDP response parse error: {e}"))
125 })?;
126
127 if val.get("id").and_then(|v| v.as_u64()) == Some(id) {
128 if let Some(error) = val.get("error") {
129 let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
130 let message = error
131 .get("message")
132 .and_then(|m| m.as_str())
133 .unwrap_or("unknown CDP error");
134 return Err(RoboticusError::Tool {
135 tool: "browser".into(),
136 message: format!("CDP error {code}: {message}"),
137 });
138 }
139 trace!(id, method, "CDP command response received");
140 return Ok(val.get("result").cloned().unwrap_or(json!({})));
141 }
142
143 if let Some(event_method) = val.get("method").and_then(|m| m.as_str()) {
144 trace!(event = event_method, "CDP event (skipped while waiting)");
145 }
146 }
147 Message::Ping(_) | Message::Pong(_) => {}
148 Message::Close(_) => {
149 return Err(RoboticusError::Network(
150 "CDP WebSocket closed by remote".into(),
151 ));
152 }
153 _ => {}
154 }
155 }
156 }
157
158 pub async fn close(self) -> Result<()> {
160 let mut ws = self.ws.into_inner();
161 ws.close(None)
162 .await
163 .map_err(|e| RoboticusError::Network(format!("CDP WebSocket close failed: {e}")))?;
164 debug!("CDP WebSocket closed");
165 Ok(())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn command_id_counter_increments() {
175 let counter = AtomicU64::new(1);
176 let id1 = counter.fetch_add(1, Ordering::SeqCst);
177 let id2 = counter.fetch_add(1, Ordering::SeqCst);
178 let id3 = counter.fetch_add(1, Ordering::SeqCst);
179 assert_eq!(id1, 1);
180 assert_eq!(id2, 2);
181 assert_eq!(id3, 3);
182 }
183
184 #[tokio::test]
185 async fn connect_to_nonexistent_fails() {
186 let result = CdpSession::connect("ws://127.0.0.1:19999/devtools/nonexistent").await;
187 assert!(result.is_err());
188 let err = match result {
189 Err(e) => e.to_string(),
190 Ok(_) => panic!("expected error"),
191 };
192 assert!(
193 err.contains("connect") || err.contains("Connection refused") || err.contains("failed"),
194 "error should mention connection failure: {err}"
195 );
196 }
197
198 #[test]
199 fn cdp_command_json_shape() {
200 let id: u64 = 42;
201 let cmd = json!({
202 "id": id,
203 "method": "Page.navigate",
204 "params": {"url": "https://example.com"},
205 });
206 assert_eq!(cmd["id"], 42);
207 assert_eq!(cmd["method"], "Page.navigate");
208 assert_eq!(cmd["params"]["url"], "https://example.com");
209 }
210
211 #[test]
212 fn response_matching_logic() {
213 let response = json!({"id": 5, "result": {"frameId": "abc123"}});
214 let target_id: u64 = 5;
215
216 assert_eq!(response.get("id").and_then(|v| v.as_u64()), Some(target_id));
217
218 let result = response.get("result").cloned().unwrap_or(json!({}));
219 assert_eq!(result["frameId"], "abc123");
220 }
221
222 #[test]
223 fn error_response_detection() {
224 let error_response = json!({
225 "id": 3,
226 "error": {
227 "code": -32000,
228 "message": "Cannot navigate to invalid URL"
229 }
230 });
231
232 let error = error_response.get("error");
233 assert!(error.is_some());
234 let code = error.unwrap().get("code").and_then(|c| c.as_i64()).unwrap();
235 assert_eq!(code, -32000);
236 }
237
238 #[test]
239 fn event_detection() {
240 let event = json!({"method": "Page.loadEventFired", "params": {"timestamp": 12345.6}});
241 let method = event.get("method").and_then(|m| m.as_str());
242 assert_eq!(method, Some("Page.loadEventFired"));
243 assert!(event.get("id").is_none());
244 }
245
246 use tokio::net::TcpListener;
252
253 async fn mock_ws_server<F>(handler: F) -> (String, tokio::task::JoinHandle<()>)
254 where
255 F: Fn(String) -> Option<String> + Send + 'static,
256 {
257 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
258 let port = listener.local_addr().unwrap().port();
259 let url = format!("ws://127.0.0.1:{port}");
260
261 let handle = tokio::spawn(async move {
262 if let Ok((stream, _addr)) = listener.accept().await {
263 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
264 let (mut sink, mut source) = ws.split();
265 while let Some(Ok(msg)) = source.next().await {
266 if let Message::Text(ref t) = msg
267 && let Some(reply) = handler(t.clone())
268 {
269 let _ = sink.send(Message::Text(reply)).await;
270 }
271 }
272 }
273 });
274
275 tokio::time::sleep(Duration::from_millis(50)).await;
277 (url, handle)
278 }
279
280 #[tokio::test]
281 async fn send_command_success() {
282 let (url, _server) = mock_ws_server(|text| {
283 let req: Value = serde_json::from_str(&text).ok()?;
284 let id = req.get("id")?.as_u64()?;
285 Some(serde_json::to_string(&json!({"id": id, "result": {"frameId": "F1"}})).unwrap())
286 })
287 .await;
288
289 let session = CdpSession::connect(&url).await.unwrap();
290 let result = session
291 .send_command("Page.navigate", json!({"url": "https://example.com"}))
292 .await
293 .unwrap();
294 assert_eq!(result["frameId"], "F1");
295 }
296
297 #[tokio::test]
298 async fn send_command_cdp_error() {
299 let (url, _server) = mock_ws_server(|text| {
300 let req: Value = serde_json::from_str(&text).ok()?;
301 let id = req.get("id")?.as_u64()?;
302 Some(
303 serde_json::to_string(&json!({
304 "id": id,
305 "error": {"code": -32000, "message": "Cannot navigate"}
306 }))
307 .unwrap(),
308 )
309 })
310 .await;
311
312 let session = CdpSession::connect(&url).await.unwrap();
313 let result = session
314 .send_command("Page.navigate", json!({"url": "invalid"}))
315 .await;
316 assert!(result.is_err());
317 let err_str = result.unwrap_err().to_string();
318 assert!(
319 err_str.contains("Cannot navigate"),
320 "expected CDP error message: {err_str}"
321 );
322 }
323
324 #[tokio::test]
325 async fn send_command_timeout() {
326 let (url, _server) = mock_ws_server(|_text| None).await;
328
329 let session = CdpSession::connect(&url).await.unwrap();
330 session.set_timeout(Duration::from_millis(200));
331
332 let result = session
333 .send_command("Page.navigate", json!({"url": "https://example.com"}))
334 .await;
335 assert!(result.is_err());
336 let err_str = result.unwrap_err().to_string();
337 assert!(
338 err_str.contains("timed out"),
339 "expected timeout error: {err_str}"
340 );
341 }
342
343 #[tokio::test]
344 async fn send_command_skips_events_before_response() {
345 let (url, _server) = mock_ws_server(|text| {
346 let req: Value = serde_json::from_str(&text).ok()?;
347 let id = req.get("id")?.as_u64()?;
348 Some(serde_json::to_string(&json!({"id": id, "result": {"ok": true}})).unwrap())
354 })
355 .await;
356
357 let session = CdpSession::connect(&url).await.unwrap();
358 let result = session
359 .send_command("Runtime.evaluate", json!({"expression": "1+1"}))
360 .await
361 .unwrap();
362 assert_eq!(result["ok"], true);
363 }
364
365 #[tokio::test]
366 async fn send_command_events_before_matching_response() {
367 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
369 let port = listener.local_addr().unwrap().port();
370 let url = format!("ws://127.0.0.1:{port}");
371
372 let _server = tokio::spawn(async move {
373 if let Ok((stream, _addr)) = listener.accept().await {
374 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
375 let (mut sink, mut source) = ws.split();
376 while let Some(Ok(msg)) = source.next().await {
377 if let Message::Text(ref t) = msg
378 && let Ok(req) = serde_json::from_str::<Value>(t)
379 && let Some(id) = req.get("id").and_then(|v| v.as_u64())
380 {
381 let event = serde_json::to_string(
383 &json!({"method": "Page.loadEventFired", "params": {}}),
384 )
385 .unwrap();
386 let _ = sink.send(Message::Text(event)).await;
387
388 tokio::time::sleep(Duration::from_millis(10)).await;
390
391 let resp =
393 serde_json::to_string(&json!({"id": id, "result": {"value": 42}}))
394 .unwrap();
395 let _ = sink.send(Message::Text(resp)).await;
396 }
397 }
398 }
399 });
400
401 tokio::time::sleep(Duration::from_millis(50)).await;
402
403 let session = CdpSession::connect(&url).await.unwrap();
404 let result = session
405 .send_command("Runtime.evaluate", json!({"expression": "21*2"}))
406 .await
407 .unwrap();
408 assert_eq!(result["value"], 42);
409 }
410
411 #[tokio::test]
412 async fn send_command_ws_closed_unexpectedly() {
413 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
415 let port = listener.local_addr().unwrap().port();
416 let url = format!("ws://127.0.0.1:{port}");
417
418 let _server = tokio::spawn(async move {
419 if let Ok((stream, _addr)) = listener.accept().await {
420 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
421 let (mut sink, _source) = ws.split();
422 let _ = sink.close().await;
424 }
425 });
426
427 tokio::time::sleep(Duration::from_millis(50)).await;
428 let session = CdpSession::connect(&url).await.unwrap();
429 session.set_timeout(Duration::from_millis(2000));
430
431 let result = session.send_command("Page.enable", json!({})).await;
432 assert!(result.is_err());
433 let err_str = result.unwrap_err().to_string();
434 assert!(
435 err_str.contains("closed") || err_str.contains("timed out"),
436 "expected close/timeout error: {err_str}"
437 );
438 }
439
440 #[tokio::test]
441 async fn set_timeout_affects_deadline() {
442 let (url, _server) = mock_ws_server(|_text| None).await;
443
444 let session = CdpSession::connect(&url).await.unwrap();
445
446 session.set_timeout(Duration::from_millis(100));
448 let start = tokio::time::Instant::now();
449 let result = session.send_command("Test", json!({})).await;
450 let elapsed = start.elapsed();
451
452 assert!(result.is_err());
453 assert!(
455 elapsed < Duration::from_millis(500),
456 "timeout took too long: {:?}",
457 elapsed
458 );
459 }
460
461 #[tokio::test]
462 async fn close_session() {
463 let (url, _server) = mock_ws_server(|_text| None).await;
464
465 let session = CdpSession::connect(&url).await.unwrap();
466 let result = session.close().await;
467 assert!(result.is_ok());
468 }
469
470 #[tokio::test]
471 async fn send_command_result_without_result_field() {
472 let (url, _server) = mock_ws_server(|text| {
474 let req: Value = serde_json::from_str(&text).ok()?;
475 let id = req.get("id")?.as_u64()?;
476 Some(serde_json::to_string(&json!({"id": id})).unwrap())
477 })
478 .await;
479
480 let session = CdpSession::connect(&url).await.unwrap();
481 let result = session
482 .send_command("Page.enable", json!({}))
483 .await
484 .unwrap();
485 assert_eq!(result, json!({}));
487 }
488
489 #[tokio::test]
490 async fn send_command_error_missing_message() {
491 let (url, _server) = mock_ws_server(|text| {
493 let req: Value = serde_json::from_str(&text).ok()?;
494 let id = req.get("id")?.as_u64()?;
495 Some(serde_json::to_string(&json!({"id": id, "error": {"code": -1}})).unwrap())
496 })
497 .await;
498
499 let session = CdpSession::connect(&url).await.unwrap();
500 let result = session.send_command("Bad.command", json!({})).await;
501 assert!(result.is_err());
502 let err_str = result.unwrap_err().to_string();
503 assert!(
505 err_str.contains("unknown CDP error") || err_str.contains("CDP error -1"),
506 "unexpected error: {err_str}"
507 );
508 }
509
510 #[tokio::test]
511 async fn send_command_mismatched_ids_eventually_matches() {
512 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
514 let port = listener.local_addr().unwrap().port();
515 let url = format!("ws://127.0.0.1:{port}");
516
517 let _server = tokio::spawn(async move {
518 if let Ok((stream, _addr)) = listener.accept().await {
519 let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
520 let (mut sink, mut source) = ws.split();
521 while let Some(Ok(msg)) = source.next().await {
522 if let Message::Text(ref t) = msg
523 && let Ok(req) = serde_json::from_str::<Value>(t)
524 && let Some(id) = req.get("id").and_then(|v| v.as_u64())
525 {
526 let wrong = serde_json::to_string(
528 &json!({"id": id + 999, "result": {"wrong": true}}),
529 )
530 .unwrap();
531 let _ = sink.send(Message::Text(wrong)).await;
532
533 tokio::time::sleep(Duration::from_millis(10)).await;
534
535 let correct =
537 serde_json::to_string(&json!({"id": id, "result": {"correct": true}}))
538 .unwrap();
539 let _ = sink.send(Message::Text(correct)).await;
540 }
541 }
542 }
543 });
544
545 tokio::time::sleep(Duration::from_millis(50)).await;
546
547 let session = CdpSession::connect(&url).await.unwrap();
548 let result = session.send_command("Test", json!({})).await.unwrap();
549 assert_eq!(result["correct"], true);
550 }
551}