1use std::env::temp_dir;
2use std::io::ErrorKind;
3use std::path::{Path, PathBuf};
4
5use chrono::{DateTime, Utc};
6#[cfg(feature = "progress")]
7use indicatif::style::TemplateError;
8#[cfg(feature = "progress")]
9use indicatif::{MultiProgress, ProgressStyle};
10use reqwest::Url;
11use semver::Version;
12use thiserror::Error;
13
14mod impls;
15
16#[derive(Error, Debug)]
17pub enum Error {
18 #[error(transparent)]
19 ReqwestError(#[from] reqwest::Error),
20
21 #[error(transparent)]
22 UrlParseError(#[from] url::ParseError),
23
24 #[error("Invalid version")]
25 InvalidVersionError,
26
27 #[error(transparent)]
28 IoError(#[from] std::io::Error),
29
30 #[cfg(feature = "progress")]
31 #[error(transparent)]
32 TemplateError(#[from] TemplateError),
33
34 #[error("Failed to get content length from '{0}'")]
35 InvalidContentLengthError(String),
36
37 #[error("File integrity mismatch. Expected size {0}, found {1}")]
38 InvalidFileSize(u64, u64),
39
40 #[error("File checksum failed")]
41 InvalidFileChecksum,
42
43 #[error(transparent)]
44 InvalidCredentialsError(anyhow::Error),
45}
46
47pub trait Registry {
68 fn get_base_url(&self) -> Url;
71
72 fn get_update_path<C: Config>(&self, config: &C) -> String;
74
75 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>>;
77}
78
79pub trait Config {
100 fn version(&self) -> Version;
102
103 fn target(&self) -> String;
105}
106
107#[derive(Debug, serde::Deserialize)]
121pub struct RemoteVersion {
122 #[serde(deserialize_with = "impls::value_to_version")]
124 pub version: Version,
125 pub checksum: String,
127 pub size: usize,
129 pub path: String,
131 pub datetime: DateTime<Utc>,
133}
134
135pub type Result<T> = std::result::Result<T, crate::Error>;
136
137pub async fn check_version<C: Config, R: Registry>(
166 config: &C,
167 registry: &R,
168) -> Result<(bool, RemoteVersion)> {
169 impls::fetch_remote_version(config, registry)
170 .await
171 .and_then(|r| Ok((r.version > config.version(), r)))
172}
173
174pub async fn update_self<C: Config, R: Registry>(
177 config: &C,
178 registry: &R,
179 #[cfg(feature = "progress")] multi_progress: Option<MultiProgress>,
180 #[cfg(feature = "progress")] progress_style: Option<ProgressStyle>,
181) -> Result<()> {
182 let result = check_version(config, registry).await?;
183 let remote_version = result.1;
184 if result.0 {
185 let remote_path = remote_version.path;
186 let remote_file_path = PathBuf::from(&remote_path);
187 let filename = remote_file_path
188 .file_name()
189 .ok_or(Error::IoError(std::io::Error::from(ErrorKind::NotFound)))?;
190 let target_path = temp_dir().join(filename);
191 let remote_path = registry.get_base_url().join(&remote_path.as_str())?;
192 let client = reqwest::ClientBuilder::default().build().unwrap();
193
194 #[cfg(feature = "progress")]
195 let _ = impls::download_file(
196 &client,
197 &remote_path,
198 &target_path,
199 registry,
200 multi_progress,
201 progress_style,
202 )
203 .await?;
204
205 #[cfg(not(feature = "progress"))]
206 let _ = impls::download_file(&client, &remote_path, &target_path, registry).await?;
207
208 let _ = impls::verify_file(
209 &target_path,
210 remote_version.size as u64,
211 remote_version.checksum.clone(),
212 )
213 .await?;
214
215 let bin_name = std::env::current_exe().or(Err(Error::IoError(std::io::Error::from(
216 ErrorKind::NotFound,
217 ))))?;
218 let bin_name_path = bin_name.parent().unwrap_or(Path::new("/")).to_path_buf();
219 impls::extract(&target_path, &bin_name_path).await
220 } else {
221 Err(Error::InvalidVersionError)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use std::any::type_name_of_val;
228
229 use console::Style;
230 #[cfg(feature = "progress")]
231 use indicatif::{MultiProgress, ProgressStyle};
232 use reqwest::Url;
233 use semver::Version;
234 use tracing::level_filters::LevelFilter;
235 use tracing::subscriber;
236 use tracing_subscriber::prelude::*;
237 use tracing_subscriber::EnvFilter;
238
239 use crate::{check_version, update_self, Config, Registry};
240
241 struct LabRegistry;
242
243 struct LabConfig;
244
245 impl Registry for LabRegistry {
246 fn get_base_url(&self) -> Url {
247 Url::parse(format!("https://test.example.com/aot/").as_str()).unwrap()
248 }
249
250 fn get_update_path<C: Config>(&self, config: &C) -> String {
251 format!("{}.json", config.target())
252 }
253
254 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
255 Ok(None)
256 }
257 }
258
259 impl Config for LabConfig {
260 fn version(&self) -> Version {
261 let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
262 Version::parse(version_str.as_str()).unwrap()
263 }
264
265 fn target(&self) -> String {
266 let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
267 format!("{}", target)
268 }
269 }
270
271 struct ArtifactoryRegistry;
272 struct ArtifactoryConfig;
273
274 impl Registry for ArtifactoryRegistry {
275 fn get_base_url(&self) -> Url {
276 Url::parse("https://test.artifactory.com/local/").unwrap()
277 }
278
279 fn get_update_path<C: Config>(&self, config: &C) -> String {
280 format!("{}.json", config.target())
281 }
282
283 fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
284 Ok(Some(("USER1".to_string(), Some("PASSWORD1".to_string()))))
285 }
286 }
287
288 impl Config for ArtifactoryConfig {
289 fn version(&self) -> Version {
290 let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
291 Version::parse(version_str.as_str()).unwrap()
292 }
293
294 fn target(&self) -> String {
295 let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
296 format!("{}", target)
297 }
298 }
299
300 #[tokio::test]
301 async fn test_check_version() {
302 let config = LabConfig;
303 let registry = LabRegistry;
304
305 let result = check_version(&config, ®istry).await;
306 assert!(result.is_ok());
307
308 let (has_update, version) = result.unwrap();
309
310 if has_update {
311 let bin_name = console::style("binary_name").cyan().italic().bold();
312 let this_version = config.version();
313 let other_version = console::style(version.version).green();
314 let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
315
316 println!(
317 "A new release of {} is available: {} → {}",
318 bin_name, this_version, other_version
319 );
320 println!("Released on {}", version.datetime);
321 println!("{}", console::style(update_url).yellow());
322 }
323 }
324
325 #[tokio::test]
326 async fn test_update_version() {
327 init_logging();
328 let current_exe = std::env::current_exe();
329 println!("current exe: {:?}", current_exe);
330
331 let config = LabConfig;
332 let registry = LabRegistry;
333
334 #[cfg(feature = "progress")]
335 let progress_style = ProgressStyle::with_template(
336 "{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ",
337 )
338 .unwrap()
339 .progress_chars("=> ");
340 #[cfg(feature = "progress")]
341 let multi_progress = MultiProgress::new();
342
343 #[cfg(feature = "progress")]
344 let result = update_self(
345 &config,
346 ®istry,
347 Some(multi_progress),
348 Some(progress_style),
349 )
350 .await;
351
352 #[cfg(not(feature = "progress"))]
353 let result = update_self(&config, ®istry).await;
354
355 match result {
356 Ok(_) => {}
357 Err(err) => {
358 println!("Error checking for version update:");
359 println!("{}", console::style(err.to_string()).red().italic());
360 }
361 }
362 }
363
364 #[tokio::test]
365 async fn test_check_version_with_basic_auth() {
366 let config = ArtifactoryConfig;
367 let registry = ArtifactoryRegistry;
368
369 let result = check_version(&config, ®istry).await;
370
371 match result {
372 Ok((has_update, version)) => {
373 if has_update {
374 let bin_name = console::style("binary_name").cyan().italic().bold();
375 let this_version = config.version();
376 let other_version = console::style(version.version).green();
377 let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
378 println!(
379 "A new release of {} is available: {} → {}",
380 bin_name, this_version, other_version
381 );
382 println!("Released on {}", version.datetime);
383 println!("{}", console::style(update_url).yellow());
384 }
385 }
386 Err(err) => {
387 println!("Error checking for version update:");
388 println!("{}", console::style(err.to_string()).red().italic());
389 println!("Error type: {:?}", type_name_of_val(&err));
390
391 match err {
392 crate::Error::ReqwestError(err) => {
393 println!("ReqwestError: {:?}", err);
394 }
395 crate::Error::UrlParseError(err) => {
396 println!("UrlParseError: {:?}", err);
397 }
398 crate::Error::InvalidVersionError => {
399 println!("InvalidVersionError");
400 }
401 crate::Error::IoError(err) => {
402 println!("IoError: {:?}", err);
403 }
404 crate::Error::InvalidContentLengthError(err) => {
405 println!("InvalidContentLengthError: {:?}", err);
406 }
407 crate::Error::InvalidFileSize(size, found) => {
408 println!("InvalidFileSize: size: {}, found: {}", size, found);
409 }
410 crate::Error::InvalidFileChecksum => {
411 println!("InvalidFileChecksum");
412 }
413 crate::Error::InvalidCredentialsError(err) => {
414 println!("InvalidCredentialsError: {:?}", err);
415 }
416 }
417 }
418 }
419 }
420
421 #[tokio::test]
422 async fn test_update_version_with_basic_auth() {
423 init_logging();
424 let current_exe = std::env::current_exe();
425 println!("current exe: {:?}", current_exe);
426
427 let config = ArtifactoryConfig;
428 let registry = ArtifactoryRegistry;
429
430 #[cfg(feature = "progress")]
431 let progress_style = ProgressStyle::with_template(
432 "{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ",
433 )
434 .unwrap()
435 .progress_chars("=> ");
436 #[cfg(feature = "progress")]
437 let multi_progress = MultiProgress::new();
438
439 #[cfg(feature = "progress")]
440 let result = update_self(
441 &config,
442 ®istry,
443 Some(multi_progress),
444 Some(progress_style),
445 )
446 .await;
447
448 #[cfg(not(feature = "progress"))]
449 let result = update_self(&config, ®istry).await;
450
451 match result {
452 Ok(_) => {}
453 Err(err) => {
454 println!("Error checking for version update:");
455 println!("{}", console::style(err.to_string()).red().italic());
456 }
457 }
458 }
459
460 fn init_logging() {
461 let registry = tracing_subscriber::Registry::default();
462
463 let term_subscriber = logging_subscriber::LoggingSubscriberBuilder::default()
464 .with_time(true)
465 .with_level(true)
466 .with_target(true)
467 .with_file(false)
468 .with_line_number(false)
469 .with_min_level(LevelFilter::TRACE)
470 .with_format_level(logging_subscriber::LevelOutput::Long)
471 .with_default_style(Style::default().dim())
472 .with_level_style_warn(Style::new().color256(220).bold())
473 .with_level_style_trace(Style::new().magenta().bold())
474 .with_date_time_style(Style::new().white())
475 .build();
476
477 let filter = EnvFilter::builder()
478 .with_default_directive(LevelFilter::TRACE.into())
479 .from_env()
480 .unwrap()
481 .add_directive("hyper::proto=warn".parse().unwrap())
482 .add_directive("hyper::client=warn".parse().unwrap());
483
484 let subscriber = registry.with(filter).with(term_subscriber);
485 subscriber::set_global_default(subscriber).unwrap();
486 }
487}