cli_autoupdate/
lib.rs

1use std::env::temp_dir;
2use std::io::ErrorKind;
3use std::path::{Path, PathBuf};
4
5use chrono::{DateTime, Utc};
6#[cfg(feature = "progress")]
7use indicatif::style::TemplateError;
8#[cfg(feature = "progress")]
9use indicatif::{MultiProgress, ProgressStyle};
10use reqwest::Url;
11use semver::Version;
12use thiserror::Error;
13
14mod impls;
15
16#[derive(Error, Debug)]
17pub enum Error {
18	#[error(transparent)]
19	ReqwestError(#[from] reqwest::Error),
20
21	#[error(transparent)]
22	UrlParseError(#[from] url::ParseError),
23
24	#[error("Invalid version")]
25	InvalidVersionError,
26
27	#[error(transparent)]
28	IoError(#[from] std::io::Error),
29
30	#[cfg(feature = "progress")]
31	#[error(transparent)]
32	TemplateError(#[from] TemplateError),
33
34	#[error("Failed to get content length from '{0}'")]
35	InvalidContentLengthError(String),
36
37	#[error("File integrity mismatch. Expected size {0}, found {1}")]
38	InvalidFileSize(u64, u64),
39
40	#[error("File checksum failed")]
41	InvalidFileChecksum,
42
43	#[error(transparent)]
44	InvalidCredentialsError(anyhow::Error),
45}
46
47/// This trait is responsible to gives the base url for and the final path
48/// of the update json package description
49///
50/// Eg:
51/// ```rust
52///     use cli_autoupdate::{Config, Registry};
53
54/// 	struct MyRegistry
55///     impl Registry for MyRegistry {
56/// 		fn get_base_url(&self) -> Url {
57/// 			Url::parse("https://registry.example/").unwrap()
58///    		}
59/// 		fn get_update_path<C: Config>(&self, config: &C) -> String {
60///    			format!("{}.json", config.target())
61///    		}
62/// 		fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
63/// 			Ok(None)
64/// 		}
65/// 	}
66/// ```
67pub trait Registry {
68	/// base url of the repository. This is also used together with the `RemoteVersion::path`
69	/// to create the final download url.
70	fn get_base_url(&self) -> Url;
71
72	/// This is the relative update json path, according to the configuration path.
73	fn get_update_path<C: Config>(&self, config: &C) -> String;
74
75	// Basic HTTP authentication (username, password)
76	fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>>;
77}
78
79/// Configuration implementation for the current package
80///
81/// Eg:
82/// ```rust
83///     use cli_autoupdate::Config;
84///    	use semver::Version;
85///     struct MyConfig;
86///
87///    	impl Config for MyConfig {
88/// 		fn version(&self) -> Version {
89/// 			let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
90/// 			Version::parse(version_str.as_str()).unwrap()
91/// 		}
92///
93///     fn target(&self) -> String {
94///         let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
95///         format!("{}", target)
96///     }
97/// }
98/// ```
99pub trait Config {
100	/// Should return the current package version
101	fn version(&self) -> Version;
102
103	/// Returns the package specific path, used by the registry
104	fn target(&self) -> String;
105}
106
107/// Deserialized object of the remote json file, used to check for updates
108/// Example:
109///
110/// ```json
111/// 	{
112/// 		"version": "2.0.0",
113/// 		"datetime": "2024-01-14T14:40:43+0100",
114/// 		"checksum": "726b934c8263868090490cf626b25770bbbbc98900689af512eddf9c33e9f785",
115/// 		"size": 5538631,
116/// 		"path": "./aarch64-apple-darwin/2.0.0/my_binary.tgz"
117/// 	}
118///    ```
119///
120#[derive(Debug, serde::Deserialize)]
121pub struct RemoteVersion {
122	/// version of the remote update file
123	#[serde(deserialize_with = "impls::value_to_version")]
124	pub version: Version,
125	/// the SHA-256 checksum of the file "path"
126	pub checksum: String,
127	/// the size in bytes of the file "path"
128	pub size: usize,
129	/// update path, relative to the repository base url (Repository::get_base_url)
130	pub path: String,
131	/// file published datetime
132	pub datetime: DateTime<Utc>,
133}
134
135pub type Result<T> = std::result::Result<T, crate::Error>;
136
137/// Check if there's a new version available
138/// Example:
139/// ```rust
140///
141/// 	use cli_autoupdate::check_version;
142///    	struct LabConfig;
143/// 	struct LabRegistry;
144///
145/// 	#[tokio::main]
146/// 	async fn main() {
147/// 		let config = LabConfig;
148///    		let registry = LabRegistry;
149///    		let (has_update, version) = check_version(&config, &registry).await.unwrap();
150///    		if has_update {
151///    			let bin_name = console::style("binary_name").cyan().italic().bold();
152///    			let this_version = config.version();
153///    			let other_version = console::style(version.version).green();
154///    			let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
155///
156///    			println!(
157///    				"A new release of {} is available: {} → {}",
158///    				bin_name, this_version, other_version
159///  			);
160///    			println!("Released on {}", version.datetime);
161///    			println!("{}", console::style(update_url).yellow());
162/// 		}
163///    	}
164/// ```
165pub async fn check_version<C: Config, R: Registry>(config: &C, registry: &R) -> Result<(bool, RemoteVersion)> {
166	impls::fetch_remote_version(config, registry)
167		.await
168		.and_then(|r| Ok((r.version > config.version(), r)))
169}
170
171///    Check for updates and auto-update the current binary, if a new version is available
172///
173pub async fn update_self<C: Config, R: Registry>(
174	config: &C,
175	registry: &R,
176	#[cfg(feature = "progress")] multi_progress: Option<MultiProgress>,
177	#[cfg(feature = "progress")] progress_style: Option<ProgressStyle>,
178) -> Result<()> {
179	let result = check_version(config, registry).await?;
180	let remote_version = result.1;
181	if result.0 {
182		let remote_path = remote_version.path;
183		let remote_file_path = PathBuf::from(&remote_path);
184		let filename = remote_file_path
185			.file_name()
186			.ok_or(Error::IoError(std::io::Error::from(ErrorKind::NotFound)))?;
187		let target_path = temp_dir().join(filename);
188		let remote_path = registry.get_base_url().join(&remote_path.as_str())?;
189		let client = reqwest::ClientBuilder::default().build().unwrap();
190
191		#[cfg(feature = "progress")]
192		let _ = impls::download_file(&client, &remote_path, &target_path, registry, multi_progress, progress_style).await?;
193
194		#[cfg(not(feature = "progress"))]
195		let _ = impls::download_file(&client, &remote_path, &target_path, registry).await?;
196
197		let _ = impls::verify_file(&target_path, remote_version.size as u64, remote_version.checksum.clone()).await?;
198
199		let bin_name = std::env::current_exe().or(Err(Error::IoError(std::io::Error::from(ErrorKind::NotFound))))?;
200		let bin_name_path = bin_name.parent().unwrap_or(Path::new("/")).to_path_buf();
201		impls::extract(&target_path, &bin_name_path).await
202	} else {
203		Err(Error::InvalidVersionError)
204	}
205}
206
207#[cfg(test)]
208mod tests {
209	use std::any::type_name_of_val;
210
211	use console::Style;
212	#[cfg(feature = "progress")]
213	use indicatif::{MultiProgress, ProgressStyle};
214	use reqwest::Url;
215	use semver::Version;
216	use tracing::level_filters::LevelFilter;
217	use tracing::subscriber;
218	use tracing_subscriber::prelude::*;
219	use tracing_subscriber::EnvFilter;
220
221	use crate::{check_version, update_self, Config, Registry};
222
223	struct LabRegistry;
224
225	struct LabConfig;
226
227	impl Registry for LabRegistry {
228		fn get_base_url(&self) -> Url {
229			Url::parse(format!("https://test.example.com/aot/").as_str()).unwrap()
230		}
231
232		fn get_update_path<C: Config>(&self, config: &C) -> String {
233			format!("{}.json", config.target())
234		}
235
236		fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
237			Ok(None)
238		}
239	}
240
241	impl Config for LabConfig {
242		fn version(&self) -> Version {
243			let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
244			Version::parse(version_str.as_str()).unwrap()
245		}
246
247		fn target(&self) -> String {
248			let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
249			format!("{}", target)
250		}
251	}
252
253	struct ArtifactoryRegistry;
254	struct ArtifactoryConfig;
255
256	impl Registry for ArtifactoryRegistry {
257		fn get_base_url(&self) -> Url {
258			Url::parse("https://test.artifactory.com/local/").unwrap()
259		}
260
261		fn get_update_path<C: Config>(&self, config: &C) -> String {
262			format!("{}.json", config.target())
263		}
264
265		fn get_basic_auth(&self) -> anyhow::Result<Option<(String, Option<String>)>> {
266			Ok(Some(("USER1".to_string(), Some("PASSWORD1".to_string()))))
267		}
268	}
269
270	impl Config for ArtifactoryConfig {
271		fn version(&self) -> Version {
272			let version_str = std::env::var("CARGO_PKG_VERSION").unwrap();
273			Version::parse(version_str.as_str()).unwrap()
274		}
275
276		fn target(&self) -> String {
277			let target = std::env::var("TARGET").unwrap_or("aarch64-apple-darwin".to_string());
278			format!("{}", target)
279		}
280	}
281
282	#[tokio::test]
283	async fn test_check_version() {
284		let config = LabConfig;
285		let registry = LabRegistry;
286		let (has_update, version) = check_version(&config, &registry).await.unwrap();
287
288		if has_update {
289			let bin_name = console::style("binary_name").cyan().italic().bold();
290			let this_version = config.version();
291			let other_version = console::style(version.version).green();
292			let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
293
294			println!(
295				"A new release of {} is available: {} → {}",
296				bin_name, this_version, other_version
297			);
298			println!("Released on {}", version.datetime);
299			println!("{}", console::style(update_url).yellow());
300		}
301	}
302
303	#[tokio::test]
304	async fn test_update_version() {
305		init_logging();
306		let current_exe = std::env::current_exe();
307		println!("current exe: {:?}", current_exe);
308
309		let config = LabConfig;
310		let registry = LabRegistry;
311
312		#[cfg(feature = "progress")]
313		let progress_style =
314			ProgressStyle::with_template("{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ")
315				.unwrap()
316				.progress_chars("=> ");
317		#[cfg(feature = "progress")]
318		let multi_progress = MultiProgress::new();
319
320		#[cfg(feature = "progress")]
321		let result = update_self(&config, &registry, Some(multi_progress), Some(progress_style)).await;
322
323		#[cfg(not(feature = "progress"))]
324		let result = update_self(&config, &registry).await;
325
326		match result {
327			Ok(_) => {}
328			Err(err) => {
329				println!("Error checking for version update:");
330				println!("{}", console::style(err.to_string()).red().italic());
331			}
332		}
333	}
334
335	#[tokio::test]
336	async fn test_check_version_with_basic_auth() {
337		let config = ArtifactoryConfig;
338		let registry = ArtifactoryRegistry;
339
340		let result = check_version(&config, &registry).await;
341
342		match result {
343			Ok((has_update, version)) => {
344				if has_update {
345					let bin_name = console::style("binary_name").cyan().italic().bold();
346					let this_version = config.version();
347					let other_version = console::style(version.version).green();
348					let update_url = registry.get_base_url().join(version.path.as_str()).unwrap();
349					println!(
350						"A new release of {} is available: {} → {}",
351						bin_name, this_version, other_version
352					);
353					println!("Released on {}", version.datetime);
354					println!("{}", console::style(update_url).yellow());
355				}
356			}
357			Err(err) => {
358				println!("Error checking for version update:");
359				println!("{}", console::style(err.to_string()).red().italic());
360				println!("Error type: {:?}", type_name_of_val(&err));
361
362				match err {
363					crate::Error::ReqwestError(err) => {
364						println!("ReqwestError: {:?}", err);
365					}
366					crate::Error::UrlParseError(err) => {
367						println!("UrlParseError: {:?}", err);
368					}
369					crate::Error::InvalidVersionError => {
370						println!("InvalidVersionError");
371					}
372					crate::Error::IoError(err) => {
373						println!("IoError: {:?}", err);
374					}
375					crate::Error::InvalidContentLengthError(err) => {
376						println!("InvalidContentLengthError: {:?}", err);
377					}
378					crate::Error::InvalidFileSize(size, found) => {
379						println!("InvalidFileSize: size: {}, found: {}", size, found);
380					}
381					crate::Error::InvalidFileChecksum => {
382						println!("InvalidFileChecksum");
383					}
384					crate::Error::InvalidCredentialsError(err) => {
385						println!("InvalidCredentialsError: {:?}", err);
386					}
387				}
388			}
389		}
390	}
391
392	#[tokio::test]
393	async fn test_update_version_with_basic_auth() {
394		init_logging();
395		let current_exe = std::env::current_exe();
396		println!("current exe: {:?}", current_exe);
397
398		let config = ArtifactoryConfig;
399		let registry = ArtifactoryRegistry;
400
401		#[cfg(feature = "progress")]
402		let progress_style =
403			ProgressStyle::with_template("{prefix:.green.bold} [{bar:40.cyan/blue.bold}] {percent:>5}% [ETA {eta}] {msg} ")
404				.unwrap()
405				.progress_chars("=> ");
406		#[cfg(feature = "progress")]
407		let multi_progress = MultiProgress::new();
408
409		#[cfg(feature = "progress")]
410		let result = update_self(&config, &registry, Some(multi_progress), Some(progress_style)).await;
411
412		#[cfg(not(feature = "progress"))]
413		let result = update_self(&config, &registry).await;
414
415		match result {
416			Ok(_) => {}
417			Err(err) => {
418				println!("Error checking for version update:");
419				println!("{}", console::style(err.to_string()).red().italic());
420			}
421		}
422	}
423
424	fn init_logging() {
425		let registry = tracing_subscriber::Registry::default();
426
427		let term_subscriber = logging_subscriber::LoggingSubscriberBuilder::default()
428			.with_time(true)
429			.with_level(true)
430			.with_target(true)
431			.with_file(false)
432			.with_line_number(false)
433			.with_min_level(LevelFilter::TRACE)
434			.with_format_level(logging_subscriber::LevelOutput::Long)
435			.with_default_style(Style::default().dim())
436			.with_level_style_warn(Style::new().color256(220).bold())
437			.with_level_style_trace(Style::new().magenta().bold())
438			.with_date_time_style(Style::new().white())
439			.build();
440
441		let filter = EnvFilter::builder()
442			.with_default_directive(LevelFilter::TRACE.into())
443			.from_env()
444			.unwrap()
445			.add_directive("hyper::proto=warn".parse().unwrap())
446			.add_directive("hyper::client=warn".parse().unwrap());
447
448		let subscriber = registry.with(filter).with(term_subscriber);
449		subscriber::set_global_default(subscriber).unwrap();
450	}
451}