chromedriver_update/
lib.rs

1/*!
2Automatically download Chromedriver when the browser/driver versions do not match.
3
4### Use default values
5```no_run
6use chromedriver_update::ChromeDriver;
7
8#[tokio::main]
9async fn main() {
10    let mut driver = ChromeDriver::new();
11    driver.init().await.unwrap();
12    if driver.need_download() {
13        driver.try_download().await.unwrap();
14    }
15}
16```
17
18### Use custom values
19
20```no_run
21use chromedriver_update::ChromeDriver;
22
23#[tokio::main]
24async fn main() {
25    let mut driver = ChromeDriver::new();
26    driver
27        .set_driver_path("/usr/local/bin/chromedriver")
28        .set_browser_path("/Applications/Google Chrome.app/Contents/MacOS/Google Chrome")
29        .set_connect_timeout(2000)
30        .set_timeout(5000)
31        .init()
32        .await
33        .unwrap();
34
35    println!("driver version {}", driver.version);
36    println!("browser version {}", driver.browser_version);
37
38    if !driver.need_download() {
39        println!("no need to update driver");
40        return;
41    }
42
43    println!("updating driver ...");
44
45    match driver.try_download().await {
46        Ok(_) => println!("Download driver successful"),
47        Err(err) => eprintln!("Download driver failed, {}", err),
48    }
49}
50```
51*/
52use regex::Regex;
53use std::{
54    io::{Cursor, Read},
55    process::Output,
56};
57use thiserror::Error;
58use tokio::{fs::File, io::AsyncWriteExt, process::Command};
59
60pub mod constant;
61use constant::{
62    CHROME_BROWSER_PATH, CHROME_DRIVER_PATH, CONNECT_TIMEOUT, DRIVER_FILE, TIMEOUT, ZIP_PATH,
63};
64
65pub struct ChromeDriver {
66    /// Chrome driver version
67    pub version: String,
68    /// Chrome browser version
69    pub browser_version: String,
70    path: String,
71    browser_path: String,
72    connect_timeout: u64,
73    timeout: u64,
74}
75
76impl ChromeDriver {
77    /// Create driver
78    pub fn new() -> Self {
79        Self {
80            version: String::new(),
81            path: CHROME_DRIVER_PATH.to_string(),
82            browser_version: String::new(),
83            browser_path: CHROME_BROWSER_PATH.to_string(),
84            connect_timeout: CONNECT_TIMEOUT,
85            timeout: TIMEOUT,
86        }
87    }
88
89    /// Update chromedriver path. Default:
90    /// - mac:    `/usr/local/bin/chromedriver`
91    /// - linux:  `/usr/bin/chromedriver`
92    /// - windows:  ``
93    pub fn set_driver_path(&mut self, path: &str) -> &mut Self {
94        self.path = path.to_string();
95        self
96    }
97
98    /// Update chrome browser path. Default:
99    /// - mac:    `/Applications/Google Chrome.app/Contents/MacOS/Google Chrome`
100    /// - linux:  `/usr/bin/google-chrome`
101    /// - windows:  ``
102    pub fn set_browser_path(&mut self, path: &str) -> &mut Self {
103        self.browser_path = path.to_string();
104        self
105    }
106
107    /// Update connect_timeout (ms) for download requests. Default: 5000.
108    pub fn set_connect_timeout(&mut self, timeout: u64) -> &mut Self {
109        self.connect_timeout = timeout;
110        self
111    }
112
113    /// Update timeout (ms) for download requests. Default: 5000.
114    pub fn set_timeout(&mut self, timeout: u64) -> &mut Self {
115        self.timeout = timeout;
116        self
117    }
118
119    /// Setup driver & browser version
120    pub async fn init(&mut self) -> DriverResult<()> {
121        self.version = self.get_driver_version().await;
122        self.browser_version = self.get_browser_version().await?;
123
124        Ok(())
125    }
126
127    /// Compare driver & browser version
128    pub fn need_download(&self) -> bool {
129        !self.version.eq(&self.browser_version)
130    }
131
132    /// Download Chromedriver
133    pub async fn try_download(&self) -> DriverResult<()> {
134        let client = reqwest::Client::builder()
135            .danger_accept_invalid_certs(true)
136            .connect_timeout(std::time::Duration::from_millis(self.connect_timeout))
137            .timeout(std::time::Duration::from_millis(self.timeout))
138            .build()
139            .map_err(|_| DriverError::RequestInvalid)?;
140
141        let url = format!(
142            "https://storage.googleapis.com/chrome-for-testing-public/{}/{}",
143            self.browser_version,
144            ZIP_PATH.as_str()
145        );
146        let bytes = client
147            .get(url)
148            .send()
149            .await
150            .map_err(|_| DriverError::RequestTimeout)?
151            .bytes()
152            .await
153            .map_err(|_| DriverError::RequestInvalid)?;
154
155        let cursor = Cursor::new(bytes.as_ref());
156        let mut archive = zip::ZipArchive::new(cursor)
157            .map_err(|_| DriverError::ResourceInvalid(ZIP_PATH.to_string()))?;
158
159        for i in 0..archive.len() {
160            let mut file = archive.by_index(i).unwrap();
161            if file.name().eq(DRIVER_FILE.as_str()) {
162                let mut output_file = File::create(&self.path)
163                    .await
164                    .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
165                let mut buffer = Vec::new();
166                file.read_to_end(&mut buffer)
167                    .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
168                output_file
169                    .write_all(&buffer)
170                    .await
171                    .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
172
173                #[cfg(unix)]
174                {
175                    use std::{fs, os::unix::fs::PermissionsExt};
176                    let mut permissions = fs::metadata(&self.path).unwrap().permissions();
177                    permissions.set_mode(0o755);
178                    fs::set_permissions(&self.path, permissions).unwrap();
179                }
180
181                return Ok(());
182            }
183        }
184
185        Err(DriverError::ResourceNotFound(DRIVER_FILE.to_string()))
186    }
187
188    async fn get_driver_version(&self) -> String {
189        match Command::new(self.path.clone())
190            .arg("--version")
191            .output()
192            .await
193        {
194            Ok(res) => get_version_from_output(res),
195            Err(_) => String::new(),
196        }
197    }
198
199    async fn get_browser_version(&self) -> DriverResult<String> {
200        let path = self.browser_path.clone();
201
202        #[cfg(unix)]
203        {
204            let output = Command::new(&path)
205                .arg("--version")
206                .output()
207                .await
208                .map_err(|_| DriverError::BrowserNotFound(path))?;
209            Ok(get_version_from_output(output))
210        }
211
212        #[cfg(windows)]
213        {
214            use std::process::Stdio;
215            let cmd = format!(
216                r#"(Get-Item (Get-Command '{}').Source).VersionInfo.ProductVersion"#,
217                &path
218            );
219
220            let output = Command::new("powershell")
221                .arg("-Command")
222                .arg(&cmd)
223                .stdout(Stdio::piped())
224                .stderr(Stdio::piped())
225                .output()
226                .await
227                .map_err(|_| DriverError::BrowserNotFound(path.clone()))?;
228
229            if !output.status.success() {
230                // let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
231                Err(DriverError::BrowserNotFound(path))
232            } else {
233                Ok(get_version_from_output(output))
234            }
235        }
236    }
237}
238
239fn get_version_from_output(output: Output) -> String {
240    let text = String::from_utf8_lossy(&output.stdout).into_owned();
241    let re = Regex::new(r"\d+\.\d+\.\d+\.\d+").unwrap();
242    re.captures(&text)
243        .unwrap()
244        .get(0)
245        .unwrap()
246        .as_str()
247        .to_string()
248}
249
250#[derive(Error, Debug)]
251pub enum DriverError {
252    #[error("browser not found `{0}`")]
253    BrowserNotFound(String),
254    #[error("resource not found `{0}`")]
255    ResourceNotFound(String),
256    #[error("resource invalid `{0}`")]
257    ResourceInvalid(String),
258    #[error("download request timeout, please increase connect_timeout/timeout or use vpn")]
259    RequestTimeout,
260    #[error("failed to send request")]
261    RequestInvalid,
262}
263
264type DriverResult<T> = Result<T, DriverError>;