Skip to main content

chrome_cli/cdp/
transport.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use futures_util::{SinkExt, StreamExt};
6use tokio::net::TcpStream;
7use tokio::sync::{mpsc, oneshot};
8use tokio::time::{Duration, Instant};
9use tokio_tungstenite::tungstenite::Message;
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11
12use super::error::CdpError;
13use super::types::{CdpCommand, CdpEvent, MessageKind, RawCdpMessage};
14
15type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
16
17/// Key for the subscriber map: (`method_name`, `session_id`).
18type SubscriberKey = (String, Option<String>);
19
20/// Command sent from the client handle to the transport task.
21pub enum TransportCommand {
22    /// Send a CDP command and deliver the response via the oneshot channel.
23    SendCommand {
24        command: CdpCommand,
25        response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
26        deadline: Instant,
27    },
28    /// Subscribe to events matching a method name (and optional session).
29    Subscribe {
30        method: String,
31        session_id: Option<String>,
32        event_tx: mpsc::Sender<CdpEvent>,
33    },
34    /// Shut down the transport gracefully.
35    Shutdown,
36}
37
38/// Tracks an in-flight command awaiting its response.
39struct PendingRequest {
40    response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
41    method: String,
42    deadline: Instant,
43}
44
45/// Reconnection configuration.
46#[derive(Debug, Clone)]
47pub struct ReconnectConfig {
48    /// Maximum number of reconnection attempts (default: 5).
49    pub max_retries: u32,
50    /// Initial backoff delay (default: 100ms).
51    pub initial_backoff: Duration,
52    /// Maximum backoff delay (default: 5s).
53    pub max_backoff: Duration,
54}
55
56impl Default for ReconnectConfig {
57    fn default() -> Self {
58        Self {
59            max_retries: 5,
60            initial_backoff: Duration::from_millis(100),
61            max_backoff: Duration::from_secs(5),
62        }
63    }
64}
65
66/// Clonable handle for communicating with the transport task.
67#[derive(Debug, Clone)]
68pub struct TransportHandle {
69    command_tx: mpsc::Sender<TransportCommand>,
70    connected: Arc<AtomicBool>,
71    next_id: Arc<AtomicU64>,
72}
73
74impl TransportHandle {
75    /// Send a transport command to the background task.
76    ///
77    /// # Errors
78    ///
79    /// Returns `CdpError::Internal` if the transport task has exited.
80    pub async fn send(&self, cmd: TransportCommand) -> Result<(), CdpError> {
81        self.command_tx
82            .send(cmd)
83            .await
84            .map_err(|_| CdpError::Internal("transport task is not running".into()))
85    }
86
87    /// Check whether the transport is currently connected.
88    #[must_use]
89    pub fn is_connected(&self) -> bool {
90        self.connected.load(Ordering::Relaxed)
91    }
92
93    /// Generate the next unique message ID for this connection.
94    pub fn next_message_id(&self) -> u64 {
95        self.next_id.fetch_add(1, Ordering::Relaxed)
96    }
97}
98
99/// Spawn the transport background task.
100///
101/// Returns a `TransportHandle` for sending commands to the task.
102///
103/// # Errors
104///
105/// Returns `CdpError::Connection` or `CdpError::ConnectionTimeout` if the
106/// initial WebSocket connection cannot be established.
107pub async fn spawn_transport(
108    url: &str,
109    channel_capacity: usize,
110    reconnect_config: ReconnectConfig,
111    connect_timeout: Duration,
112) -> Result<TransportHandle, CdpError> {
113    let ws_stream = connect_ws(url, connect_timeout).await?;
114    let connected = Arc::new(AtomicBool::new(true));
115    let next_id = Arc::new(AtomicU64::new(1));
116    let (command_tx, command_rx) = mpsc::channel(channel_capacity);
117
118    let handle = TransportHandle {
119        command_tx,
120        connected: Arc::clone(&connected),
121        next_id,
122    };
123
124    let url_owned = url.to_owned();
125    tokio::spawn(async move {
126        let mut task = TransportTask {
127            ws_stream,
128            command_rx,
129            pending: HashMap::new(),
130            subscribers: HashMap::new(),
131            connected,
132            url: url_owned,
133            reconnect_config,
134            connect_timeout,
135            reconnect_failure: None,
136        };
137        task.run().await;
138    });
139
140    Ok(handle)
141}
142
143/// Establish a WebSocket connection with a timeout.
144async fn connect_ws(url: &str, timeout: Duration) -> Result<WsStream, CdpError> {
145    match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
146        Ok(Ok((stream, _response))) => Ok(stream),
147        Ok(Err(e)) => Err(CdpError::Connection(e.to_string())),
148        Err(_) => Err(CdpError::ConnectionTimeout),
149    }
150}
151
152/// The background transport task that owns the WebSocket connection.
153struct TransportTask {
154    ws_stream: WsStream,
155    command_rx: mpsc::Receiver<TransportCommand>,
156    pending: HashMap<u64, PendingRequest>,
157    subscribers: HashMap<SubscriberKey, Vec<mpsc::Sender<CdpEvent>>>,
158    connected: Arc<AtomicBool>,
159    url: String,
160    reconnect_config: ReconnectConfig,
161    connect_timeout: Duration,
162    reconnect_failure: Option<(u32, String)>,
163}
164
165impl TransportTask {
166    async fn run(&mut self) {
167        loop {
168            // If reconnection has permanently failed, drain remaining
169            // commands with ReconnectFailed errors until shutdown.
170            if let Some((attempts, ref last_error)) = self.reconnect_failure {
171                match self.command_rx.recv().await {
172                    Some(TransportCommand::SendCommand { response_tx, .. }) => {
173                        let _ = response_tx.send(Err(CdpError::ReconnectFailed {
174                            attempts,
175                            last_error: last_error.clone(),
176                        }));
177                        continue;
178                    }
179                    Some(TransportCommand::Subscribe { .. }) => continue,
180                    Some(TransportCommand::Shutdown) | None => return,
181                }
182            }
183
184            let next_deadline = self.earliest_deadline();
185            let timeout_sleep = async {
186                if let Some(deadline) = next_deadline {
187                    tokio::time::sleep_until(deadline).await;
188                } else {
189                    // No pending requests — sleep forever (will be cancelled by select)
190                    std::future::pending::<()>().await;
191                }
192            };
193
194            tokio::select! {
195                // Branch 1: WebSocket read
196                ws_msg = self.ws_stream.next() => {
197                    match ws_msg {
198                        Some(Ok(Message::Text(text))) => {
199                            self.handle_text_message(&text);
200                        }
201                        Some(Ok(Message::Close(_)) | Err(_)) | None => {
202                            self.handle_disconnect().await;
203                            // If reconnected, continue normally.
204                            // If reconnect failed, reconnect_failure is set and
205                            // the top-of-loop check will drain commands.
206                        }
207                        Some(Ok(_)) => {
208                            // Binary, Ping, Pong, Frame — ignore
209                        }
210                    }
211                }
212
213                // Branch 2: Command channel
214                cmd = self.command_rx.recv() => {
215                    match cmd {
216                        Some(TransportCommand::SendCommand { command, response_tx, deadline }) => {
217                            self.handle_send_command(command, response_tx, deadline).await;
218                        }
219                        Some(TransportCommand::Subscribe { method, session_id, event_tx }) => {
220                            self.subscribers
221                                .entry((method, session_id))
222                                .or_default()
223                                .push(event_tx);
224                        }
225                        Some(TransportCommand::Shutdown) | None => {
226                            self.drain_pending();
227                            let _ = self.ws_stream.close(None).await;
228                            self.connected.store(false, Ordering::Relaxed);
229                            return;
230                        }
231                    }
232                }
233
234                // Branch 3: Timeout sweep
235                () = timeout_sleep => {
236                    self.sweep_timeouts();
237                }
238            }
239        }
240    }
241
242    fn handle_text_message(&mut self, text: &str) {
243        let raw: RawCdpMessage = match serde_json::from_str(text) {
244            Ok(msg) => msg,
245            Err(_) => {
246                // Malformed JSON — ignore and continue
247                return;
248            }
249        };
250
251        let Some(kind) = raw.classify() else {
252            // Unclassifiable message — ignore
253            return;
254        };
255
256        match kind {
257            MessageKind::Response(response) => {
258                if let Some(pending) = self.pending.remove(&response.id) {
259                    let result = match response.result {
260                        Ok(value) => Ok(value),
261                        Err(proto_err) => Err(CdpError::Protocol {
262                            code: proto_err.code,
263                            message: proto_err.message,
264                        }),
265                    };
266                    let _ = pending.response_tx.send(result);
267                }
268            }
269            MessageKind::Event(event) => {
270                self.dispatch_event(&event);
271            }
272        }
273    }
274
275    fn dispatch_event(&mut self, event: &CdpEvent) {
276        let key = (event.method.clone(), event.session_id.clone());
277        if let Some(senders) = self.subscribers.get_mut(&key) {
278            // Remove senders whose receiver has been dropped
279            senders.retain(|tx| tx.try_send(event.clone()).is_ok() || !tx.is_closed());
280            if senders.is_empty() {
281                self.subscribers.remove(&key);
282            }
283        }
284    }
285
286    async fn handle_send_command(
287        &mut self,
288        command: CdpCommand,
289        response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
290        deadline: Instant,
291    ) {
292        let id = command.id;
293        let method = command.method.clone();
294
295        let json = match serde_json::to_string(&command) {
296            Ok(j) => j,
297            Err(e) => {
298                let _ =
299                    response_tx.send(Err(CdpError::Internal(format!("serialization error: {e}"))));
300                return;
301            }
302        };
303
304        if let Err(e) = self.ws_stream.send(Message::Text(json.into())).await {
305            let _ = response_tx.send(Err(CdpError::Connection(format!(
306                "WebSocket write error: {e}"
307            ))));
308            return;
309        }
310
311        self.pending.insert(
312            id,
313            PendingRequest {
314                response_tx,
315                method,
316                deadline,
317            },
318        );
319    }
320
321    fn earliest_deadline(&self) -> Option<Instant> {
322        self.pending.values().map(|p| p.deadline).min()
323    }
324
325    fn sweep_timeouts(&mut self) {
326        let now = Instant::now();
327        let timed_out: Vec<u64> = self
328            .pending
329            .iter()
330            .filter(|(_, p)| p.deadline <= now)
331            .map(|(&id, _)| id)
332            .collect();
333
334        for id in timed_out {
335            if let Some(pending) = self.pending.remove(&id) {
336                let _ = pending.response_tx.send(Err(CdpError::CommandTimeout {
337                    method: pending.method,
338                }));
339            }
340        }
341    }
342
343    fn drain_pending(&mut self) {
344        let pending = std::mem::take(&mut self.pending);
345        for (_, req) in pending {
346            let _ = req.response_tx.send(Err(CdpError::ConnectionClosed));
347        }
348    }
349
350    async fn handle_disconnect(&mut self) {
351        self.connected.store(false, Ordering::Relaxed);
352        self.drain_pending();
353
354        let mut backoff = self.reconnect_config.initial_backoff;
355        let mut last_error_msg = String::from("no retries configured");
356
357        for attempt in 1..=self.reconnect_config.max_retries {
358            tokio::time::sleep(backoff).await;
359
360            match connect_ws(&self.url, self.connect_timeout).await {
361                Ok(new_stream) => {
362                    self.ws_stream = new_stream;
363                    self.connected.store(true, Ordering::Relaxed);
364                    return;
365                }
366                Err(e) => {
367                    last_error_msg = e.to_string();
368                    if attempt < self.reconnect_config.max_retries {
369                        backoff = (backoff * 2).min(self.reconnect_config.max_backoff);
370                    }
371                }
372            }
373        }
374
375        // All retries exhausted — store failure and let the run loop
376        // drain remaining commands with ReconnectFailed errors.
377        self.reconnect_failure = Some((self.reconnect_config.max_retries, last_error_msg));
378    }
379}