chromedriver_launch/
chromedriver.rs

1// Copyright (C) 2025 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4use std::collections::HashSet;
5use std::net::IpAddr;
6use std::net::Ipv4Addr;
7use std::net::SocketAddr;
8use std::os::unix::process::CommandExt as _;
9use std::path::Path;
10use std::path::PathBuf;
11use std::process::Child;
12use std::process::Command;
13use std::process::Stdio;
14use std::thread::sleep;
15use std::time::Duration;
16use std::time::Instant;
17
18use anyhow::bail;
19use anyhow::Context as _;
20use anyhow::Result;
21
22use libc::killpg;
23use libc::setpgid;
24use libc::SIGKILL;
25
26use crate::socket;
27use crate::tcp;
28use crate::util::check;
29
30
31/// The name of the `chromedriver` binary.
32const CHROME_DRIVER: &str = "chromedriver";
33/// The timeout used when searching for a bound local port.
34const PORT_FIND_TIMEOUT: Duration = Duration::from_secs(30);
35
36
37fn find_localhost_port(pid: u32) -> Result<u16> {
38  let start = Instant::now();
39
40  // Wait for the driver process to bind to a local host address.
41  let port = loop {
42    let inodes = socket::socket_inodes(pid)?.collect::<Result<HashSet<_>>>()?;
43    let result = tcp::parse(pid)?.find(|result| match result {
44      Ok(entry) => {
45        if inodes.contains(&entry.inode) {
46          entry.addr == Ipv4Addr::LOCALHOST
47        } else {
48          false
49        }
50      },
51      Err(_) => true,
52    });
53    match result {
54      None => {
55        if start.elapsed() >= PORT_FIND_TIMEOUT {
56          bail!("failed to find local host port for process {pid}");
57        }
58        sleep(Duration::from_millis(1))
59      },
60      Some(result) => {
61        break result
62          .context("failed to find localhost proc tcp entry")?
63          .port
64      },
65    }
66  };
67
68  Ok(port)
69}
70
71
72/// A builder for configurable launch of a Chromedriver process.
73#[derive(Debug)]
74pub struct Builder {
75  /// The path to the `chromedriver` binary to use.
76  chromedriver: PathBuf,
77  /// The timeout to use waiting for `chromedriver` to start up
78  /// properly.
79  timeout: Duration,
80}
81
82impl Builder {
83  /// Set the Chromedriver to use.
84  pub fn set_chromedriver(mut self, chromedriver: impl AsRef<Path>) -> Self {
85    self.chromedriver = chromedriver.as_ref().to_path_buf();
86    self
87  }
88
89  /// Set the timeout to wait for Chromedriver to start up properly.
90  pub fn set_timeout(mut self, timeout: Duration) -> Self {
91    self.timeout = timeout;
92    self
93  }
94
95  /// Launch the Chromedriver process and wait for it to be fully
96  /// initialized and serving a webdriver service.
97  pub fn launch(self) -> Result<Chromedriver> {
98    let process = unsafe {
99      Command::new(CHROME_DRIVER)
100        .arg("--port=0")
101        .stdout(Stdio::piped())
102        .stderr(Stdio::piped())
103        .pre_exec(|| {
104          // Create a new process group, so that we can destroy the
105          // entire group and have reasonable assurance that nothing
106          // stray stays around.
107          let result = setpgid(0, 0);
108          check(result, -1)
109        })
110        .spawn()
111        .with_context(|| format!("failed to launch `{CHROME_DRIVER}` instance"))
112    }?;
113
114    let pid = process.id();
115    let port = find_localhost_port(pid)?;
116
117    let slf = Chromedriver { process, port };
118    Ok(slf)
119  }
120}
121
122impl Default for Builder {
123  fn default() -> Self {
124    Self {
125      chromedriver: PathBuf::from(CHROME_DRIVER),
126      timeout: PORT_FIND_TIMEOUT,
127    }
128  }
129}
130
131
132/// A client for shaving data of websites.
133#[derive(Debug)]
134pub struct Chromedriver {
135  /// The Chromdriver process.
136  process: Child,
137  /// The port on which the webdriver protocol is being served.
138  port: u16,
139}
140
141impl Chromedriver {
142  /// Launch a Chromedriver process and wait for it to be serving a
143  /// webdriver service.
144  pub fn launch() -> Result<Self> {
145    Self::builder().launch()
146  }
147
148  /// Create a [`Builder`] for configurable launch of a Chromedriver
149  /// process.
150  pub fn builder() -> Builder {
151    Builder::default()
152  }
153
154  /// Destroy the Chromedriver process, freeing up all resources.
155  fn destroy_impl(&mut self) -> Result<()> {
156    // NB: We created the process in a new process group, so the process
157    //     ID equals the process group ID here.
158    let pid = self.process.id();
159    // SAFETY: `killpg` is always save to call.
160    let result = unsafe { killpg(pid as _, SIGKILL) };
161    let () = check(result, -1).context("failed to shut down chromedriver process group")?;
162
163    // Clean up the child to prevent any build up of zombie processes.
164    // The `killpg()` should pretty much be immediate, so the `wait()`
165    // shouldn't be blocking for long. However, using `try_wait()`
166    // instead could probably be racy, as `kill()` will only deliver the
167    // signal, not ensure that it got processed to completion.
168    let _status = self.process.wait()?;
169    Ok(())
170  }
171
172  /// Destroy the Chromedriver process, freeing up all resources.
173  #[inline]
174  pub fn destroy(mut self) -> Result<()> {
175    self.destroy_impl()
176  }
177
178  /// Retrieve the socket address on which the webdriver service is
179  /// listening.
180  #[inline]
181  pub fn socket_addr(&self) -> SocketAddr {
182    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), self.port)
183  }
184}
185
186impl Drop for Chromedriver {
187  fn drop(&mut self) {
188    let _result = self.destroy_impl();
189  }
190}
191
192
193#[cfg(test)]
194mod tests {
195  use super::*;
196
197  use std::net::TcpListener;
198  use std::process;
199
200
201  /// Check that we can find a bound port on localhost.
202  #[test]
203  fn localhost_port_finding() {
204    let listener = TcpListener::bind("127.0.0.1:0").unwrap();
205    let addr = listener.local_addr().unwrap();
206    let port = find_localhost_port(process::id()).unwrap();
207    assert_eq!(port, addr.port());
208  }
209}