cli_autoupdate/
lib.rs

1use std::env::temp_dir;
2use std::io::ErrorKind;
3use std::path::{Path, PathBuf};
4
5use chrono::{DateTime, Utc};
6#[cfg(feature = "progress")]
7use indicatif::style::TemplateError;
8#[cfg(feature = "progress")]
9use indicatif::{MultiProgress, ProgressStyle};
10use reqwest::Url;
11use semver::Version;
12use thiserror::Error;
13
14mod impls;
15
16#[derive(Error, Debug)]
17pub enum Error {
18    #[error(transparent)]
19    ReqwestError(#[from] reqwest::Error),
20
21    #[error(transparent)]
22    UrlParseError(#[from] url::ParseError),
23
24    #[error("Invalid version")]
25    InvalidVersionError,
26
27    #[error(transparent)]
28    IoError(#[from] std::io::Error),
29
30    #[cfg(feature = "progress")]
31    #[error(transparent)]
32    TemplateError(#[from] TemplateError),
33
34    #[error("Failed to get content length from '{0}'")]
35    InvalidContentLengthError(String),
36
37    #[error("File integrity mismatch. Expected size {0}, found {1}")]
38    InvalidFileSize(u64, u64),
39
40    #[error("File checksum failed")]
41    InvalidFileChecksum,
42
43    #[error(transparent)]
44    InvalidCredentialsError(anyhow::Error),
45}
46
47/// This trait is responsible to gives the base url for and the final path
48/// of the update json package description
49///
50/// Eg:
51/// ```rust
52///     use cli_autoupdate::{Config, Registry};
53
54/// 	struct MyRegistry
55///     impl Registry for MyRegistry {
56/// 		fn get_base_url(&self) -> Url {
57/// 			Url::parse("https://registry.example/").unwrap()
58///    		}
59/// 		fn get_update_path<C: Config>(&self, config: &C) -> String {
60///    			format!("{}.json", config.target())
61///    		}
62/// 		fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
63/// 			Ok(None)
64/// 		}
65/// 	}
66/// ```
67pub trait Registry {
68    /// base url of the repository. This is also used together with the `RemoteVersion::path`
69    /// to create the final download url.
70    fn get_base_url(&self) -> Url;
71
72    /// This is the relative update json path, according to the configuration path.
73    fn get_update_path<C: Config>(&self, config: &C) -> String;
74
75    // Basic HTTP authentication (username, password)
76    fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>>;
77}
78
79/// Configuration implementation for the current package
80///
81/// Eg:
82/// ```rust
83///     use cli_autoupdate::Config;
84///    	use semver::Version;
85///     struct MyConfig;
86///
87///    	impl Config for MyConfig {
88/// 		fn version(&self) -> Version {
89/// 			let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
90/// 			Version::parse(version_str.as_str()).unwrap()
91/// 		}
92///
93///     fn target(&self) -> String {
94///         let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
95///         format!("{}", target)
96///     }
97/// }
98/// ```
99pub trait Config {
100    /// Should return the current package version
101    fn version(&self) -> Version;
102
103    /// Returns the package specific path, used by the registry
104    fn target(&self) -> String;
105}
106
107/// Deserialized object of the remote json file, used to check for updates
108/// Example:
109///
110/// ```json
111/// 	{
112/// 		"version": "2.0.0",
113/// 		"datetime": "2024-01-14T14:40:43+0100",
114/// 		"checksum": "726b934c8263868090490cf626b25770bbbbc98900689af512eddf9c33e9f785",
115/// 		"size": 5538631,
116/// 		"path": "./aarch64-apple-darwin/2.0.0/my_binary.tgz"
117/// 	}
118///    ```
119///
120#[derive(Debug, serde::Deserialize)]
121pub struct RemoteVersion {
122    /// version of the remote update file
123    #[serde(deserialize_with = "impls::value_to_version")]
124    pub version: Version,
125    /// the SHA-256 checksum of the file "path"
126    pub checksum: String,
127    /// the size in bytes of the file "path"
128    pub size: usize,
129    /// update path, relative to the repository base url (Repository::get_base_url)
130    pub path: String,
131    /// file published datetime
132    pub datetime: DateTime<Utc>,
133}
134
135pub type Result<T> = std::result::Result<T, crate::Error>;
136
137/// Check if there's a new version available
138/// Example:
139/// ```rust
140///
141/// 	use cli_autoupdate::check_version;
142///    	struct LabConfig;
143/// 	struct LabRegistry;
144///
145/// 	#[tokio::main]
146/// 	async fn main() {
147/// 		let config = LabConfig;
148///    		let registry = LabRegistry;
149///    		let (has_update, version) = check_version(&config, &registry).await.unwrap();
150///    		if has_update {
151///    			let bin_name = console::style("binary_name").cyan().italic().bold();
152///    			let this_version = config.version();
153///    			let other_version = console::style(version.version).green();
154///    			let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
155///
156///    			println!(
157///    				"A new release of {} is available: {} → {}",
158///    				bin_name, this_version, other_version
159///  			);
160///    			println!("Released on {}", version.datetime);
161///    			println!("{}", console::style(update_url).yellow());
162/// 		}
163///    	}
164/// ```
165pub async fn check_version<C: Config, R: Registry>(
166    config: &C,
167    registry: &R,
168) -> Result<(bool, RemoteVersion)> {
169    impls::fetch_remote_version(config, registry)
170        .await
171        .and_then(|r| Ok((r.version > config.version(), r)))
172}
173
174///    Check for updates and auto-update the current binary, if a new version is available
175///
176pub async fn update_self<C: Config, R: Registry>(
177    config: &C,
178    registry: &R,
179    #[cfg(feature = "progress")] multi_progress: Option<MultiProgress>,
180    #[cfg(feature = "progress")] progress_style: Option<ProgressStyle>,
181) -> Result<()> {
182    let result = check_version(config, registry).await?;
183    let remote_version = result.1;
184    if result.0 {
185        let remote_path = remote_version.path;
186        let remote_file_path = PathBuf::from(&remote_path);
187        let filename = remote_file_path
188            .file_name()
189            .ok_or(Error::IoError(std::io::Error::from(ErrorKind::NotFound)))?;
190        let target_path = temp_dir().join(filename);
191        let remote_path = registry.get_base_url().join(&remote_path.as_str())?;
192        let client = reqwest::ClientBuilder::default().build().unwrap();
193
194        #[cfg(feature = "progress")]
195        let _ = impls::download_file(
196            &client,
197            &remote_path,
198            &target_path,
199            registry,
200            multi_progress,
201            progress_style,
202        )
203        .await?;
204
205        #[cfg(not(feature = "progress"))]
206        let _ = impls::download_file(&client, &remote_path, &target_path, registry).await?;
207
208        let _ = impls::verify_file(
209            &target_path,
210            remote_version.size as u64,
211            remote_version.checksum.clone(),
212        )
213        .await?;
214
215        let bin_name = std::env::current_exe().or(Err(Error::IoError(std::io::Error::from(
216            ErrorKind::NotFound,
217        ))))?;
218        let bin_name_path = bin_name.parent().unwrap_or(Path::new("/")).to_path_buf();
219        impls::extract(&target_path, &bin_name_path).await
220    } else {
221        Err(Error::InvalidVersionError)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use std::any::type_name_of_val;
228
229    use console::Style;
230    #[cfg(feature = "progress")]
231    use indicatif::{MultiProgress, ProgressStyle};
232    use reqwest::Url;
233    use semver::Version;
234    use tracing::level_filters::LevelFilter;
235    use tracing::subscriber;
236    use tracing_subscriber::prelude::*;
237    use tracing_subscriber::EnvFilter;
238
239    use crate::{check_version, update_self, Config, Registry};
240
241    struct LabRegistry;
242
243    struct LabConfig;
244
245    impl Registry for LabRegistry {
246        fn get_base_url(&self) -> Url {
247            Url::parse(format!("https://test.example.com/aot/").as_str()).unwrap()
248        }
249
250        fn get_update_path<C: Config>(&self, config: &C) -> String {
251            format!("{}.json", config.target())
252        }
253
254        fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
255            Ok(None)
256        }
257    }
258
259    impl Config for LabConfig {
260        fn version(&self) -> Version {
261            let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
262            Version::parse(version_str.as_str()).unwrap()
263        }
264
265        fn target(&self) -> String {
266            let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
267            format!("{}", target)
268        }
269    }
270
271    struct ArtifactoryRegistry;
272    struct ArtifactoryConfig;
273
274    impl Registry for ArtifactoryRegistry {
275        fn get_base_url(&self) -> Url {
276            Url::parse("https://test.artifactory.com/local/").unwrap()
277        }
278
279        fn get_update_path<C: Config>(&self, config: &C) -> String {
280            format!("{}.json", config.target())
281        }
282
283        fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
284            Ok(Some(("USER1".to_string(), Some("PASSWORD1".to_string()))))
285        }
286    }
287
288    impl Config for ArtifactoryConfig {
289        fn version(&self) -> Version {
290            let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
291            Version::parse(version_str.as_str()).unwrap()
292        }
293
294        fn target(&self) -> String {
295            let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
296            format!("{}", target)
297        }
298    }
299
300    #[tokio::test]
301    async fn test_check_version() {
302        let config = LabConfig;
303        let registry = LabRegistry;
304
305        let result = check_version(&config, &registry).await;
306        assert!(result.is_ok());
307
308        let (has_update, version) = result.unwrap();
309
310        if has_update {
311            let bin_name = console::style("binary_name").cyan().italic().bold();
312            let this_version = config.version();
313            let other_version = console::style(version.version).green();
314            let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
315
316            println!(
317                "A new release of {} is available: {} → {}",
318                bin_name, this_version, other_version
319            );
320            println!("Released on {}", version.datetime);
321            println!("{}", console::style(update_url).yellow());
322        }
323    }
324
325    #[tokio::test]
326    async fn test_update_version() {
327        init_logging();
328        let current_exe = std::env::current_exe();
329        println!("current exe: {:?}", current_exe);
330
331        let config = LabConfig;
332        let registry = LabRegistry;
333
334        #[cfg(feature = "progress")]
335        let progress_style = ProgressStyle::with_template(
336            "{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ",
337        )
338        .unwrap()
339        .progress_chars("=> ");
340        #[cfg(feature = "progress")]
341        let multi_progress = MultiProgress::new();
342
343        #[cfg(feature = "progress")]
344        let result = update_self(
345            &config,
346            &registry,
347            Some(multi_progress),
348            Some(progress_style),
349        )
350        .await;
351
352        #[cfg(not(feature = "progress"))]
353        let result = update_self(&config, &registry).await;
354
355        match result {
356            Ok(_) => {}
357            Err(err) => {
358                println!("Error checking for version update:");
359                println!("{}", console::style(err.to_string()).red().italic());
360            }
361        }
362    }
363
364    #[tokio::test]
365    async fn test_check_version_with_basic_auth() {
366        let config = ArtifactoryConfig;
367        let registry = ArtifactoryRegistry;
368
369        let result = check_version(&config, &registry).await;
370
371        match result {
372            Ok((has_update, version)) => {
373                if has_update {
374                    let bin_name = console::style("binary_name").cyan().italic().bold();
375                    let this_version = config.version();
376                    let other_version = console::style(version.version).green();
377                    let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
378                    println!(
379                        "A new release of {} is available: {} → {}",
380                        bin_name, this_version, other_version
381                    );
382                    println!("Released on {}", version.datetime);
383                    println!("{}", console::style(update_url).yellow());
384                }
385            }
386            Err(err) => {
387                println!("Error checking for version update:");
388                println!("{}", console::style(err.to_string()).red().italic());
389                println!("Error type: {:?}", type_name_of_val(&err));
390
391                match err {
392                    crate::Error::ReqwestError(err) => {
393                        println!("ReqwestError: {:?}", err);
394                    }
395                    crate::Error::UrlParseError(err) => {
396                        println!("UrlParseError: {:?}", err);
397                    }
398                    crate::Error::InvalidVersionError => {
399                        println!("InvalidVersionError");
400                    }
401                    crate::Error::IoError(err) => {
402                        println!("IoError: {:?}", err);
403                    }
404                    crate::Error::InvalidContentLengthError(err) => {
405                        println!("InvalidContentLengthError: {:?}", err);
406                    }
407                    crate::Error::InvalidFileSize(size, found) => {
408                        println!("InvalidFileSize: size: {}, found: {}", size, found);
409                    }
410                    crate::Error::InvalidFileChecksum => {
411                        println!("InvalidFileChecksum");
412                    }
413                    crate::Error::InvalidCredentialsError(err) => {
414                        println!("InvalidCredentialsError: {:?}", err);
415                    }
416                }
417            }
418        }
419    }
420
421    #[tokio::test]
422    async fn test_update_version_with_basic_auth() {
423        init_logging();
424        let current_exe = std::env::current_exe();
425        println!("current exe: {:?}", current_exe);
426
427        let config = ArtifactoryConfig;
428        let registry = ArtifactoryRegistry;
429
430        #[cfg(feature = "progress")]
431        let progress_style = ProgressStyle::with_template(
432            "{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ",
433        )
434        .unwrap()
435        .progress_chars("=> ");
436        #[cfg(feature = "progress")]
437        let multi_progress = MultiProgress::new();
438
439        #[cfg(feature = "progress")]
440        let result = update_self(
441            &config,
442            &registry,
443            Some(multi_progress),
444            Some(progress_style),
445        )
446        .await;
447
448        #[cfg(not(feature = "progress"))]
449        let result = update_self(&config, &registry).await;
450
451        match result {
452            Ok(_) => {}
453            Err(err) => {
454                println!("Error checking for version update:");
455                println!("{}", console::style(err.to_string()).red().italic());
456            }
457        }
458    }
459
460    fn init_logging() {
461        let registry = tracing_subscriber::Registry::default();
462
463        let term_subscriber = logging_subscriber::LoggingSubscriberBuilder::default()
464            .with_time(true)
465            .with_level(true)
466            .with_target(true)
467            .with_file(false)
468            .with_line_number(false)
469            .with_min_level(LevelFilter::TRACE)
470            .with_format_level(logging_subscriber::LevelOutput::Long)
471            .with_default_style(Style::default().dim())
472            .with_level_style_warn(Style::new().color256(220).bold())
473            .with_level_style_trace(Style::new().magenta().bold())
474            .with_date_time_style(Style::new().white())
475            .build();
476
477        let filter = EnvFilter::builder()
478            .with_default_directive(LevelFilter::TRACE.into())
479            .from_env()
480            .unwrap()
481            .add_directive("hyper::proto=warn".parse().unwrap())
482            .add_directive("hyper::client=warn".parse().unwrap());
483
484        let subscriber = registry.with(filter).with(term_subscriber);
485        subscriber::set_global_default(subscriber).unwrap();
486    }
487}