chromedriver_update/
lib.rs1use regex::Regex;
53use std::{
54 io::{Cursor, Read},
55 process::Output,
56};
57use thiserror::Error;
58use tokio::{fs::File, io::AsyncWriteExt, process::Command};
59
60pub mod constant;
61use constant::{
62 CHROME_BROWSER_PATH, CHROME_DRIVER_PATH, CONNECT_TIMEOUT, DRIVER_FILE, TIMEOUT, ZIP_PATH,
63};
64
65pub struct ChromeDriver {
66 pub version: String,
68 pub browser_version: String,
70 path: String,
71 browser_path: String,
72 connect_timeout: u64,
73 timeout: u64,
74}
75
76impl ChromeDriver {
77 pub fn new() -> Self {
79 Self {
80 version: String::new(),
81 path: CHROME_DRIVER_PATH.to_string(),
82 browser_version: String::new(),
83 browser_path: CHROME_BROWSER_PATH.to_string(),
84 connect_timeout: CONNECT_TIMEOUT,
85 timeout: TIMEOUT,
86 }
87 }
88
89 pub fn set_driver_path(&mut self, path: &str) -> &mut Self {
94 self.path = path.to_string();
95 self
96 }
97
98 pub fn set_browser_path(&mut self, path: &str) -> &mut Self {
103 self.browser_path = path.to_string();
104 self
105 }
106
107 pub fn set_connect_timeout(&mut self, timeout: u64) -> &mut Self {
109 self.connect_timeout = timeout;
110 self
111 }
112
113 pub fn set_timeout(&mut self, timeout: u64) -> &mut Self {
115 self.timeout = timeout;
116 self
117 }
118
119 pub async fn init(&mut self) -> DriverResult<()> {
121 self.version = self.get_driver_version().await;
122 self.browser_version = self.get_browser_version().await?;
123
124 Ok(())
125 }
126
127 pub fn need_download(&self) -> bool {
129 !self.version.eq(&self.browser_version)
130 }
131
132 pub async fn try_download(&self) -> DriverResult<()> {
134 let client = reqwest::Client::builder()
135 .danger_accept_invalid_certs(true)
136 .connect_timeout(std::time::Duration::from_millis(self.connect_timeout))
137 .timeout(std::time::Duration::from_millis(self.timeout))
138 .build()
139 .map_err(|_| DriverError::RequestInvalid)?;
140
141 let url = format!(
142 "https://storage.googleapis.com/chrome-for-testing-public/{}/{}",
143 self.browser_version,
144 ZIP_PATH.as_str()
145 );
146 let bytes = client
147 .get(url)
148 .send()
149 .await
150 .map_err(|_| DriverError::RequestTimeout)?
151 .bytes()
152 .await
153 .map_err(|_| DriverError::RequestInvalid)?;
154
155 let cursor = Cursor::new(bytes.as_ref());
156 let mut archive = zip::ZipArchive::new(cursor)
157 .map_err(|_| DriverError::ResourceInvalid(ZIP_PATH.to_string()))?;
158
159 for i in 0..archive.len() {
160 let mut file = archive.by_index(i).unwrap();
161 if file.name().eq(DRIVER_FILE.as_str()) {
162 let mut output_file = File::create(&self.path)
163 .await
164 .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
165 let mut buffer = Vec::new();
166 file.read_to_end(&mut buffer)
167 .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
168 output_file
169 .write_all(&buffer)
170 .await
171 .map_err(|_| DriverError::ResourceInvalid(self.path.clone()))?;
172
173 #[cfg(unix)]
174 {
175 use std::{fs, os::unix::fs::PermissionsExt};
176 let mut permissions = fs::metadata(&self.path).unwrap().permissions();
177 permissions.set_mode(0o755);
178 fs::set_permissions(&self.path, permissions).unwrap();
179 }
180
181 return Ok(());
182 }
183 }
184
185 Err(DriverError::ResourceNotFound(DRIVER_FILE.to_string()))
186 }
187
188 async fn get_driver_version(&self) -> String {
189 match Command::new(self.path.clone())
190 .arg("--version")
191 .output()
192 .await
193 {
194 Ok(res) => get_version_from_output(res),
195 Err(_) => String::new(),
196 }
197 }
198
199 async fn get_browser_version(&self) -> DriverResult<String> {
200 let path = self.browser_path.clone();
201
202 #[cfg(unix)]
203 {
204 let output = Command::new(&path)
205 .arg("--version")
206 .output()
207 .await
208 .map_err(|_| DriverError::BrowserNotFound(path))?;
209 Ok(get_version_from_output(output))
210 }
211
212 #[cfg(windows)]
213 {
214 use std::process::Stdio;
215 let cmd = format!(
216 r#"(Get-Item (Get-Command '{}').Source).VersionInfo.ProductVersion"#,
217 &path
218 );
219
220 let output = Command::new("powershell")
221 .arg("-Command")
222 .arg(&cmd)
223 .stdout(Stdio::piped())
224 .stderr(Stdio::piped())
225 .output()
226 .await
227 .map_err(|_| DriverError::BrowserNotFound(path.clone()))?;
228
229 if !output.status.success() {
230 Err(DriverError::BrowserNotFound(path))
232 } else {
233 Ok(get_version_from_output(output))
234 }
235 }
236 }
237}
238
239fn get_version_from_output(output: Output) -> String {
240 let text = String::from_utf8_lossy(&output.stdout).into_owned();
241 let re = Regex::new(r"\d+\.\d+\.\d+\.\d+").unwrap();
242 re.captures(&text)
243 .unwrap()
244 .get(0)
245 .unwrap()
246 .as_str()
247 .to_string()
248}
249
250#[derive(Error, Debug)]
251pub enum DriverError {
252 #[error("browser not found `{0}`")]
253 BrowserNotFound(String),
254 #[error("resource not found `{0}`")]
255 ResourceNotFound(String),
256 #[error("resource invalid `{0}`")]
257 ResourceInvalid(String),
258 #[error("download request timeout, please increase connect_timeout/timeout or use vpn")]
259 RequestTimeout,
260 #[error("failed to send request")]
261 RequestInvalid,
262}
263
264type DriverResult<T> = Result<T, DriverError>;