use std::collections::HashSet;
use std::sync::Arc;
use serde_json::Value;
use tokio::sync::Mutex;
use super::CdpCommand;
use super::SessionId;
use super::capabilities::resolve_cdp_websocket_url;
use super::command::RawEvent;
use super::error::CdpError;
use super::stream::{EventStream, RawEventStream};
use super::transport::ws::WsTransport;
use crate::cdp::CdpEvent;
use crate::cdp::domains::target::{
AttachToTarget, AttachToTargetResult, DetachFromTarget, GetTargets,
};
use crate::error::WebDriverResult;
use crate::session::handle::SessionHandle;
#[derive(Debug, Clone)]
pub struct CdpSession {
transport: WsTransport,
session_id: Option<SessionId>,
enabled: Arc<Mutex<HashSet<&'static str>>>,
}
impl CdpSession {
pub(crate) async fn connect(handle: Arc<SessionHandle>) -> WebDriverResult<Self> {
let url = resolve_cdp_websocket_url(&handle).await?;
let transport = WsTransport::connect(&url).await?;
let raw = transport
.send_raw_sessioned(GetTargets::METHOD, serde_json::json!({}), None)
.await
.map_err(into_wde)?;
let infos: super::domains::target::GetTargetsResult =
serde_json::from_value(raw).map_err(|e| {
crate::error::WebDriverError::Json(format!("Target.getTargets parse: {e}"))
})?;
let target =
infos.target_infos.into_iter().find(|t| t.r#type == "page").ok_or_else(|| {
crate::error::WebDriverError::from_inner(
crate::error::WebDriverErrorInner::NotFound(
"page target".to_string(),
"no page target found via Target.getTargets".to_string(),
),
)
})?;
let attach_params = AttachToTarget::flat(target.target_id);
let raw = transport
.send_raw_sessioned(AttachToTarget::METHOD, serde_json::to_value(&attach_params)?, None)
.await
.map_err(into_wde)?;
let attached: AttachToTargetResult = serde_json::from_value(raw)?;
Ok(Self {
transport,
session_id: Some(attached.session_id),
enabled: Arc::new(Mutex::new(HashSet::new())),
})
}
pub fn session_id(&self) -> Option<&SessionId> {
self.session_id.as_ref()
}
pub async fn send<C: CdpCommand>(&self, params: C) -> Result<C::Returns, CdpError> {
let raw =
self.send_raw(C::METHOD, serde_json::to_value(params).map_err(serde_err)?).await?;
serde_json::from_value(raw).map_err(serde_err)
}
pub async fn send_raw<P: serde::Serialize>(
&self,
method: &str,
params: P,
) -> Result<Value, CdpError> {
let value = serde_json::to_value(params).map_err(serde_err)?;
self.transport.send_raw_sessioned(method, value, self.session_id.as_ref()).await
}
pub async fn subscribe<E: CdpEvent>(&self) -> Result<EventStream<E>, CdpError> {
if let Some(enable_method) = E::ENABLE {
let mut enabled = self.enabled.lock().await;
if !enabled.contains(enable_method) {
self.send_raw(enable_method, serde_json::json!({})).await?;
enabled.insert(enable_method);
}
}
Ok(EventStream::new(self.transport.subscribe_events(), self.session_id.clone(), E::METHOD))
}
pub fn subscribe_all(&self) -> RawEventStream {
RawEventStream::new(self.transport.subscribe_events(), self.session_id.clone())
}
pub fn subscribe_connection(&self) -> tokio::sync::broadcast::Receiver<RawEvent> {
self.transport.subscribe_events()
}
pub async fn detach(self) -> Result<(), CdpError> {
if let Some(session) = self.session_id.clone() {
let raw = self
.transport
.send_raw_sessioned(
DetachFromTarget::METHOD,
serde_json::to_value(DetachFromTarget {
session_id: Some(session),
})
.map_err(serde_err)?,
None,
)
.await?;
drop(raw);
}
Ok(())
}
}
fn into_wde(e: CdpError) -> crate::error::WebDriverError {
crate::error::WebDriverError::from_inner(crate::error::WebDriverErrorInner::ParseError(
e.to_string(),
))
}
fn serde_err(e: serde_json::Error) -> CdpError {
CdpError {
command: "<serde>".to_string(),
code: -32603,
message: e.to_string(),
data: None,
}
}