Skip to main content

playwright_rs/protocol/
web_socket.rs

1//! WebSocket protocol object — represents a WebSocket connection in the page.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use playwright_rs::protocol::Playwright;
7//!
8//! #[tokio::main]
9//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
10//!     let playwright = Playwright::launch().await?;
11//!     let browser = playwright.chromium().launch().await?;
12//!     let page = browser.new_page().await?;
13//!
14//!     // Register the handler BEFORE the action that opens the WebSocket
15//!     page.on_websocket(|ws| async move {
16//!         println!("WebSocket URL: {}", ws.url());
17//!
18//!         // Wait for the connection to close
19//!         let close_waiter = ws.expect_close(Some(5000.0)).await?;
20//!         close_waiter.wait().await?;
21//!         assert!(ws.is_closed());
22//!         Ok(())
23//!     }).await?;
24//!
25//!     // Navigate to a page that opens a WebSocket
26//!     page.goto("https://example.com/ws-demo", None).await?;
27//!
28//!     browser.close().await?;
29//!     Ok(())
30//! }
31//! ```
32
33use crate::error::Result;
34use crate::protocol::EventWaiter;
35use crate::server::channel::Channel;
36use crate::server::channel_owner::{ChannelOwner, ChannelOwnerImpl, ParentOrConnection};
37use serde_json::Value;
38use std::any::Any;
39use std::future::Future;
40use std::pin::Pin;
41use std::sync::atomic::{AtomicBool, Ordering};
42use std::sync::{Arc, Mutex};
43use tokio::sync::oneshot;
44
45/// Represents a WebSocket connection initiated by a page.
46///
47/// `WebSocket` objects are created by the Playwright server when the page
48/// opens a WebSocket connection. Use [`crate::protocol::Page::on_websocket`] to receive
49/// `WebSocket` objects.
50///
51/// See: <https://playwright.dev/docs/api/class-websocket>
52#[derive(Clone)]
53pub struct WebSocket {
54    base: ChannelOwnerImpl,
55    /// The URL of the WebSocket connection.
56    url: String,
57    /// Tracks whether the WebSocket has been closed.
58    is_closed: Arc<AtomicBool>,
59    /// General event handlers (frameSent, frameReceived, socketError, close).
60    handlers: Arc<Mutex<Vec<WebSocketEventHandler>>>,
61    /// One-shot senders waiting for the next "close" event.
62    close_waiters: Arc<Mutex<Vec<oneshot::Sender<()>>>>,
63    /// One-shot senders waiting for the next "frameReceived" event.
64    frame_received_waiters: Arc<Mutex<Vec<oneshot::Sender<String>>>>,
65    /// One-shot senders waiting for the next "frameSent" event.
66    frame_sent_waiters: Arc<Mutex<Vec<oneshot::Sender<String>>>>,
67}
68
69/// Type alias for boxed event handler future.
70type WebSocketEventHandlerFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
71
72/// WebSocket event handler type.
73type WebSocketEventHandler =
74    Arc<dyn Fn(WebSocketEvent) -> WebSocketEventHandlerFuture + Send + Sync>;
75
76#[derive(Clone, Debug)]
77enum WebSocketEvent {
78    FrameSent(String),
79    FrameReceived(String),
80    SocketError(String),
81    Close,
82}
83
84impl WebSocket {
85    /// Creates a new `WebSocket` object.
86    pub fn new(
87        parent: Arc<dyn ChannelOwner>,
88        type_name: String,
89        guid: Arc<str>,
90        initializer: Value,
91    ) -> Result<Self> {
92        let url = initializer["url"].as_str().unwrap_or("").to_string();
93
94        let base = ChannelOwnerImpl::new(
95            ParentOrConnection::Parent(parent),
96            type_name,
97            guid,
98            initializer,
99        );
100
101        Ok(Self {
102            base,
103            url,
104            is_closed: Arc::new(AtomicBool::new(false)),
105            handlers: Arc::new(Mutex::new(Vec::new())),
106            close_waiters: Arc::new(Mutex::new(Vec::new())),
107            frame_received_waiters: Arc::new(Mutex::new(Vec::new())),
108            frame_sent_waiters: Arc::new(Mutex::new(Vec::new())),
109        })
110    }
111
112    /// Returns the URL of the WebSocket connection.
113    ///
114    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-url>
115    pub fn url(&self) -> &str {
116        &self.url
117    }
118
119    /// Returns `true` if the WebSocket is closed.
120    ///
121    /// The value becomes `true` when the `"close"` event fires (i.e. when the
122    /// underlying TCP connection is torn down). It remains `false` from
123    /// construction until that point.
124    ///
125    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-is-closed>
126    pub fn is_closed(&self) -> bool {
127        self.is_closed.load(Ordering::Acquire)
128    }
129
130    /// Registers a handler that fires when a frame is sent from the page to the server.
131    ///
132    /// The handler receives the frame payload as a `String`. For binary frames the
133    /// value is the base-64-encoded representation.
134    ///
135    /// # Errors
136    ///
137    /// Returns an error only if the handler cannot be registered (in practice
138    /// this never fails).
139    ///
140    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-frame-sent>
141    pub async fn on_frame_sent<F>(&self, handler: F) -> Result<()>
142    where
143        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
144    {
145        let handler_arc = Arc::new(move |event| match event {
146            WebSocketEvent::FrameSent(payload) => handler(payload),
147            _ => Box::pin(async { Ok(()) }),
148        });
149        self.handlers.lock().unwrap().push(handler_arc);
150        Ok(())
151    }
152
153    /// Registers a handler that fires when a frame is received from the server.
154    ///
155    /// The handler receives the frame payload as a `String`. For binary frames the
156    /// value is the base-64-encoded representation.
157    ///
158    /// # Errors
159    ///
160    /// Returns an error only if the handler cannot be registered (in practice
161    /// this never fails).
162    ///
163    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-frame-received>
164    pub async fn on_frame_received<F>(&self, handler: F) -> Result<()>
165    where
166        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
167    {
168        let handler_arc = Arc::new(move |event| match event {
169            WebSocketEvent::FrameReceived(payload) => handler(payload),
170            _ => Box::pin(async { Ok(()) }),
171        });
172        self.handlers.lock().unwrap().push(handler_arc);
173        Ok(())
174    }
175
176    /// Registers a handler that fires when the WebSocket encounters an error.
177    ///
178    /// The handler receives the error message as a `String`.
179    ///
180    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-socket-error>
181    pub async fn on_error<F>(&self, handler: F) -> Result<()>
182    where
183        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
184    {
185        let handler_arc = Arc::new(move |event| match event {
186            WebSocketEvent::SocketError(msg) => handler(msg),
187            _ => Box::pin(async { Ok(()) }),
188        });
189        self.handlers.lock().unwrap().push(handler_arc);
190        Ok(())
191    }
192
193    /// Registers a handler that fires when the WebSocket is closed.
194    ///
195    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-close>
196    pub async fn on_close<F>(&self, handler: F) -> Result<()>
197    where
198        F: Fn(()) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
199    {
200        let handler_arc = Arc::new(move |event| match event {
201            WebSocketEvent::Close => handler(()),
202            _ => Box::pin(async { Ok(()) }),
203        });
204        self.handlers.lock().unwrap().push(handler_arc);
205        Ok(())
206    }
207
208    /// Creates a one-shot waiter that resolves when the WebSocket is closed.
209    ///
210    /// The waiter **must** be created before the action that closes the
211    /// WebSocket to avoid a race condition.
212    ///
213    /// # Arguments
214    ///
215    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
216    ///
217    /// # Errors
218    ///
219    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if the WebSocket
220    /// is not closed within the timeout.
221    ///
222    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
223    pub async fn expect_close(&self, timeout: Option<f64>) -> Result<EventWaiter<()>> {
224        let (tx, rx) = oneshot::channel();
225        self.close_waiters.lock().unwrap().push(tx);
226        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
227    }
228
229    /// Creates a one-shot waiter that resolves when the next frame is received from the server.
230    ///
231    /// The waiter **must** be created before the action that causes a frame to be
232    /// received to avoid a race condition.
233    ///
234    /// # Arguments
235    ///
236    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
237    ///
238    /// # Errors
239    ///
240    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if no frame is
241    /// received within the timeout.
242    ///
243    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
244    pub async fn expect_frame_received(&self, timeout: Option<f64>) -> Result<EventWaiter<String>> {
245        let (tx, rx) = oneshot::channel();
246        self.frame_received_waiters.lock().unwrap().push(tx);
247        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
248    }
249
250    /// Creates a one-shot waiter that resolves when the next frame is sent from the page.
251    ///
252    /// The waiter **must** be created before the action that sends the frame to
253    /// avoid a race condition.
254    ///
255    /// # Arguments
256    ///
257    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
258    ///
259    /// # Errors
260    ///
261    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if no frame is
262    /// sent within the timeout.
263    ///
264    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
265    pub async fn expect_frame_sent(&self, timeout: Option<f64>) -> Result<EventWaiter<String>> {
266        let (tx, rx) = oneshot::channel();
267        self.frame_sent_waiters.lock().unwrap().push(tx);
268        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
269    }
270
271    /// Dispatches a server-sent event to all registered handlers and waiters.
272    pub(crate) fn handle_event(&self, event: &str, params: &Value) {
273        let ws_event = match event {
274            "frameSent" => {
275                WebSocketEvent::FrameSent(params["data"].as_str().unwrap_or("").to_string())
276            }
277            "frameReceived" => {
278                WebSocketEvent::FrameReceived(params["data"].as_str().unwrap_or("").to_string())
279            }
280            "socketError" => {
281                WebSocketEvent::SocketError(params["error"].as_str().unwrap_or("").to_string())
282            }
283            "close" => {
284                // Mark as closed before notifying waiters so is_closed() is true
285                // when any await continuation runs.
286                self.is_closed.store(true, Ordering::Release);
287                WebSocketEvent::Close
288            }
289            _ => return,
290        };
291
292        // Notify one-shot waiters for specific event types
293        match &ws_event {
294            WebSocketEvent::Close => {
295                let waiters: Vec<_> = std::mem::take(&mut *self.close_waiters.lock().unwrap());
296                for tx in waiters {
297                    let _ = tx.send(());
298                }
299            }
300            WebSocketEvent::FrameReceived(payload) => {
301                if let Some(tx) = self.frame_received_waiters.lock().unwrap().pop() {
302                    let _ = tx.send(payload.clone());
303                }
304            }
305            WebSocketEvent::FrameSent(payload) => {
306                if let Some(tx) = self.frame_sent_waiters.lock().unwrap().pop() {
307                    let _ = tx.send(payload.clone());
308                }
309            }
310            WebSocketEvent::SocketError(_) => {}
311        }
312
313        // Notify general handlers (fire-and-forget)
314        let handlers = self.handlers.lock().unwrap().clone();
315        for handler in handlers {
316            let event = ws_event.clone();
317            tokio::spawn(async move {
318                let _ = handler(event).await;
319            });
320        }
321    }
322}
323
324impl ChannelOwner for WebSocket {
325    fn guid(&self) -> &str {
326        self.base.guid()
327    }
328
329    fn type_name(&self) -> &str {
330        self.base.type_name()
331    }
332
333    fn parent(&self) -> Option<Arc<dyn ChannelOwner>> {
334        self.base.parent()
335    }
336
337    fn connection(&self) -> Arc<dyn crate::server::connection::ConnectionLike> {
338        self.base.connection()
339    }
340
341    fn initializer(&self) -> &Value {
342        self.base.initializer()
343    }
344
345    fn channel(&self) -> &Channel {
346        self.base.channel()
347    }
348
349    fn dispose(&self, reason: crate::server::channel_owner::DisposeReason) {
350        // When the WebSocket object is disposed (page closed, or server-initiated),
351        // mark it as closed and satisfy any pending close waiters — even if the
352        // "close" event was never explicitly delivered before __dispose__.
353        self.is_closed.store(true, Ordering::Release);
354        let waiters: Vec<_> = std::mem::take(&mut *self.close_waiters.lock().unwrap());
355        for tx in waiters {
356            let _ = tx.send(());
357        }
358        self.base.dispose(reason)
359    }
360
361    fn adopt(&self, child: Arc<dyn ChannelOwner>) {
362        self.base.adopt(child)
363    }
364
365    fn add_child(&self, guid: Arc<str>, child: Arc<dyn ChannelOwner>) {
366        self.base.add_child(guid, child)
367    }
368
369    fn remove_child(&self, guid: &str) {
370        self.base.remove_child(guid)
371    }
372
373    fn on_event(&self, method: &str, params: Value) {
374        self.handle_event(method, &params);
375        self.base.on_event(method, params)
376    }
377
378    fn was_collected(&self) -> bool {
379        self.base.was_collected()
380    }
381
382    fn as_any(&self) -> &dyn Any {
383        self
384    }
385}