1use std::env::temp_dir;
2use std::io::ErrorKind;
3use std::path::{Path, PathBuf};
4
5use chrono::{DateTime, Utc};
6#[cfg(feature = "progress")]
7use indicatif::style::TemplateError;
8#[cfg(feature = "progress")]
9use indicatif::{MultiProgress, ProgressStyle};
10use reqwest::Url;
11use semver::Version;
12use thiserror::Error;
13
14mod impls;
15
16#[derive(Error, Debug)]
17pub enum Error {
18 #[error(transparent)]
19 ReqwestError(#[from] reqwest::Error),
20
21 #[error(transparent)]
22 UrlParseError(#[from] url::ParseError),
23
24 #[error("Invalid version")]
25 InvalidVersionError,
26
27 #[error(transparent)]
28 IoError(#[from] std::io::Error),
29
30 #[cfg(feature = "progress")]
31 #[error(transparent)]
32 TemplateError(#[from] TemplateError),
33
34 #[error("Failed to get content length from '{0}'")]
35 InvalidContentLengthError(String),
36
37 #[error("File integrity mismatch. Expected size {0}, found {1}")]
38 InvalidFileSize(u64, u64),
39
40 #[error("File checksum failed")]
41 InvalidFileChecksum,
42
43 #[error(transparent)]
44 InvalidCredentialsError(anyhow::Error),
45}
46
47pub trait Registry {
68 fn get_base_url(&self) -> Url;
71
72 fn get_update_path<C: Config>(&self, config: &C) -> String;
74
75 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>>;
77}
78
79pub trait Config {
100 fn version(&self) -> Version;
102
103 fn target(&self) -> String;
105}
106
107#[derive(Debug, serde::Deserialize)]
121pub struct RemoteVersion {
122 #[serde(deserialize_with = "impls::value_to_version")]
124 pub version: Version,
125 pub checksum: String,
127 pub size: usize,
129 pub path: String,
131 pub datetime: DateTime<Utc>,
133}
134
135pub type Result<T> = std::result::Result<T, crate::Error>;
136
137pub async fn check_version<C: Config, R: Registry>(config: &C, registry: &R) -> Result<(bool, RemoteVersion)> {
166 impls::fetch_remote_version(config, registry)
167 .await
168 .and_then(|r| Ok((r.version > config.version(), r)))
169}
170
171pub async fn update_self<C: Config, R: Registry>(
174 config: &C,
175 registry: &R,
176 #[cfg(feature = "progress")] multi_progress: Option<MultiProgress>,
177 #[cfg(feature = "progress")] progress_style: Option<ProgressStyle>,
178) -> Result<()> {
179 let result = check_version(config, registry).await?;
180 let remote_version = result.1;
181 if result.0 {
182 let remote_path = remote_version.path;
183 let remote_file_path = PathBuf::from(&remote_path);
184 let filename = remote_file_path
185 .file_name()
186 .ok_or(Error::IoError(std::io::Error::from(ErrorKind::NotFound)))?;
187 let target_path = temp_dir().join(filename);
188 let remote_path = registry.get_base_url().join(&remote_path.as_str())?;
189 let client = reqwest::ClientBuilder::default().build().unwrap();
190
191 #[cfg(feature = "progress")]
192 let _ = impls::download_file(&client, &remote_path, &target_path, registry, multi_progress, progress_style).await?;
193
194 #[cfg(not(feature = "progress"))]
195 let _ = impls::download_file(&client, &remote_path, &target_path, registry).await?;
196
197 let _ = impls::verify_file(&target_path, remote_version.size as u64, remote_version.checksum.clone()).await?;
198
199 let bin_name = std::env::current_exe().or(Err(Error::IoError(std::io::Error::from(ErrorKind::NotFound))))?;
200 let bin_name_path = bin_name.parent().unwrap_or(Path::new("/")).to_path_buf();
201 impls::extract(&target_path, &bin_name_path).await
202 } else {
203 Err(Error::InvalidVersionError)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use std::any::type_name_of_val;
210
211 use console::Style;
212 #[cfg(feature = "progress")]
213 use indicatif::{MultiProgress, ProgressStyle};
214 use reqwest::Url;
215 use semver::Version;
216 use tracing::level_filters::LevelFilter;
217 use tracing::subscriber;
218 use tracing_subscriber::prelude::*;
219 use tracing_subscriber::EnvFilter;
220
221 use crate::{check_version, update_self, Config, Registry};
222
223 struct LabRegistry;
224
225 struct LabConfig;
226
227 impl Registry for LabRegistry {
228 fn get_base_url(&self) -> Url {
229 Url::parse(format!("https://test.example.com/aot/").as_str()).unwrap()
230 }
231
232 fn get_update_path<C: Config>(&self, config: &C) -> String {
233 format!("{}.json", config.target())
234 }
235
236 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
237 Ok(None)
238 }
239 }
240
241 impl Config for LabConfig {
242 fn version(&self) -> Version {
243 let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
244 Version::parse(version_str.as_str()).unwrap()
245 }
246
247 fn target(&self) -> String {
248 let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
249 format!("{}", target)
250 }
251 }
252
253 struct ArtifactoryRegistry;
254 struct ArtifactoryConfig;
255
256 impl Registry for ArtifactoryRegistry {
257 fn get_base_url(&self) -> Url {
258 Url::parse("https://test.artifactory.com/local/").unwrap()
259 }
260
261 fn get_update_path<C: Config>(&self, config: &C) -> String {
262 format!("{}.json", config.target())
263 }
264
265 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
266 Ok(Some(("USER1".to_string(), Some("PASSWORD1".to_string()))))
267 }
268 }
269
270 impl Config for ArtifactoryConfig {
271 fn version(&self) -> Version {
272 let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
273 Version::parse(version_str.as_str()).unwrap()
274 }
275
276 fn target(&self) -> String {
277 let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
278 format!("{}", target)
279 }
280 }
281
282 #[tokio::test]
283 async fn test_check_version() {
284 let config = LabConfig;
285 let registry = LabRegistry;
286 let (has_update, version) = check_version(&config, ®istry).await.unwrap();
287
288 if has_update {
289 let bin_name = console::style("binary_name").cyan().italic().bold();
290 let this_version = config.version();
291 let other_version = console::style(version.version).green();
292 let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
293
294 println!(
295 "A new release of {} is available: {} → {}",
296 bin_name, this_version, other_version
297 );
298 println!("Released on {}", version.datetime);
299 println!("{}", console::style(update_url).yellow());
300 }
301 }
302
303 #[tokio::test]
304 async fn test_update_version() {
305 init_logging();
306 let current_exe = std::env::current_exe();
307 println!("current exe: {:?}", current_exe);
308
309 let config = LabConfig;
310 let registry = LabRegistry;
311
312 #[cfg(feature = "progress")]
313 let progress_style =
314 ProgressStyle::with_template("{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ")
315 .unwrap()
316 .progress_chars("=> ");
317 #[cfg(feature = "progress")]
318 let multi_progress = MultiProgress::new();
319
320 #[cfg(feature = "progress")]
321 let result = update_self(&config, ®istry, Some(multi_progress), Some(progress_style)).await;
322
323 #[cfg(not(feature = "progress"))]
324 let result = update_self(&config, ®istry).await;
325
326 match result {
327 Ok(_) => {}
328 Err(err) => {
329 println!("Error checking for version update:");
330 println!("{}", console::style(err.to_string()).red().italic());
331 }
332 }
333 }
334
335 #[tokio::test]
336 async fn test_check_version_with_basic_auth() {
337 let config = ArtifactoryConfig;
338 let registry = ArtifactoryRegistry;
339
340 let result = check_version(&config, ®istry).await;
341
342 match result {
343 Ok((has_update, version)) => {
344 if has_update {
345 let bin_name = console::style("binary_name").cyan().italic().bold();
346 let this_version = config.version();
347 let other_version = console::style(version.version).green();
348 let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
349 println!(
350 "A new release of {} is available: {} → {}",
351 bin_name, this_version, other_version
352 );
353 println!("Released on {}", version.datetime);
354 println!("{}", console::style(update_url).yellow());
355 }
356 }
357 Err(err) => {
358 println!("Error checking for version update:");
359 println!("{}", console::style(err.to_string()).red().italic());
360 println!("Error type: {:?}", type_name_of_val(&err));
361
362 match err {
363 crate::Error::ReqwestError(err) => {
364 println!("ReqwestError: {:?}", err);
365 }
366 crate::Error::UrlParseError(err) => {
367 println!("UrlParseError: {:?}", err);
368 }
369 crate::Error::InvalidVersionError => {
370 println!("InvalidVersionError");
371 }
372 crate::Error::IoError(err) => {
373 println!("IoError: {:?}", err);
374 }
375 crate::Error::InvalidContentLengthError(err) => {
376 println!("InvalidContentLengthError: {:?}", err);
377 }
378 crate::Error::InvalidFileSize(size, found) => {
379 println!("InvalidFileSize: size: {}, found: {}", size, found);
380 }
381 crate::Error::InvalidFileChecksum => {
382 println!("InvalidFileChecksum");
383 }
384 crate::Error::InvalidCredentialsError(err) => {
385 println!("InvalidCredentialsError: {:?}", err);
386 }
387 }
388 }
389 }
390 }
391
392 #[tokio::test]
393 async fn test_update_version_with_basic_auth() {
394 init_logging();
395 let current_exe = std::env::current_exe();
396 println!("current exe: {:?}", current_exe);
397
398 let config = ArtifactoryConfig;
399 let registry = ArtifactoryRegistry;
400
401 #[cfg(feature = "progress")]
402 let progress_style =
403 ProgressStyle::with_template("{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ")
404 .unwrap()
405 .progress_chars("=> ");
406 #[cfg(feature = "progress")]
407 let multi_progress = MultiProgress::new();
408
409 #[cfg(feature = "progress")]
410 let result = update_self(&config, ®istry, Some(multi_progress), Some(progress_style)).await;
411
412 #[cfg(not(feature = "progress"))]
413 let result = update_self(&config, ®istry).await;
414
415 match result {
416 Ok(_) => {}
417 Err(err) => {
418 println!("Error checking for version update:");
419 println!("{}", console::style(err.to_string()).red().italic());
420 }
421 }
422 }
423
424 fn init_logging() {
425 let registry = tracing_subscriber::Registry::default();
426
427 let term_subscriber = logging_subscriber::LoggingSubscriberBuilder::default()
428 .with_time(true)
429 .with_level(true)
430 .with_target(true)
431 .with_file(false)
432 .with_line_number(false)
433 .with_min_level(LevelFilter::TRACE)
434 .with_format_level(logging_subscriber::LevelOutput::Long)
435 .with_default_style(Style::default().dim())
436 .with_level_style_warn(Style::new().color256(220).bold())
437 .with_level_style_trace(Style::new().magenta().bold())
438 .with_date_time_style(Style::new().white())
439 .build();
440
441 let filter = EnvFilter::builder()
442 .with_default_directive(LevelFilter::TRACE.into())
443 .from_env()
444 .unwrap()
445 .add_directive("hyper::proto=warn".parse().unwrap())
446 .add_directive("hyper::client=warn".parse().unwrap());
447
448 let subscriber = registry.with(filter).with(term_subscriber);
449 subscriber::set_global_default(subscriber).unwrap();
450 }
451}