Skip to main content

night_fury_core/domains/
websocket.rs

1use chromiumoxide_cdp::cdp::browser_protocol::network::{
2    EnableParams as NetworkEnableParams, EventWebSocketCreated, EventWebSocketFrameReceived,
3    EventWebSocketFrameSent,
4};
5use futures::StreamExt;
6use tokio::sync::oneshot;
7
8use crate::error::NightFuryError;
9use crate::session::BrowserSession;
10use crate::types::{WsDirection, WsMessage};
11use crate::worker::WorkerState;
12
13// ---------------------------------------------------------------------------
14// Command enum
15// ---------------------------------------------------------------------------
16
17/// Commands for the WebSocket monitoring domain.
18#[non_exhaustive]
19pub enum WebSocketCmd {
20    EnableWsCapture {
21        reply: oneshot::Sender<Result<String, String>>,
22    },
23    GetWsMessages {
24        reply: oneshot::Sender<Result<Vec<WsMessage>, String>>,
25    },
26}
27
28// ---------------------------------------------------------------------------
29// Dispatch
30// ---------------------------------------------------------------------------
31
32impl WebSocketCmd {
33    pub(crate) async fn dispatch(self, state: &mut WorkerState) {
34        match self {
35            WebSocketCmd::EnableWsCapture { reply } => handle_enable_ws_capture(state, reply).await,
36            WebSocketCmd::GetWsMessages { reply } => handle_get_ws_messages(state, reply).await,
37        }
38    }
39}
40
41// ---------------------------------------------------------------------------
42// Handlers
43// ---------------------------------------------------------------------------
44
45async fn handle_enable_ws_capture(
46    state: &mut WorkerState,
47    reply: oneshot::Sender<Result<String, String>>,
48) {
49    let result: Result<String, String> = async {
50        // If already enabled, return early.
51        if state
52            .ws_capture_enabled
53            .load(std::sync::atomic::Ordering::SeqCst)
54        {
55            return Ok("WebSocket capture already enabled".to_string());
56        }
57
58        let page = &state.tabs[state.active_tab].page;
59
60        // Network.enable is required for WebSocket events.
61        page.raw_page()
62            .execute(NetworkEnableParams::default())
63            .await
64            .map_err(|e| format!("Network.enable failed: {e}"))?;
65
66        state
67            .ws_capture_enabled
68            .store(true, std::sync::atomic::Ordering::SeqCst);
69
70        // Track WebSocket URLs by request ID so we can associate frames with URLs.
71        let ws_urls = std::sync::Arc::clone(&state.ws_urls);
72        let ws_messages = std::sync::Arc::clone(&state.ws_messages);
73
74        // Spawn listener for webSocketCreated events to track URLs.
75        let ws_urls_created = std::sync::Arc::clone(&ws_urls);
76        if let Ok(mut stream) = page
77            .raw_page()
78            .event_listener::<EventWebSocketCreated>()
79            .await
80        {
81            tokio::task::spawn_local(async move {
82                while let Some(event) = stream.next().await {
83                    let mut locked = ws_urls_created.lock().unwrap();
84                    locked.insert(event.request_id.inner().to_string(), event.url.clone());
85                }
86            });
87        }
88
89        // Spawn listener for webSocketFrameReceived events.
90        let ws_urls_recv = std::sync::Arc::clone(&ws_urls);
91        let ws_messages_recv = std::sync::Arc::clone(&ws_messages);
92        if let Ok(mut stream) = page
93            .raw_page()
94            .event_listener::<EventWebSocketFrameReceived>()
95            .await
96        {
97            tokio::task::spawn_local(async move {
98                while let Some(event) = stream.next().await {
99                    let url = {
100                        let locked = ws_urls_recv.lock().unwrap();
101                        locked
102                            .get(event.request_id.inner())
103                            .cloned()
104                            .unwrap_or_default()
105                    };
106                    let msg = WsMessage {
107                        url,
108                        direction: WsDirection::Received,
109                        data: event.response.payload_data.clone(),
110                        opcode: event.response.opcode as u8,
111                        timestamp: *event.timestamp.inner(),
112                    };
113                    let mut locked = ws_messages_recv.lock().unwrap();
114                    locked.push(msg);
115                }
116            });
117        }
118
119        // Spawn listener for webSocketFrameSent events.
120        let ws_urls_sent = std::sync::Arc::clone(&ws_urls);
121        let ws_messages_sent = std::sync::Arc::clone(&ws_messages);
122        if let Ok(mut stream) = page
123            .raw_page()
124            .event_listener::<EventWebSocketFrameSent>()
125            .await
126        {
127            tokio::task::spawn_local(async move {
128                while let Some(event) = stream.next().await {
129                    let url = {
130                        let locked = ws_urls_sent.lock().unwrap();
131                        locked
132                            .get(event.request_id.inner())
133                            .cloned()
134                            .unwrap_or_default()
135                    };
136                    let msg = WsMessage {
137                        url,
138                        direction: WsDirection::Sent,
139                        data: event.response.payload_data.clone(),
140                        opcode: event.response.opcode as u8,
141                        timestamp: *event.timestamp.inner(),
142                    };
143                    let mut locked = ws_messages_sent.lock().unwrap();
144                    locked.push(msg);
145                }
146            });
147        }
148
149        Ok("WebSocket capture enabled".to_string())
150    }
151    .await;
152    let _ = reply.send(result);
153}
154
155async fn handle_get_ws_messages(
156    state: &mut WorkerState,
157    reply: oneshot::Sender<Result<Vec<WsMessage>, String>>,
158) {
159    let messages = {
160        let mut locked = state.ws_messages.lock().unwrap();
161        std::mem::take(&mut *locked)
162    };
163    let _ = reply.send(Ok(messages));
164}
165
166// ---------------------------------------------------------------------------
167// Session API
168// ---------------------------------------------------------------------------
169
170impl BrowserSession {
171    /// Enable WebSocket message capture.
172    ///
173    /// Subscribes to CDP `Network.webSocketCreated`, `Network.webSocketFrameSent`,
174    /// and `Network.webSocketFrameReceived` events. Captured messages are
175    /// retrieved via `get_ws_messages()`.
176    ///
177    /// This method is idempotent: calling it multiple times is safe.
178    pub async fn enable_ws_capture(&self) -> Result<String, NightFuryError> {
179        send_cmd!(
180            self,
181            |tx| crate::cmd::BrowserCmd::WebSocket(WebSocketCmd::EnableWsCapture { reply: tx }),
182            NightFuryError::OperationFailed
183        )
184    }
185
186    /// Retrieve and drain all captured WebSocket messages since the last call.
187    ///
188    /// Returns a `Vec<WsMessage>` containing the URL, direction, payload data,
189    /// opcode, and timestamp of each captured frame. The internal buffer is
190    /// cleared after each call.
191    pub async fn get_ws_messages(&self) -> Result<Vec<WsMessage>, NightFuryError> {
192        send_cmd!(
193            self,
194            |tx| crate::cmd::BrowserCmd::WebSocket(WebSocketCmd::GetWsMessages { reply: tx }),
195            NightFuryError::OperationFailed
196        )
197    }
198}