cdp_html_shot/
transport.rs1#![allow(dead_code)]
2
3use anyhow::{Result, anyhow};
4use futures_util::stream::{SplitSink, SplitStream};
5use futures_util::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tokio::sync::{mpsc, oneshot};
13use tokio::time;
14use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
15
16pub(crate) static GLOBAL_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18pub(crate) fn next_id() -> usize {
19 GLOBAL_ID_COUNTER.fetch_add(1, Ordering::SeqCst) + 1
20}
21
22#[derive(Debug)]
23pub(crate) enum TransportMessage {
24 Request(Value, oneshot::Sender<Result<TransportResponse>>),
25 ListenTargetMessage(u64, oneshot::Sender<Result<TransportResponse>>),
26 WaitForEvent(String, String, oneshot::Sender<()>),
27 Shutdown,
28}
29
30#[derive(Debug)]
31pub(crate) enum TransportResponse {
32 Response(Response),
33 Target(TargetMessage),
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37pub(crate) struct Response {
38 pub(crate) id: u64,
39 pub(crate) result: Value,
40}
41
42#[derive(Debug, Serialize, Deserialize)]
43pub(crate) struct TargetMessage {
44 pub(crate) params: Value,
45}
46
47struct TransportActor {
48 pending_requests: HashMap<u64, oneshot::Sender<Result<TransportResponse>>>,
49 event_listeners: HashMap<(String, String), Vec<oneshot::Sender<()>>>,
50 ws_sink: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
51 command_rx: mpsc::Receiver<TransportMessage>,
52}
53
54impl TransportActor {
55 async fn run(mut self, mut ws_stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>) {
56 loop {
57 tokio::select! {
58 Some(msg) = ws_stream.next() => {
59 match msg {
60 Ok(Message::Text(text)) => {
61 if let Ok(response) = serde_json::from_str::<Response>(&text) {
62 if let Some(sender) = self.pending_requests.remove(&response.id) {
63 let _ = sender.send(Ok(TransportResponse::Response(response)));
64 }
65 }
66 else if let Ok(target_msg) = serde_json::from_str::<TargetMessage>(&text)
67 && let Some(inner_str) = target_msg.params.get("message").and_then(|v| v.as_str())
68 && let Ok(inner_json) = serde_json::from_str::<Value>(inner_str) {
69
70 if let Some(id) = inner_json.get("id").and_then(|i| i.as_u64()) {
71 if let Some(sender) = self.pending_requests.remove(&id) {
72 let _ = sender.send(Ok(TransportResponse::Target(target_msg)));
73 }
74 }
75 else if let Some(method) = inner_json.get("method").and_then(|s| s.as_str())
76 && let Some(session_id) = target_msg.params.get("sessionId").and_then(|s| s.as_str()) {
77 let key = (session_id.to_string(), method.to_string());
78 if let Some(senders) = self.event_listeners.remove(&key) {
79 for tx in senders {
80 let _ = tx.send(());
81 }
82 }
83 }
84 }
85 }
86 Err(_) => break,
87 _ => {}
88 }
89 }
90 Some(msg) = self.command_rx.recv() => {
91 match msg {
92 TransportMessage::Request(cmd, tx) => {
93 if let Some(id) = cmd["id"].as_u64()
94 && let Ok(text) = serde_json::to_string(&cmd) {
95 if self.ws_sink.send(Message::Text(text.into())).await.is_ok() {
96 self.pending_requests.insert(id, tx);
97 } else {
98 let _ = tx.send(Err(anyhow!("WebSocket send failed")));
99 }
100 }
101 },
102 TransportMessage::ListenTargetMessage(id, tx) => {
103 self.pending_requests.insert(id, tx);
104 },
105 TransportMessage::WaitForEvent(session_id, method, tx) => {
106 self.event_listeners.entry((session_id, method)).or_default().push(tx);
107 },
108 TransportMessage::Shutdown => {
109 let _ = self.ws_sink.send(Message::Text(json!({
110 "id": next_id(),
111 "method": "Browser.close",
112 "params": {}
113 }).to_string().into())).await;
114 let _ = self.ws_sink.close().await;
115 break;
116 }
117 }
118 }
119 else => break,
120 }
121 }
122 }
123}
124
125#[derive(Debug)]
126pub(crate) struct Transport {
127 tx: mpsc::Sender<TransportMessage>,
128}
129
130impl Transport {
131 pub(crate) async fn new(ws_url: &str) -> Result<Self> {
132 let (ws_stream, _) = connect_async(ws_url).await?;
133 let (ws_sink, ws_stream) = ws_stream.split();
134 let (tx, rx) = mpsc::channel(100);
135
136 tokio::spawn(async move {
137 let actor = TransportActor {
138 pending_requests: HashMap::new(),
139 event_listeners: HashMap::new(),
140 ws_sink,
141 command_rx: rx,
142 };
143 actor.run(ws_stream).await;
144 });
145
146 Ok(Self { tx })
147 }
148
149 pub(crate) async fn send(&self, command: Value) -> Result<TransportResponse> {
150 let (tx, rx) = oneshot::channel();
151 self.tx
152 .send(TransportMessage::Request(command, tx))
153 .await
154 .map_err(|_| anyhow!("Transport actor dropped"))?;
155 time::timeout(Duration::from_secs(30), rx)
156 .await
157 .map_err(|_| anyhow!("Timeout waiting for response"))?
158 .map_err(|_| anyhow!("Response channel closed"))?
159 }
160
161 pub(crate) async fn get_target_msg(&self, msg_id: usize) -> Result<TransportResponse> {
162 let (tx, rx) = oneshot::channel();
163 self.tx
164 .send(TransportMessage::ListenTargetMessage(msg_id as u64, tx))
165 .await
166 .map_err(|_| anyhow!("Transport actor dropped"))?;
167 time::timeout(Duration::from_secs(30), rx)
168 .await
169 .map_err(|_| anyhow!("Timeout waiting for target message"))?
170 .map_err(|_| anyhow!("Response channel closed"))?
171 }
172
173 pub(crate) async fn listen_for_event(
174 &self,
175 session_id: &str,
176 method: &str,
177 ) -> Result<oneshot::Receiver<()>> {
178 let (tx, rx) = oneshot::channel();
179 self.tx
180 .send(TransportMessage::WaitForEvent(
181 session_id.to_string(),
182 method.to_string(),
183 tx,
184 ))
185 .await
186 .map_err(|_| anyhow!("Transport actor dropped"))?;
187 Ok(rx)
188 }
189
190 pub(crate) async fn wait_for_event(&self, session_id: &str, method: &str) -> Result<()> {
191 let (tx, rx) = oneshot::channel();
192 self.tx
193 .send(TransportMessage::WaitForEvent(
194 session_id.to_string(),
195 method.to_string(),
196 tx,
197 ))
198 .await
199 .map_err(|_| anyhow!("Transport actor dropped"))?;
200
201 time::timeout(Duration::from_secs(30), rx)
202 .await
203 .map_err(|_| anyhow!("Timeout waiting for event {}", method))?
204 .map_err(|_| anyhow!("Event channel closed"))?;
205 Ok(())
206 }
207
208 pub(crate) async fn shutdown(&self) {
209 let _ = self.tx.send(TransportMessage::Shutdown).await;
210 }
211}