use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::mpsc;
use std::sync::mpsc::Sender;
use std::sync::mpsc::{Receiver, RecvTimeoutError, TryRecvError};
use std::time::Duration;
use anyhow::Result;
use thiserror::Error;
use log::{error, info, trace, warn};
use url::Url;
use waiting_call_registry::WaitingCallRegistry;
use web_socket_connection::WebSocketConnection;
use crate::protocol::cdp::{Target, types::Event, types::Method};
use crate::types::{CallId, Message, parse_raw_message, parse_response};
use crate::util;
mod waiting_call_registry;
mod web_socket_connection;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SessionId(String);
pub enum MethodDestination {
Target(SessionId),
Browser,
}
impl SessionId {
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<String> for SessionId {
fn from(session_id: String) -> Self {
Self(session_id)
}
}
#[derive(Debug, Eq, PartialEq, Hash)]
enum ListenerId {
SessionId(SessionId),
Browser,
}
type Listeners = Arc<Mutex<HashMap<ListenerId, Sender<Event>>>>;
#[derive(Debug)]
pub struct Transport {
web_socket_connection: Arc<WebSocketConnection>,
waiting_call_registry: Arc<WaitingCallRegistry>,
listeners: Listeners,
open: Arc<AtomicBool>,
call_id_counter: Arc<AtomicU32>,
loop_shutdown_tx: Mutex<mpsc::SyncSender<()>>,
idle_browser_timeout: Duration,
}
#[derive(Debug, Error)]
#[error("Unable to make method calls because underlying connection is closed")]
pub struct ConnectionClosed {}
impl Transport {
pub fn new(
ws_url: Url,
process_id: Option<u32>,
idle_browser_timeout: Duration,
) -> Result<Self> {
let (messages_tx, messages_rx) = mpsc::channel();
let web_socket_connection =
Arc::new(WebSocketConnection::new(&ws_url, process_id, messages_tx)?);
let waiting_call_registry = Arc::new(WaitingCallRegistry::new());
let listeners = Arc::new(Mutex::new(HashMap::new()));
let open = Arc::new(AtomicBool::new(true));
let (shutdown_tx, shutdown_rx) = mpsc::sync_channel(100);
let guarded_shutdown_tx = Mutex::new(shutdown_tx);
Self::handle_incoming_messages(
messages_rx,
Arc::clone(&waiting_call_registry),
Arc::clone(&listeners),
Arc::clone(&open),
Arc::clone(&web_socket_connection),
shutdown_rx,
process_id,
idle_browser_timeout,
);
Ok(Self {
web_socket_connection,
waiting_call_registry,
listeners,
open,
call_id_counter: Arc::new(AtomicU32::new(0)),
loop_shutdown_tx: guarded_shutdown_tx,
idle_browser_timeout,
})
}
pub fn unique_call_id(&self) -> CallId {
self.call_id_counter.fetch_add(1, Ordering::SeqCst)
}
pub fn call_method<C>(
&self,
method: C,
destination: MethodDestination,
) -> Result<C::ReturnObject>
where
C: Method + serde::Serialize,
{
if !self.open.load(Ordering::SeqCst) {
return Err(ConnectionClosed {}.into());
}
let call_id = self.unique_call_id();
let call = method.to_method_call(call_id);
let message_text = serde_json::to_string(&call)?;
let response_rx = self.waiting_call_registry.register_call(call.id);
match destination {
MethodDestination::Target(session_id) => {
let message = message_text.clone();
let target_method = Target::SendMessageToTarget {
target_id: None,
session_id: Some(session_id.0),
message,
};
trace!(
"Msg to tab: {}",
message_text.chars().take(300).collect::<String>()
);
if let Err(e) = self.call_method_on_browser(target_method) {
warn!("Failed to call method on browser: {e:?}");
self.waiting_call_registry.unregister_call(call.id);
trace!("Unregistered callback: {:?}", call.id);
return Err(e);
}
}
MethodDestination::Browser => {
if let Err(e) = self.web_socket_connection.send_message(&message_text) {
self.waiting_call_registry.unregister_call(call.id);
return Err(e);
}
trace!("sent method call to browser via websocket");
}
}
let params_string = format!("{:?}", call.get_params());
trace!(
"waiting for response from call registry: {} {:?}",
&call_id,
params_string.chars().take(400).collect::<String>()
);
let response_result = util::Wait::new(self.idle_browser_timeout, Duration::from_millis(5))
.until(|| response_rx.try_recv().ok());
trace!("received response for: {} {:?}", &call_id, params_string);
parse_response::<C::ReturnObject>((response_result?)?)
}
pub fn call_method_on_target<C>(
&self,
session_id: SessionId,
method: C,
) -> Result<C::ReturnObject>
where
C: Method + serde::Serialize,
{
self.call_method(method, MethodDestination::Target(session_id))
}
pub fn call_method_on_browser<C>(&self, method: C) -> Result<C::ReturnObject>
where
C: Method + serde::Serialize,
{
self.call_method(method, MethodDestination::Browser)
}
pub fn listen_to_browser_events(&self) -> Receiver<Event> {
let (events_tx, events_rx) = mpsc::channel();
let mut listeners = self.listeners.lock().unwrap();
listeners.insert(ListenerId::Browser, events_tx);
events_rx
}
pub fn listen_to_target_events(&self, session_id: SessionId) -> Receiver<Event> {
let (events_tx, events_rx) = mpsc::channel();
let mut listeners = self.listeners.lock().unwrap();
listeners.insert(ListenerId::SessionId(session_id), events_tx);
events_rx
}
pub fn shutdown(&self) {
self.web_socket_connection.shutdown();
let shutdown_tx = self.loop_shutdown_tx.lock().unwrap();
let _ = shutdown_tx.send(());
}
#[allow(clippy::too_many_arguments)]
fn handle_incoming_messages(
messages_rx: Receiver<Message>,
waiting_call_registry: Arc<WaitingCallRegistry>,
listeners: Listeners,
open: Arc<AtomicBool>,
conn: Arc<WebSocketConnection>,
shutdown_rx: Receiver<()>,
process_id: Option<u32>,
idle_browser_timeout: Duration,
) {
trace!("Starting handle_incoming_messages");
std::thread::spawn(move || {
trace!("Inside handle_incoming_messages thread");
loop {
match shutdown_rx.try_recv() {
Ok(()) | Err(TryRecvError::Disconnected) => {
info!("Transport incoming message loop loop received shutdown message");
break;
}
Err(TryRecvError::Empty) => {}
}
match messages_rx.recv_timeout(idle_browser_timeout) {
Err(recv_timeout_error) => {
match recv_timeout_error {
RecvTimeoutError::Timeout => {
error!(
"Transport loop got a timeout while listening for messages (Chrome #{process_id:?})",
);
}
RecvTimeoutError::Disconnected => {
error!(
"Transport loop got disconnected from WS's sender (Chrome #{process_id:?})",
);
}
}
break;
}
Ok(message) => match message {
Message::ConnectionShutdown => {
info!("Received shutdown message");
break;
}
Message::Response(response_to_browser_method_call) => {
if waiting_call_registry
.resolve_call(response_to_browser_method_call)
.is_err()
{
warn!(
"The browser registered a call but then closed its receiving channel"
);
break;
}
}
Message::Event(browser_event) => match browser_event {
Event::ReceivedMessageFromTarget(target_message_event) => {
let session_id = target_message_event.params.session_id.into();
let raw_message = target_message_event.params.message;
let msg_res = parse_raw_message(&raw_message);
match msg_res {
Ok(target_message) => match target_message {
Message::Event(target_event) => {
if let Some(tx) = listeners
.lock()
.unwrap()
.get(&ListenerId::SessionId(session_id))
{
tx.send(target_event)
.expect("Couldn't send event to listener");
}
}
Message::Response(resp) => {
if waiting_call_registry.resolve_call(resp).is_err() {
warn!(
"The browser registered a call but then closed its receiving channel"
);
break;
}
}
Message::ConnectionShutdown => {}
},
Err(e) => {
trace!(
"Message from target isn't recognised: {:?} - {}",
&raw_message, e,
);
}
}
}
_ => {
if let Some(tx) =
listeners.lock().unwrap().get(&ListenerId::Browser)
{
if let Err(err) = tx.send(browser_event.clone()) {
let event_string = format!("{browser_event:?}");
warn!(
"Couldn't send browser an event: {:?}\n{:?}",
event_string.chars().take(400).collect::<String>(),
err
);
break;
}
}
}
},
},
}
}
info!("Shutting down message handling loop");
conn.shutdown();
open.store(false, Ordering::SeqCst);
waiting_call_registry.cancel_outstanding_method_calls();
let mut listeners = listeners.lock().unwrap();
*listeners = HashMap::new();
info!("cleared listeners, I think");
});
}
}
impl Drop for Transport {
fn drop(&mut self) {
info!("dropping transport");
}
}