use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::time::Instant;
use crate::browser::tab::Tab;
use crate::protocol::Event;
use crate::util::base64_decode;
use crate::{Error, Result};
const MAX_BUFFERED: usize = 2000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WsDirection {
Sent,
Received,
}
impl WsDirection {
pub fn as_str(self) -> &'static str {
match self {
WsDirection::Sent => "sent",
WsDirection::Received => "received",
}
}
}
#[derive(Debug, Clone)]
pub struct WsMessage {
pub direction: WsDirection,
pub url: String,
pub socket_id: String,
pub frame_id: String,
pub wsid: String,
pub opcode: i64,
pub data: String,
}
impl WsMessage {
pub fn is_text(&self) -> bool {
self.opcode == 1
}
pub fn is_binary(&self) -> bool {
self.opcode == 2
}
pub fn is_control(&self) -> bool {
matches!(self.opcode, 8..=10)
}
pub fn opcode_name(&self) -> String {
match self.opcode {
0 => "continuation".into(),
1 => "text".into(),
2 => "binary".into(),
8 => "close".into(),
9 => "ping".into(),
10 => "pong".into(),
n => format!("opcode({n})"),
}
}
pub fn text(&self) -> Option<String> {
self.is_text().then(|| self.data.clone())
}
pub fn bytes(&self) -> Vec<u8> {
if self.is_text() {
self.data.clone().into_bytes()
} else {
base64_decode(&self.data).unwrap_or_default()
}
}
pub fn text_lossy(&self) -> String {
if self.is_text() {
self.data.clone()
} else {
String::from_utf8_lossy(&self.bytes()).into_owned()
}
}
pub fn json(&self) -> Option<Value> {
if self.is_text() {
serde_json::from_str(&self.data).ok()
} else {
serde_json::from_slice(&self.bytes()).ok()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct WsSocket {
pub socket_id: String,
pub url: String,
pub opened: bool,
pub closed: bool,
pub error: String,
}
#[derive(Debug, Clone, Default)]
pub struct WsFilter {
pub url_keywords: Vec<String>,
pub direction: Option<WsDirection>,
pub include_control: bool,
}
impl WsFilter {
pub fn new() -> Self {
Self::default()
}
pub fn url_contains(mut self, needle: &str) -> Self {
self.url_keywords.push(needle.to_string());
self
}
pub fn sent_only(mut self) -> Self {
self.direction = Some(WsDirection::Sent);
self
}
pub fn received_only(mut self) -> Self {
self.direction = Some(WsDirection::Received);
self
}
pub fn with_control(mut self) -> Self {
self.include_control = true;
self
}
fn url_matches(&self, url: &str) -> bool {
self.url_keywords.is_empty() || self.url_keywords.iter().any(|k| url.contains(k))
}
fn matches(&self, direction: WsDirection, opcode: i64, url: &str) -> bool {
if let Some(d) = self.direction
&& d != direction
{
return false;
}
if !self.include_control && matches!(opcode, 8..=10) {
return false;
}
self.url_matches(url)
}
}
pub(crate) struct WsShared {
pub buf: Mutex<VecDeque<WsMessage>>,
pub sockets: Mutex<HashMap<String, WsSocket>>,
pub active: AtomicBool,
}
impl WsShared {
pub(crate) fn new() -> Self {
Self {
buf: Mutex::new(VecDeque::new()),
sockets: Mutex::new(HashMap::new()),
active: AtomicBool::new(false),
}
}
}
pub struct WsListener {
tab: Tab,
}
impl WsListener {
pub(crate) fn new(tab: Tab) -> Self {
Self { tab }
}
pub async fn start(&self) -> Result<()> {
self.start_with(WsFilter::default()).await
}
pub async fn start_with(&self, filter: WsFilter) -> Result<()> {
let shared = self.tab.core.ws.clone();
if shared.active.swap(true, Ordering::SeqCst) {
return Ok(()); }
shared.buf.lock().await.clear();
shared.sockets.lock().await.clear();
let events = self.tab.core.conn.subscribe();
let session = self.tab.core.session_id.clone();
let task = tokio::spawn(ws_loop(events, session, shared, filter));
*self.tab.core.ws_task.lock().await = Some(task);
Ok(())
}
pub fn listening(&self) -> bool {
self.tab.core.ws.active.load(Ordering::SeqCst)
}
pub async fn wait(&self, timeout: Option<Duration>) -> Result<Option<WsMessage>> {
let shared = &self.tab.core.ws;
if !shared.active.load(Ordering::SeqCst) {
return Err(Error::Other("尚未调用 websocket().start()".into()));
}
let deadline = timeout.map(|d| Instant::now() + d);
loop {
if let Some(m) = shared.buf.lock().await.pop_front() {
return Ok(Some(m));
}
if !shared.active.load(Ordering::SeqCst) {
return Ok(None); }
if let Some(dl) = deadline
&& Instant::now() >= dl
{
return Ok(None);
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
pub async fn wait_count(
&self,
count: usize,
timeout: Option<Duration>,
) -> Result<Vec<WsMessage>> {
let total = timeout.unwrap_or_else(|| self.tab.core.timeout());
let deadline = Instant::now() + total;
let mut out = Vec::with_capacity(count);
while out.len() < count {
let remain = deadline.saturating_duration_since(Instant::now());
if remain.is_zero() {
break;
}
match self.wait(Some(remain)).await? {
Some(m) => out.push(m),
None => break,
}
}
Ok(out)
}
pub async fn messages(&self) -> Vec<WsMessage> {
self.tab.core.ws.buf.lock().await.drain(..).collect()
}
pub async fn sockets(&self) -> Vec<WsSocket> {
self.tab
.core
.ws
.sockets
.lock()
.await
.values()
.cloned()
.collect()
}
pub async fn clear(&self) {
self.tab.core.ws.buf.lock().await.clear();
}
pub fn steps(&self) -> WsSteps {
WsSteps {
tab: self.tab.clone(),
}
}
pub async fn stop(&self) -> Result<()> {
self.tab.core.ws.active.store(false, Ordering::SeqCst);
if let Some(h) = self.tab.core.ws_task.lock().await.take() {
h.abort();
}
self.tab.core.ws.buf.lock().await.clear();
self.tab.core.ws.sockets.lock().await.clear();
Ok(())
}
}
pub struct WsSteps {
tab: Tab,
}
impl WsSteps {
pub async fn next(&self, timeout: Option<Duration>) -> Result<Option<WsMessage>> {
WsListener::new(self.tab.clone()).wait(timeout).await
}
}
async fn ws_loop(
mut events: tokio::sync::broadcast::Receiver<Event>,
session: String,
shared: Arc<WsShared>,
filter: WsFilter,
) {
loop {
let ev = match events.recv().await {
Ok(ev) => ev,
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(skipped = n, "WebSocket 监听落后,跳过部分事件");
continue;
}
Err(_) => break,
};
if !shared.active.load(Ordering::SeqCst) {
break;
}
if ev.session_id.as_deref() != Some(&session) {
continue;
}
match ev.method.as_str() {
"Page.webSocketCreated" => {
let id = socket_id(&ev.params);
let url = ev.params["requestURL"]
.as_str()
.unwrap_or_default()
.to_string();
let mut socks = shared.sockets.lock().await;
let s = socks.entry(id.clone()).or_default();
s.socket_id = id;
if !url.is_empty() {
s.url = url;
}
}
"Page.webSocketOpened" => {
let id = socket_id(&ev.params);
let url = ev.params["effectiveURL"]
.as_str()
.unwrap_or_default()
.to_string();
let mut socks = shared.sockets.lock().await;
let s = socks.entry(id.clone()).or_default();
s.socket_id = id;
s.opened = true;
if !url.is_empty() {
s.url = url;
}
}
"Page.webSocketClosed" => {
let id = socket_id(&ev.params);
let err = ev.params["error"].as_str().unwrap_or_default().to_string();
let mut socks = shared.sockets.lock().await;
let s = socks.entry(id.clone()).or_default();
s.socket_id = id;
s.closed = true;
s.error = err;
}
"Page.webSocketFrameSent" => {
push_frame(&shared, &filter, WsDirection::Sent, &ev.params).await;
}
"Page.webSocketFrameReceived" => {
push_frame(&shared, &filter, WsDirection::Received, &ev.params).await;
}
_ => {}
}
}
tracing::debug!(%session, "WebSocket 监听任务结束");
}
async fn push_frame(
shared: &Arc<WsShared>,
filter: &WsFilter,
direction: WsDirection,
params: &Value,
) {
let id = socket_id(params);
let opcode = params["opcode"].as_i64().unwrap_or(-1);
let url = shared
.sockets
.lock()
.await
.get(&id)
.map(|s| s.url.clone())
.unwrap_or_default();
if !filter.matches(direction, opcode, &url) {
return;
}
let msg = WsMessage {
direction,
url,
socket_id: id,
frame_id: params["frameId"].as_str().unwrap_or_default().to_string(),
wsid: params["wsid"].as_str().unwrap_or_default().to_string(),
opcode,
data: params["data"].as_str().unwrap_or_default().to_string(),
};
let mut buf = shared.buf.lock().await;
if buf.len() >= MAX_BUFFERED {
buf.pop_front();
}
buf.push_back(msg);
}
fn socket_id(params: &Value) -> String {
let frame = params["frameId"].as_str().unwrap_or_default();
let wsid = params["wsid"].as_str().unwrap_or_default();
format!("{frame}---{wsid}")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn text_frame_helpers() {
let m = WsMessage {
direction: WsDirection::Received,
url: "wss://x/path".into(),
socket_id: "f---1".into(),
frame_id: "f".into(),
wsid: "1".into(),
opcode: 1,
data: r#"{"a":1,"b":[2,3]}"#.into(),
};
assert!(m.is_text());
assert!(!m.is_binary());
assert!(!m.is_control());
assert_eq!(m.opcode_name(), "text");
assert_eq!(m.text().as_deref(), Some(r#"{"a":1,"b":[2,3]}"#));
assert_eq!(m.bytes(), br#"{"a":1,"b":[2,3]}"#.to_vec());
let j = m.json().unwrap();
assert_eq!(j["b"][1], 3);
}
#[test]
fn binary_frame_is_base64() {
let m = WsMessage {
direction: WsDirection::Sent,
url: String::new(),
socket_id: "f---2".into(),
frame_id: "f".into(),
wsid: "2".into(),
opcode: 2,
data: "AQIDBA==".into(),
};
assert!(m.is_binary());
assert!(m.text().is_none());
assert_eq!(m.bytes(), vec![1, 2, 3, 4]);
assert_eq!(m.opcode_name(), "binary");
}
#[test]
fn binary_json_decodes_then_parses() {
let m = WsMessage {
direction: WsDirection::Received,
url: String::new(),
socket_id: "f---3".into(),
frame_id: "f".into(),
wsid: "3".into(),
opcode: 2,
data: crate::util::base64_encode(br#"{"k":42}"#),
};
assert_eq!(m.json().unwrap()["k"], 42);
assert_eq!(m.text_lossy(), r#"{"k":42}"#);
}
#[test]
fn filter_direction_and_control_and_url() {
let f = WsFilter::default();
assert!(f.matches(WsDirection::Sent, 1, "wss://a"));
assert!(f.matches(WsDirection::Received, 2, "wss://a"));
assert!(!f.matches(WsDirection::Sent, 9, "wss://a"));
let f = WsFilter::new().with_control();
assert!(f.matches(WsDirection::Sent, 9, "wss://a"));
let f = WsFilter::new().sent_only();
assert!(f.matches(WsDirection::Sent, 1, "x"));
assert!(!f.matches(WsDirection::Received, 1, "x"));
let f = WsFilter::new().url_contains("/live/");
assert!(f.matches(WsDirection::Sent, 1, "wss://h/live/room"));
assert!(!f.matches(WsDirection::Sent, 1, "wss://h/other"));
}
#[test]
fn socket_id_combines_frame_and_wsid() {
assert_eq!(socket_id(&json!({"frameId":"abc","wsid":"7"})), "abc---7");
}
}