use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::task::JoinHandle;
use tracing::{debug, warn};
use super::browser::BrowserKind;
use super::error::ManagerError;
use super::status::{
DriverId, DriverLogLine, DriverLogSubscription, DriverStream, Emitter, LogSubscribers, Status,
};
#[derive(Debug, Clone, Copy, Default)]
pub enum StdioMode {
#[default]
Tracing,
Inherit,
Null,
}
impl StdioMode {
fn to_stdio(self) -> Stdio {
match self {
StdioMode::Tracing => Stdio::piped(),
StdioMode::Inherit => Stdio::inherit(),
StdioMode::Null => Stdio::null(),
}
}
}
pub(crate) struct SpawnConfig {
pub host: IpAddr,
pub ready_timeout: Duration,
pub stdio: StdioMode,
}
impl Default for SpawnConfig {
fn default() -> Self {
Self {
host: IpAddr::V4(Ipv4Addr::LOCALHOST),
ready_timeout: Duration::from_secs(30),
stdio: StdioMode::default(),
}
}
}
pub(crate) struct ManagedDriverProcess {
pub host: IpAddr,
pub port: u16,
pub browser: BrowserKind,
pub version: String,
pub driver_id: DriverId,
pub log_subscribers: LogSubscribers,
child: Option<Child>,
shutdown: Arc<AtomicBool>,
pump_handles: Vec<JoinHandle<()>>,
emitter: Emitter,
}
impl std::fmt::Debug for ManagedDriverProcess {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ManagedDriverProcess")
.field("host", &self.host)
.field("port", &self.port)
.field("browser", &self.browser)
.finish()
}
}
pub(crate) struct SpawnContext<'a> {
pub driver_id: DriverId,
pub version: &'a str,
pub emitter: &'a Emitter,
pub manager_log_subscribers: LogSubscribers,
}
impl ManagedDriverProcess {
pub(crate) async fn spawn(
binary: &Path,
browser: BrowserKind,
cfg: &SpawnConfig,
ctx: SpawnContext<'_>,
) -> Result<Self, ManagerError> {
const MAX_PORT_ATTEMPTS: u8 = 3;
let mut last_err: Option<ManagerError> = None;
for attempt in 0..MAX_PORT_ATTEMPTS {
let port = pick_port(cfg.host)?;
match spawn_at_port(binary, browser, cfg, port, &ctx).await {
Ok(p) => return Ok(p),
Err(e) if is_port_in_use(&e) => {
debug!("driver port {port} already in use (attempt {attempt}): {e}");
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| ManagerError::Spawn("port allocation exhausted".into())))
}
pub(crate) fn url(&self) -> String {
format!("http://{}:{}", self.host, self.port)
}
pub(crate) fn subscribe_log<F>(&self, f: F) -> DriverLogSubscription
where
F: Fn(&DriverLogLine) + Send + Sync + 'static,
{
self.log_subscribers.add(f)
}
}
fn is_port_in_use(err: &ManagerError) -> bool {
let msg = match err {
ManagerError::Spawn(s) => s.as_str(),
_ => return false,
};
let lower = msg.to_ascii_lowercase();
lower.contains("address already in use")
|| lower.contains("only one usage of each socket address")
|| lower.contains("addrinuse")
}
async fn spawn_at_port(
binary: &Path,
browser: BrowserKind,
cfg: &SpawnConfig,
port: u16,
ctx: &SpawnContext<'_>,
) -> Result<ManagedDriverProcess, ManagerError> {
let mut cmd = Command::new(binary);
cmd.arg(format!("--port={port}"));
cmd.stdout(cfg.stdio.to_stdio());
cmd.stderr(cfg.stdio.to_stdio());
cmd.stdin(Stdio::null());
cmd.kill_on_drop(true);
if matches!(browser, BrowserKind::Chrome | BrowserKind::Edge)
&& cfg.host != IpAddr::V4(Ipv4Addr::LOCALHOST)
{
cmd.arg(format!("--allowed-ips={}", cfg.host));
}
let mut child = cmd
.spawn()
.map_err(|e| ManagerError::Spawn(format!("spawn {}: {}", binary.display(), e)))?;
let pid = child.id().unwrap_or(0);
ctx.emitter.emit(Status::DriverProcessSpawned {
browser,
version: ctx.version.to_string(),
pid,
port,
binary: PathBuf::from(binary),
});
let shutdown = Arc::new(AtomicBool::new(false));
let log_subscribers = LogSubscribers::new();
let mut pump_handles = Vec::new();
if matches!(cfg.stdio, StdioMode::Tracing) {
let line_ctx = LogLineContext {
driver_id: ctx.driver_id,
browser,
version: ctx.version.to_string(),
port,
};
if let Some(stdout) = child.stdout.take() {
pump_handles.push(spawn_pump(
DriverStream::Stdout,
stdout,
Arc::clone(&shutdown),
line_ctx.clone(),
log_subscribers.clone(),
ctx.manager_log_subscribers.clone(),
));
}
if let Some(stderr) = child.stderr.take() {
pump_handles.push(spawn_pump(
DriverStream::Stderr,
stderr,
Arc::clone(&shutdown),
line_ctx,
log_subscribers.clone(),
ctx.manager_log_subscribers.clone(),
));
}
}
let ready_started = Instant::now();
if let Err(e) = wait_until_ready(cfg.host, port, cfg.ready_timeout).await {
let _ = child.kill().await;
for h in &pump_handles {
h.abort();
}
return Err(e);
}
let url = format!("http://{}:{}", cfg.host, port);
ctx.emitter.emit(Status::DriverReady {
browser,
version: ctx.version.to_string(),
url,
elapsed: ready_started.elapsed(),
});
Ok(ManagedDriverProcess {
host: cfg.host,
port,
browser,
version: ctx.version.to_string(),
driver_id: ctx.driver_id,
log_subscribers,
child: Some(child),
shutdown,
pump_handles,
emitter: ctx.emitter.clone(),
})
}
fn pick_port(host: IpAddr) -> Result<u16, ManagerError> {
let listener = TcpListener::bind(SocketAddr::new(host, 0))
.map_err(|e| ManagerError::Spawn(format!("bind ephemeral port: {e}")))?;
let port =
listener.local_addr().map_err(|e| ManagerError::Spawn(format!("local_addr: {e}")))?.port();
drop(listener);
Ok(port)
}
#[derive(Clone)]
struct LogLineContext {
driver_id: DriverId,
browser: BrowserKind,
version: String,
port: u16,
}
fn spawn_pump<R>(
stream: DriverStream,
reader: R,
shutdown: Arc<AtomicBool>,
ctx: LogLineContext,
process_subs: LogSubscribers,
manager_subs: LogSubscribers,
) -> JoinHandle<()>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let stream_label = match stream {
DriverStream::Stdout => "stdout",
DriverStream::Stderr => "stderr",
};
tokio::spawn(async move {
let mut lines = BufReader::new(reader).lines();
loop {
if shutdown.load(Ordering::Relaxed) {
break;
}
match lines.next_line().await {
Ok(Some(line)) => {
debug!(target: "thirtyfour::manager::driver", stream = stream_label, line = %line);
let log = DriverLogLine {
driver_id: ctx.driver_id,
browser: ctx.browser,
version: ctx.version.clone(),
port: ctx.port,
stream,
line,
};
process_subs.dispatch(&log);
manager_subs.dispatch(&log);
}
Ok(None) => break,
Err(e) => {
warn!(target: "thirtyfour::manager::driver", stream = stream_label, error = %e);
break;
}
}
}
})
}
async fn wait_until_ready(host: IpAddr, port: u16, timeout: Duration) -> Result<(), ManagerError> {
let url = format!("http://{host}:{port}/status");
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.map_err(|e| ManagerError::Spawn(e.to_string()))?;
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if let Ok(resp) = client.get(&url).send().await
&& resp.status().is_success()
&& let Ok(body) = resp.json::<serde_json::Value>().await
&& body
.get("value")
.and_then(|v| v.get("ready"))
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
return Ok(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
Err(ManagerError::DriverNotReady(timeout))
}
impl Drop for ManagedDriverProcess {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
if let Some(mut child) = self.child.take()
&& let Err(e) = child.start_kill()
{
warn!(target: "thirtyfour::manager", error = %e, "failed to kill driver");
}
for h in self.pump_handles.drain(..) {
h.abort();
}
self.emitter.emit(Status::DriverShutdown {
browser: self.browser,
version: self.version.clone(),
port: self.port,
});
}
}