use futures::{SinkExt, StreamExt};
use std::path::Path;
use std::sync::Arc;
use tokio_tungstenite::tungstenite::Message;
use super::transport::CdpDispatcher;
use crate::error::{FerriError, Result};
pub struct WsTransport {
write_tx: tokio::sync::mpsc::Sender<Message>,
dispatcher: Arc<CdpDispatcher>,
}
impl WsTransport {
pub async fn connect(ws_url: &str) -> Result<Self> {
Box::pin(Self::connect_with_headers(ws_url, &std::collections::HashMap::new())).await
}
pub async fn connect_with_headers(ws_url: &str, headers: &std::collections::HashMap<String, String>) -> Result<Self> {
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http;
let mut request = ws_url
.into_client_request()
.map_err(|e| FerriError::Backend(format!("WebSocket request build: {e}")))?;
for (key, value) in headers {
let header_name = http::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| FerriError::Backend(format!("invalid header name '{key}': {e}")))?;
let header_value = http::header::HeaderValue::from_str(value)
.map_err(|e| FerriError::Backend(format!("invalid header value for '{key}': {e}")))?;
request.headers_mut().insert(header_name, header_value);
}
let (ws_stream, _) = Box::pin(tokio_tungstenite::connect_async(request))
.await
.map_err(|e| FerriError::Backend(format!("WebSocket connect to {ws_url}: {e}")))?;
let (write, read) = ws_stream.split();
let dispatcher = Arc::new(CdpDispatcher::new());
let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::<Message>(64);
tokio::spawn(async move {
let mut writer = write;
while let Some(msg) = write_rx.recv().await {
if writer.send(msg).await.is_err() {
break;
}
}
});
let dispatcher2 = dispatcher.clone();
tokio::spawn(async move {
let mut read = read;
while let Some(Ok(msg)) = read.next().await {
let Message::Text(text) = msg else { continue };
dispatcher2.dispatch_message(text.as_bytes());
}
dispatcher2.fail_all_pending("CDP transport closed (websocket ended)");
});
Ok(Self { write_tx, dispatcher })
}
pub async fn spawn(
chromium_path: &str,
user_data_dir: &Path,
extra_flags: &[String],
) -> Result<(Self, tokio::process::Child)> {
let mut command = tokio::process::Command::new(chromium_path);
command.arg(format!("--user-data-dir={}", user_data_dir.display()));
command.arg("--remote-debugging-port=0");
for flag in extra_flags {
command.arg(flag);
}
command.arg("--no-startup-window");
command
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
#[cfg(unix)]
#[allow(unsafe_code)]
unsafe {
command.pre_exec(super::super::process::setsid_pre_exec());
}
let mut child = command
.spawn()
.map_err(|e| FerriError::Backend(format!("Chrome launch: {e}")))?;
let port_file = user_data_dir.join("DevToolsActivePort");
let ws_url = discover_ws_url(&port_file, &mut child).await?;
let transport = Box::pin(Self::connect(&ws_url)).await?;
Ok((transport, child))
}
}
impl super::transport::CdpTransport for WsTransport {
#[tracing::instrument(skip(self, session_id, params), fields(method))]
async fn send_command(
&self,
session_id: Option<&str>,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
let (mut data, rx) = self.dispatcher.build_command(session_id, method, ¶ms)?;
if data.last() == Some(&0) {
data.pop();
}
let text = String::from_utf8(data).map_err(|e| FerriError::Backend(format!("UTF-8: {e}")))?;
self
.write_tx
.send(Message::Text(text.into()))
.await
.map_err(|_| FerriError::backend("WS writer closed"))?;
match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(FerriError::Backend(format!("Response channel dropped for {method}"))),
Err(_) => Err(FerriError::timeout(format!("waiting for {method} response"), 30_000)),
}
}
fn subscribe_events(&self) -> tokio::sync::broadcast::Receiver<std::sync::Arc<serde_json::Value>> {
self.dispatcher.subscribe_events()
}
fn register_lifecycle_tracker(
&self,
session_id: &str,
state: Arc<std::sync::Mutex<super::LifecycleState>>,
notify: Arc<tokio::sync::Notify>,
) {
self.dispatcher.register_lifecycle_tracker(session_id, state, notify);
}
}
async fn discover_ws_url(port_file: &Path, child: &mut tokio::process::Child) -> Result<String> {
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10);
loop {
if tokio::time::Instant::now() >= deadline {
return Err(FerriError::timeout("waiting for DevToolsActivePort", 10_000));
}
if let Ok(contents) = tokio::fs::read_to_string(port_file).await {
let lines: Vec<&str> = contents.lines().collect();
if lines.len() >= 2 {
let port = lines[0].trim();
let path = lines[1].trim();
return Ok(format!("ws://127.0.0.1:{port}{path}"));
}
}
if let Ok(Some(status)) = child.try_wait() {
return Err(FerriError::Backend(format!(
"Chrome exited with status {status} before providing DevToolsActivePort"
)));
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
}