use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use rustc_hash::FxHashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use tokio::sync::{broadcast, oneshot};
use crate::backend::json_scan;
use crate::error::{FerriError, Result};
fn truncate_for_log(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max])
}
}
type CdpResult = Result<serde_json::Value>;
pub(crate) struct PendingEntry {
tx: oneshot::Sender<CdpResult>,
method: String,
send_at: Option<Instant>,
}
pub(crate) type PendingMap = DashMap<u64, PendingEntry>;
#[derive(Default)]
struct RttBucket {
count: u64,
total_ns: u128,
max_ns: u128,
}
fn fmt_ms2(ns: u128) -> String {
let ms = ns / 1_000_000;
let dec = (ns % 1_000_000) / 10_000;
format!("{ms}.{dec:02}")
}
fn fmt_us1(ns: u128) -> String {
let us = ns / 1_000;
let dec = (ns % 1_000) / 100;
format!("{us}.{dec}")
}
fn fmt_avg_us1(total_ns: u128, count: u64) -> String {
if count == 0 {
return "0.0".to_string();
}
let total_us10 = total_ns * 10 / 1_000;
let avg_us10 = total_us10 / u128::from(count);
let us = avg_us10 / 10;
let dec = avg_us10 % 10;
format!("{us}.{dec}")
}
#[derive(Default)]
pub(crate) struct RttStats {
buckets: FxHashMap<String, RttBucket>,
}
impl RttStats {
fn record(&mut self, method: &str, elapsed_ns: u128) {
let entry = self.buckets.entry(method.to_string()).or_default();
entry.count += 1;
entry.total_ns += elapsed_ns;
if elapsed_ns > entry.max_ns {
entry.max_ns = elapsed_ns;
}
}
fn merge(&mut self, other: &RttStats) {
for (method, b) in &other.buckets {
let entry = self.buckets.entry(method.clone()).or_default();
entry.count += b.count;
entry.total_ns += b.total_ns;
if b.max_ns > entry.max_ns {
entry.max_ns = b.max_ns;
}
}
}
fn dump(&self) {
if self.buckets.is_empty() {
return;
}
let mut rows: Vec<(&String, &RttBucket)> = self.buckets.iter().collect();
rows.sort_by_key(|r| std::cmp::Reverse(r.1.total_ns));
let total_count: u64 = self.buckets.values().map(|b| b.count).sum();
let total_ns: u128 = self.buckets.values().map(|b| b.total_ns).sum();
eprintln!(
"─── ferridriver CDP RTT stats ─── total_calls={total_count} total_time={}ms",
fmt_ms2(total_ns)
);
eprintln!(
" {:<48} {:>7} {:>10} {:>10} {:>10}",
"method", "count", "total_ms", "avg_us", "max_us"
);
for (method, bucket) in rows {
eprintln!(
" {:<48} {:>7} {:>10} {:>10} {:>10}",
method,
bucket.count,
fmt_ms2(bucket.total_ns),
fmt_avg_us1(bucket.total_ns, bucket.count),
fmt_us1(bucket.max_ns),
);
}
}
}
fn rtt_stats_enabled() -> bool {
static ENABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*ENABLED.get_or_init(|| {
let on = std::env::var("FERRIDRIVER_RTT_STATS").is_ok_and(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"));
if on {
#[allow(unsafe_code)]
unsafe {
libc::atexit(rtt_atexit_dump);
}
}
on
})
}
fn global_rtt_stats() -> &'static std::sync::Mutex<RttStats> {
static GLOBAL: std::sync::OnceLock<std::sync::Mutex<RttStats>> = std::sync::OnceLock::new();
GLOBAL.get_or_init(|| std::sync::Mutex::new(RttStats::default()))
}
extern "C" fn rtt_atexit_dump() {
if let Ok(stats) = global_rtt_stats().lock() {
if !stats.buckets.is_empty() {
stats.dump();
}
}
}
pub fn dump_global_rtt_stats() {
if !rtt_stats_enabled() {
return;
}
if let Ok(stats) = global_rtt_stats().lock() {
if !stats.buckets.is_empty() {
stats.dump();
}
}
}
pub trait CdpTransport: Send + Sync + 'static {
fn send_command(
&self,
session_id: Option<&str>,
method: &str,
params: &serde_json::Value,
) -> impl std::future::Future<Output = Result<serde_json::Value>> + Send;
fn subscribe_events(&self) -> broadcast::Receiver<Arc<serde_json::Value>>;
fn subscribe_event_method(&self, method: &'static str) -> broadcast::Receiver<Arc<serde_json::Value>>;
fn subscribe_event_domain(&self, domain: &'static str) -> broadcast::Receiver<Arc<serde_json::Value>>;
fn register_lifecycle_tracker(
&self,
session_id: &str,
state: Arc<std::sync::Mutex<super::LifecycleState>>,
notify: Arc<tokio::sync::Notify>,
);
}
pub(crate) struct LifecycleTracker {
pub state: Arc<std::sync::Mutex<super::LifecycleState>>,
pub notify: Arc<tokio::sync::Notify>,
}
pub(crate) struct CdpDispatcher {
pub next_id: AtomicU64,
pub pending: Arc<PendingMap>,
lifecycle_trackers: Arc<DashMap<String, LifecycleTracker>>,
pub event_tx: broadcast::Sender<Arc<serde_json::Value>>,
method_event_txs: Arc<DashMap<&'static str, broadcast::Sender<Arc<serde_json::Value>>>>,
domain_event_txs: Arc<DashMap<&'static str, broadcast::Sender<Arc<serde_json::Value>>>>,
rtt_stats: Arc<std::sync::Mutex<RttStats>>,
}
fn lock_or_recover<T>(m: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
m.lock().unwrap_or_else(std::sync::PoisonError::into_inner)
}
const EVENT_BROADCAST_CAPACITY: usize = 4096;
impl CdpDispatcher {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(EVENT_BROADCAST_CAPACITY);
Self {
next_id: AtomicU64::new(1),
pending: Arc::new(DashMap::default()),
lifecycle_trackers: Arc::new(DashMap::default()),
event_tx,
method_event_txs: Arc::new(DashMap::default()),
domain_event_txs: Arc::new(DashMap::default()),
rtt_stats: Arc::new(std::sync::Mutex::new(RttStats::default())),
}
}
pub fn register_lifecycle_tracker(
&self,
session_id: &str,
state: Arc<std::sync::Mutex<super::LifecycleState>>,
notify: Arc<tokio::sync::Notify>,
) {
self
.lifecycle_trackers
.insert(session_id.to_string(), LifecycleTracker { state, notify });
}
pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<serde_json::Value>> {
self.event_tx.subscribe()
}
pub fn subscribe_event_method(&self, method: &'static str) -> broadcast::Receiver<Arc<serde_json::Value>> {
match self.method_event_txs.entry(method) {
Entry::Occupied(entry) => entry.get().subscribe(),
Entry::Vacant(entry) => {
let (tx, rx) = broadcast::channel(EVENT_BROADCAST_CAPACITY);
entry.insert(tx);
rx
},
}
}
pub fn subscribe_event_domain(&self, domain: &'static str) -> broadcast::Receiver<Arc<serde_json::Value>> {
match self.domain_event_txs.entry(domain) {
Entry::Occupied(entry) => entry.get().subscribe(),
Entry::Vacant(entry) => {
let (tx, rx) = broadcast::channel(EVENT_BROADCAST_CAPACITY);
entry.insert(tx);
rx
},
}
}
pub fn fail_all_pending(&self, reason: &str) {
let keys: Vec<u64> = self.pending.iter().map(|e| *e.key()).collect();
for id in keys {
if let Some((_, entry)) = self.pending.remove(&id) {
let _ = entry.tx.send(Err(FerriError::target_closed(Some(reason.to_string()))));
}
}
}
pub fn build_command(
&self,
session_id: Option<&str>,
method: &str,
params: &serde_json::Value,
) -> Result<(Vec<u8>, oneshot::Receiver<CdpResult>)> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let params_str = serde_json::to_string(params).map_err(|e| FerriError::Backend(format!("Serialize: {e}")))?;
let mut data = if let Some(sid) = session_id {
format!(r#"{{"id":{id},"method":"{method}","params":{params_str},"sessionId":"{sid}"}}"#).into_bytes()
} else {
format!(r#"{{"id":{id},"method":"{method}","params":{params_str}}}"#).into_bytes()
};
data.push(0);
tracing::debug!(
target: "ferridriver::cdp::send",
id,
method,
params = truncate_for_log(¶ms_str, 200),
"CDP >>",
);
let (tx, rx) = oneshot::channel();
let stats_enabled = rtt_stats_enabled();
let entry = PendingEntry {
tx,
method: if stats_enabled {
method.to_string()
} else {
String::new()
},
send_at: stats_enabled.then(Instant::now),
};
self.pending.insert(id, entry);
Ok((data, rx))
}
pub fn dispatch_message(&self, raw: &[u8]) {
let id = json_scan::json_id(raw);
if id > 0 {
let error_field = json_scan::json_field(raw, b"error");
let payload = if error_field.is_empty() {
let result_field = json_scan::json_field(raw, b"result");
if result_field.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::new()))
} else {
let val: serde_json::Value =
serde_json::from_slice(result_field).unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
Ok(val)
}
} else {
let msg_bytes = json_scan::error_message(error_field);
let msg_str = std::str::from_utf8(msg_bytes).unwrap_or("CDP error");
Err(FerriError::protocol("CDP", msg_str))
};
tracing::debug!(
target: "ferridriver::cdp::recv",
id,
ok = payload.is_ok(),
payload = truncate_for_log(&format!("{payload:?}"), 200),
"CDP << response",
);
if let Some((_, entry)) = self.pending.remove(&id) {
if rtt_stats_enabled()
&& let Some(send_at) = entry.send_at
{
let elapsed = send_at.elapsed().as_nanos();
lock_or_recover(&self.rtt_stats).record(&entry.method, elapsed);
lock_or_recover(global_rtt_stats()).record(&entry.method, elapsed);
}
let _ = entry.tx.send(payload);
}
} else {
let method = json_scan::json_string(json_scan::json_field(raw, b"method"));
let session_id = json_scan::json_string(json_scan::json_field(raw, b"sessionId"));
let method_str = std::str::from_utf8(method).unwrap_or("");
let sid_str = std::str::from_utf8(session_id).unwrap_or("");
self.dispatch_lifecycle(raw, method_str, sid_str);
tracing::trace!(
target: "ferridriver::cdp::recv",
method = method_str,
"CDP << event",
);
let method_tx = self.method_event_txs.get(method_str).map(|entry| entry.clone());
let domain_tx = method_str
.split_once('.')
.and_then(|(domain, _)| self.domain_event_txs.get(domain).map(|entry| entry.clone()));
let needs_global = self.event_tx.receiver_count() > 0;
let needs_method = method_tx.as_ref().is_some_and(|tx| tx.receiver_count() > 0);
let needs_domain = domain_tx.as_ref().is_some_and(|tx| tx.receiver_count() > 0);
if (needs_global || needs_method || needs_domain)
&& let Ok(msg) = serde_json::from_slice::<serde_json::Value>(raw)
{
let msg = Arc::new(msg);
if needs_global {
let _ = self.event_tx.send(msg.clone());
}
if needs_method && let Some(tx) = method_tx {
let _ = tx.send(msg.clone());
}
if needs_domain && let Some(tx) = domain_tx {
let _ = tx.send(msg);
}
}
}
}
fn dispatch_lifecycle(&self, raw: &[u8], method_str: &str, key: &str) {
if let Some(tracker) = self.lifecycle_trackers.get(key) {
match method_str {
"Page.frameNavigated" => {
let params = json_scan::json_field(raw, b"params");
let frame = json_scan::json_field(params, b"frame");
let parent_id = json_scan::json_field(frame, b"parentId");
if !parent_id.is_empty() {
return;
}
let loader_id = json_scan::json_string(json_scan::json_field(frame, b"loaderId"));
let loader_id_owned = std::str::from_utf8(loader_id).unwrap_or("").to_string();
{
let mut state = lock_or_recover(&tracker.state);
state.current_loader_id = loader_id_owned;
state.fired = super::LC_COMMIT;
state.nav_committed_seq = state.nav_committed_seq.wrapping_add(1);
}
tracker.notify.notify_waiters();
},
"Page.frameStartedNavigating" => {
{
let mut state = lock_or_recover(&tracker.state);
state.nav_started_seq = state.nav_started_seq.wrapping_add(1);
}
tracker.notify.notify_waiters();
},
"Page.lifecycleEvent" => {
let params = json_scan::json_field(raw, b"params");
let loader_id = json_scan::json_string(json_scan::json_field(params, b"loaderId"));
let loader_id_str = std::str::from_utf8(loader_id).unwrap_or("");
let name = json_scan::json_string(json_scan::json_field(params, b"name"));
let name_str = std::str::from_utf8(name).unwrap_or("");
let event_name = match name_str {
"DOMContentLoaded" => Some(super::LC_DOMCONTENTLOADED),
"load" => Some(super::LC_LOAD),
_ => None,
};
if let Some(event_flag) = event_name {
let matched = {
let mut state = lock_or_recover(&tracker.state);
if state.current_loader_id == loader_id_str {
state.fired |= event_flag;
true
} else {
false
}
};
if matched {
tracker.notify.notify_waiters();
}
}
},
"Inspector.targetCrashed" => {
let mut state = lock_or_recover(&tracker.state);
state.crashed = true;
drop(state);
tracker.notify.notify_waiters();
},
_ => {},
}
}
}
}
impl Drop for CdpDispatcher {
fn drop(&mut self) {
if rtt_stats_enabled() {
let local = lock_or_recover(&self.rtt_stats);
lock_or_recover(global_rtt_stats()).merge(&local);
local.dump();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
const NETWORK_EVENT: &[u8] = br#"{"method":"Network.requestWillBeSent","sessionId":"s1","params":{"requestId":"r1","request":{"url":"https://example.test/asset.js","method":"GET"},"type":"Script"}}"#;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "benchmark; run with --ignored --nocapture"]
async fn bench_routed_event_dispatch_wakeups() {
const EVENTS: usize = 2_000;
const LISTENERS: usize = 8;
let global = CdpDispatcher::new();
let global_wakeups = Arc::new(AtomicUsize::new(0));
let mut global_handles = Vec::with_capacity(LISTENERS);
for _ in 0..LISTENERS {
let mut rx = global.subscribe_events();
let wakeups = global_wakeups.clone();
global_handles.push(tokio::spawn(async move {
for _ in 0..EVENTS {
let event = rx.recv().await.expect("global event");
let _ = event.get("method").and_then(|m| m.as_str());
wakeups.fetch_add(1, Ordering::Relaxed);
}
}));
}
tokio::task::yield_now().await;
let global_started = Instant::now();
for _ in 0..EVENTS {
global.dispatch_message(NETWORK_EVENT);
}
for handle in global_handles {
handle.await.expect("global listener task");
}
let global_elapsed = global_started.elapsed();
let routed = CdpDispatcher::new();
let mut idle_method_receivers = [
routed.subscribe_event_method("Runtime.consoleAPICalled"),
routed.subscribe_event_method("Runtime.exceptionThrown"),
routed.subscribe_event_method("Page.javascriptDialogOpening"),
routed.subscribe_event_method("Page.fileChooserOpened"),
routed.subscribe_event_method("Runtime.bindingCalled"),
routed.subscribe_event_method("Page.screencastFrame"),
];
let mut idle_domain_receivers = [
routed.subscribe_event_domain("Browser"),
routed.subscribe_event_domain("Fetch"),
];
let routed_wakeups = Arc::new(AtomicUsize::new(0));
let mut network_rx = routed.subscribe_event_domain("Network");
let wakeups = routed_wakeups.clone();
let network_handle = tokio::spawn(async move {
for _ in 0..EVENTS {
let event = network_rx.recv().await.expect("routed network event");
let _ = event.get("method").and_then(|m| m.as_str());
wakeups.fetch_add(1, Ordering::Relaxed);
}
});
tokio::task::yield_now().await;
let routed_started = Instant::now();
for _ in 0..EVENTS {
routed.dispatch_message(NETWORK_EVENT);
}
network_handle.await.expect("network listener task");
let routed_elapsed = routed_started.elapsed();
let idle_method_wakeups: usize = idle_method_receivers
.iter_mut()
.map(|rx| rx.try_recv().ok().map_or(0, |_| 1))
.sum();
let idle_domain_wakeups: usize = idle_domain_receivers
.iter_mut()
.map(|rx| rx.try_recv().ok().map_or(0, |_| 1))
.sum();
println!(
"global broadcast: {:?}, wakeups={}",
global_elapsed,
global_wakeups.load(Ordering::Relaxed)
);
println!(
"routed dispatch: {:?}, wakeups={}, idle_wakeups={}",
routed_elapsed,
routed_wakeups.load(Ordering::Relaxed),
idle_method_wakeups + idle_domain_wakeups
);
assert_eq!(global_wakeups.load(Ordering::Relaxed), EVENTS * LISTENERS);
assert_eq!(routed_wakeups.load(Ordering::Relaxed), EVENTS);
assert_eq!(idle_method_wakeups + idle_domain_wakeups, 0);
}
#[test]
fn build_command_serializes_borrowed_empty_params_by_reference() {
let dispatcher = CdpDispatcher::new();
let empty = &crate::backend::EMPTY_PARAMS;
let (data1, _rx1) = dispatcher
.build_command(Some("sess-1"), "Page.enable", empty)
.expect("build with session");
let (data2, _rx2) = dispatcher
.build_command(None, "Runtime.enable", empty)
.expect("build without session");
let s1 = String::from_utf8(data1).expect("utf8");
let s2 = String::from_utf8(data2).expect("utf8");
assert!(s1.contains(r#""params":{}"#), "expected empty params object, got {s1}");
assert!(s1.contains(r#""method":"Page.enable""#), "method missing: {s1}");
assert!(s1.contains(r#""sessionId":"sess-1""#), "sessionId missing: {s1}");
assert!(s2.contains(r#""params":{}"#), "expected empty params object, got {s2}");
assert!(!s2.contains("sessionId"), "no sessionId expected: {s2}");
let a = std::ptr::from_ref::<serde_json::Value>(&crate::backend::EMPTY_PARAMS);
let b = std::ptr::from_ref::<serde_json::Value>(&crate::backend::EMPTY_PARAMS);
assert_eq!(a, b, "EMPTY_PARAMS must be a single shared static");
}
fn register_test_tracker(
dispatcher: &CdpDispatcher,
session_id: &str,
) -> Arc<std::sync::Mutex<super::super::LifecycleState>> {
let state = Arc::new(std::sync::Mutex::new(super::super::LifecycleState {
current_loader_id: String::new(),
nav_started_seq: 0,
nav_committed_seq: 0,
fired: 0,
crashed: false,
}));
let notify = Arc::new(tokio::sync::Notify::new());
dispatcher.register_lifecycle_tracker(session_id, state.clone(), notify);
state
}
#[test]
fn lifecycle_dispatch_commit_then_load_updates_state() {
let dispatcher = CdpDispatcher::new();
let state = register_test_tracker(&dispatcher, "s1");
dispatcher.dispatch_message(
br#"{"method":"Page.frameNavigated","sessionId":"s1","params":{"frame":{"id":"f1","loaderId":"L1","url":"https://example.test/"}}}"#,
);
{
let guard = lock_or_recover(&state);
assert_eq!(guard.current_loader_id, "L1");
assert_eq!(guard.fired, super::super::LC_COMMIT);
}
dispatcher.dispatch_message(
br#"{"method":"Page.lifecycleEvent","sessionId":"s1","params":{"loaderId":"L1","name":"load"}}"#,
);
{
let guard = lock_or_recover(&state);
assert_eq!(guard.current_loader_id, "L1");
assert_eq!(guard.fired, super::super::LC_COMMIT | super::super::LC_LOAD);
}
}
#[test]
fn lifecycle_dispatch_subframe_navigation_does_not_clobber_main() {
let dispatcher = CdpDispatcher::new();
let state = register_test_tracker(&dispatcher, "s1");
dispatcher.dispatch_message(
br#"{"method":"Page.frameNavigated","sessionId":"s1","params":{"frame":{"id":"f1","loaderId":"L1","url":"https://example.test/"}}}"#,
);
dispatcher.dispatch_message(
br#"{"method":"Page.frameNavigated","sessionId":"s1","params":{"frame":{"id":"f2","parentId":"f1","loaderId":"L2","url":"https://sub.example.test/"}}}"#,
);
let guard = lock_or_recover(&state);
assert_eq!(guard.current_loader_id, "L1");
assert_eq!(guard.fired, super::super::LC_COMMIT);
}
#[test]
fn lifecycle_dispatch_mismatched_loader_id_is_ignored() {
let dispatcher = CdpDispatcher::new();
let state = register_test_tracker(&dispatcher, "s1");
dispatcher.dispatch_message(
br#"{"method":"Page.frameNavigated","sessionId":"s1","params":{"frame":{"id":"f1","loaderId":"L1","url":"https://example.test/"}}}"#,
);
dispatcher.dispatch_message(
br#"{"method":"Page.lifecycleEvent","sessionId":"s1","params":{"loaderId":"OTHER","name":"load"}}"#,
);
let guard = lock_or_recover(&state);
assert_eq!(guard.fired, super::super::LC_COMMIT);
}
#[test]
fn lifecycle_dispatch_target_crashed_sets_flag() {
let dispatcher = CdpDispatcher::new();
let state = register_test_tracker(&dispatcher, "s1");
dispatcher.dispatch_message(br#"{"method":"Inspector.targetCrashed","sessionId":"s1","params":{}}"#);
assert!(lock_or_recover(&state).crashed);
}
}