1use serde_json::{json, Value};
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU32, Ordering};
4use std::sync::Arc;
5use tokio::sync::{broadcast, oneshot, RwLock};
6use tokio::time::{timeout, Duration};
7use futures_util::SinkExt;
8use tokio_tungstenite::tungstenite::Message;
9
10use crate::error::{BrowserError, Result};
11
12#[derive(Debug, Clone)]
14pub struct CDPRequest {
15 pub id: u32,
17 pub method: String,
19 pub params: Option<Value>,
21 pub session_id: Option<String>,
23}
24
25impl CDPRequest {
26 pub fn new(id: u32, method: String, params: Option<Value>) -> Self {
28 Self { id, method, params, session_id: None }
29 }
30
31 pub fn with_session(id: u32, method: String, params: Option<Value>, session_id: String) -> Self {
33 Self { id, method, params, session_id: Some(session_id) }
34 }
35
36 pub fn to_json(&self) -> Value {
38 let mut obj = json!({
39 "id": self.id,
40 "method": self.method,
41 });
42
43 if let Some(session_id) = &self.session_id {
44 obj["sessionId"] = json!(session_id);
45 }
46
47 if let Some(params) = &self.params {
48 obj["params"] = params.clone();
49 }
50
51 obj
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct CDPMessage {
58 pub id: Option<u32>,
60 pub method: Option<String>,
62 pub params: Option<Value>,
64 pub result: Option<Value>,
66 pub error: Option<Value>,
68 pub session_id: Option<String>,
71}
72
73impl CDPMessage {
74 pub fn from_json(value: Value) -> Result<Self> {
76 Ok(CDPMessage {
77 id: value.get("id").and_then(|v| v.as_u64()).map(|v| v as u32),
78 method: value.get("method").and_then(|v| v.as_str()).map(|s| s.to_string()),
79 params: value.get("params").cloned(),
80 result: value.get("result").cloned(),
81 error: value.get("error").cloned(),
82 session_id: value
84 .get("sessionId")
85 .and_then(|v| v.as_str())
86 .map(|s| s.to_string()),
87 })
88 }
89}
90
91pub type WebSocketSink = futures_util::stream::SplitSink<
93 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
94 Message,
95>;
96
97pub struct CDPClient {
99 ws_url: String,
100 message_id_counter: Arc<AtomicU32>,
101 pending_responses: Arc<RwLock<HashMap<u32, oneshot::Sender<Value>>>>,
102 event_broadcast: broadcast::Sender<CDPMessage>,
105 ws_sink: Arc<RwLock<Option<WebSocketSink>>>,
106}
107
108impl CDPClient {
109 pub fn new(ws_url: String) -> Self {
111 let (event_broadcast, _) = broadcast::channel(1024);
112 Self {
113 ws_url,
114 message_id_counter: Arc::new(AtomicU32::new(1)),
115 pending_responses: Arc::new(RwLock::new(HashMap::new())),
116 event_broadcast,
117 ws_sink: Arc::new(RwLock::new(None)),
118 }
119 }
120
121 pub async fn set_sink(&self, sink: WebSocketSink) {
123 let mut ws = self.ws_sink.write().await;
124 *ws = Some(sink);
125 }
126
127 pub fn next_id(&self) -> u32 {
129 self.message_id_counter.fetch_add(1, Ordering::SeqCst)
130 }
131
132 pub async fn connect(
134 &self,
135 ) -> Result<
136 tokio_tungstenite::WebSocketStream<
137 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
138 >,
139 > {
140 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.ws_url)
141 .await
142 .map_err(|e| BrowserError::connection_failed(&self.ws_url, e.to_string()))?;
143
144 Ok(ws_stream)
145 }
146
147 pub async fn send_raw(&self, msg: String) -> Result<()> {
149 let mut ws = self.ws_sink.write().await;
150 if let Some(sink) = ws.as_mut() {
151 sink.send(Message::Text(msg))
152 .await
153 .map_err(|e| BrowserError::websocket("send_raw", e.to_string()))?;
154 } else {
155 return Err(BrowserError::websocket("send_raw", "WebSocket not connected"));
156 }
157 Ok(())
158 }
159
160 pub fn subscribe_events(&self) -> broadcast::Receiver<CDPMessage> {
169 self.event_broadcast.subscribe()
170 }
171
172 pub async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
177 let id = self.next_id();
178 let request = CDPRequest::new(id, method.clone(), params);
179
180 let (tx, rx) = oneshot::channel();
182 self.register_response_handler(id, tx).await;
183 let json_str = request.to_json().to_string();
184 self.send_raw(json_str).await?;
185 const TIMEOUT_SECS: u64 = 30;
188 match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
189 Ok(Ok(value)) => Ok(value),
190 Ok(Err(_)) => Err(BrowserError::command_failed(
191 &method,
192 "response channel closed unexpectedly",
193 )),
194 Err(_) => {
195 let mut pending = self.pending_responses.write().await;
196 pending.remove(&id);
197 Err(BrowserError::timeout(
198 format!("waiting for response to '{method}'"),
199 TIMEOUT_SECS,
200 ))
201 }
202 }
203 }
204
205 pub async fn send_command_with_session(
209 &self,
210 session_id: &str,
211 method: String,
212 params: Option<Value>,
213 ) -> Result<Value> {
214 let id = self.next_id();
215 let request =
216 CDPRequest::with_session(id, method.clone(), params, session_id.to_string());
217
218 let (tx, rx) = oneshot::channel();
220 self.register_response_handler(id, tx).await;
221 let json_str = request.to_json().to_string();
222 self.send_raw(json_str).await?;
223 const TIMEOUT_SECS: u64 = 30;
226 match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
227 Ok(Ok(value)) => Ok(value),
228 Ok(Err(_)) => Err(BrowserError::command_failed(
229 &method,
230 "response channel closed unexpectedly",
231 )),
232 Err(_) => {
233 let mut pending = self.pending_responses.write().await;
234 pending.remove(&id);
235 Err(BrowserError::timeout(
236 format!("waiting for response to '{method}'"),
237 TIMEOUT_SECS,
238 ))
239 }
240 }
241 }
242
243 pub async fn register_response_handler(&self, id: u32, tx: oneshot::Sender<Value>) {
245 let mut pending = self.pending_responses.write().await;
246 pending.insert(id, tx);
247 }
248
249 pub async fn handle_message(&self, msg: CDPMessage) -> Result<()> {
251 if let Some(id) = msg.id {
252 let mut pending = self.pending_responses.write().await;
254 if let Some(tx) = pending.remove(&id) {
255 if let Some(error) = msg.error {
256 let _ = tx.send(json!({ "error": error }));
257 } else if let Some(result) = msg.result {
258 let _ = tx.send(result);
259 } else {
260 let _ = tx.send(json!({}));
261 }
262 }
263 } else if msg.method.is_some() {
264 let _ = self.event_broadcast.send(msg);
267 }
268
269 Ok(())
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_cdp_request_creation() {
279 let req = CDPRequest::new(
280 1,
281 "Page.navigate".to_string(),
282 Some(json!({"url": "https://example.com"})),
283 );
284 assert_eq!(req.id, 1);
285 assert_eq!(req.method, "Page.navigate");
286 assert_eq!(req.params.as_ref().unwrap()["url"], "https://example.com");
287 }
288
289 #[test]
290 fn test_cdp_request_to_json() {
291 let req = CDPRequest::new(
292 1,
293 "Page.navigate".to_string(),
294 Some(json!({"url": "https://example.com"})),
295 );
296 let json = req.to_json();
297 assert_eq!(json["id"], 1);
298 assert_eq!(json["method"], "Page.navigate");
299 assert_eq!(json["params"]["url"], "https://example.com");
300 }
301
302 #[test]
303 fn test_cdp_message_from_json() {
304 let json_val = json!({
305 "id": 1,
306 "result": {"url": "https://example.com"},
307 "sessionId": "SES001"
308 });
309 let msg = CDPMessage::from_json(json_val).unwrap();
310 assert_eq!(msg.id, Some(1));
311 assert_eq!(msg.result.as_ref().unwrap()["url"], "https://example.com");
312 assert_eq!(msg.session_id.as_deref(), Some("SES001"));
313 }
314
315 #[test]
316 fn test_cdp_message_session_id_parsed() {
317 let event = json!({
318 "method": "Page.loadEventFired",
319 "params": {},
320 "sessionId": "ABC123"
321 });
322 let msg = CDPMessage::from_json(event).unwrap();
323 assert_eq!(msg.method.as_deref(), Some("Page.loadEventFired"));
324 assert_eq!(msg.session_id.as_deref(), Some("ABC123"));
325 }
326
327 #[test]
328 fn test_cdp_request_with_session() {
329 let req = CDPRequest::with_session(
330 2,
331 "Runtime.evaluate".to_string(),
332 Some(json!({"expression": "1+1"})),
333 "SES001".to_string(),
334 );
335 let json = req.to_json();
336 assert_eq!(json["sessionId"], "SES001");
337 assert_eq!(json["method"], "Runtime.evaluate");
338 }
339}