use anyhow::{Context, Result};
use base64::Engine;
use base64::engine::general_purpose;
use futures::{SinkExt, StreamExt};
use serde_json::{Value, json};
use std::net::TcpListener;
use std::process::{Child, Command, Stdio};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use crate::fetch_manager::BrowserConfig;
pub struct ChromeHeadless {
process: Arc<Mutex<Child>>,
port: u16,
}
fn find_available_port() -> Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
Ok(port)
}
impl ChromeHeadless {
pub async fn new(uri: &str) -> Result<Self> {
let browser_config = BrowserConfig::default().context("chromium isn't installed. either install it manually (chrome/msedge will do so too) or call `mcat --fetch-chromium`")?;
let path = browser_config.path;
let port = find_available_port()?;
let process = Command::new(path)
.args([
"--headless=new",
&format!("--remote-debugging-port={}", port),
&format!(
"--user-data-dir={}",
std::env::temp_dir().join("mcat-chrome-profile").display()
),
"--hide-scrollbars",
"--force-color-profile=srgb",
"--disable-dev-shm-usage",
"--disable-extensions",
"--disable-plugins",
"--disable-default-apps",
"--disable-sync",
"--disable-background-networking",
"--disable-breakpad",
"--disable-hang-monitor",
"--block-new-web-contents",
"--disable-notifications",
"--run-all-compositor-stages-before-draw",
"--disable-threaded-animation",
"--disable-threaded-scrolling",
"--disable-checker-imaging",
"--disable-features=PaintHolding",
"--disable-features=LazyFrameLoading",
"--disable-features=ScriptStreaming",
"--disable-features=AutofillServerCommunication",
"--disable-features=OptimizationHints",
"--disable-features=CertificateTransparencyComponentUpdater",
"--disable-features=Translate",
"--disable-features=MediaRouter",
"--disable-features=DialMediaRouteProvider",
"--disable-features=IsolateOrigins",
"--disable-client-side-phishing-detection",
"--disable-component-update",
"--disable-component-extensions-with-background-pages",
"--no-default-browser-check",
"--mute-audio",
"--password-store=basic",
"--use-mock-keychain",
"--no-service-autorun",
"--deny-permission-prompts",
"--disable-external-intent-requests",
"--autoplay-policy=no-user-gesture-required",
"--noerrdialogs",
"--no-pings",
"--disable-domain-reliability",
"--no-first-run",
"--disable-prompt-on-repost",
"--disable-background-timer-throttling",
"--disable-renderer-backgrounding",
"--disable-backgrounding-occluded-windows",
"--disable-ipc-flooding-protection",
uri,
])
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()?;
let shutdown = rasteroid::term_misc::setup_signal_handler();
let pc_arc = Arc::new(Mutex::new(process));
let shutdown_arc = pc_arc.clone();
tokio::spawn(async move {
loop {
if shutdown.load(Ordering::SeqCst) {
let mut process = shutdown_arc.lock().unwrap();
process.kill().unwrap();
std::process::exit(1);
};
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
});
let instance = Self {
process: pc_arc,
port,
};
instance.wait_for_server().await?;
Ok(instance)
}
async fn wait_for_server(&self) -> Result<()> {
loop {
match timeout(
Duration::from_millis(2000),
TcpStream::connect(format!("127.0.0.1:{}", self.port)),
)
.await
{
Ok(Ok(_)) => return Ok(()),
Ok(Err(_)) | Err(_) => tokio::task::yield_now().await,
}
}
}
pub async fn capture_screenshot(&self) -> Result<Vec<u8>> {
let endpoint = self.get_websocket_endpoint().await?;
let (mut ws_stream, _) = tokio_tungstenite::connect_async(&endpoint).await?;
self.send_command(&mut ws_stream, 1, "Page.enable", None, true)
.await?;
self.wait_for_load_event(&mut ws_stream).await?;
let metrics = self
.send_command(&mut ws_stream, 2, "Page.getLayoutMetrics", None, true)
.await?;
let width = metrics["contentSize"]["width"].as_f64().unwrap();
let height = metrics["contentSize"]["height"].as_f64().unwrap();
let desired_width = 1920.0;
let desired_height = 1080.0;
let scalex = if width > desired_width {
1.0
} else {
desired_width / width
};
let scaleh = if height > desired_height {
1.0
} else {
desired_height / height
};
let scale = scalex.min(scaleh).round() as u16;
self.send_command(
&mut ws_stream,
3,
"Emulation.setDeviceMetricsOverride",
Some(json!({
"mobile": false,
"width": width,
"height": height,
"deviceScaleFactor": scale
})),
false,
)
.await?;
self.send_command(
&mut ws_stream,
6,
"Emulation.setDefaultBackgroundColorOverride",
Some(json!({ "color": { "r": 0, "g": 0, "b": 0, "a": 0 } })),
false,
)
.await?;
let response = self
.send_command(
&mut ws_stream,
4,
"Page.captureScreenshot",
Some(json!({ "format": "png", "captureBeyondViewport": true })),
true,
)
.await?;
let screenshot_data = response["data"]
.as_str()
.context("failed to get screenshot")?;
let bytes = general_purpose::STANDARD.decode(screenshot_data)?;
{
let mut process = self.process.lock().unwrap();
let _ = process.kill();
}
tracing::info!(
size = bytes.len(),
"captured screenshot via headless chrome"
);
Ok(bytes)
}
async fn get_websocket_endpoint(&self) -> Result<String> {
let max_attempts = 10;
for _ in 1..=max_attempts {
let url = format!("http://127.0.0.1:{}/json", self.port);
let body = reqwest::get(&url).await?.text().await?;
let json: Value = serde_json::from_str(&body)?;
if let Some(arr) = json.as_array()
&& let Some(page) = arr.iter().find(|entry| entry["type"] == "page")
&& let Some(ws_url) = page["webSocketDebuggerUrl"].as_str()
{
return Ok(ws_url.to_owned());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
anyhow::bail!("Failed to get websocket for headless chrome")
}
async fn send_command(
&self,
ws_stream: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
id: u64,
method: &str,
params: Option<Value>,
wait: bool,
) -> Result<Value> {
let command = match params {
Some(params) => json!({ "id": id, "method": method, "params": params }),
None => json!({ "id": id, "method": method }),
};
ws_stream
.send(Message::Text(command.to_string().into()))
.await?;
if wait {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if let Message::Text(text) = msg {
let response: Value = serde_json::from_str(&text)?;
if response["id"] == id {
if let Some(error) = response.get("error") {
anyhow::bail!("Chrome error: {}", error);
}
return Ok(response["result"].to_owned());
}
}
}
anyhow::bail!("WebSocket connection closed unexpectedly")
} else {
Ok(Value::Null)
}
}
async fn wait_for_load_event(
&self,
ws_stream: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> Result<()> {
let ready_state = self
.send_command(
ws_stream,
100,
"Runtime.evaluate",
Some(json!({ "expression": "document.readyState" })),
true,
)
.await?;
if let Some(state) = ready_state["result"]["value"].as_str()
&& state == "complete"
{
return Ok(());
}
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if let Message::Text(text) = msg
&& let Ok(json) = serde_json::from_str::<Value>(&text)
&& json["method"] == "Page.loadEventFired"
{
return Ok(());
}
}
anyhow::bail!("WebSocket closed before loadEventFired")
}
}
impl Drop for ChromeHeadless {
fn drop(&mut self) {
let mut process = self.process.lock().unwrap();
let _ = process.kill();
}
}