Skip to main content

playwright_rs/protocol/
web_socket.rs

1//! WebSocket protocol object — represents a WebSocket connection in the page.
2//!
3//! # Example
4//!
5//! ```ignore
6//! use playwright_rs::protocol::Playwright;
7//!
8//! #[tokio::main]
9//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
10//!     let playwright = Playwright::launch().await?;
11//!     let browser = playwright.chromium().launch().await?;
12//!     let page = browser.new_page().await?;
13//!
14//!     // Set up the waiter BEFORE the action that opens the WebSocket
15//!     let ws_waiter = page.expect_websocket(None).await?;
16//!
17//!     // Navigate to a page that opens a WebSocket
18//!     page.goto("https://example.com/ws-demo", None).await?;
19//!
20//!     let ws = ws_waiter.wait().await?;
21//!     println!("WebSocket URL: {}", ws.url());
22//!
23//!     // Wait for the connection to close
24//!     let close_waiter = ws.expect_close(Some(5000.0)).await?;
25//!     // ... trigger close ...
26//!     close_waiter.wait().await?;
27//!
28//!     assert!(ws.is_closed());
29//!
30//!     browser.close().await?;
31//!     Ok(())
32//! }
33//! ```
34
35use crate::error::Result;
36use crate::protocol::EventWaiter;
37use crate::server::channel::Channel;
38use crate::server::channel_owner::{ChannelOwner, ChannelOwnerImpl, ParentOrConnection};
39use serde_json::Value;
40use std::any::Any;
41use std::future::Future;
42use std::pin::Pin;
43use std::sync::atomic::{AtomicBool, Ordering};
44use std::sync::{Arc, Mutex};
45use tokio::sync::oneshot;
46
47/// Represents a WebSocket connection initiated by a page.
48///
49/// `WebSocket` objects are created by the Playwright server when the page
50/// opens a WebSocket connection. Use [`crate::protocol::Page::on_websocket`] to receive
51/// `WebSocket` objects.
52///
53/// See: <https://playwright.dev/docs/api/class-websocket>
54#[derive(Clone)]
55pub struct WebSocket {
56    base: ChannelOwnerImpl,
57    /// The URL of the WebSocket connection.
58    url: String,
59    /// Tracks whether the WebSocket has been closed.
60    is_closed: Arc<AtomicBool>,
61    /// General event handlers (frameSent, frameReceived, socketError, close).
62    handlers: Arc<Mutex<Vec<WebSocketEventHandler>>>,
63    /// One-shot senders waiting for the next "close" event.
64    close_waiters: Arc<Mutex<Vec<oneshot::Sender<()>>>>,
65    /// One-shot senders waiting for the next "frameReceived" event.
66    frame_received_waiters: Arc<Mutex<Vec<oneshot::Sender<String>>>>,
67    /// One-shot senders waiting for the next "frameSent" event.
68    frame_sent_waiters: Arc<Mutex<Vec<oneshot::Sender<String>>>>,
69}
70
71/// Type alias for boxed event handler future.
72type WebSocketEventHandlerFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
73
74/// WebSocket event handler type.
75type WebSocketEventHandler =
76    Arc<dyn Fn(WebSocketEvent) -> WebSocketEventHandlerFuture + Send + Sync>;
77
78#[derive(Clone, Debug)]
79enum WebSocketEvent {
80    FrameSent(String),
81    FrameReceived(String),
82    SocketError(String),
83    Close,
84}
85
86impl WebSocket {
87    /// Creates a new `WebSocket` object.
88    pub fn new(
89        parent: Arc<dyn ChannelOwner>,
90        type_name: String,
91        guid: Arc<str>,
92        initializer: Value,
93    ) -> Result<Self> {
94        let url = initializer["url"].as_str().unwrap_or("").to_string();
95
96        let base = ChannelOwnerImpl::new(
97            ParentOrConnection::Parent(parent),
98            type_name,
99            guid,
100            initializer,
101        );
102
103        Ok(Self {
104            base,
105            url,
106            is_closed: Arc::new(AtomicBool::new(false)),
107            handlers: Arc::new(Mutex::new(Vec::new())),
108            close_waiters: Arc::new(Mutex::new(Vec::new())),
109            frame_received_waiters: Arc::new(Mutex::new(Vec::new())),
110            frame_sent_waiters: Arc::new(Mutex::new(Vec::new())),
111        })
112    }
113
114    /// Returns the URL of the WebSocket connection.
115    ///
116    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-url>
117    pub fn url(&self) -> &str {
118        &self.url
119    }
120
121    /// Returns `true` if the WebSocket is closed.
122    ///
123    /// The value becomes `true` when the `"close"` event fires (i.e. when the
124    /// underlying TCP connection is torn down). It remains `false` from
125    /// construction until that point.
126    ///
127    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-is-closed>
128    pub fn is_closed(&self) -> bool {
129        self.is_closed.load(Ordering::Acquire)
130    }
131
132    /// Registers a handler that fires when a frame is sent from the page to the server.
133    ///
134    /// The handler receives the frame payload as a `String`. For binary frames the
135    /// value is the base-64-encoded representation.
136    ///
137    /// # Errors
138    ///
139    /// Returns an error only if the handler cannot be registered (in practice
140    /// this never fails).
141    ///
142    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-frame-sent>
143    pub async fn on_frame_sent<F>(&self, handler: F) -> Result<()>
144    where
145        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
146    {
147        let handler_arc = Arc::new(move |event| match event {
148            WebSocketEvent::FrameSent(payload) => handler(payload),
149            _ => Box::pin(async { Ok(()) }),
150        });
151        self.handlers.lock().unwrap().push(handler_arc);
152        Ok(())
153    }
154
155    /// Registers a handler that fires when a frame is received from the server.
156    ///
157    /// The handler receives the frame payload as a `String`. For binary frames the
158    /// value is the base-64-encoded representation.
159    ///
160    /// # Errors
161    ///
162    /// Returns an error only if the handler cannot be registered (in practice
163    /// this never fails).
164    ///
165    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-frame-received>
166    pub async fn on_frame_received<F>(&self, handler: F) -> Result<()>
167    where
168        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
169    {
170        let handler_arc = Arc::new(move |event| match event {
171            WebSocketEvent::FrameReceived(payload) => handler(payload),
172            _ => Box::pin(async { Ok(()) }),
173        });
174        self.handlers.lock().unwrap().push(handler_arc);
175        Ok(())
176    }
177
178    /// Registers a handler that fires when the WebSocket encounters an error.
179    ///
180    /// The handler receives the error message as a `String`.
181    ///
182    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-socket-error>
183    pub async fn on_error<F>(&self, handler: F) -> Result<()>
184    where
185        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
186    {
187        let handler_arc = Arc::new(move |event| match event {
188            WebSocketEvent::SocketError(msg) => handler(msg),
189            _ => Box::pin(async { Ok(()) }),
190        });
191        self.handlers.lock().unwrap().push(handler_arc);
192        Ok(())
193    }
194
195    /// Registers a handler that fires when the WebSocket is closed.
196    ///
197    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-event-close>
198    pub async fn on_close<F>(&self, handler: F) -> Result<()>
199    where
200        F: Fn(()) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
201    {
202        let handler_arc = Arc::new(move |event| match event {
203            WebSocketEvent::Close => handler(()),
204            _ => Box::pin(async { Ok(()) }),
205        });
206        self.handlers.lock().unwrap().push(handler_arc);
207        Ok(())
208    }
209
210    /// Creates a one-shot waiter that resolves when the WebSocket is closed.
211    ///
212    /// The waiter **must** be created before the action that closes the
213    /// WebSocket to avoid a race condition.
214    ///
215    /// # Arguments
216    ///
217    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
218    ///
219    /// # Errors
220    ///
221    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if the WebSocket
222    /// is not closed within the timeout.
223    ///
224    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
225    pub async fn expect_close(&self, timeout: Option<f64>) -> Result<EventWaiter<()>> {
226        let (tx, rx) = oneshot::channel();
227        self.close_waiters.lock().unwrap().push(tx);
228        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
229    }
230
231    /// Creates a one-shot waiter that resolves when the next frame is received from the server.
232    ///
233    /// The waiter **must** be created before the action that causes a frame to be
234    /// received to avoid a race condition.
235    ///
236    /// # Arguments
237    ///
238    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
239    ///
240    /// # Errors
241    ///
242    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if no frame is
243    /// received within the timeout.
244    ///
245    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
246    pub async fn expect_frame_received(&self, timeout: Option<f64>) -> Result<EventWaiter<String>> {
247        let (tx, rx) = oneshot::channel();
248        self.frame_received_waiters.lock().unwrap().push(tx);
249        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
250    }
251
252    /// Creates a one-shot waiter that resolves when the next frame is sent from the page.
253    ///
254    /// The waiter **must** be created before the action that sends the frame to
255    /// avoid a race condition.
256    ///
257    /// # Arguments
258    ///
259    /// * `timeout` — Timeout in milliseconds. Defaults to 30 000 ms if `None`.
260    ///
261    /// # Errors
262    ///
263    /// Returns [`Error::Timeout`](crate::error::Error::Timeout) if no frame is
264    /// sent within the timeout.
265    ///
266    /// See: <https://playwright.dev/docs/api/class-websocket#web-socket-wait-for-event>
267    pub async fn expect_frame_sent(&self, timeout: Option<f64>) -> Result<EventWaiter<String>> {
268        let (tx, rx) = oneshot::channel();
269        self.frame_sent_waiters.lock().unwrap().push(tx);
270        Ok(EventWaiter::new(rx, timeout.or(Some(30_000.0))))
271    }
272
273    /// Dispatches a server-sent event to all registered handlers and waiters.
274    pub(crate) fn handle_event(&self, event: &str, params: &Value) {
275        let ws_event = match event {
276            "frameSent" => {
277                WebSocketEvent::FrameSent(params["data"].as_str().unwrap_or("").to_string())
278            }
279            "frameReceived" => {
280                WebSocketEvent::FrameReceived(params["data"].as_str().unwrap_or("").to_string())
281            }
282            "socketError" => {
283                WebSocketEvent::SocketError(params["error"].as_str().unwrap_or("").to_string())
284            }
285            "close" => {
286                // Mark as closed before notifying waiters so is_closed() is true
287                // when any await continuation runs.
288                self.is_closed.store(true, Ordering::Release);
289                WebSocketEvent::Close
290            }
291            _ => return,
292        };
293
294        // Notify one-shot waiters for specific event types
295        match &ws_event {
296            WebSocketEvent::Close => {
297                let waiters: Vec<_> = std::mem::take(&mut *self.close_waiters.lock().unwrap());
298                for tx in waiters {
299                    let _ = tx.send(());
300                }
301            }
302            WebSocketEvent::FrameReceived(payload) => {
303                if let Some(tx) = self.frame_received_waiters.lock().unwrap().pop() {
304                    let _ = tx.send(payload.clone());
305                }
306            }
307            WebSocketEvent::FrameSent(payload) => {
308                if let Some(tx) = self.frame_sent_waiters.lock().unwrap().pop() {
309                    let _ = tx.send(payload.clone());
310                }
311            }
312            WebSocketEvent::SocketError(_) => {}
313        }
314
315        // Notify general handlers (fire-and-forget)
316        let handlers = self.handlers.lock().unwrap().clone();
317        for handler in handlers {
318            let event = ws_event.clone();
319            tokio::spawn(async move {
320                let _ = handler(event).await;
321            });
322        }
323    }
324}
325
326impl ChannelOwner for WebSocket {
327    fn guid(&self) -> &str {
328        self.base.guid()
329    }
330
331    fn type_name(&self) -> &str {
332        self.base.type_name()
333    }
334
335    fn parent(&self) -> Option<Arc<dyn ChannelOwner>> {
336        self.base.parent()
337    }
338
339    fn connection(&self) -> Arc<dyn crate::server::connection::ConnectionLike> {
340        self.base.connection()
341    }
342
343    fn initializer(&self) -> &Value {
344        self.base.initializer()
345    }
346
347    fn channel(&self) -> &Channel {
348        self.base.channel()
349    }
350
351    fn dispose(&self, reason: crate::server::channel_owner::DisposeReason) {
352        // When the WebSocket object is disposed (page closed, or server-initiated),
353        // mark it as closed and satisfy any pending close waiters — even if the
354        // "close" event was never explicitly delivered before __dispose__.
355        self.is_closed.store(true, Ordering::Release);
356        let waiters: Vec<_> = std::mem::take(&mut *self.close_waiters.lock().unwrap());
357        for tx in waiters {
358            let _ = tx.send(());
359        }
360        self.base.dispose(reason)
361    }
362
363    fn adopt(&self, child: Arc<dyn ChannelOwner>) {
364        self.base.adopt(child)
365    }
366
367    fn add_child(&self, guid: Arc<str>, child: Arc<dyn ChannelOwner>) {
368        self.base.add_child(guid, child)
369    }
370
371    fn remove_child(&self, guid: &str) {
372        self.base.remove_child(guid)
373    }
374
375    fn on_event(&self, method: &str, params: Value) {
376        self.handle_event(method, &params);
377        self.base.on_event(method, params)
378    }
379
380    fn was_collected(&self) -> bool {
381        self.base.was_collected()
382    }
383
384    fn as_any(&self) -> &dyn Any {
385        self
386    }
387}