use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde_json::{from_value, to_string};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, error, trace, warn};
use crate::error::{Error, Result};
use crate::identifiers::RequestId;
use crate::protocol::{Event, EventReply, Request, Response};
const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_PENDING_REQUESTS: usize = 100;
const READY_TIMEOUT: Duration = Duration::from_secs(30);
const COMMAND_CHANNEL_CAPACITY: usize = 256;
type CorrelationMap = FxHashMap<RequestId, oneshot::Sender<Result<Response>>>;
pub type EventHandler = Box<dyn Fn(Event) -> Option<EventReply> + Send + Sync>;
type HandlerEntry = (String, Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync>);
type HandlerVec = Vec<HandlerEntry>;
#[derive(Debug, Clone)]
pub struct ReadyData {
pub tab_id: u32,
pub session_id: u32,
}
enum ConnectionCommand {
Send {
request: Request,
response_tx: oneshot::Sender<Result<Response>>,
},
RemoveCorrelation(RequestId),
Shutdown,
}
pub struct Connection {
command_tx: mpsc::Sender<ConnectionCommand>,
correlation: Arc<Mutex<CorrelationMap>>,
event_handlers: Arc<Mutex<HandlerVec>>,
pending_count: Arc<AtomicUsize>,
}
impl Connection {
pub(crate) fn new(ws_stream: WebSocketStream<TcpStream>) -> Self {
let (command_tx, command_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
let correlation = Arc::new(Mutex::new(CorrelationMap::default()));
let event_handlers: Arc<Mutex<HandlerVec>> = Arc::new(Mutex::new(Vec::new()));
let pending_count = Arc::new(AtomicUsize::new(0));
let correlation_clone = Arc::clone(&correlation);
let event_handlers_clone = Arc::clone(&event_handlers);
let pending_count_clone = Arc::clone(&pending_count);
tokio::spawn(Self::run_event_loop(
ws_stream,
command_rx,
correlation_clone,
event_handlers_clone,
pending_count_clone,
));
Self {
command_tx,
correlation,
event_handlers,
pending_count,
}
}
pub async fn wait_ready(&self) -> Result<ReadyData> {
let (tx, rx) = oneshot::channel();
{
let mut correlation = self.correlation.lock();
correlation.insert(RequestId::ready(), tx);
}
self.pending_count.fetch_add(1, Ordering::Relaxed);
let response = timeout(READY_TIMEOUT, rx)
.await
.map_err(|_| Error::connection_timeout(READY_TIMEOUT.as_millis() as u64))??;
let response = response?;
let tab_id = response.get_u64("tabId").max(1) as u32;
let session_id = response.get_u64("sessionId").max(1) as u32;
debug!(tab_id, session_id, "READY handshake completed");
Ok(ReadyData { tab_id, session_id })
}
pub fn add_event_handler(&self, key: String, handler: EventHandler) {
let handler: Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync> = Arc::from(handler);
let mut guard = self.event_handlers.lock();
if let Some(entry) = guard.iter_mut().find(|(k, _)| k == &key) {
entry.1 = handler;
} else {
guard.push((key, handler));
}
}
pub fn remove_event_handler(&self, key: &str) {
let mut guard = self.event_handlers.lock();
guard.retain(|(k, _)| k != key);
}
pub fn clear_all_event_handlers(&self) {
let mut guard = self.event_handlers.lock();
guard.clear();
}
pub async fn send(&self, request: Request) -> Result<Response> {
self.send_with_timeout(request, DEFAULT_COMMAND_TIMEOUT)
.await
}
pub async fn send_with_timeout(
&self,
request: Request,
request_timeout: Duration,
) -> Result<Response> {
let request_id = request.id;
let pending = self.pending_count.load(Ordering::Relaxed);
if pending >= MAX_PENDING_REQUESTS {
warn!(
pending = pending,
max = MAX_PENDING_REQUESTS,
"Too many pending requests"
);
return Err(Error::protocol(format!(
"Too many pending requests: {}/{}",
pending, MAX_PENDING_REQUESTS
)));
}
let (response_tx, response_rx) = oneshot::channel();
self.command_tx
.try_send(ConnectionCommand::Send {
request,
response_tx,
})
.map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => {
Error::protocol("Command channel full (backpressure)")
}
mpsc::error::TrySendError::Closed(_) => Error::ConnectionClosed,
})?;
match timeout(request_timeout, response_rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(Error::ConnectionClosed),
Err(_) => {
let _ = self
.command_tx
.try_send(ConnectionCommand::RemoveCorrelation(request_id));
Err(Error::request_timeout(
request_id,
request_timeout.as_millis() as u64,
))
}
}
}
#[inline]
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending_count.load(Ordering::Relaxed)
}
pub fn shutdown(&self) {
let _ = self.command_tx.try_send(ConnectionCommand::Shutdown);
}
async fn run_event_loop(
ws_stream: WebSocketStream<TcpStream>,
mut command_rx: mpsc::Receiver<ConnectionCommand>,
correlation: Arc<Mutex<CorrelationMap>>,
event_handlers: Arc<Mutex<HandlerVec>>,
pending_count: Arc<AtomicUsize>,
) {
let (mut ws_write, mut ws_read) = ws_stream.split();
loop {
tokio::select! {
message = ws_read.next() => {
match message {
Some(Ok(Message::Text(text))) => {
let reply = Self::handle_incoming_message(
&text,
&correlation,
&event_handlers,
&pending_count,
);
if let Some(reply) = reply
&& let Ok(json) = to_string(&reply)
&& let Err(e) = ws_write.send(Message::Text(json.into())).await
{
warn!(error = %e, "Failed to send event reply");
}
}
Some(Ok(Message::Close(_))) => {
debug!("WebSocket closed by remote");
break;
}
Some(Err(e)) => {
error!(error = %e, "WebSocket error");
break;
}
None => {
debug!("WebSocket stream ended");
break;
}
_ => {}
}
}
command = command_rx.recv() => {
match command {
Some(ConnectionCommand::Send { request, response_tx }) => {
Self::handle_send_command(
request,
response_tx,
&mut ws_write,
&correlation,
&pending_count,
).await;
}
Some(ConnectionCommand::RemoveCorrelation(request_id)) => {
if correlation.lock().remove(&request_id).is_some() {
pending_count.fetch_sub(1, Ordering::Relaxed);
}
debug!(?request_id, "Removed timed-out correlation");
}
Some(ConnectionCommand::Shutdown) => {
debug!("Shutdown command received");
let _ = ws_write.close().await;
break;
}
None => {
debug!("Command channel closed");
break;
}
}
}
}
}
Self::fail_pending_requests(&correlation, &pending_count);
debug!("Event loop terminated");
}
fn handle_incoming_message(
text: &str,
correlation: &Arc<Mutex<CorrelationMap>>,
event_handlers: &Arc<Mutex<HandlerVec>>,
pending_count: &Arc<AtomicUsize>,
) -> Option<EventReply> {
let value: serde_json::Value = match serde_json::from_str(text) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, text = %text, "Failed to parse incoming message as JSON");
return None;
}
};
if value
.get("type")
.and_then(|v| v.as_str())
.is_some_and(|t| t == "success" || t == "error")
{
let response: Response = match from_value(value) {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "Failed to deserialize Response from Value");
return None;
}
};
let tx = correlation.lock().remove(&response.id);
if let Some(tx) = tx {
pending_count.fetch_sub(1, Ordering::Relaxed);
let _ = tx.send(Ok(response));
} else {
warn!(id = %response.id, "Response for unknown request");
}
return None;
}
if value.get("method").is_some() {
let event: Event = match from_value(value) {
Ok(e) => e,
Err(e) => {
warn!(error = %e, "Failed to deserialize Event from Value");
return None;
}
};
let handlers: Vec<HandlerEntry> = {
let guard = event_handlers.lock();
guard.clone()
};
for (_key, handler) in &handlers {
if let Some(reply) = handler(event.clone()) {
return Some(reply);
}
}
return None;
}
warn!(text = %text, "Failed to parse incoming message: no type or method field");
None
}
async fn handle_send_command(
request: Request,
response_tx: oneshot::Sender<Result<Response>>,
ws_write: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>,
correlation: &Arc<Mutex<CorrelationMap>>,
pending_count: &Arc<AtomicUsize>,
) {
let request_id = request.id;
let json = match to_string(&request) {
Ok(j) => j,
Err(e) => {
let _ = response_tx.send(Err(Error::Json(e)));
return;
}
};
correlation.lock().insert(request_id, response_tx);
pending_count.fetch_add(1, Ordering::Relaxed);
if let Err(e) = ws_write.send(Message::Text(json.into())).await {
if let Some(tx) = correlation.lock().remove(&request_id) {
pending_count.fetch_sub(1, Ordering::Relaxed);
let _ = tx.send(Err(Error::connection(e.to_string())));
}
}
trace!(?request_id, "Request sent");
}
fn fail_pending_requests(
correlation: &Arc<Mutex<CorrelationMap>>,
pending_count: &Arc<AtomicUsize>,
) {
let pending: Vec<_> = correlation.lock().drain().collect();
let count = pending.len();
for (_, tx) in pending {
let _ = tx.send(Err(Error::ConnectionClosed));
}
pending_count.store(0, Ordering::Relaxed);
if count > 0 {
debug!(count, "Failed pending requests on shutdown");
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constants() {
assert_eq!(DEFAULT_COMMAND_TIMEOUT.as_secs(), 30);
assert_eq!(MAX_PENDING_REQUESTS, 100);
assert_eq!(READY_TIMEOUT.as_secs(), 30);
}
#[test]
fn test_ready_data() {
let data = ReadyData {
tab_id: 1,
session_id: 2,
};
assert_eq!(data.tab_id, 1);
assert_eq!(data.session_id, 2);
}
}