use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::{RwLock, broadcast};
use tracing::{debug, trace};
use viewpoint_cdp::CdpConnection;
use viewpoint_cdp::protocol::{
WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrame as CdpWebSocketFrame,
WebSocketFrameReceivedEvent, WebSocketFrameSentEvent,
};
#[derive(Clone)]
pub struct WebSocket {
request_id: String,
url: String,
is_closed: Arc<AtomicBool>,
frame_sent_tx: broadcast::Sender<WebSocketFrame>,
frame_received_tx: broadcast::Sender<WebSocketFrame>,
close_tx: broadcast::Sender<()>,
}
impl std::fmt::Debug for WebSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocket")
.field("request_id", &self.request_id)
.field("url", &self.url)
.field("is_closed", &self.is_closed.load(Ordering::SeqCst))
.finish()
}
}
impl WebSocket {
pub(crate) fn new(request_id: String, url: String) -> Self {
let (frame_sent_tx, _) = broadcast::channel(256);
let (frame_received_tx, _) = broadcast::channel(256);
let (close_tx, _) = broadcast::channel(16);
Self {
request_id,
url,
is_closed: Arc::new(AtomicBool::new(false)),
frame_sent_tx,
frame_received_tx,
close_tx,
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::SeqCst)
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub async fn on_framesent<F, Fut>(&self, handler: F)
where
F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut rx = self.frame_sent_tx.subscribe();
tokio::spawn(async move {
while let Ok(frame) = rx.recv().await {
handler(frame).await;
}
});
}
pub async fn on_framereceived<F, Fut>(&self, handler: F)
where
F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut rx = self.frame_received_tx.subscribe();
tokio::spawn(async move {
while let Ok(frame) = rx.recv().await {
handler(frame).await;
}
});
}
pub async fn on_close<F, Fut>(&self, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut rx = self.close_tx.subscribe();
tokio::spawn(async move {
if rx.recv().await.is_ok() {
handler().await;
}
});
}
pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
let _ = self.frame_sent_tx.send(frame);
}
pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
let _ = self.frame_received_tx.send(frame);
}
pub(crate) fn mark_closed(&self) {
self.is_closed.store(true, Ordering::SeqCst);
let _ = self.close_tx.send(());
}
}
#[derive(Debug, Clone)]
pub struct WebSocketFrame {
opcode: u8,
payload_data: String,
}
impl WebSocketFrame {
pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
Self {
opcode,
payload_data,
}
}
pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
Self {
opcode: cdp_frame.opcode as u8,
payload_data: cdp_frame.payload_data.clone(),
}
}
pub fn opcode(&self) -> u8 {
self.opcode
}
pub fn payload(&self) -> &str {
&self.payload_data
}
pub fn is_text(&self) -> bool {
self.opcode == 1
}
pub fn is_binary(&self) -> bool {
self.opcode == 2
}
}
pub type WebSocketEventHandler =
Box<dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
pub struct WebSocketManager {
connection: Arc<CdpConnection>,
session_id: String,
websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
is_listening: AtomicBool,
}
impl WebSocketManager {
pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
Self {
connection,
session_id,
websockets: Arc::new(RwLock::new(HashMap::new())),
handler: Arc::new(RwLock::new(None)),
is_listening: AtomicBool::new(false),
}
}
pub async fn set_handler<F, Fut>(&self, handler: F)
where
F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let boxed_handler: WebSocketEventHandler = Box::new(move |ws| Box::pin(handler(ws)));
let mut h = self.handler.write().await;
*h = Some(boxed_handler);
self.start_listening().await;
}
pub async fn remove_handler(&self) {
let mut h = self.handler.write().await;
*h = None;
}
async fn start_listening(&self) {
if self.is_listening.swap(true, Ordering::SeqCst) {
return;
}
let mut events = self.connection.subscribe_events();
let session_id = self.session_id.clone();
let websockets = self.websockets.clone();
let handler = self.handler.clone();
tokio::spawn(async move {
debug!("WebSocket manager started listening for events");
while let Ok(event) = events.recv().await {
if event.session_id.as_deref() != Some(&session_id) {
continue;
}
match event.method.as_str() {
"Network.webSocketCreated" => {
if let Some(params) = &event.params {
if let Ok(created) =
serde_json::from_value::<WebSocketCreatedEvent>(params.clone())
{
trace!(
"WebSocket created: {} -> {}",
created.request_id, created.url
);
let ws = WebSocket::new(created.request_id.clone(), created.url);
{
let mut sockets = websockets.write().await;
sockets.insert(created.request_id, ws.clone());
}
let h = handler.read().await;
if let Some(ref handler_fn) = *h {
handler_fn(ws).await;
}
}
}
}
"Network.webSocketClosed" => {
if let Some(params) = &event.params {
if let Ok(closed) =
serde_json::from_value::<WebSocketClosedEvent>(params.clone())
{
trace!("WebSocket closed: {}", closed.request_id);
let sockets = websockets.read().await;
if let Some(ws) = sockets.get(&closed.request_id) {
ws.mark_closed();
}
}
}
}
"Network.webSocketFrameSent" => {
if let Some(params) = &event.params {
if let Ok(frame_event) =
serde_json::from_value::<WebSocketFrameSentEvent>(params.clone())
{
trace!("WebSocket frame sent: {}", frame_event.request_id);
let sockets = websockets.read().await;
if let Some(ws) = sockets.get(&frame_event.request_id) {
let frame = WebSocketFrame::from_cdp(&frame_event.response);
ws.emit_frame_sent(frame);
}
}
}
}
"Network.webSocketFrameReceived" => {
if let Some(params) = &event.params {
if let Ok(frame_event) = serde_json::from_value::<
WebSocketFrameReceivedEvent,
>(params.clone())
{
trace!("WebSocket frame received: {}", frame_event.request_id);
let sockets = websockets.read().await;
if let Some(ws) = sockets.get(&frame_event.request_id) {
let frame = WebSocketFrame::from_cdp(&frame_event.response);
ws.emit_frame_received(frame);
}
}
}
}
_ => {}
}
}
debug!("WebSocket manager stopped listening");
});
}
pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
let sockets = self.websockets.read().await;
sockets.get(request_id).cloned()
}
pub async fn all(&self) -> Vec<WebSocket> {
let sockets = self.websockets.read().await;
sockets.values().cloned().collect()
}
}
#[cfg(test)]
mod tests;