Skip to main content

playwright_rs/protocol/
web_socket.rs

1// WebSocket protocol object
2//
3// Represents a WebSocket connection in the page.
4//
5// # Example
6//
7// ```ignore
8// use playwright_rs::protocol::{Playwright, WebSocket};
9//
10// #[tokio::main]
11// async fn main() -> Result<(), Box<dyn std::error::Error>> {
12//     let playwright = Playwright::launch().await?;
13//     let browser = playwright.chromium().launch().await?;
14//     let page = browser.new_page().await?;
15//
16//     // Listen for WebSocket connections
17//     page.on_websocket(|ws| {
18//         println!("WebSocket opened: {}", ws.url());
19//
20//         // Listen for frames
21//         let ws_clone = ws.clone();
22//         Box::pin(async move {
23//             ws_clone.on_frame_received(|payload| {
24//                 Box::pin(async move {
25//                     println!("Received: {:?}", payload);
26//                     Ok(())
27//                 })
28//             }).await?;
29//             Ok(())
30//         })
31//     }).await?;
32//
33//     page.goto("https://websocket.org/echo.html", None).await?;
34//
35//     browser.close().await?;
36//     Ok(())
37// }
38// ```
39
40use crate::error::Result;
41use crate::server::channel::Channel;
42use crate::server::channel_owner::{ChannelOwner, ChannelOwnerImpl, ParentOrConnection};
43use serde_json::Value;
44use std::any::Any;
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::{Arc, Mutex};
48
49/// WebSocket represents a WebSocket connection in the page.
50#[derive(Clone)]
51pub struct WebSocket {
52    base: ChannelOwnerImpl,
53    url: String,
54    // Event handlers
55    handlers: Arc<Mutex<Vec<WebSocketEventHandler>>>,
56}
57
58/// Type alias for boxed event handler future
59type WebSocketEventHandlerFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
60
61/// WebSocket event handler
62type WebSocketEventHandler =
63    Arc<dyn Fn(WebSocketEvent) -> WebSocketEventHandlerFuture + Send + Sync>;
64
65#[derive(Clone, Debug)]
66enum WebSocketEvent {
67    FrameSent(String), // Payload is string (text) or base64 (binary)
68    FrameReceived(String),
69    SocketError(String),
70    Close,
71}
72
73impl WebSocket {
74    /// Creates a new WebSocket object
75    pub fn new(
76        parent: Arc<dyn ChannelOwner>,
77        type_name: String,
78        guid: Arc<str>,
79        initializer: Value,
80    ) -> Result<Self> {
81        let url = initializer["url"].as_str().unwrap_or("").to_string();
82
83        let base = ChannelOwnerImpl::new(
84            ParentOrConnection::Parent(parent),
85            type_name,
86            guid,
87            initializer,
88        );
89
90        let handlers = Arc::new(Mutex::new(Vec::new()));
91
92        Ok(Self {
93            base,
94            url,
95            handlers,
96        })
97    }
98
99    /// Returns the URL of the WebSocket.
100    pub fn url(&self) -> &str {
101        &self.url
102    }
103
104    /// Returns true if the WebSocket is closed.
105    pub fn is_closed(&self) -> bool {
106        // Simple check based on basic close event tracking could be added here
107        // For now, we rely on the protocol state or user tracking
108        false
109    }
110
111    /// Adds a listener for FrameSent events.
112    pub async fn on_frame_sent<F>(&self, handler: F) -> Result<()>
113    where
114        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
115    {
116        let handler_arc = Arc::new(move |event| match event {
117            WebSocketEvent::FrameSent(payload) => handler(payload),
118            _ => Box::pin(async { Ok(()) }),
119        });
120        self.handlers.lock().unwrap().push(handler_arc);
121        Ok(())
122    }
123
124    /// Adds a listener for FrameReceived events.
125    pub async fn on_frame_received<F>(&self, handler: F) -> Result<()>
126    where
127        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
128    {
129        let handler_arc = Arc::new(move |event| match event {
130            WebSocketEvent::FrameReceived(payload) => handler(payload),
131            _ => Box::pin(async { Ok(()) }),
132        });
133        self.handlers.lock().unwrap().push(handler_arc);
134        Ok(())
135    }
136
137    /// Adds a listener for SocketError events.
138    pub async fn on_error<F>(&self, handler: F) -> Result<()>
139    where
140        F: Fn(String) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
141    {
142        let handler_arc = Arc::new(move |event| match event {
143            WebSocketEvent::SocketError(msg) => handler(msg),
144            _ => Box::pin(async { Ok(()) }),
145        });
146        self.handlers.lock().unwrap().push(handler_arc);
147        Ok(())
148    }
149
150    /// Adds a listener for Close events.
151    pub async fn on_close<F>(&self, handler: F) -> Result<()>
152    where
153        F: Fn(()) -> WebSocketEventHandlerFuture + Send + Sync + 'static,
154    {
155        let handler_arc = Arc::new(move |event| match event {
156            WebSocketEvent::Close => handler(()),
157            _ => Box::pin(async { Ok(()) }),
158        });
159        self.handlers.lock().unwrap().push(handler_arc);
160        Ok(())
161    }
162
163    // Dispatch methods required by the protocol layer
164    // These are called when the server sends an event
165
166    pub(crate) fn handle_event(&self, event: &str, params: &Value) {
167        let ws_event = match event {
168            "frameSent" => {
169                let _payload = params["opcode"].as_i64().map_or("".to_string(), |op| {
170                    if op == 2 {
171                        // Binary
172                        params["data"].as_str().unwrap_or("").to_string()
173                    } else {
174                        // Text
175                        params["data"].as_str().unwrap_or("").to_string()
176                    }
177                });
178                // Simplified: Just returning data for now
179                WebSocketEvent::FrameSent(params["data"].as_str().unwrap_or("").to_string())
180            }
181            "frameReceived" => {
182                WebSocketEvent::FrameReceived(params["data"].as_str().unwrap_or("").to_string())
183            }
184            "socketError" => {
185                WebSocketEvent::SocketError(params["error"].as_str().unwrap_or("").to_string())
186            }
187            "close" => WebSocketEvent::Close,
188            _ => return,
189        };
190
191        let handlers = self.handlers.lock().unwrap();
192        for handler in handlers.iter() {
193            let handler = handler.clone();
194            let event = ws_event.clone();
195            // Fire and forget
196            tokio::spawn(async move {
197                let _ = handler(event).await;
198            });
199        }
200    }
201}
202
203impl ChannelOwner for WebSocket {
204    fn guid(&self) -> &str {
205        self.base.guid()
206    }
207
208    fn type_name(&self) -> &str {
209        self.base.type_name()
210    }
211
212    fn parent(&self) -> Option<Arc<dyn ChannelOwner>> {
213        self.base.parent()
214    }
215
216    fn connection(&self) -> Arc<dyn crate::server::connection::ConnectionLike> {
217        self.base.connection()
218    }
219
220    fn initializer(&self) -> &Value {
221        self.base.initializer()
222    }
223
224    fn channel(&self) -> &Channel {
225        self.base.channel()
226    }
227
228    fn dispose(&self, reason: crate::server::channel_owner::DisposeReason) {
229        self.base.dispose(reason)
230    }
231
232    fn adopt(&self, child: Arc<dyn ChannelOwner>) {
233        self.base.adopt(child)
234    }
235
236    fn add_child(&self, guid: Arc<str>, child: Arc<dyn ChannelOwner>) {
237        self.base.add_child(guid, child)
238    }
239
240    fn remove_child(&self, guid: &str) {
241        self.base.remove_child(guid)
242    }
243
244    fn on_event(&self, method: &str, params: Value) {
245        self.handle_event(method, &params);
246        self.base.on_event(method, params)
247    }
248
249    fn was_collected(&self) -> bool {
250        self.base.was_collected()
251    }
252
253    fn as_any(&self) -> &dyn Any {
254        self
255    }
256}