cdp_html_shot/
transport.rs

1#![allow(dead_code)]
2
3use anyhow::{Result, anyhow};
4use futures_util::stream::{SplitSink, SplitStream};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tokio::sync::{mpsc, oneshot};
13use tokio::time;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
15
16pub(crate) static GLOBAL_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18pub(crate) fn next_id() -> usize {
19    GLOBAL_ID_COUNTER.fetch_add(1, Ordering::SeqCst) + 1
20}
21
22#[derive(Debug)]
23pub(crate) enum TransportMessage {
24    Request(Value, oneshot::Sender<Result<TransportResponse>>),
25    ListenTargetMessage(u64, oneshot::Sender<Result<TransportResponse>>),
26    WaitForEvent(String, String, oneshot::Sender<()>),
27    Shutdown,
28}
29
30#[derive(Debug)]
31pub(crate) enum TransportResponse {
32    Response(Response),
33    Target(TargetMessage),
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37pub(crate) struct Response {
38    pub(crate) id: u64,
39    pub(crate) result: Value,
40}
41
42#[derive(Debug, Serialize, Deserialize)]
43pub(crate) struct TargetMessage {
44    pub(crate) params: Value,
45}
46
47struct TransportActor {
48    pending_requests: HashMap<u64, oneshot::Sender<Result<TransportResponse>>>,
49    event_listeners: HashMap<(String, String), Vec<oneshot::Sender<()>>>,
50    ws_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
51    command_rx: mpsc::Receiver<TransportMessage>,
52}
53
54impl TransportActor {
55    async fn run(mut self, mut ws_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>) {
56        loop {
57            tokio::select! {
58                Some(msg) = ws_stream.next() => {
59                    match msg {
60                        Ok(Message::Text(text)) => {
61                            if let Ok(response) = serde_json::from_str::<Response>(&text) {
62                                if let Some(sender) = self.pending_requests.remove(&response.id) {
63                                    let _ = sender.send(Ok(TransportResponse::Response(response)));
64                                }
65                            }
66                            else if let Ok(target_msg) = serde_json::from_str::<TargetMessage>(&text)
67                                && let Some(inner_str) = target_msg.params.get("message").and_then(|v| v.as_str())
68                                    && let Ok(inner_json) = serde_json::from_str::<Value>(inner_str) {
69
70                                        if let Some(id) = inner_json.get("id").and_then(|i| i.as_u64()) {
71                                            if let Some(sender) = self.pending_requests.remove(&id) {
72                                                let _ = sender.send(Ok(TransportResponse::Target(target_msg)));
73                                            }
74                                        }
75                                        else if let Some(method) = inner_json.get("method").and_then(|s| s.as_str())
76                                            && let Some(session_id) = target_msg.params.get("sessionId").and_then(|s| s.as_str()) {
77                                                let key = (session_id.to_string(), method.to_string());
78                                                if let Some(senders) = self.event_listeners.remove(&key) {
79                                                    for tx in senders {
80                                                        let _ = tx.send(());
81                                                    }
82                                                }
83                                            }
84                                    }
85                        }
86                        Err(_) => break,
87                        _ => {}
88                    }
89                }
90                Some(msg) = self.command_rx.recv() => {
91                    match msg {
92                        TransportMessage::Request(cmd, tx) => {
93                            if let Some(id) = cmd["id"].as_u64()
94                                && let Ok(text) = serde_json::to_string(&cmd) {
95                                    if self.ws_sink.send(Message::Text(text.into())).await.is_ok() {
96                                        self.pending_requests.insert(id, tx);
97                                    } else {
98                                        let _ = tx.send(Err(anyhow!("WebSocket send failed")));
99                                    }
100                                }
101                        },
102                        TransportMessage::ListenTargetMessage(id, tx) => {
103                            self.pending_requests.insert(id, tx);
104                        },
105                        TransportMessage::WaitForEvent(session_id, method, tx) => {
106                            self.event_listeners.entry((session_id, method)).or_default().push(tx);
107                        },
108                        TransportMessage::Shutdown => {
109                            let _ = self.ws_sink.send(Message::Text(json!({
110                                "id": next_id(),
111                                "method": "Browser.close",
112                                "params": {}
113                            }).to_string().into())).await;
114                            let _ = self.ws_sink.close().await;
115                            break;
116                        }
117                    }
118                }
119                else => break,
120            }
121        }
122    }
123}
124
125#[derive(Debug)]
126pub(crate) struct Transport {
127    tx: mpsc::Sender<TransportMessage>,
128}
129
130impl Transport {
131    pub(crate) async fn new(ws_url: &str) -> Result<Self> {
132        let (ws_stream, _) = connect_async(ws_url).await?;
133        let (ws_sink, ws_stream) = ws_stream.split();
134        let (tx, rx) = mpsc::channel(100);
135
136        tokio::spawn(async move {
137            let actor = TransportActor {
138                pending_requests: HashMap::new(),
139                event_listeners: HashMap::new(),
140                ws_sink,
141                command_rx: rx,
142            };
143            actor.run(ws_stream).await;
144        });
145
146        Ok(Self { tx })
147    }
148
149    pub(crate) async fn send(&self, command: Value) -> Result<TransportResponse> {
150        let (tx, rx) = oneshot::channel();
151        self.tx
152            .send(TransportMessage::Request(command, tx))
153            .await
154            .map_err(|_| anyhow!("Transport actor dropped"))?;
155        time::timeout(Duration::from_secs(30), rx)
156            .await
157            .map_err(|_| anyhow!("Timeout waiting for response"))?
158            .map_err(|_| anyhow!("Response channel closed"))?
159    }
160
161    pub(crate) async fn get_target_msg(&self, msg_id: usize) -> Result<TransportResponse> {
162        let (tx, rx) = oneshot::channel();
163        self.tx
164            .send(TransportMessage::ListenTargetMessage(msg_id as u64, tx))
165            .await
166            .map_err(|_| anyhow!("Transport actor dropped"))?;
167        time::timeout(Duration::from_secs(30), rx)
168            .await
169            .map_err(|_| anyhow!("Timeout waiting for target message"))?
170            .map_err(|_| anyhow!("Response channel closed"))?
171    }
172
173    pub(crate) async fn listen_for_event(
174        &self,
175        session_id: &str,
176        method: &str,
177    ) -> Result<oneshot::Receiver<()>> {
178        let (tx, rx) = oneshot::channel();
179        self.tx
180            .send(TransportMessage::WaitForEvent(
181                session_id.to_string(),
182                method.to_string(),
183                tx,
184            ))
185            .await
186            .map_err(|_| anyhow!("Transport actor dropped"))?;
187        Ok(rx)
188    }
189
190    pub(crate) async fn wait_for_event(&self, session_id: &str, method: &str) -> Result<()> {
191        let (tx, rx) = oneshot::channel();
192        self.tx
193            .send(TransportMessage::WaitForEvent(
194                session_id.to_string(),
195                method.to_string(),
196                tx,
197            ))
198            .await
199            .map_err(|_| anyhow!("Transport actor dropped"))?;
200
201        time::timeout(Duration::from_secs(30), rx)
202            .await
203            .map_err(|_| anyhow!("Timeout waiting for event {}", method))?
204            .map_err(|_| anyhow!("Event channel closed"))?;
205        Ok(())
206    }
207
208    pub(crate) async fn shutdown(&self) {
209        let _ = self.tx.send(TransportMessage::Shutdown).await;
210    }
211}