1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4
5use futures_util::{SinkExt, StreamExt};
6use tokio::net::TcpStream;
7use tokio::sync::{mpsc, oneshot};
8use tokio::time::{Duration, Instant};
9use tokio_tungstenite::tungstenite::Message;
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11
12use super::error::CdpError;
13use super::types::{CdpCommand, CdpEvent, MessageKind, RawCdpMessage};
14
15type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
16
17type SubscriberKey = (String, Option<String>);
19
20pub enum TransportCommand {
22 SendCommand {
24 command: CdpCommand,
25 response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
26 deadline: Instant,
27 },
28 Subscribe {
30 method: String,
31 session_id: Option<String>,
32 event_tx: mpsc::Sender<CdpEvent>,
33 },
34 Shutdown,
36}
37
38struct PendingRequest {
40 response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
41 method: String,
42 deadline: Instant,
43}
44
45#[derive(Debug, Clone)]
47pub struct ReconnectConfig {
48 pub max_retries: u32,
50 pub initial_backoff: Duration,
52 pub max_backoff: Duration,
54}
55
56impl Default for ReconnectConfig {
57 fn default() -> Self {
58 Self {
59 max_retries: 5,
60 initial_backoff: Duration::from_millis(100),
61 max_backoff: Duration::from_secs(5),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct TransportHandle {
69 command_tx: mpsc::Sender<TransportCommand>,
70 connected: Arc<AtomicBool>,
71 next_id: Arc<AtomicU64>,
72}
73
74impl TransportHandle {
75 pub async fn send(&self, cmd: TransportCommand) -> Result<(), CdpError> {
81 self.command_tx
82 .send(cmd)
83 .await
84 .map_err(|_| CdpError::Internal("transport task is not running".into()))
85 }
86
87 #[must_use]
89 pub fn is_connected(&self) -> bool {
90 self.connected.load(Ordering::Relaxed)
91 }
92
93 pub fn next_message_id(&self) -> u64 {
95 self.next_id.fetch_add(1, Ordering::Relaxed)
96 }
97}
98
99pub async fn spawn_transport(
108 url: &str,
109 channel_capacity: usize,
110 reconnect_config: ReconnectConfig,
111 connect_timeout: Duration,
112) -> Result<TransportHandle, CdpError> {
113 let ws_stream = connect_ws(url, connect_timeout).await?;
114 let connected = Arc::new(AtomicBool::new(true));
115 let next_id = Arc::new(AtomicU64::new(1));
116 let (command_tx, command_rx) = mpsc::channel(channel_capacity);
117
118 let handle = TransportHandle {
119 command_tx,
120 connected: Arc::clone(&connected),
121 next_id,
122 };
123
124 let url_owned = url.to_owned();
125 tokio::spawn(async move {
126 let mut task = TransportTask {
127 ws_stream,
128 command_rx,
129 pending: HashMap::new(),
130 subscribers: HashMap::new(),
131 connected,
132 url: url_owned,
133 reconnect_config,
134 connect_timeout,
135 reconnect_failure: None,
136 };
137 task.run().await;
138 });
139
140 Ok(handle)
141}
142
143async fn connect_ws(url: &str, timeout: Duration) -> Result<WsStream, CdpError> {
145 match tokio::time::timeout(timeout, tokio_tungstenite::connect_async(url)).await {
146 Ok(Ok((stream, _response))) => Ok(stream),
147 Ok(Err(e)) => Err(CdpError::Connection(e.to_string())),
148 Err(_) => Err(CdpError::ConnectionTimeout),
149 }
150}
151
152struct TransportTask {
154 ws_stream: WsStream,
155 command_rx: mpsc::Receiver<TransportCommand>,
156 pending: HashMap<u64, PendingRequest>,
157 subscribers: HashMap<SubscriberKey, Vec<mpsc::Sender<CdpEvent>>>,
158 connected: Arc<AtomicBool>,
159 url: String,
160 reconnect_config: ReconnectConfig,
161 connect_timeout: Duration,
162 reconnect_failure: Option<(u32, String)>,
163}
164
165impl TransportTask {
166 async fn run(&mut self) {
167 loop {
168 if let Some((attempts, ref last_error)) = self.reconnect_failure {
171 match self.command_rx.recv().await {
172 Some(TransportCommand::SendCommand { response_tx, .. }) => {
173 let _ = response_tx.send(Err(CdpError::ReconnectFailed {
174 attempts,
175 last_error: last_error.clone(),
176 }));
177 continue;
178 }
179 Some(TransportCommand::Subscribe { .. }) => continue,
180 Some(TransportCommand::Shutdown) | None => return,
181 }
182 }
183
184 let next_deadline = self.earliest_deadline();
185 let timeout_sleep = async {
186 if let Some(deadline) = next_deadline {
187 tokio::time::sleep_until(deadline).await;
188 } else {
189 std::future::pending::<()>().await;
191 }
192 };
193
194 tokio::select! {
195 ws_msg = self.ws_stream.next() => {
197 match ws_msg {
198 Some(Ok(Message::Text(text))) => {
199 self.handle_text_message(&text);
200 }
201 Some(Ok(Message::Close(_)) | Err(_)) | None => {
202 self.handle_disconnect().await;
203 }
207 Some(Ok(_)) => {
208 }
210 }
211 }
212
213 cmd = self.command_rx.recv() => {
215 match cmd {
216 Some(TransportCommand::SendCommand { command, response_tx, deadline }) => {
217 self.handle_send_command(command, response_tx, deadline).await;
218 }
219 Some(TransportCommand::Subscribe { method, session_id, event_tx }) => {
220 self.subscribers
221 .entry((method, session_id))
222 .or_default()
223 .push(event_tx);
224 }
225 Some(TransportCommand::Shutdown) | None => {
226 self.drain_pending();
227 let _ = self.ws_stream.close(None).await;
228 self.connected.store(false, Ordering::Relaxed);
229 return;
230 }
231 }
232 }
233
234 () = timeout_sleep => {
236 self.sweep_timeouts();
237 }
238 }
239 }
240 }
241
242 fn handle_text_message(&mut self, text: &str) {
243 let raw: RawCdpMessage = match serde_json::from_str(text) {
244 Ok(msg) => msg,
245 Err(_) => {
246 return;
248 }
249 };
250
251 let Some(kind) = raw.classify() else {
252 return;
254 };
255
256 match kind {
257 MessageKind::Response(response) => {
258 if let Some(pending) = self.pending.remove(&response.id) {
259 let result = match response.result {
260 Ok(value) => Ok(value),
261 Err(proto_err) => Err(CdpError::Protocol {
262 code: proto_err.code,
263 message: proto_err.message,
264 }),
265 };
266 let _ = pending.response_tx.send(result);
267 }
268 }
269 MessageKind::Event(event) => {
270 self.dispatch_event(&event);
271 }
272 }
273 }
274
275 fn dispatch_event(&mut self, event: &CdpEvent) {
276 let key = (event.method.clone(), event.session_id.clone());
277 if let Some(senders) = self.subscribers.get_mut(&key) {
278 senders.retain(|tx| tx.try_send(event.clone()).is_ok() || !tx.is_closed());
280 if senders.is_empty() {
281 self.subscribers.remove(&key);
282 }
283 }
284 }
285
286 async fn handle_send_command(
287 &mut self,
288 command: CdpCommand,
289 response_tx: oneshot::Sender<Result<serde_json::Value, CdpError>>,
290 deadline: Instant,
291 ) {
292 let id = command.id;
293 let method = command.method.clone();
294
295 let json = match serde_json::to_string(&command) {
296 Ok(j) => j,
297 Err(e) => {
298 let _ =
299 response_tx.send(Err(CdpError::Internal(format!("serialization error: {e}"))));
300 return;
301 }
302 };
303
304 if let Err(e) = self.ws_stream.send(Message::Text(json.into())).await {
305 let _ = response_tx.send(Err(CdpError::Connection(format!(
306 "WebSocket write error: {e}"
307 ))));
308 return;
309 }
310
311 self.pending.insert(
312 id,
313 PendingRequest {
314 response_tx,
315 method,
316 deadline,
317 },
318 );
319 }
320
321 fn earliest_deadline(&self) -> Option<Instant> {
322 self.pending.values().map(|p| p.deadline).min()
323 }
324
325 fn sweep_timeouts(&mut self) {
326 let now = Instant::now();
327 let timed_out: Vec<u64> = self
328 .pending
329 .iter()
330 .filter(|(_, p)| p.deadline <= now)
331 .map(|(&id, _)| id)
332 .collect();
333
334 for id in timed_out {
335 if let Some(pending) = self.pending.remove(&id) {
336 let _ = pending.response_tx.send(Err(CdpError::CommandTimeout {
337 method: pending.method,
338 }));
339 }
340 }
341 }
342
343 fn drain_pending(&mut self) {
344 let pending = std::mem::take(&mut self.pending);
345 for (_, req) in pending {
346 let _ = req.response_tx.send(Err(CdpError::ConnectionClosed));
347 }
348 }
349
350 async fn handle_disconnect(&mut self) {
351 self.connected.store(false, Ordering::Relaxed);
352 self.drain_pending();
353
354 let mut backoff = self.reconnect_config.initial_backoff;
355 let mut last_error_msg = String::from("no retries configured");
356
357 for attempt in 1..=self.reconnect_config.max_retries {
358 tokio::time::sleep(backoff).await;
359
360 match connect_ws(&self.url, self.connect_timeout).await {
361 Ok(new_stream) => {
362 self.ws_stream = new_stream;
363 self.connected.store(true, Ordering::Relaxed);
364 return;
365 }
366 Err(e) => {
367 last_error_msg = e.to_string();
368 if attempt < self.reconnect_config.max_retries {
369 backoff = (backoff * 2).min(self.reconnect_config.max_backoff);
370 }
371 }
372 }
373 }
374
375 self.reconnect_failure = Some((self.reconnect_config.max_retries, last_error_msg));
378 }
379}