1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use serde_json::Value;
10use tokio::net::TcpStream;
11use tokio::sync::{oneshot, Mutex as AsyncMutex};
12use tokio::task::JoinHandle;
13use tokio_tungstenite::tungstenite::client::IntoClientRequest;
14use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
15use tokio_tungstenite::tungstenite::Message;
16use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
17
18use crate::CdpError;
19
20const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
22
23const DEFAULT_MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024;
25
26#[derive(Debug, Clone)]
30pub struct CdpClientConfig {
31 pub max_message_size: Option<usize>,
33 pub max_frame_size: Option<usize>,
35 pub additional_headers: HashMap<String, String>,
37 pub command_timeout: Duration,
39}
40
41impl Default for CdpClientConfig {
42 fn default() -> Self {
43 Self {
44 max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
45 max_frame_size: None, additional_headers: HashMap::new(),
47 command_timeout: DEFAULT_COMMAND_TIMEOUT,
48 }
49 }
50}
51
52type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
53type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
54type WsSource = futures_util::stream::SplitStream<WsStream>;
55type PendingRequests = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
56
57pub type EventHandler = Arc<
59 dyn Fn(Value, Option<String>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
60>;
61
62pub struct EventRegistry {
66 handlers: std::sync::Mutex<HashMap<String, EventHandler>>,
67}
68
69impl EventRegistry {
70 pub fn new() -> Self {
71 Self {
72 handlers: std::sync::Mutex::new(HashMap::new()),
73 }
74 }
75
76 pub fn register(&self, method: &str, handler: EventHandler) {
78 self.handlers
79 .lock()
80 .unwrap()
81 .insert(method.to_string(), handler);
82 }
83
84 pub fn unregister(&self, method: &str) {
86 self.handlers.lock().unwrap().remove(method);
87 }
88
89 pub async fn handle_event(
94 &self,
95 method: &str,
96 params: Value,
97 session_id: Option<String>,
98 ) -> bool {
99 let handler = {
100 let handlers = self.handlers.lock().unwrap();
101 handlers.get(method).cloned()
102 };
103
104 if let Some(handler) = handler {
105 handler(params, session_id).await;
106 true
107 } else {
108 false
109 }
110 }
111
112 pub fn clear(&self) {
114 self.handlers.lock().unwrap().clear();
115 }
116}
117
118impl Default for EventRegistry {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124#[derive(Clone)]
138pub struct CdpClient {
139 inner: Arc<ClientInner>,
140}
141
142struct ClientInner {
143 sink: AsyncMutex<WsSink>,
144 next_id: AtomicU64,
145 pending: Arc<AsyncMutex<PendingRequests>>,
146 event_registry: Arc<EventRegistry>,
147 closed: AtomicBool,
148 command_timeout: Duration,
149 message_loop_handle: std::sync::Mutex<Option<JoinHandle<()>>>,
150}
151
152impl Drop for ClientInner {
153 fn drop(&mut self) {
154 if let Some(handle) = self.message_loop_handle.get_mut().unwrap().take() {
155 handle.abort();
156 }
157 }
158}
159
160impl CdpClient {
161 pub async fn connect(url: &str) -> Result<Self, CdpError> {
163 Self::connect_with_config(url, CdpClientConfig::default()).await
164 }
165
166 pub async fn connect_with_config(
182 url: &str,
183 config: CdpClientConfig,
184 ) -> Result<Self, CdpError> {
185 let mut request = url.into_client_request()?;
186
187 for (key, value) in &config.additional_headers {
189 request.headers_mut().insert(
190 key.parse::<tokio_tungstenite::tungstenite::http::HeaderName>()
191 .map_err(|e| CdpError::Protocol {
192 code: -1,
193 message: format!("Invalid header name '{key}': {e}"),
194 data: None,
195 })?,
196 value
197 .parse()
198 .map_err(|e| CdpError::Protocol {
199 code: -1,
200 message: format!("Invalid header value for '{key}': {e}"),
201 data: None,
202 })?,
203 );
204 }
205
206 let mut ws_config = WebSocketConfig::default();
207 ws_config.max_message_size = config.max_message_size;
208 ws_config.max_frame_size = config.max_frame_size;
209
210 let (ws_stream, _) =
211 connect_async_with_config(request, Some(ws_config), false).await?;
212 let (sink, stream) = ws_stream.split();
213
214 let pending = Arc::new(AsyncMutex::new(HashMap::new()));
215 let event_registry = Arc::new(EventRegistry::new());
216 let closed = Arc::new(AtomicBool::new(false));
217
218 let handle = tokio::spawn({
219 let pending = pending.clone();
220 let registry = event_registry.clone();
221 let closed = closed.clone();
222 async move {
223 message_loop(stream, pending, registry, closed).await;
224 }
225 });
226
227 Ok(Self {
228 inner: Arc::new(ClientInner {
229 sink: AsyncMutex::new(sink),
230 next_id: AtomicU64::new(0),
231 pending,
232 event_registry,
233 closed: AtomicBool::new(false),
234 command_timeout: config.command_timeout,
235 message_loop_handle: std::sync::Mutex::new(Some(handle)),
236 }),
237 })
238 }
239
240 pub async fn send_raw(
245 &self,
246 method: &str,
247 params: Value,
248 session_id: Option<&str>,
249 ) -> Result<Value, CdpError> {
250 if self.inner.closed.load(Ordering::Acquire) {
251 return Err(CdpError::ConnectionClosed);
252 }
253
254 let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed) + 1;
255
256 let (tx, rx) = oneshot::channel();
257 self.inner.pending.lock().await.insert(id, tx);
258
259 let mut msg = serde_json::json!({
260 "id": id,
261 "method": method,
262 "params": params,
263 });
264 if let Some(sid) = session_id {
265 msg["sessionId"] = Value::String(sid.to_string());
266 }
267
268 let send_result = self
269 .inner
270 .sink
271 .lock()
272 .await
273 .send(Message::Text(msg.to_string().into()))
274 .await;
275
276 if let Err(e) = send_result {
277 self.inner.pending.lock().await.remove(&id);
279 return Err(e.into());
280 }
281
282 match tokio::time::timeout(self.inner.command_timeout, rx).await {
284 Ok(Ok(result)) => result,
285 Ok(Err(_)) => {
286 Err(CdpError::ConnectionClosed)
288 }
289 Err(_elapsed) => {
290 self.inner.pending.lock().await.remove(&id);
292 Err(CdpError::Timeout)
293 }
294 }
295 }
296
297 pub async fn emit_event(
302 &self,
303 method: &str,
304 params: Value,
305 session_id: Option<&str>,
306 ) -> bool {
307 self.inner
308 .event_registry
309 .handle_event(method, params, session_id.map(String::from))
310 .await
311 }
312
313 pub(crate) fn event_registry(&self) -> &Arc<EventRegistry> {
316 &self.inner.event_registry
317 }
318
319 pub async fn close(&self) -> Result<(), CdpError> {
321 self.inner.closed.store(true, Ordering::Release);
323
324 {
327 let mut pending = self.inner.pending.lock().await;
328 for (_, tx) in pending.drain() {
329 let _ = tx.send(Err(CdpError::ConnectionClosed));
330 }
331 }
332
333 if let Some(handle) = self.inner.message_loop_handle.lock().unwrap().take() {
335 handle.abort();
336 let _ = handle.await;
337 }
338
339 self.inner.sink.lock().await.close().await?;
341 Ok(())
342 }
343}
344
345async fn message_loop(
346 mut stream: WsSource,
347 pending: Arc<AsyncMutex<PendingRequests>>,
348 event_registry: Arc<EventRegistry>,
349 closed: Arc<AtomicBool>,
350) {
351 while let Some(msg_result) = stream.next().await {
352 match msg_result {
353 Ok(Message::Text(text)) => {
354 let data: Value = match serde_json::from_str(&text) {
355 Ok(v) => v,
356 Err(e) => {
357 tracing::warn!("Failed to parse CDP message: {e}");
358 continue;
359 }
360 };
361
362 if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
363 let mut pending = pending.lock().await;
365 if let Some(tx) = pending.remove(&id) {
366 let result = if let Some(error) = data.get("error") {
367 let code = error.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
368 let message = error
369 .get("message")
370 .and_then(|v| v.as_str())
371 .unwrap_or("Unknown error")
372 .to_string();
373 let err_data = error.get("data").map(|v| v.to_string());
374 Err(CdpError::Protocol {
375 code,
376 message,
377 data: err_data,
378 })
379 } else {
380 Ok(data
381 .get("result")
382 .cloned()
383 .unwrap_or(Value::Object(Default::default())))
384 };
385 let _ = tx.send(result);
386 }
387 } else if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
388 let params = data.get("params").cloned().unwrap_or_default();
390 let session_id = data
391 .get("sessionId")
392 .and_then(|v| v.as_str())
393 .map(String::from);
394 let registry = event_registry.clone();
395 let method = method.to_string();
396 tokio::spawn(async move {
397 registry.handle_event(&method, params, session_id).await;
398 });
399 }
400 }
401 Ok(Message::Close(_)) | Err(_) => {
402 closed.store(true, Ordering::Release);
404 let mut pending = pending.lock().await;
405 for (_, tx) in pending.drain() {
406 let _ = tx.send(Err(CdpError::ConnectionClosed));
407 }
408 break;
409 }
410 _ => {} }
412 }
413}