#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(all(feature = "reqwest", feature = "ureq"))]
compile_error!(
"features `reqwest` and `ureq` are mutually exclusive - enable exactly one HTTP client \
(for `ureq`, set `default-features = false`)"
);
#[cfg(not(any(feature = "reqwest", feature = "ureq")))]
compile_error!(
"no HTTP client selected - enable exactly one of the `reqwest` (default) or `ureq` features"
);
#[cfg(all(feature = "default-tls", feature = "rustls"))]
compile_error!(
"features `default-tls` and `rustls` are mutually exclusive - to use `rustls`, set \
`default-features = false`"
);
#[cfg(all(feature = "async", feature = "ureq"))]
compile_error!(
"feature `async` requires the `reqwest` client and is incompatible with `ureq` - \
`ureq` has no async API"
);
pub use http;
#[cfg(feature = "reqwest")]
#[cfg_attr(docsrs, doc(cfg(feature = "reqwest")))]
pub use reqwest;
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub use update::AsyncReleaseSource;
pub use update::{
Release, ReleaseAsset, ReleaseBuilder, ReleaseSource, ReleaseStatus, ReleaseUpdate, Releases,
UpdateConfig,
};
#[cfg(feature = "ureq")]
#[cfg_attr(docsrs, doc(cfg(feature = "ureq")))]
pub use ureq;
#[cfg(feature = "signatures")]
#[cfg_attr(docsrs, doc(cfg(feature = "signatures")))]
pub use zipsign_api;
#[cfg(feature = "signatures")]
#[cfg_attr(docsrs, doc(cfg(feature = "signatures")))]
pub type VerifyingKey = [u8; zipsign_api::PUBLIC_KEY_LENGTH];
#[cfg(feature = "compression-flate2")]
use either::Either;
use indicatif::{ProgressBar, ProgressStyle};
use log::debug;
use std::cmp::min;
use std::fs;
use std::io;
use std::path;
#[macro_use]
mod macros;
pub mod backends;
#[cfg(feature = "checksums")]
mod checksum;
pub mod errors;
mod http_client;
pub mod update;
pub mod version;
pub use errors::{Error, Result};
#[cfg(feature = "checksums")]
#[cfg_attr(docsrs, doc(cfg(feature = "checksums")))]
pub use checksum::Checksum;
use http_client::{header, HttpResponse};
pub(crate) const DEFAULT_PROGRESS_TEMPLATE: &str =
"[{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({eta}) {msg}";
pub(crate) const DEFAULT_PROGRESS_CHARS: &str = "=>-";
pub fn get_target() -> &'static str {
env!("TARGET")
}
fn confirm(msg: &str) -> Result<()> {
print_flush!("{}", msg);
let mut s = String::new();
io::stdin().read_line(&mut s)?;
let s = s.trim().to_lowercase();
if !s.is_empty() && s != "y" {
return Err(Error::Aborted);
}
Ok(())
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum VersionStatus {
UpToDate(String),
Updated(String),
}
impl VersionStatus {
pub fn version(&self) -> &str {
use VersionStatus::*;
match *self {
UpToDate(ref s) => s,
Updated(ref s) => s,
}
}
pub fn is_up_to_date(&self) -> bool {
matches!(*self, VersionStatus::UpToDate(_))
}
pub fn is_updated(&self) -> bool {
matches!(*self, VersionStatus::Updated(_))
}
}
impl std::fmt::Display for VersionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
use VersionStatus::*;
match *self {
UpToDate(ref s) => write!(f, "UpToDate({})", s),
Updated(ref s) => write!(f, "Updated({})", s),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum ArchiveKind {
#[cfg(feature = "archive-tar")]
#[cfg_attr(docsrs, doc(cfg(feature = "archive-tar")))]
Tar(Option<Compression>),
Plain(Option<Compression>),
#[cfg(feature = "archive-zip")]
#[cfg_attr(docsrs, doc(cfg(feature = "archive-zip")))]
Zip,
}
impl std::fmt::Display for ArchiveKind {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
#[cfg(feature = "archive-tar")]
ArchiveKind::Tar(Some(Compression::Gz)) => write!(f, "tar.gz"),
#[cfg(feature = "archive-tar")]
ArchiveKind::Tar(None) => write!(f, "tar"),
ArchiveKind::Plain(Some(Compression::Gz)) => write!(f, "gz"),
ArchiveKind::Plain(None) => write!(f, "plain"),
#[cfg(feature = "archive-zip")]
ArchiveKind::Zip => write!(f, "zip"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum Compression {
Gz,
}
fn detect_archive(path: &path::Path) -> Result<ArchiveKind> {
let ext = path.extension();
debug!("Detecting archive type using extension: {:?}", ext);
let res = match ext {
Some(extension) if extension == std::ffi::OsStr::new("zip") => {
#[cfg(feature = "archive-zip")]
{
debug!("Detected .zip archive");
Ok(ArchiveKind::Zip)
}
#[cfg(not(feature = "archive-zip"))]
{
Err(Error::ArchiveNotEnabled("zip".to_string()))
}
}
Some(extension) if extension == std::ffi::OsStr::new("tar") => {
#[cfg(feature = "archive-tar")]
{
debug!("Detected .tar archive");
Ok(ArchiveKind::Tar(None))
}
#[cfg(not(feature = "archive-tar"))]
{
Err(Error::ArchiveNotEnabled("tar".to_string()))
}
}
Some(extension) if extension == std::ffi::OsStr::new("tgz") => {
#[cfg(feature = "archive-tar")]
{
debug!("Detected .tgz archive");
Ok(ArchiveKind::Tar(Some(Compression::Gz)))
}
#[cfg(not(feature = "archive-tar"))]
{
Err(Error::ArchiveNotEnabled("tar".to_string()))
}
}
Some(extension) if extension == std::ffi::OsStr::new("gz") => match path
.file_stem()
.map(path::Path::new)
.and_then(|f| f.extension())
{
Some(extension) if extension == std::ffi::OsStr::new("tar") => {
#[cfg(feature = "archive-tar")]
{
debug!("Detected .tar.gz archive");
Ok(ArchiveKind::Tar(Some(Compression::Gz)))
}
#[cfg(not(feature = "archive-tar"))]
{
Err(Error::ArchiveNotEnabled("tar".to_string()))
}
}
_ => Ok(ArchiveKind::Plain(Some(Compression::Gz))),
},
_ => Ok(ArchiveKind::Plain(None)),
};
debug!("Detected archive type: {:?}", res);
res
}
#[derive(Debug)]
#[non_exhaustive]
pub struct Extract<'a> {
source: &'a path::Path,
archive: Option<ArchiveKind>,
}
#[cfg(feature = "compression-flate2")]
type GetArchiveReaderResult = Either<fs::File, flate2::read::GzDecoder<fs::File>>;
#[cfg(not(feature = "compression-flate2"))]
type GetArchiveReaderResult = fs::File;
impl<'a> Extract<'a> {
pub fn from_source(source: &'a path::Path) -> Extract<'a> {
Self {
source,
archive: None,
}
}
pub fn archive(&mut self, kind: ArchiveKind) -> &mut Self {
self.archive = Some(kind);
self
}
#[allow(unused_variables)]
fn get_archive_reader(
source: fs::File,
compression: Option<Compression>,
) -> GetArchiveReaderResult {
#[cfg(feature = "compression-flate2")]
match compression {
Some(Compression::Gz) => Either::Right(flate2::read::GzDecoder::new(source)),
None => Either::Left(source),
}
#[cfg(not(feature = "compression-flate2"))]
source
}
pub fn extract_into(&self, into_dir: &path::Path) -> Result<()> {
let source = fs::File::open(self.source)?;
let archive = match self.archive {
Some(archive) => archive,
None => detect_archive(self.source)?,
};
let extract_into_plain_or_tar = |source: fs::File, compression: Option<Compression>| {
let mut reader = Self::get_archive_reader(source, compression);
match archive {
ArchiveKind::Plain(_) => {
match fs::create_dir_all(into_dir) {
Ok(_) => (),
Err(e) => {
if e.kind() != io::ErrorKind::AlreadyExists {
return Err(Error::Io(e));
}
}
}
let file_name = self
.source
.file_name()
.ok_or_else(|| Error::Update("Extractor source has no file-name".into()))?;
let mut out_path = into_dir.join(file_name);
out_path.set_extension("");
let mut out_file = fs::File::create(&out_path)?;
io::copy(&mut reader, &mut out_file)?;
}
#[cfg(feature = "archive-tar")]
ArchiveKind::Tar(_) => {
let mut archive = tar::Archive::new(reader);
archive.unpack(into_dir)?;
}
#[allow(unreachable_patterns)]
_ => unreachable!(
"detect_archive() returns in case the proper feature flag is not enabled"
),
};
Ok(())
};
match archive {
#[cfg(feature = "archive-tar")]
ArchiveKind::Plain(compression) | ArchiveKind::Tar(compression) => {
extract_into_plain_or_tar(source, compression)?;
}
#[cfg(not(feature = "archive-tar"))]
ArchiveKind::Plain(compression) => {
extract_into_plain_or_tar(source, compression)?;
}
#[cfg(feature = "archive-zip")]
ArchiveKind::Zip => {
let mut archive = zip::ZipArchive::new(source)?;
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let output_path = into_dir.join(file.name());
if let Some(parent_dir) = output_path.parent() {
if let Err(e) = fs::create_dir_all(parent_dir) {
if e.kind() != io::ErrorKind::AlreadyExists {
return Err(Error::Io(e));
}
}
}
let mut output = fs::File::create(output_path)?;
io::copy(&mut file, &mut output)?;
}
}
};
Ok(())
}
pub fn extract_file<T: AsRef<path::Path>>(
&self,
into_dir: &path::Path,
file_to_extract: T,
) -> Result<()> {
let file_to_extract = file_to_extract.as_ref();
let source = fs::File::open(self.source)?;
let archive = match self.archive {
Some(archive) => archive,
None => detect_archive(self.source)?,
};
debug!(
"Attempting to extract {:?} file from {:?}",
file_to_extract, self.source
);
let extract_file_plain_or_tar = |source: fs::File, compression: Option<Compression>| {
let mut reader = Self::get_archive_reader(source, compression);
match archive {
ArchiveKind::Plain(_) => {
debug!("Copying file directly");
match fs::create_dir_all(into_dir) {
Ok(_) => (),
Err(e) => {
if e.kind() != io::ErrorKind::AlreadyExists {
return Err(Error::Io(e));
}
}
}
let file_name = file_to_extract
.file_name()
.ok_or_else(|| Error::Update("Extractor source has no file-name".into()))?;
let out_path = into_dir.join(file_name);
let mut out_file = fs::File::create(out_path)?;
io::copy(&mut reader, &mut out_file)?;
}
#[cfg(feature = "archive-tar")]
ArchiveKind::Tar(_) => {
debug!("Extracting from tar");
let mut archive = tar::Archive::new(reader);
let mut entry = archive
.entries()?
.filter_map(|e| e.ok())
.find(|e| {
let p = e.path();
debug!("Archive path: {:?}", p);
p.ok().filter(|p| p == file_to_extract).is_some()
})
.ok_or_else(|| {
Error::Update(format!(
"Could not find the required path in the archive: {:?}",
file_to_extract
))
})?;
entry.unpack_in(into_dir)?;
}
#[allow(unreachable_patterns)]
_ => unreachable!(
"detect_archive() returns in case the proper feature flag is not enabled"
),
};
Ok(())
};
match archive {
#[cfg(feature = "archive-tar")]
ArchiveKind::Plain(compression) | ArchiveKind::Tar(compression) => {
extract_file_plain_or_tar(source, compression)?;
}
#[cfg(not(feature = "archive-tar"))]
ArchiveKind::Plain(compression) => {
extract_file_plain_or_tar(source, compression)?;
}
#[cfg(feature = "archive-zip")]
ArchiveKind::Zip => {
let mut archive = zip::ZipArchive::new(source)?;
let file_name = file_to_extract.to_str().ok_or_else(|| {
Error::Update(format!(
"cannot extract file with a non-UTF-8 path: {:?}",
file_to_extract
))
})?;
let mut file = archive.by_name(file_name)?;
let output_path = into_dir.join(file.name());
if let Some(parent_dir) = output_path.parent() {
if let Err(e) = fs::create_dir_all(parent_dir) {
if e.kind() != io::ErrorKind::AlreadyExists {
return Err(Error::Io(e));
}
}
}
let mut output = fs::File::create(output_path)?;
io::copy(&mut file, &mut output)?;
}
};
Ok(())
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct Move<'a> {
source: &'a path::Path,
temp: Option<&'a path::Path>,
}
impl<'a> Move<'a> {
pub fn from_source(source: &'a path::Path) -> Move<'a> {
Self { source, temp: None }
}
pub fn replace_using_temp(&mut self, temp: &'a path::Path) -> &mut Self {
self.temp = Some(temp);
self
}
pub fn to_dest(&self, dest: &path::Path) -> Result<()> {
match self.temp {
Some(temp) if dest.exists() => {
fs::rename(dest, temp)?;
if let Err(e) = fs::rename(self.source, dest) {
fs::rename(temp, dest)?;
return Err(Error::from(e));
}
}
_ => {
fs::rename(self.source, dest)?;
}
};
Ok(())
}
}
#[derive(Debug)]
#[must_use = "queued moves are only applied when `.commit()` is called"]
#[non_exhaustive]
pub struct MoveAll<'a> {
temp: &'a path::Path,
moves: Vec<(path::PathBuf, path::PathBuf)>,
}
impl<'a> MoveAll<'a> {
pub fn from_temp(temp: &'a path::Path) -> Self {
Self {
temp,
moves: Vec::new(),
}
}
pub fn add(
&mut self,
source: impl AsRef<path::Path>,
dest: impl AsRef<path::Path>,
) -> &mut Self {
self.moves
.push((source.as_ref().to_path_buf(), dest.as_ref().to_path_buf()));
self
}
pub fn commit(&mut self) -> Result<()> {
let moves = std::mem::take(&mut self.moves);
let mut applied: Vec<Applied> = Vec::with_capacity(moves.len());
for (i, (source, dest)) in moves.iter().enumerate() {
let stash = if dest.exists() {
let stash = self.temp.join(format!("self_update-stash-{i}"));
if let Err(e) = fs::rename(dest, &stash) {
rollback(&applied);
return Err(Error::from(e));
}
Some(stash)
} else {
None
};
if let Err(e) = fs::rename(source, dest) {
if let Some(stash) = &stash {
if let Err(restore_err) = fs::rename(stash, dest) {
log::error!(
"failed to restore {:?} from stash {:?} during rollback: {}",
dest,
stash,
restore_err
);
}
}
rollback(&applied);
return Err(Error::from(e));
}
applied.push(Applied {
dest: dest.clone(),
stash,
});
}
Ok(())
}
}
#[derive(Debug)]
struct Applied {
dest: path::PathBuf,
stash: Option<path::PathBuf>,
}
fn rollback(applied: &[Applied]) {
for entry in applied.iter().rev() {
match &entry.stash {
Some(stash) => {
if let Err(e) = fs::rename(stash, &entry.dest) {
log::error!(
"failed to restore {:?} from stash {:?} during rollback: {}",
entry.dest,
stash,
e
);
}
}
None => {
if let Err(e) = fs::remove_file(&entry.dest) {
log::error!("failed to remove {:?} during rollback: {}", entry.dest, e);
}
}
}
}
}
pub(crate) type DynProgressFn = dyn Fn(u64, Option<u64>) + Send + Sync;
#[derive(Clone)]
pub(crate) struct ProgressCallback(pub(crate) std::sync::Arc<DynProgressFn>);
impl std::fmt::Debug for ProgressCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("ProgressCallback(..)")
}
}
pub(crate) type DynVerifyFn = dyn Fn(&std::path::Path) -> bool + Send + Sync;
#[derive(Clone)]
pub(crate) struct VerifyCallback(pub(crate) std::sync::Arc<DynVerifyFn>);
impl std::fmt::Debug for VerifyCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("VerifyCallback(..)")
}
}
pub(crate) type DynAssetMatcher = dyn Fn(&[ReleaseAsset]) -> Option<ReleaseAsset> + Send + Sync;
#[derive(Clone)]
pub(crate) struct AssetMatcher(pub(crate) std::sync::Arc<DynAssetMatcher>);
impl std::fmt::Debug for AssetMatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("AssetMatcher(..)")
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct Download {
show_progress: bool,
url: String,
headers: http_client::header::HeaderMap,
progress_template: String,
progress_chars: String,
timeout: Option<std::time::Duration>,
on_progress: Option<ProgressCallback>,
client: http_client::ClientOverride,
}
impl Download {
pub fn from_url(url: &str) -> Self {
Self {
show_progress: false,
url: url.to_owned(),
headers: http_client::header::HeaderMap::new(),
progress_template: DEFAULT_PROGRESS_TEMPLATE.to_string(),
progress_chars: DEFAULT_PROGRESS_CHARS.to_string(),
timeout: None,
on_progress: None,
client: http_client::ClientOverride::default(),
}
}
pub fn show_download_progress(&mut self, b: bool) -> &mut Self {
self.show_progress = b;
self
}
pub fn timeout(&mut self, timeout: std::time::Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
pub fn progress_callback(
&mut self,
callback: impl Fn(u64, Option<u64>) + Send + Sync + 'static,
) -> &mut Self {
self.on_progress = Some(ProgressCallback(std::sync::Arc::new(callback)));
self
}
pub(crate) fn set_progress_callback_arc(
&mut self,
callback: std::sync::Arc<DynProgressFn>,
) -> &mut Self {
self.on_progress = Some(ProgressCallback(callback));
self
}
pub fn progress_style(
&mut self,
progress_template: impl Into<String>,
progress_chars: impl Into<String>,
) -> &mut Self {
self.progress_template = progress_template.into();
self.progress_chars = progress_chars.into();
self
}
pub fn replace_headers(&mut self, headers: http_client::header::HeaderMap) -> &mut Self {
self.headers = headers;
self
}
#[cfg(feature = "reqwest")]
pub fn reqwest_client(&mut self, client: ::reqwest::blocking::Client) -> &mut Self {
self.client.blocking = Some(client);
self
}
#[cfg(feature = "async")]
pub fn reqwest_async_client(&mut self, client: ::reqwest::Client) -> &mut Self {
self.client.r#async = Some(client);
self
}
#[cfg(feature = "ureq")]
pub fn ureq_agent(&mut self, agent: ::ureq::Agent) -> &mut Self {
self.client.agent = Some(agent);
self
}
pub(crate) fn set_client_override(&mut self, client: http_client::ClientOverride) -> &mut Self {
self.client = client;
self
}
pub fn header<N, V>(&mut self, name: N, value: V) -> Result<&mut Self>
where
N: ::core::convert::TryInto<http_client::header::HeaderName>,
V: ::core::convert::TryInto<http_client::header::HeaderValue>,
{
let name = name.try_into().map_err(|_| {
Error::Config("invalid HTTP header name passed to `header`".to_string())
})?;
let value = value.try_into().map_err(|_| {
Error::Config("invalid HTTP header value passed to `header`".to_string())
})?;
self.headers.insert(name, value);
Ok(self)
}
pub fn download_to<T: io::Write>(&self, mut dest: T) -> Result<()> {
use io::BufRead;
let mut headers = self.headers.clone();
if !headers.contains_key(header::USER_AGENT) {
headers.insert(
header::USER_AGENT,
"rust-reqwest/self-update"
.parse()
.expect("invalid user-agent"),
);
}
let resp = http_client::get(&self.url, headers, self.timeout, &self.client)?;
let size = resp
.headers()
.get(http_client::header::CONTENT_LENGTH)
.map(|val| {
val.to_str()
.map(|s| s.parse::<u64>().unwrap_or(0))
.unwrap_or(0)
})
.unwrap_or(0);
let total = if size == 0 { None } else { Some(size) };
let show_progress = if size == 0 { false } else { self.show_progress };
let mut src = io::BufReader::new(resp.body());
let mut downloaded: u64 = 0;
let mut bar = if show_progress {
let pb = ProgressBar::new(size);
pb.set_style(
ProgressStyle::default_bar()
.template(&self.progress_template)
.expect("set ProgressStyle template failed")
.progress_chars(&self.progress_chars),
);
Some(pb)
} else {
None
};
loop {
let n = {
let buf = src.fill_buf()?;
dest.write_all(buf)?;
buf.len()
};
if n == 0 {
break;
}
src.consume(n);
downloaded += n as u64;
if let Some(ref mut bar) = bar {
bar.set_position(min(downloaded, size));
}
if let Some(ref cb) = self.on_progress {
(cb.0)(downloaded, total);
}
}
if let Some(ref mut bar) = bar {
bar.finish_with_message("Done");
}
Ok(())
}
#[cfg(feature = "async")]
pub async fn download_to_async<T: io::Write>(&self, mut dest: T) -> Result<()> {
use futures_util::StreamExt;
let mut headers = self.headers.clone();
if !headers.contains_key(header::USER_AGENT) {
headers.insert(
header::USER_AGENT,
"rust-reqwest/self-update"
.parse()
.expect("invalid user-agent"),
);
}
let resp = http_client::get_async(&self.url, headers, self.timeout, &self.client).await?;
let size = resp
.headers()
.get(http_client::header::CONTENT_LENGTH)
.map(|val| {
val.to_str()
.map(|s| s.parse::<u64>().unwrap_or(0))
.unwrap_or(0)
})
.unwrap_or(0);
let total = if size == 0 { None } else { Some(size) };
let show_progress = if size == 0 { false } else { self.show_progress };
let mut downloaded: u64 = 0;
let mut bar = if show_progress {
let pb = ProgressBar::new(size);
pb.set_style(
ProgressStyle::default_bar()
.template(&self.progress_template)
.expect("set ProgressStyle template failed")
.progress_chars(&self.progress_chars),
);
Some(pb)
} else {
None
};
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
dest.write_all(&chunk)?;
downloaded += chunk.len() as u64;
if let Some(ref mut bar) = bar {
bar.set_position(min(downloaded, size));
}
if let Some(ref cb) = self.on_progress {
(cb.0)(downloaded, total);
}
}
if let Some(ref mut bar) = bar {
bar.finish_with_message("Done");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(dead_code, unused_mut, unused_variables)]
use super::*;
#[cfg(feature = "compression-flate2")]
use flate2::{self, write::GzEncoder};
#[allow(unused_imports)]
use std::{
fs::{self, File},
io::{self, Read, Write},
path::{Path, PathBuf},
};
#[test]
fn version_status_is_up_to_date() {
assert!(VersionStatus::UpToDate("1.2.3".to_string()).is_up_to_date());
assert!(!VersionStatus::Updated("1.2.3".to_string()).is_up_to_date());
assert!(VersionStatus::Updated("1.2.3".to_string()).is_updated());
assert!(!VersionStatus::UpToDate("1.2.3".to_string()).is_updated());
}
#[test]
fn version_status_version_accessor() {
assert_eq!(
VersionStatus::UpToDate("1.0.0".to_string()).version(),
"1.0.0"
);
assert_eq!(
VersionStatus::Updated("2.0.0".to_string()).version(),
"2.0.0"
);
}
#[test]
fn version_status_display() {
assert_eq!(
VersionStatus::UpToDate("1.0.0".to_string()).to_string(),
"UpToDate(1.0.0)"
);
assert_eq!(
VersionStatus::Updated("2.0.0".to_string()).to_string(),
"Updated(2.0.0)"
);
}
#[test]
fn archive_kind_display_is_human_readable() {
assert_eq!(ArchiveKind::Plain(None).to_string(), "plain");
assert_eq!(ArchiveKind::Plain(Some(Compression::Gz)).to_string(), "gz");
#[cfg(feature = "archive-tar")]
{
assert_eq!(ArchiveKind::Tar(None).to_string(), "tar");
assert_eq!(
ArchiveKind::Tar(Some(Compression::Gz)).to_string(),
"tar.gz"
);
}
#[cfg(feature = "archive-zip")]
assert_eq!(ArchiveKind::Zip.to_string(), "zip");
}
#[test]
fn download_header_accepts_str_name_and_value() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
dl.header("x-custom-header", "custom-value")
.expect("valid str header should be accepted");
let stored = dl
.headers
.get("x-custom-header")
.expect("header should be inserted");
assert_eq!(stored, "custom-value");
}
#[test]
fn download_header_accepts_typed_name_and_value() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
dl.header(http_client::header::ACCEPT, "application/octet-stream")
.expect("typed name + str value should be accepted");
assert_eq!(
dl.headers.get(http_client::header::ACCEPT).unwrap(),
"application/octet-stream"
);
}
#[test]
fn download_header_overwrites_on_repeated_name() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
dl.header("x-dup", "first").unwrap();
dl.header("x-dup", "second").unwrap();
assert_eq!(dl.headers.get("x-dup").unwrap(), "second");
assert_eq!(
dl.headers.get_all("x-dup").iter().count(),
1,
"a repeated header name must overwrite, not accumulate"
);
}
#[test]
fn replace_headers_wholesale_replaces_after_header_calls() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
dl.header("x-old-a", "a").unwrap();
dl.header("x-old-b", "b").unwrap();
let mut fresh = http_client::header::HeaderMap::new();
fresh.insert("x-new", "n".parse().unwrap());
dl.replace_headers(fresh);
assert!(
dl.headers.get("x-old-a").is_none(),
"replace_headers must drop previously-added headers"
);
assert!(dl.headers.get("x-old-b").is_none());
assert_eq!(dl.headers.get("x-new").unwrap(), "n");
assert_eq!(
dl.headers.len(),
1,
"replace_headers installs exactly the supplied map"
);
dl.header("x-after", "y").unwrap();
assert_eq!(dl.headers.get("x-after").unwrap(), "y");
assert_eq!(dl.headers.get("x-new").unwrap(), "n");
}
#[test]
fn download_header_rejects_invalid_value() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
let err = dl
.header("x-ok", "bad\nvalue")
.expect_err("invalid header value must be rejected");
assert!(
matches!(err, Error::Config(_)),
"expected Error::Config, got {:?}",
err
);
assert!(dl.headers.get("x-ok").is_none());
}
#[test]
fn download_header_rejects_invalid_name() {
let mut dl = Download::from_url("https://example.com/app.tar.gz");
let err = dl
.header("inva lid", "ok")
.expect_err("invalid header name must be rejected");
assert!(matches!(err, Error::Config(_)));
assert!(
dl.headers.is_empty(),
"an invalid header name must not leave a partial value inserted"
);
}
#[test]
fn detect_plain() {
assert_eq!(
ArchiveKind::Plain(None),
detect_archive(&PathBuf::from("Something.exe")).unwrap()
);
}
#[test]
fn move_all_commits_every_move() {
let dir = tempfile::tempdir().unwrap();
let temp = tempfile::tempdir().unwrap();
let src_a = dir.path().join("src_a");
let src_b = dir.path().join("src_b");
fs::write(&src_a, b"new-a").unwrap();
fs::write(&src_b, b"new-b").unwrap();
let dst_a = dir.path().join("dst_a");
let dst_b = dir.path().join("dst_b");
fs::write(&dst_a, b"old-a").unwrap();
fs::write(&dst_b, b"old-b").unwrap();
MoveAll::from_temp(temp.path())
.add(&src_a, &dst_a)
.add(&src_b, &dst_b)
.commit()
.unwrap();
assert_eq!(fs::read(&dst_a).unwrap(), b"new-a");
assert_eq!(fs::read(&dst_b).unwrap(), b"new-b");
}
#[test]
fn move_all_rolls_back_on_failure() {
let dir = tempfile::tempdir().unwrap();
let temp = tempfile::tempdir().unwrap();
let src_a = dir.path().join("src_a");
let src_b = dir.path().join("src_b");
fs::write(&src_a, b"new-a").unwrap();
fs::write(&src_b, b"new-b").unwrap();
let missing_src = dir.path().join("does_not_exist");
let dst_a = dir.path().join("dst_a");
let dst_b = dir.path().join("dst_b");
let dst_c = dir.path().join("dst_c");
fs::write(&dst_a, b"old-a").unwrap();
fs::write(&dst_b, b"old-b").unwrap();
fs::write(&dst_c, b"old-c").unwrap();
let res = MoveAll::from_temp(temp.path())
.add(&src_a, &dst_a)
.add(&src_b, &dst_b)
.add(&missing_src, &dst_c)
.commit();
assert!(res.is_err(), "a failing move must abort the transaction");
assert_eq!(
fs::read(&dst_a).unwrap(),
b"old-a",
"the first applied move must be rolled back"
);
assert_eq!(
fs::read(&dst_b).unwrap(),
b"old-b",
"the second applied move must be rolled back"
);
assert_eq!(
fs::read(&dst_c).unwrap(),
b"old-c",
"the failed move's stashed destination must be restored"
);
}
#[test]
fn move_all_installs_fresh_destinations() {
let dir = tempfile::tempdir().unwrap();
let temp = tempfile::tempdir().unwrap();
let src = dir.path().join("src");
fs::write(&src, b"fresh").unwrap();
let dst = dir.path().join("new_dst");
MoveAll::from_temp(temp.path())
.add(&src, &dst)
.commit()
.unwrap();
assert_eq!(fs::read(&dst).unwrap(), b"fresh");
}
#[test]
fn move_all_second_commit_is_a_noop() {
let dir = tempfile::tempdir().unwrap();
let temp = tempfile::tempdir().unwrap();
let src = dir.path().join("src");
fs::write(&src, b"new").unwrap();
let dst = dir.path().join("dst");
fs::write(&dst, b"old").unwrap();
let mut mover = MoveAll::from_temp(temp.path());
mover.add(&src, &dst);
mover.commit().unwrap();
assert_eq!(fs::read(&dst).unwrap(), b"new");
mover.commit().unwrap();
assert_eq!(fs::read(&dst).unwrap(), b"new");
}
#[test]
fn download_invokes_progress_callback() {
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
let body = "x".repeat(20_000);
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let served = body.clone();
std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf);
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
served.len(),
served
);
let _ = stream.write_all(resp.as_bytes());
});
let progress = Arc::new(Mutex::new(Vec::<(u64, Option<u64>)>::new()));
let sink_progress = progress.clone();
let mut out = Vec::new();
Download::from_url(&format!("http://{addr}/file"))
.progress_callback(move |downloaded, total| {
sink_progress.lock().unwrap().push((downloaded, total));
})
.download_to(&mut out)
.unwrap();
assert_eq!(out.len(), 20_000);
let calls = progress.lock().unwrap();
assert!(!calls.is_empty(), "callback should have been invoked");
assert!(calls.iter().all(|(_, total)| *total == Some(20_000)));
let mut last = 0u64;
for (downloaded, _) in calls.iter() {
assert!(*downloaded >= last);
last = *downloaded;
}
assert_eq!(calls.last().unwrap().0, 20_000);
}
#[test]
fn detect_plain_gz() {
assert_eq!(
ArchiveKind::Plain(Some(Compression::Gz)),
detect_archive(&PathBuf::from("Something.exe.gz")).unwrap()
);
}
#[cfg(not(feature = "archive-tar"))]
#[test]
#[ignore]
fn detect_tar_gz() {
println!("WARNING: Please enable 'archive-tar' feature!");
}
#[cfg(feature = "archive-tar")]
#[test]
fn detect_tar_gz() {
assert_eq!(
ArchiveKind::Tar(Some(Compression::Gz)),
detect_archive(&PathBuf::from("Something.tar.gz")).unwrap()
);
}
#[cfg(not(feature = "archive-tar"))]
#[test]
#[ignore]
fn detect_plain_tar() {
println!("WARNING: Please enable 'archive-tar' feature!");
}
#[cfg(feature = "archive-tar")]
#[test]
fn detect_plain_tar() {
assert_eq!(
ArchiveKind::Tar(None),
detect_archive(&PathBuf::from("Something.tar")).unwrap()
);
}
#[cfg(not(feature = "archive-zip"))]
#[test]
#[ignore]
fn detect_zip() {
println!("WARNING: Please enable 'archive-zip' feature!");
}
#[cfg(feature = "archive-zip")]
#[test]
fn detect_zip() {
assert_eq!(
ArchiveKind::Zip,
detect_archive(&PathBuf::from("Something.zip")).unwrap()
);
}
#[allow(dead_code)]
fn cmp_content<T: AsRef<Path>>(path: T, s: &str) {
let mut content = String::new();
let mut f = File::open(&path).unwrap();
f.read_to_string(&mut content).unwrap();
assert!(s == content);
}
#[cfg(not(feature = "compression-flate2"))]
#[test]
#[ignore]
fn unpack_plain_gzip() {
println!("WARNING: Please enable 'compression-flate2' feature!");
}
#[cfg(feature = "compression-flate2")]
#[test]
fn unpack_plain_gzip() {
let tmp_dir = tempfile::Builder::new()
.prefix("self_update_unpack_plain_gzip_src")
.tempdir()
.expect("tempdir fail");
let fp = tmp_dir.path().with_file_name("temp.gz");
let mut tmp_file = File::create(&fp).expect("temp file create fail");
let mut e = GzEncoder::new(&mut tmp_file, flate2::Compression::default());
e.write_all(b"This is a test!").expect("gz encode fail");
e.finish().expect("gz finish fail");
let out_tmp = tempfile::Builder::new()
.prefix("self_update_unpack_plain_gzip_outdir")
.tempdir()
.expect("tempdir fail");
let out_path = out_tmp.path();
Extract::from_source(&fp)
.extract_into(out_path)
.expect("extract fail");
let out_file = out_path.join("temp");
assert!(out_file.exists());
cmp_content(out_file, "This is a test!");
}
#[cfg(not(feature = "compression-flate2"))]
#[test]
#[ignore]
fn unpack_plain_gzip_double_ext() {
println!("WARNING: Please enable 'compression-flate2' feature!");
}
#[cfg(feature = "compression-flate2")]
#[test]
fn unpack_plain_gzip_double_ext() {
let tmp_dir = tempfile::Builder::new()
.prefix("self_update_unpack_plain_gzip_double_ext_src")
.tempdir()
.expect("tempdir fail");
let fp = tmp_dir.path().with_file_name("temp.txt.gz");
let mut tmp_file = File::create(&fp).expect("temp file create fail");
let mut e = GzEncoder::new(&mut tmp_file, flate2::Compression::default());
e.write_all(b"This is a test!").expect("gz encode fail");
e.finish().expect("gz finish fail");
let out_tmp = tempfile::Builder::new()
.prefix("self_update_unpack_plain_gzip_double_ext_outdir")
.tempdir()
.expect("tempdir fail");
let out_path = out_tmp.path();
Extract::from_source(&fp)
.extract_into(out_path)
.expect("extract fail");
let out_file = out_path.join("temp.txt");
assert!(out_file.exists());
cmp_content(out_file, "This is a test!");
}
#[cfg(not(all(feature = "archive-tar", feature = "compression-flate2")))]
#[test]
#[ignore]
fn unpack_tar_gzip() {
println!("WARNING: Please enable 'archive-tar compression-flate2' features!");
}
#[cfg(all(feature = "archive-tar", feature = "compression-flate2"))]
#[test]
fn unpack_tar_gzip() {
test_extract_into(
"self_update_unpack_tar_gzip_src",
"archive.tar.gz",
ArchiveKind::Tar(Some(Compression::Gz)),
);
}
#[cfg(not(feature = "compression-flate2"))]
#[test]
#[ignore]
fn unpack_file_plain_gzip() {
println!("WARNING: Please enable 'compression-flate2' feature!");
}
#[cfg(feature = "compression-flate2")]
#[test]
fn unpack_file_plain_gzip() {
let tmp_dir = tempfile::Builder::new()
.prefix("self_update_unpack_file_plain_gzip_src")
.tempdir()
.expect("tempdir fail");
let fp = tmp_dir.path().with_file_name("temp.gz");
let mut tmp_file = File::create(&fp).expect("temp file create fail");
let mut e = GzEncoder::new(&mut tmp_file, flate2::Compression::default());
e.write_all(b"This is a test!").expect("gz encode fail");
e.finish().expect("gz finish fail");
let out_tmp = tempfile::Builder::new()
.prefix("self_update_unpack_file_plain_gzip_outdir")
.tempdir()
.expect("tempdir fail");
let out_path = out_tmp.path();
Extract::from_source(&fp)
.extract_file(out_path, "renamed_file")
.expect("extract fail");
let out_file = out_path.join("renamed_file");
assert!(out_file.exists());
cmp_content(out_file, "This is a test!");
}
#[cfg(not(all(feature = "archive-tar", feature = "compression-flate2")))]
#[test]
#[ignore]
fn unpack_file_tar_gzip() {
println!("WARNING: Please enable 'archive-tar compression-flate2' features!");
}
#[cfg(all(feature = "archive-tar", feature = "compression-flate2"))]
#[test]
fn unpack_file_tar_gzip() {
test_extract_file(
"self_update_unpack_file_tar_gzip_src",
"archive.tar.gz",
ArchiveKind::Tar(Some(Compression::Gz)),
);
}
#[cfg(not(feature = "archive-zip"))]
#[test]
#[ignore]
fn unpack_zip() {
println!("WARNING: Please enable 'archive-zip' feature!");
}
#[cfg(feature = "archive-zip")]
#[test]
fn unpack_zip() {
test_extract_into(
"self_update_unpack_zip_src",
"archive.zip",
ArchiveKind::Zip,
);
}
#[cfg(not(feature = "archive-zip"))]
#[test]
#[ignore]
fn unpack_zip_file() {
println!("WARNING: Please enable 'archive-zip' feature!");
}
#[cfg(feature = "archive-zip")]
#[test]
fn unpack_zip_file() {
test_extract_file(
"self_update_unpack_zip_src",
"archive.zip",
ArchiveKind::Zip,
);
}
fn test_extract_into(tmpfile_prefix: &str, src_archive_path: &str, archive_kind: ArchiveKind) {
let tmp_dir = tempfile::Builder::new()
.prefix(tmpfile_prefix)
.tempdir()
.expect("Failed to create temp dir");
let tmp_path = tmp_dir.path();
let archive_file_path = tmp_path.join(src_archive_path);
let archive_file = File::create(&archive_file_path).expect("Failed to create archive file");
build_test_archive(archive_file, &archive_file_path, archive_kind);
let out_tmp = tempfile::Builder::new()
.prefix(&format!("{}_outdir", tmpfile_prefix))
.tempdir()
.expect("tempdir fail");
let out_path = out_tmp.path();
Extract::from_source(&archive_file_path)
.extract_into(out_path)
.expect("extract fail");
let out_file = out_path.join("temp.txt");
assert!(out_file.exists());
cmp_content(&out_file, "This is a test!");
let out_file = out_path.join("inner_archive/temp2.txt");
assert!(out_file.exists());
cmp_content(&out_file, "This is a second test!");
}
fn test_extract_file(tmpfile_prefix: &str, src_archive_path: &str, archive_kind: ArchiveKind) {
let tmp_dir = tempfile::Builder::new()
.prefix(tmpfile_prefix)
.tempdir()
.expect("Failed to create temp dir");
let tmp_path = tmp_dir.path();
let archive_file_path = tmp_path.join(src_archive_path);
let archive_file = File::create(&archive_file_path).expect("Failed to create archive file");
build_test_archive(archive_file, &archive_file_path, archive_kind);
let out_tmp = tempfile::Builder::new()
.prefix(&format!("{}_outdir", tmpfile_prefix))
.tempdir()
.expect("tempdir fail");
let out_path = out_tmp.path();
Extract::from_source(&archive_file_path)
.extract_file(out_path, "temp.txt")
.expect("extract fail");
let out_file = out_path.join("temp.txt");
assert!(out_file.exists());
cmp_content(&out_file, "This is a test!");
Extract::from_source(&archive_file_path)
.extract_file(out_path, "inner_archive/temp2.txt")
.expect("extract fail");
let out_file = out_path.join("inner_archive/temp2.txt");
assert!(out_file.exists());
cmp_content(&out_file, "This is a second test!");
}
fn build_test_archive<T: AsRef<Path>>(
mut archive_file: fs::File,
archive_file_path: T,
archive_kind: ArchiveKind,
) {
let archive_file_path = archive_file_path.as_ref();
match archive_kind {
#[cfg(all(feature = "archive-tar", feature = "compression-flate2"))]
ArchiveKind::Tar(Some(Compression::Gz)) => {
let tmp_tar_path = archive_file_path
.parent()
.expect("Missing archive file path parent")
.join("tar_contents");
let tmp_tar_inner_path = tmp_tar_path.join("inner_archive");
fs::create_dir_all(&tmp_tar_inner_path).expect("Failed to create temp tar path");
let fp = tmp_tar_path.join("temp.txt");
let mut tmp_file = File::create(fp).expect("temp file create fail");
tmp_file.write_all(b"This is a test!").unwrap();
let fp = tmp_tar_inner_path.join("temp2.txt");
let mut tmp_file = File::create(fp).expect("temp file create fail");
tmp_file.write_all(b"This is a second test!").unwrap();
let mut ar = tar::Builder::new(vec![]);
ar.append_dir_all(".", &tmp_tar_path)
.expect("tar append dir all fail");
let tar_writer = ar.into_inner().expect("failed getting tar writer");
let mut e = GzEncoder::new(&mut archive_file, flate2::Compression::default());
io::copy(&mut tar_writer.as_slice(), &mut e)
.expect("failed writing from tar archive to gz encoder");
e.finish().expect("gz finish fail");
}
#[cfg(feature = "archive-zip")]
ArchiveKind::Zip => {
let mut zip = zip::ZipWriter::new(archive_file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("temp.txt", options)
.expect("failed starting zip file");
zip.write_all(b"This is a test!")
.expect("failed writing to zip");
zip.start_file("inner_archive/temp2.txt", options)
.expect("failed starting second zip file");
zip.write_all(b"This is a second test!")
.expect("failed writing to second zip");
zip.finish().expect("failed finishing zip");
}
_ => {
unimplemented!("{:?} not handled", archive_kind);
}
}
}
}