use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock, Weak};
use base64::Engine;
use serde_json::{Value, json};
use crate::backend::AnyPage;
use crate::url_matcher::UrlMatcher;
pub const WS_BINDING_NAME: &str = "__pwWebSocketBinding";
pub const WS_MOCK_SOURCE: &str = include_str!("injected/dist/websocket-mock.min.js");
#[derive(Clone, Debug)]
pub enum WsMessage {
Text(String),
Binary(Vec<u8>),
}
impl WsMessage {
fn to_wsdata(&self) -> Value {
match self {
WsMessage::Text(s) => json!({ "data": s, "isBase64": false }),
WsMessage::Binary(b) => json!({
"data": base64::engine::general_purpose::STANDARD.encode(b),
"isBase64": true,
}),
}
}
fn from_wsdata(data: &Value) -> Self {
let is_base64 = data.get("isBase64").and_then(Value::as_bool).unwrap_or(false);
let raw = data.get("data").and_then(Value::as_str).unwrap_or("");
if is_base64 {
let bytes = base64::engine::general_purpose::STANDARD
.decode(raw)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "WS binary frame: malformed base64 from page mock; treating as empty");
Vec::new()
});
WsMessage::Binary(bytes)
} else {
WsMessage::Text(raw.to_string())
}
}
}
pub type WsHandlerFuture = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
pub type WsHandler = Arc<dyn Fn(WebSocketRoute) -> WsHandlerFuture + Send + Sync>;
type WsMsgCb = Arc<dyn Fn(WsMessage) + Send + Sync>;
type WsCloseCb = Arc<dyn Fn(Option<u32>, Option<String>) + Send + Sync>;
#[derive(Default)]
struct WsCallbacks {
page_message: Option<WsMsgCb>,
page_close: Option<WsCloseCb>,
server_message: Option<WsMsgCb>,
server_close: Option<WsCloseCb>,
}
struct WsRouteState {
id: String,
url: String,
protocols: Vec<String>,
page: AnyPage,
frame_id: Option<String>,
callbacks: Mutex<WsCallbacks>,
connected: AtomicBool,
}
impl WsRouteState {
async fn dispatch(&self, request: Value) {
let fn_source = format!(
"() => {{ globalThis.__pwWebSocketDispatch && globalThis.__pwWebSocketDispatch({}); }}",
serde_json::to_string(&request).unwrap_or_else(|_| "null".to_string())
);
let _ = self
.page
.call_utility_evaluate(&fn_source, &[], &[], self.frame_id.as_deref(), Some(true), true)
.await;
}
}
#[derive(Clone)]
pub struct WebSocketRoute {
inner: Arc<WsRouteState>,
}
impl WebSocketRoute {
fn new(id: String, url: String, protocols: Vec<String>, page: AnyPage, frame_id: Option<String>) -> Self {
Self {
inner: Arc::new(WsRouteState {
id,
url,
protocols,
page,
frame_id,
callbacks: Mutex::new(WsCallbacks::default()),
connected: AtomicBool::new(false),
}),
}
}
#[must_use]
pub fn url(&self) -> &str {
&self.inner.url
}
#[must_use]
pub fn protocols(&self) -> &[String] {
&self.inner.protocols
}
pub async fn send(&self, message: WsMessage) {
self
.inner
.dispatch(json!({ "id": self.inner.id, "type": "sendToPage", "data": message.to_wsdata() }))
.await;
}
pub async fn close(&self, code: Option<u32>, reason: Option<String>) {
self
.inner
.dispatch(json!({
"id": self.inner.id, "type": "closePage",
"code": code, "reason": reason, "wasClean": true,
}))
.await;
}
pub fn on_message(&self, cb: WsMsgCb) {
self.lock().page_message = Some(cb);
}
pub fn on_close(&self, cb: WsCloseCb) {
self.lock().page_close = Some(cb);
}
#[must_use]
pub fn connect_to_server(&self) -> WebSocketRouteServer {
self.inner.connected.store(true, Ordering::SeqCst);
WebSocketRouteServer {
inner: self.inner.clone(),
}
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.inner.connected.load(Ordering::SeqCst)
}
async fn after_handle(&self) {
let req = if self.is_connected() {
json!({ "id": self.inner.id, "type": "connect" })
} else {
json!({ "id": self.inner.id, "type": "ensureOpened" })
};
self.inner.dispatch(req).await;
}
fn lock(&self) -> std::sync::MutexGuard<'_, WsCallbacks> {
self
.inner
.callbacks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
async fn on_message_from_page(&self, data: &Value) {
let cb = self.lock().page_message.clone();
if let Some(cb) = cb {
cb(WsMessage::from_wsdata(data));
} else if self.is_connected() {
self
.inner
.dispatch(json!({ "id": self.inner.id, "type": "sendToServer", "data": data }))
.await;
}
}
async fn on_message_from_server(&self, data: &Value) {
let cb = self.lock().server_message.clone();
if let Some(cb) = cb {
cb(WsMessage::from_wsdata(data));
} else {
self
.inner
.dispatch(json!({ "id": self.inner.id, "type": "sendToPage", "data": data }))
.await;
}
}
async fn on_close_page(&self, code: Option<u32>, reason: Option<String>, was_clean: bool) {
let cb = self.lock().page_close.clone();
if let Some(cb) = cb {
cb(code, reason);
} else {
self
.inner
.dispatch(json!({
"id": self.inner.id, "type": "closeServer",
"code": code, "reason": reason, "wasClean": was_clean,
}))
.await;
}
}
async fn on_close_server(&self, code: Option<u32>, reason: Option<String>, was_clean: bool) {
let cb = self.lock().server_close.clone();
if let Some(cb) = cb {
cb(code, reason);
} else {
self
.inner
.dispatch(json!({
"id": self.inner.id, "type": "closePage",
"code": code, "reason": reason, "wasClean": was_clean,
}))
.await;
}
}
fn execution_context_gone(&self) {
let (page_cb, server_cb) = {
let cbs = self.lock();
(cbs.page_close.clone(), cbs.server_close.clone())
};
if let Some(cb) = page_cb {
cb(None, None);
}
if let Some(cb) = server_cb {
cb(None, None);
}
}
}
#[derive(Clone)]
pub struct WebSocketRouteServer {
inner: Arc<WsRouteState>,
}
impl WebSocketRouteServer {
#[must_use]
pub fn url(&self) -> &str {
&self.inner.url
}
pub async fn send(&self, message: WsMessage) {
self
.inner
.dispatch(json!({ "id": self.inner.id, "type": "sendToServer", "data": message.to_wsdata() }))
.await;
}
pub async fn close(&self, code: Option<u32>, reason: Option<String>) {
self
.inner
.dispatch(json!({
"id": self.inner.id, "type": "closeServer",
"code": code, "reason": reason, "wasClean": true,
}))
.await;
}
pub fn on_message(&self, cb: WsMsgCb) {
self
.inner
.callbacks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.server_message = Some(cb);
}
pub fn on_close(&self, cb: WsCloseCb) {
self
.inner
.callbacks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.server_close = Some(cb);
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum WsRouteScope {
Page,
Context,
}
pub struct PageWsRouter {
page: AnyPage,
page_routes: Mutex<Vec<(UrlMatcher, WsHandler)>>,
context_routes: Mutex<Vec<(UrlMatcher, WsHandler)>>,
active: Mutex<rustc_hash::FxHashMap<String, WebSocketRoute>>,
installed: AtomicBool,
}
impl PageWsRouter {
#[must_use]
pub fn new(page: AnyPage) -> Arc<Self> {
let router = Arc::new(Self {
page,
page_routes: Mutex::new(Vec::new()),
context_routes: Mutex::new(Vec::new()),
active: Mutex::new(rustc_hash::FxHashMap::default()),
installed: AtomicBool::new(false),
});
Self::spawn_lifecycle_listener(&router);
router
}
fn spawn_lifecycle_listener(router: &Arc<Self>) {
let mut rx = router.page.events().subscribe();
let weak = Arc::downgrade(router);
tokio::spawn(async move {
while let Some(event) = crate::events::recv_tolerant(&mut rx).await {
let Some(router) = weak.upgrade() else { break };
match event {
crate::events::PageEvent::FrameNavigated(info) => {
router.frame_context_gone(&info.frame_id);
},
crate::events::PageEvent::FrameDetached { frame_id } => {
router.frame_context_gone(&frame_id);
},
crate::events::PageEvent::Close => {
router.all_contexts_gone();
break;
},
_ => {},
}
}
});
}
fn frame_context_gone(&self, frame_id: &str) {
let gone: Vec<WebSocketRoute> = {
let mut active = self.active.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let ids: Vec<String> = active
.iter()
.filter(|(_, r)| r.inner.frame_id.as_deref() == Some(frame_id))
.map(|(id, _)| id.clone())
.collect();
ids.iter().filter_map(|id| active.remove(id)).collect()
};
for route in gone {
route.execution_context_gone();
}
}
fn all_contexts_gone(&self) {
let gone: Vec<WebSocketRoute> = {
let mut active = self.active.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
active.drain().map(|(_, r)| r).collect()
};
for route in gone {
route.execution_context_gone();
}
}
pub fn add_route(&self, matcher: UrlMatcher, handler: WsHandler, scope: WsRouteScope) -> bool {
let list = match scope {
WsRouteScope::Page => &self.page_routes,
WsRouteScope::Context => &self.context_routes,
};
list
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(0, (matcher, handler));
!self.installed.swap(true, Ordering::SeqCst)
}
pub async fn handle_binding(self: &Arc<Self>, payload: &Value, source_frame: Option<&str>) {
let kind = payload.get("type").and_then(Value::as_str).unwrap_or("");
if kind == "onCreate" {
self.handle_create(payload, source_frame).await;
return;
}
let Some(id) = payload.get("id").and_then(Value::as_str) else {
return;
};
let route = self
.active
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.get(id)
.cloned();
let Some(route) = route else { return };
match kind {
"onMessageFromPage" => {
if let Some(data) = payload.get("data") {
route.on_message_from_page(data).await;
}
},
"onMessageFromServer" => {
if let Some(data) = payload.get("data") {
route.on_message_from_server(data).await;
}
},
"onClosePage" => {
route
.on_close_page(close_code(payload), close_reason(payload), was_clean(payload))
.await;
},
"onCloseServer" => {
route
.on_close_server(close_code(payload), close_reason(payload), was_clean(payload))
.await;
},
_ => {},
}
}
async fn handle_create(self: &Arc<Self>, payload: &Value, source_frame: Option<&str>) {
let id = payload.get("id").and_then(Value::as_str).unwrap_or("").to_string();
let url = payload.get("url").and_then(Value::as_str).unwrap_or("").to_string();
let protocols = payload
.get("protocols")
.and_then(Value::as_array)
.map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();
let handler = {
let page_routes = self
.page_routes
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
page_routes
.iter()
.find(|(m, _)| m.matches(&url))
.map(|(_, h)| h.clone())
}
.or_else(|| {
let ctx_routes = self
.context_routes
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
ctx_routes.iter().find(|(m, _)| m.matches(&url)).map(|(_, h)| h.clone())
});
let Some(handler) = handler else {
let req = json!({ "id": id, "type": "passthrough" });
let fn_source = format!(
"() => {{ globalThis.__pwWebSocketDispatch && globalThis.__pwWebSocketDispatch({}); }}",
serde_json::to_string(&req).unwrap_or_else(|_| "null".to_string())
);
let _ = self
.page
.call_utility_evaluate(&fn_source, &[], &[], source_frame, Some(true), true)
.await;
return;
};
let route = WebSocketRoute::new(
id.clone(),
url,
protocols,
self.page.clone(),
source_frame.map(String::from),
);
self
.active
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(id, route.clone());
handler(route.clone()).await;
route.after_handle().await;
}
}
static WS_ROUTERS: OnceLock<Mutex<rustc_hash::FxHashMap<usize, Weak<PageWsRouter>>>> = OnceLock::new();
#[must_use]
pub fn router_for_page(page_id: usize, page: AnyPage) -> Arc<PageWsRouter> {
let map = WS_ROUTERS.get_or_init(|| Mutex::new(rustc_hash::FxHashMap::default()));
let mut guard = map.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(existing) = guard.get(&page_id).and_then(Weak::upgrade) {
return existing;
}
let router = PageWsRouter::new(page);
guard.insert(page_id, Arc::downgrade(&router));
router
}
fn close_code(payload: &Value) -> Option<u32> {
payload
.get("code")
.and_then(Value::as_u64)
.and_then(|c| u32::try_from(c).ok())
}
fn close_reason(payload: &Value) -> Option<String> {
payload.get("reason").and_then(Value::as_str).map(String::from)
}
fn was_clean(payload: &Value) -> bool {
payload.get("wasClean").and_then(Value::as_bool).unwrap_or(false)
}
#[must_use]
pub fn binding_callback(router: Arc<PageWsRouter>) -> crate::events::ExposedBinding {
Arc::new(move |source: crate::events::BindingSource, args: Vec<Value>| {
let router = router.clone();
Box::pin(async move {
if let Some(payload) = args.into_iter().next() {
let frame = (!source.frame.is_empty()).then_some(source.frame.as_str());
router.handle_binding(&payload, frame).await;
}
Value::Null
})
})
}
#[must_use]
pub fn mock_init_script() -> crate::options::InitScriptSource {
crate::options::InitScriptSource::Source(WS_MOCK_SOURCE.to_string())
}