use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde_json::{Value, json};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, trace};
use roboticus_core::{Result, RoboticusError};
type WsStream =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub struct CdpSession {
ws: Mutex<WsStream>,
command_id: AtomicU64,
timeout_ms: AtomicU64,
}
impl CdpSession {
pub async fn connect(ws_url: &str) -> Result<Self> {
debug!(url = ws_url, "connecting to CDP WebSocket");
let (ws, _response) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| RoboticusError::Network(format!("CDP WebSocket connect failed: {e}")))?;
debug!("CDP WebSocket connected");
Ok(Self {
ws: Mutex::new(ws),
command_id: AtomicU64::new(1),
timeout_ms: AtomicU64::new(30_000),
})
}
pub fn set_timeout(&self, timeout: Duration) {
self.timeout_ms
.store(timeout.as_millis() as u64, Ordering::SeqCst);
}
fn timeout(&self) -> Duration {
Duration::from_millis(self.timeout_ms.load(Ordering::SeqCst))
}
fn next_id(&self) -> u64 {
self.command_id.fetch_add(1, Ordering::SeqCst)
}
pub async fn send_command(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id();
let cmd = json!({
"id": id,
"method": method,
"params": params,
});
let text = serde_json::to_string(&cmd)
.map_err(|e| RoboticusError::Network(format!("serialize CDP command: {e}")))?;
trace!(id, method, "sending CDP command");
let mut ws = self.ws.lock().await;
ws.send(Message::Text(text))
.await
.map_err(|e| RoboticusError::Network(format!("CDP send failed: {e}")))?;
let timeout = self.timeout();
let deadline = tokio::time::Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(RoboticusError::Network(format!(
"CDP command {method} (id={id}) timed out after {timeout:?}",
)));
}
let frame = tokio::time::timeout(remaining, ws.next()).await;
let msg = match frame {
Ok(Some(Ok(m))) => m,
Ok(Some(Err(e))) => {
return Err(RoboticusError::Network(format!("CDP read error: {e}")));
}
Ok(None) => {
return Err(RoboticusError::Network(
"CDP WebSocket closed unexpectedly".into(),
));
}
Err(_) => {
return Err(RoboticusError::Network(format!(
"CDP command {method} (id={id}) timed out after {timeout:?}",
)));
}
};
match msg {
Message::Text(ref t) => {
let val: Value = serde_json::from_str(t).map_err(|e| {
RoboticusError::Network(format!("CDP response parse error: {e}"))
})?;
if val.get("id").and_then(|v| v.as_u64()) == Some(id) {
if let Some(error) = val.get("error") {
let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("unknown CDP error");
return Err(RoboticusError::Tool {
tool: "browser".into(),
message: format!("CDP error {code}: {message}"),
});
}
trace!(id, method, "CDP command response received");
return Ok(val.get("result").cloned().unwrap_or(json!({})));
}
if let Some(event_method) = val.get("method").and_then(|m| m.as_str()) {
trace!(event = event_method, "CDP event (skipped while waiting)");
}
}
Message::Ping(_) | Message::Pong(_) => {}
Message::Close(_) => {
return Err(RoboticusError::Network(
"CDP WebSocket closed by remote".into(),
));
}
_ => {}
}
}
}
pub async fn close(self) -> Result<()> {
let mut ws = self.ws.into_inner();
ws.close(None)
.await
.map_err(|e| RoboticusError::Network(format!("CDP WebSocket close failed: {e}")))?;
debug!("CDP WebSocket closed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn command_id_counter_increments() {
let counter = AtomicU64::new(1);
let id1 = counter.fetch_add(1, Ordering::SeqCst);
let id2 = counter.fetch_add(1, Ordering::SeqCst);
let id3 = counter.fetch_add(1, Ordering::SeqCst);
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[tokio::test]
async fn connect_to_nonexistent_fails() {
let result = CdpSession::connect("ws://127.0.0.1:19999/devtools/nonexistent").await;
assert!(result.is_err());
let err = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(
err.contains("connect") || err.contains("Connection refused") || err.contains("failed"),
"error should mention connection failure: {err}"
);
}
#[test]
fn cdp_command_json_shape() {
let id: u64 = 42;
let cmd = json!({
"id": id,
"method": "Page.navigate",
"params": {"url": "https://example.com"},
});
assert_eq!(cmd["id"], 42);
assert_eq!(cmd["method"], "Page.navigate");
assert_eq!(cmd["params"]["url"], "https://example.com");
}
#[test]
fn response_matching_logic() {
let response = json!({"id": 5, "result": {"frameId": "abc123"}});
let target_id: u64 = 5;
assert_eq!(response.get("id").and_then(|v| v.as_u64()), Some(target_id));
let result = response.get("result").cloned().unwrap_or(json!({}));
assert_eq!(result["frameId"], "abc123");
}
#[test]
fn error_response_detection() {
let error_response = json!({
"id": 3,
"error": {
"code": -32000,
"message": "Cannot navigate to invalid URL"
}
});
let error = error_response.get("error");
assert!(error.is_some());
let code = error.unwrap().get("code").and_then(|c| c.as_i64()).unwrap();
assert_eq!(code, -32000);
}
#[test]
fn event_detection() {
let event = json!({"method": "Page.loadEventFired", "params": {"timestamp": 12345.6}});
let method = event.get("method").and_then(|m| m.as_str());
assert_eq!(method, Some("Page.loadEventFired"));
assert!(event.get("id").is_none());
}
use tokio::net::TcpListener;
async fn mock_ws_server<F>(handler: F) -> (String, tokio::task::JoinHandle<()>)
where
F: Fn(String) -> Option<String> + Send + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
let handle = tokio::spawn(async move {
if let Ok((stream, _addr)) = listener.accept().await {
let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let (mut sink, mut source) = ws.split();
while let Some(Ok(msg)) = source.next().await {
if let Message::Text(ref t) = msg
&& let Some(reply) = handler(t.clone())
{
let _ = sink.send(Message::Text(reply)).await;
}
}
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
(url, handle)
}
#[tokio::test]
async fn send_command_success() {
let (url, _server) = mock_ws_server(|text| {
let req: Value = serde_json::from_str(&text).ok()?;
let id = req.get("id")?.as_u64()?;
Some(serde_json::to_string(&json!({"id": id, "result": {"frameId": "F1"}})).unwrap())
})
.await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session
.send_command("Page.navigate", json!({"url": "https://example.com"}))
.await
.unwrap();
assert_eq!(result["frameId"], "F1");
}
#[tokio::test]
async fn send_command_cdp_error() {
let (url, _server) = mock_ws_server(|text| {
let req: Value = serde_json::from_str(&text).ok()?;
let id = req.get("id")?.as_u64()?;
Some(
serde_json::to_string(&json!({
"id": id,
"error": {"code": -32000, "message": "Cannot navigate"}
}))
.unwrap(),
)
})
.await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session
.send_command("Page.navigate", json!({"url": "invalid"}))
.await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("Cannot navigate"),
"expected CDP error message: {err_str}"
);
}
#[tokio::test]
async fn send_command_timeout() {
let (url, _server) = mock_ws_server(|_text| None).await;
let session = CdpSession::connect(&url).await.unwrap();
session.set_timeout(Duration::from_millis(200));
let result = session
.send_command("Page.navigate", json!({"url": "https://example.com"}))
.await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("timed out"),
"expected timeout error: {err_str}"
);
}
#[tokio::test]
async fn send_command_skips_events_before_response() {
let (url, _server) = mock_ws_server(|text| {
let req: Value = serde_json::from_str(&text).ok()?;
let id = req.get("id")?.as_u64()?;
Some(serde_json::to_string(&json!({"id": id, "result": {"ok": true}})).unwrap())
})
.await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session
.send_command("Runtime.evaluate", json!({"expression": "1+1"}))
.await
.unwrap();
assert_eq!(result["ok"], true);
}
#[tokio::test]
async fn send_command_events_before_matching_response() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
let _server = tokio::spawn(async move {
if let Ok((stream, _addr)) = listener.accept().await {
let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let (mut sink, mut source) = ws.split();
while let Some(Ok(msg)) = source.next().await {
if let Message::Text(ref t) = msg
&& let Ok(req) = serde_json::from_str::<Value>(t)
&& let Some(id) = req.get("id").and_then(|v| v.as_u64())
{
let event = serde_json::to_string(
&json!({"method": "Page.loadEventFired", "params": {}}),
)
.unwrap();
let _ = sink.send(Message::Text(event)).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let resp =
serde_json::to_string(&json!({"id": id, "result": {"value": 42}}))
.unwrap();
let _ = sink.send(Message::Text(resp)).await;
}
}
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session
.send_command("Runtime.evaluate", json!({"expression": "21*2"}))
.await
.unwrap();
assert_eq!(result["value"], 42);
}
#[tokio::test]
async fn send_command_ws_closed_unexpectedly() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
let _server = tokio::spawn(async move {
if let Ok((stream, _addr)) = listener.accept().await {
let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let (mut sink, _source) = ws.split();
let _ = sink.close().await;
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
let session = CdpSession::connect(&url).await.unwrap();
session.set_timeout(Duration::from_millis(2000));
let result = session.send_command("Page.enable", json!({})).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("closed") || err_str.contains("timed out"),
"expected close/timeout error: {err_str}"
);
}
#[tokio::test]
async fn set_timeout_affects_deadline() {
let (url, _server) = mock_ws_server(|_text| None).await;
let session = CdpSession::connect(&url).await.unwrap();
session.set_timeout(Duration::from_millis(100));
let start = tokio::time::Instant::now();
let result = session.send_command("Test", json!({})).await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(
elapsed < Duration::from_millis(500),
"timeout took too long: {:?}",
elapsed
);
}
#[tokio::test]
async fn close_session() {
let (url, _server) = mock_ws_server(|_text| None).await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session.close().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn send_command_result_without_result_field() {
let (url, _server) = mock_ws_server(|text| {
let req: Value = serde_json::from_str(&text).ok()?;
let id = req.get("id")?.as_u64()?;
Some(serde_json::to_string(&json!({"id": id})).unwrap())
})
.await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session
.send_command("Page.enable", json!({}))
.await
.unwrap();
assert_eq!(result, json!({}));
}
#[tokio::test]
async fn send_command_error_missing_message() {
let (url, _server) = mock_ws_server(|text| {
let req: Value = serde_json::from_str(&text).ok()?;
let id = req.get("id")?.as_u64()?;
Some(serde_json::to_string(&json!({"id": id, "error": {"code": -1}})).unwrap())
})
.await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session.send_command("Bad.command", json!({})).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(
err_str.contains("unknown CDP error") || err_str.contains("CDP error -1"),
"unexpected error: {err_str}"
);
}
#[tokio::test]
async fn send_command_mismatched_ids_eventually_matches() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let url = format!("ws://127.0.0.1:{port}");
let _server = tokio::spawn(async move {
if let Ok((stream, _addr)) = listener.accept().await {
let ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let (mut sink, mut source) = ws.split();
while let Some(Ok(msg)) = source.next().await {
if let Message::Text(ref t) = msg
&& let Ok(req) = serde_json::from_str::<Value>(t)
&& let Some(id) = req.get("id").and_then(|v| v.as_u64())
{
let wrong = serde_json::to_string(
&json!({"id": id + 999, "result": {"wrong": true}}),
)
.unwrap();
let _ = sink.send(Message::Text(wrong)).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let correct =
serde_json::to_string(&json!({"id": id, "result": {"correct": true}}))
.unwrap();
let _ = sink.send(Message::Text(correct)).await;
}
}
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
let session = CdpSession::connect(&url).await.unwrap();
let result = session.send_command("Test", json!({})).await.unwrap();
assert_eq!(result["correct"], true);
}
}