use dashmap::DashMap;
use rustc_hash::FxHashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use tokio::sync::{broadcast, oneshot};
use crate::backend::json_scan;
use crate::error::{FerriError, Result};
fn truncate_for_log(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max])
}
}
type CdpResult = Result<serde_json::Value>;
pub(crate) struct PendingEntry {
tx: oneshot::Sender<CdpResult>,
method: String,
send_at: Instant,
}
pub(crate) type PendingMap = DashMap<u64, PendingEntry>;
#[derive(Default)]
struct RttBucket {
count: u64,
total_ns: u128,
max_ns: u128,
}
fn fmt_ms2(ns: u128) -> String {
let ms = ns / 1_000_000;
let dec = (ns % 1_000_000) / 10_000;
format!("{ms}.{dec:02}")
}
fn fmt_us1(ns: u128) -> String {
let us = ns / 1_000;
let dec = (ns % 1_000) / 100;
format!("{us}.{dec}")
}
fn fmt_avg_us1(total_ns: u128, count: u64) -> String {
if count == 0 {
return "0.0".to_string();
}
let total_us10 = total_ns * 10 / 1_000;
let avg_us10 = total_us10 / u128::from(count);
let us = avg_us10 / 10;
let dec = avg_us10 % 10;
format!("{us}.{dec}")
}
#[derive(Default)]
pub(crate) struct RttStats {
buckets: FxHashMap<String, RttBucket>,
}
impl RttStats {
fn record(&mut self, method: &str, elapsed_ns: u128) {
let entry = self.buckets.entry(method.to_string()).or_default();
entry.count += 1;
entry.total_ns += elapsed_ns;
if elapsed_ns > entry.max_ns {
entry.max_ns = elapsed_ns;
}
}
fn merge(&mut self, other: &RttStats) {
for (method, b) in &other.buckets {
let entry = self.buckets.entry(method.clone()).or_default();
entry.count += b.count;
entry.total_ns += b.total_ns;
if b.max_ns > entry.max_ns {
entry.max_ns = b.max_ns;
}
}
}
fn dump(&self) {
if self.buckets.is_empty() {
return;
}
let mut rows: Vec<(&String, &RttBucket)> = self.buckets.iter().collect();
rows.sort_by_key(|r| std::cmp::Reverse(r.1.total_ns));
let total_count: u64 = self.buckets.values().map(|b| b.count).sum();
let total_ns: u128 = self.buckets.values().map(|b| b.total_ns).sum();
eprintln!(
"─── ferridriver CDP RTT stats ─── total_calls={total_count} total_time={}ms",
fmt_ms2(total_ns)
);
eprintln!(
" {:<48} {:>7} {:>10} {:>10} {:>10}",
"method", "count", "total_ms", "avg_us", "max_us"
);
for (method, bucket) in rows {
eprintln!(
" {:<48} {:>7} {:>10} {:>10} {:>10}",
method,
bucket.count,
fmt_ms2(bucket.total_ns),
fmt_avg_us1(bucket.total_ns, bucket.count),
fmt_us1(bucket.max_ns),
);
}
}
}
fn rtt_stats_enabled() -> bool {
static ENABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*ENABLED.get_or_init(|| {
let on = std::env::var("FERRIDRIVER_RTT_STATS").is_ok_and(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on"));
if on {
#[allow(unsafe_code)]
unsafe {
libc::atexit(rtt_atexit_dump);
}
}
on
})
}
fn global_rtt_stats() -> &'static std::sync::Mutex<RttStats> {
static GLOBAL: std::sync::OnceLock<std::sync::Mutex<RttStats>> = std::sync::OnceLock::new();
GLOBAL.get_or_init(|| std::sync::Mutex::new(RttStats::default()))
}
extern "C" fn rtt_atexit_dump() {
if let Ok(stats) = global_rtt_stats().lock() {
if !stats.buckets.is_empty() {
stats.dump();
}
}
}
pub fn dump_global_rtt_stats() {
if !rtt_stats_enabled() {
return;
}
if let Ok(stats) = global_rtt_stats().lock() {
if !stats.buckets.is_empty() {
stats.dump();
}
}
}
pub trait CdpTransport: Send + Sync + 'static {
fn send_command(
&self,
session_id: Option<&str>,
method: &str,
params: serde_json::Value,
) -> impl std::future::Future<Output = Result<serde_json::Value>> + Send;
fn subscribe_events(&self) -> broadcast::Receiver<Arc<serde_json::Value>>;
fn register_lifecycle_tracker(
&self,
session_id: &str,
state: Arc<std::sync::Mutex<super::LifecycleState>>,
notify: Arc<tokio::sync::Notify>,
);
}
pub(crate) struct LifecycleTracker {
pub state: Arc<std::sync::Mutex<super::LifecycleState>>,
pub notify: Arc<tokio::sync::Notify>,
}
pub(crate) struct CdpDispatcher {
pub next_id: AtomicU64,
pub pending: Arc<PendingMap>,
lifecycle_trackers: Arc<DashMap<String, LifecycleTracker>>,
pub event_tx: broadcast::Sender<Arc<serde_json::Value>>,
rtt_stats: Arc<std::sync::Mutex<RttStats>>,
}
fn lock_or_recover<T>(m: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
m.lock().unwrap_or_else(std::sync::PoisonError::into_inner)
}
const EVENT_BROADCAST_CAPACITY: usize = 4096;
impl CdpDispatcher {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(EVENT_BROADCAST_CAPACITY);
Self {
next_id: AtomicU64::new(1),
pending: Arc::new(DashMap::default()),
lifecycle_trackers: Arc::new(DashMap::default()),
event_tx,
rtt_stats: Arc::new(std::sync::Mutex::new(RttStats::default())),
}
}
pub fn register_lifecycle_tracker(
&self,
session_id: &str,
state: Arc<std::sync::Mutex<super::LifecycleState>>,
notify: Arc<tokio::sync::Notify>,
) {
self
.lifecycle_trackers
.insert(session_id.to_string(), LifecycleTracker { state, notify });
}
pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<serde_json::Value>> {
self.event_tx.subscribe()
}
pub fn fail_all_pending(&self, reason: &str) {
let keys: Vec<u64> = self.pending.iter().map(|e| *e.key()).collect();
for id in keys {
if let Some((_, entry)) = self.pending.remove(&id) {
let _ = entry.tx.send(Err(FerriError::target_closed(Some(reason.to_string()))));
}
}
}
pub fn build_command(
&self,
session_id: Option<&str>,
method: &str,
params: &serde_json::Value,
) -> Result<(Vec<u8>, oneshot::Receiver<CdpResult>)> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let params_str = serde_json::to_string(params).map_err(|e| FerriError::Backend(format!("Serialize: {e}")))?;
let mut data = if let Some(sid) = session_id {
format!(r#"{{"id":{id},"method":"{method}","params":{params_str},"sessionId":"{sid}"}}"#).into_bytes()
} else {
format!(r#"{{"id":{id},"method":"{method}","params":{params_str}}}"#).into_bytes()
};
data.push(0);
tracing::debug!(
target: "ferridriver::cdp::send",
id,
method,
params = truncate_for_log(¶ms_str, 200),
"CDP >>",
);
let (tx, rx) = oneshot::channel();
let entry = PendingEntry {
tx,
method: if rtt_stats_enabled() {
method.to_string()
} else {
String::new()
},
send_at: Instant::now(),
};
self.pending.insert(id, entry);
Ok((data, rx))
}
pub fn dispatch_message(&self, raw: &[u8]) {
let id = json_scan::json_id(raw);
if id > 0 {
let error_field = json_scan::json_field(raw, b"error");
let payload = if error_field.is_empty() {
let result_field = json_scan::json_field(raw, b"result");
if result_field.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::new()))
} else {
let val: serde_json::Value =
serde_json::from_slice(result_field).unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
Ok(val)
}
} else {
let msg_bytes = json_scan::error_message(error_field);
let msg_str = std::str::from_utf8(msg_bytes).unwrap_or("CDP error");
Err(FerriError::protocol("CDP", msg_str))
};
tracing::debug!(
target: "ferridriver::cdp::recv",
id,
ok = payload.is_ok(),
payload = truncate_for_log(&format!("{payload:?}"), 200),
"CDP << response",
);
if let Some((_, entry)) = self.pending.remove(&id) {
if rtt_stats_enabled() {
let elapsed = entry.send_at.elapsed().as_nanos();
lock_or_recover(&self.rtt_stats).record(&entry.method, elapsed);
lock_or_recover(global_rtt_stats()).record(&entry.method, elapsed);
}
let _ = entry.tx.send(payload);
}
} else {
let method = json_scan::json_string(json_scan::json_field(raw, b"method"));
let session_id = json_scan::json_string(json_scan::json_field(raw, b"sessionId"));
let method_str = std::str::from_utf8(method).unwrap_or("");
let sid_str = std::str::from_utf8(session_id).unwrap_or("");
let key = sid_str.to_string();
self.dispatch_lifecycle(raw, method_str, &key);
tracing::trace!(
target: "ferridriver::cdp::recv",
method = method_str,
"CDP << event",
);
if let Ok(msg) = serde_json::from_slice::<serde_json::Value>(raw) {
let _ = self.event_tx.send(Arc::new(msg));
}
}
}
fn dispatch_lifecycle(&self, raw: &[u8], method_str: &str, key: &str) {
if let Some(tracker) = self.lifecycle_trackers.get(key) {
match method_str {
"Page.frameNavigated" => {
let params = json_scan::json_field(raw, b"params");
let frame = json_scan::json_field(params, b"frame");
let parent_id = json_scan::json_field(frame, b"parentId");
if !parent_id.is_empty() {
return;
}
let loader_id = json_scan::json_string(json_scan::json_field(frame, b"loaderId"));
let loader_id_str = std::str::from_utf8(loader_id).unwrap_or("");
let mut state = lock_or_recover(&tracker.state);
state.current_loader_id = loader_id_str.to_string();
state.fired.clear();
state.fired.insert("commit".to_string());
drop(state);
tracker.notify.notify_waiters();
},
"Page.lifecycleEvent" => {
let params = json_scan::json_field(raw, b"params");
let loader_id = json_scan::json_string(json_scan::json_field(params, b"loaderId"));
let loader_id_str = std::str::from_utf8(loader_id).unwrap_or("");
let name = json_scan::json_string(json_scan::json_field(params, b"name"));
let name_str = std::str::from_utf8(name).unwrap_or("");
let event_name = match name_str {
"DOMContentLoaded" => Some("domcontentloaded"),
"load" => Some("load"),
_ => None,
};
if let Some(event_name) = event_name {
let mut state = lock_or_recover(&tracker.state);
if state.current_loader_id == loader_id_str {
state.fired.insert(event_name.to_string());
drop(state);
tracker.notify.notify_waiters();
}
}
},
"Inspector.targetCrashed" => {
let mut state = lock_or_recover(&tracker.state);
state.crashed = true;
drop(state);
tracker.notify.notify_waiters();
},
_ => {},
}
}
}
}
impl Drop for CdpDispatcher {
fn drop(&mut self) {
if rtt_stats_enabled() {
let local = lock_or_recover(&self.rtt_stats);
lock_or_recover(global_rtt_stats()).merge(&local);
local.dump();
}
}
}