use crate::error::Result;
use crate::server::channel::Channel;
use crate::server::channel_owner::{
ChannelOwner, ChannelOwnerImpl, DisposeReason, ParentOrConnection,
};
use crate::server::connection::ConnectionLike;
use parking_lot::Mutex;
use serde_json::Value;
use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tracing::Instrument;
type EventHandlerFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
type EventHandler = Arc<dyn Fn(Value) -> EventHandlerFuture + Send + Sync + 'static>;
type CloseHandlerFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
type CloseHandler = Arc<dyn Fn() -> CloseHandlerFuture + Send + Sync + 'static>;
#[derive(Clone)]
pub struct CDPSession {
base: ChannelOwnerImpl,
event_handlers: Arc<Mutex<HashMap<String, Vec<EventHandler>>>>,
close_handlers: Arc<Mutex<Vec<CloseHandler>>>,
}
impl CDPSession {
pub fn new(
parent: ParentOrConnection,
type_name: String,
guid: Arc<str>,
initializer: Value,
) -> Result<Self> {
Ok(Self {
base: ChannelOwnerImpl::new(parent, type_name, guid, initializer),
event_handlers: Arc::new(Mutex::new(HashMap::new())),
close_handlers: Arc::new(Mutex::new(Vec::new())),
})
}
#[tracing::instrument(level = "debug", skip_all, fields(guid = %self.guid(), method = %method))]
pub async fn send(&self, method: &str, params: Option<Value>) -> Result<Value> {
let params = serde_json::json!({
"method": method,
"params": params.unwrap_or(serde_json::json!({})),
});
self.channel().send("send", params).await
}
pub fn on<F, Fut>(&self, method: impl Into<String>, handler: F)
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
let h: EventHandler =
Arc::new(move |v: Value| -> EventHandlerFuture { Box::pin(handler(v)) });
self.event_handlers
.lock()
.entry(method.into())
.or_default()
.push(h);
}
pub fn on_close<F, Fut>(&self, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
let h: CloseHandler = Arc::new(move || -> CloseHandlerFuture { Box::pin(handler()) });
self.close_handlers.lock().push(h);
}
#[tracing::instrument(level = "debug", skip_all, fields(guid = %self.guid()))]
pub async fn detach(&self) -> Result<()> {
self.channel()
.send_no_result("detach", serde_json::json!({}))
.await
}
}
impl ChannelOwner for CDPSession {
fn guid(&self) -> &str {
self.base.guid()
}
fn type_name(&self) -> &str {
self.base.type_name()
}
fn parent(&self) -> Option<Arc<dyn ChannelOwner>> {
self.base.parent()
}
fn connection(&self) -> Arc<dyn ConnectionLike> {
self.base.connection()
}
fn initializer(&self) -> &Value {
self.base.initializer()
}
fn channel(&self) -> &Channel {
self.base.channel()
}
fn dispose(&self, reason: DisposeReason) {
self.base.dispose(reason)
}
fn adopt(&self, child: Arc<dyn ChannelOwner>) {
self.base.adopt(child)
}
fn add_child(&self, guid: Arc<str>, child: Arc<dyn ChannelOwner>) {
self.base.add_child(guid, child)
}
fn remove_child(&self, guid: &str) {
self.base.remove_child(guid)
}
fn on_event(&self, method: &str, params: Value) {
match method {
"event" => {
let cdp_method = params
.get("method")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let cdp_params = params.get("params").cloned().unwrap_or(Value::Null);
if let Some(cdp_method) = cdp_method {
let handlers = self
.event_handlers
.lock()
.get(&cdp_method)
.cloned()
.unwrap_or_default();
for h in handlers {
let p = cdp_params.clone();
tokio::spawn(
async move {
if let Err(e) = h(p).await {
tracing::warn!("CDPSession event handler error: {}", e);
}
}
.in_current_span(),
);
}
}
}
"close" => {
let handlers = self.close_handlers.lock().clone();
for h in handlers {
tokio::spawn(
async move {
if let Err(e) = h().await {
tracing::warn!("CDPSession close handler error: {}", e);
}
}
.in_current_span(),
);
}
}
_ => {}
}
self.base.on_event(method, params);
}
fn was_collected(&self) -> bool {
self.base.was_collected()
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl std::fmt::Debug for CDPSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CDPSession")
.field("guid", &self.guid())
.finish()
}
}