#[cfg(test)]
#[path = "tests/client.rs"]
mod tests;
use crate::responses::{LatestVersionResponse, VersionHashResponse};
use core::{
fmt::{Display, self},
str::FromStr,
sync::atomic::{AtomicUsize, Ordering},
};
use ed25519_dalek::Signature;
use flume::{Sender, self};
use futures_util::StreamExt as _;
use hex;
use parking_lot::RwLock;
use reqwest::{
StatusCode,
Url,
header::{AsHeaderName, CONTENT_LENGTH, CONTENT_TYPE},
};
use rubedo::{
crypto::{Sha256Hash, VerifyingKey},
sugar::s,
};
use semver::Version;
use serde::de::DeserializeOwned;
use sha2::{Sha256, Digest as _};
use std::{
env::args,
io::Error as IoError,
os::unix::fs::PermissionsExt as _,
path::PathBuf,
sync::Arc,
};
use tempfile::{tempdir, TempDir};
use thiserror::Error as ThisError;
use tokio::{
fs::{File as AsyncFile, self},
io::AsyncWriteExt as _,
select,
spawn,
sync::broadcast::{Receiver as Listener, Sender as Broadcaster, self},
time::{Duration, interval},
};
use tracing::{debug, error, info, warn};
#[cfg(not(test))]
use ::{
reqwest::{Client, Response},
std::{
env::current_exe,
os::unix::process::CommandExt as _,
process::{Command, Stdio, exit},
},
};
#[cfg(test)]
use crate::mocks::std_env::mock_current_exe as current_exe;
#[cfg(test)]
use sham::{
reqwest::{MockClient as Client, MockResponse as Response},
std_process::{FakeCommand as Command, MockStdio as Stdio, mock_exit as exit}
};
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum Status {
Idle,
Checking,
Downloading(Version, u8),
Installing(Version),
PendingRestart(Version),
Restarting(Version),
}
impl Display for Status {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", match *self {
Self::Idle => s!( "Idle"),
Self::Checking => s!( "Checking"),
Self::Installing(ref version) => format!("Installing: {version}"),
Self::Downloading(ref version, ref percent) => format!("Downloading: {version} ({percent}%)"),
Self::PendingRestart(ref version) => format!("Pending restart: {version}"),
Self::Restarting(ref version) => format!("Restarting: {version}"),
})
}
}
#[derive(Clone, Debug, Eq, PartialEq, ThisError)]
#[non_exhaustive]
pub enum UpdaterError {
#[error("Failed hash verification for downloaded version {0}")]
FailedHashVerification(Version),
#[error("Failed signature verification for response from {0}")]
FailedSignatureVerification(Url),
#[error("HTTP status code {1} received when calling {0}")]
HttpError(Url, StatusCode),
#[error("HTTP request to {0} failed: {1}")]
HttpRequestFailed(Url, String),
#[error("Invalid HTTP body received from {0}")]
InvalidBody(Url),
#[error("Invalid payload received from {0}")]
InvalidPayload(Url),
#[error(r#"Invalid signature header "{1}" received from {0}"#)]
InvalidSignature(Url, String),
#[error("Invalid URL specified: {0} plus {1}")]
InvalidUrl(Url, String),
#[error("HTTP response body from {0} is shorter than expected: {1} < {2}")]
MissingData(Url, usize, usize),
#[error("HTTP response from {0} does not contain a signature header")]
MissingSignature(Url),
#[error("HTTP response body from {0} is longer than expected: {1} > {2}")]
TooMuchData(Url, usize, usize),
#[error(r#"Unable to create download file "{0:?}": {1}"#)]
UnableToCreateDownload(PathBuf, String),
#[error("Unable to create temporary directory: {0}")]
UnableToCreateTempDir(String),
#[error(r#"Unable to get file metadata for the new executable "{0:?}": {1}"#)]
UnableToGetFileMetadata(PathBuf, String),
#[error("Unable to move the new executable {0:?}: {1}")]
UnableToMoveNewExe(PathBuf, String),
#[error("Unable to obtain current executable path: {0}")]
UnableToObtainCurrentExePath(String),
#[error("Unable to rename the current executable {0:?}: {1}")]
UnableToRenameCurrentExe(PathBuf, String),
#[error(r#"Unable to set file permissions for the new executable "{0:?}": {1}"#)]
UnableToSetFilePermissions(PathBuf, String),
#[error(r#"Unable to write to download file "{0:?}": {1}"#)]
UnableToWriteToDownload(PathBuf, String),
#[error(r#"HTTP response from {0} had unexpected content type: "{1}", expected: "{2}""#)]
UnexpectedContentType(Url, String, String),
}
#[expect(clippy::exhaustive_structs, reason = "Provided for configuration")]
#[derive(Clone, Debug)]
pub struct Config {
pub version: Version,
pub api: Url,
pub key: VerifyingKey,
pub check_on_startup: bool,
pub check_interval: Option<Duration>,
}
#[derive(Debug)]
pub struct Updater {
actions: AtomicUsize,
broadcast: Broadcaster<Status>,
config: Config,
exe_path: PathBuf,
http_client: Client,
queue: Sender<()>,
status: RwLock<Status>,
}
impl Updater {
#[expect(clippy::result_large_err, reason = "Doesn't matter here")]
pub fn new(config: Config) -> Result<Arc<Self>, UpdaterError> {
let http_client = Client::new();
let (sender, receiver) = flume::unbounded();
let (tx, mut rx) = broadcast::channel(1);
let updater = Arc::new(Self {
actions: AtomicUsize::new(0),
broadcast: tx,
config,
exe_path: current_exe().map_err(|err| UpdaterError::UnableToObtainCurrentExePath(err.to_string()))?,
http_client,
queue: sender,
status: RwLock::new(Status::Idle),
});
#[expect(clippy::pattern_type_mismatch, reason = "Cannot dereference here")]
drop(spawn(async move { loop { select! {
Ok(status) = rx.recv() => {
debug!("Status changed: {status}");
}
else => break,
}}}));
if updater.config.check_on_startup {
let startup_updater = Arc::clone(&updater);
drop(spawn(async move {
startup_updater.check_for_updates().await;
}));
}
if let Some(check_interval) = updater.config.check_interval {
let mut timer = interval(check_interval);
let mut first_tick = true;
let timer_updater = Arc::clone(&updater);
drop(spawn(async move { loop { select!{
_ = timer.tick() => {
if first_tick {
first_tick = false;
continue;
}
timer_updater.check_for_updates().await;
}
_ = receiver.recv_async() => {
info!("Stopping updater");
break;
}
}}}));
}
Ok(updater)
}
pub fn register_action(&self) -> Option<usize> {
match self.status() {
Status::Idle |
Status::Checking |
Status::Downloading(_, _) |
Status::Installing(_) => {},
Status::PendingRestart(_) |
Status::Restarting(_) => return None,
}
let value = self.actions
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { value.checked_add(1) })
.ok()?
;
Some(value.saturating_add(1))
}
pub fn deregister_action(&self) -> Option<usize> {
let mut value = self.actions
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { value.checked_sub(1) })
.ok()?
;
value = value.saturating_sub(1);
if let Status::PendingRestart(version) = self.status() {
if value > 0 {
info!("Pending restart: {} critical actions in progress", self.actions.load(Ordering::SeqCst));
} else {
self.set_status(Status::Restarting(version));
info!("Restarting");
self.restart();
}
}
Some(value)
}
pub fn is_safe_to_update(&self) -> bool {
self.actions.load(Ordering::SeqCst) == 0
}
pub fn status(&self) -> Status {
let lock = self.status.read(); (*lock).clone()
}
pub fn set_status(&self, status: Status) {
let mut lock = self.status.write(); *lock = status.clone();
drop(lock); if let Err(err) = self.broadcast.send(status) {
error!("Failed to broadcast status change: {err}");
}
}
pub fn subscribe(&self) -> Listener<Status> {
self.broadcast.subscribe()
}
async fn check_for_updates(&self) {
if self.status() != Status::Idle {
return;
}
self.set_status(Status::Checking);
info!("Checking for updates");
let (url, response) = match self.request("latest").await {
Ok(data) => data,
Err(err) => {
self.set_status(Status::Idle);
error!("Error checking for updates: {err}");
return;
},
};
let version = match self.decode_and_verify::<LatestVersionResponse>(url, response).await {
Ok(json) => json.version,
Err(err) => {
self.set_status(Status::Idle);
error!("Error checking for updates: {err}");
return;
},
};
if version <= self.config.version {
self.set_status(Status::Idle);
info!("The current version {} is the latest available", self.config.version);
return;
}
info!("New version {} available", version);
self.set_status(Status::Downloading(version.clone(), 0));
info!("Downloading update {version}");
let (_download_dir, update_path, file_hash) = match self.download_update(&version).await {
Ok(data) => data,
Err(err) => {
error!("Error downloading update file: {err}");
return;
},
};
info!("Update file downloaded");
info!("Verifying update {version}");
if let Err(err) = self.verify_update(&version, file_hash).await {
error!("Error verifying update file: {err}");
return;
}
info!("Update file verified");
self.set_status(Status::Installing(version.clone()));
info!("Installing update");
if let Err(err) = self.replace_executable(&update_path).await {
error!("Error installing update: {err}");
return;
}
if !self.is_safe_to_update() {
self.set_status(Status::PendingRestart(version.clone()));
info!("Pending restart: {} critical actions in progress", self.actions.load(Ordering::SeqCst));
return;
}
self.set_status(Status::Restarting(version.clone()));
info!("Restarting");
self.restart();
}
#[expect(tail_expr_drop_order, reason = "Drop order change is harmless here")]
async fn download_update(&self, version: &Version) -> Result<(TempDir, PathBuf, Sha256Hash), UpdaterError> {
let download_dir = tempdir().map_err(|err| UpdaterError::UnableToCreateTempDir(err.to_string()))?;
let update_path = download_dir.path().join(format!("update-{version}"));
let mut file = AsyncFile::create(&update_path).await.map_err(|err|
UpdaterError::UnableToCreateDownload(update_path.clone(), err.to_string())
)?;
let (url, response) = self.request(&format!("releases/{version}")).await?;
let content_type: String = get_header(&response, CONTENT_TYPE);
let content_length: usize = get_header(&response, CONTENT_LENGTH);
if content_type != "application/octet-stream" {
return Err(UpdaterError::UnexpectedContentType(url, content_type, s!("application/octet-stream")));
}
let mut response_stream = response.bytes_stream();
let mut hasher = Sha256::new();
let mut body_len = 0_usize;
while let Some(Ok(chunk)) = response_stream.next().await {
file.write_all(&chunk).await.map_err(|err|
UpdaterError::UnableToWriteToDownload(update_path.clone(), err.to_string())
)?;
hasher.update(&chunk);
body_len = body_len.saturating_add(chunk.len());
#[expect(clippy::cast_possible_truncation, reason = "Loss of precision is not important here")]
#[expect(clippy::cast_precision_loss, reason = "Loss of precision is not important here")]
#[expect(clippy::cast_sign_loss, reason = "Loss of sign is not important here")]
self.set_status(Status::Downloading(version.clone(), (body_len as f64 / content_length as f64 * 100.0) as u8));
}
if body_len < content_length {
return Err(UpdaterError::MissingData(url, body_len, content_length));
}
if body_len > content_length {
return Err(UpdaterError::TooMuchData(url, body_len, content_length));
}
Ok((download_dir, update_path, hasher.finalize().into()))
}
async fn verify_update(&self, version: &Version, hash: Sha256Hash) -> Result<(), UpdaterError> {
let (url, response) = self.request(&format!("hashes/{version}")).await?;
match self.decode_and_verify::<VersionHashResponse>(url.clone(), response).await {
Ok(json) => {
if json.version != *version {
return Err(UpdaterError::InvalidPayload(url));
}
if json.hash != hash {
return Err(UpdaterError::FailedHashVerification(version.clone()));
}
Ok(())
},
Err(err) => Err(err),
}
}
async fn request(&self, endpoint: &str) -> Result<(Url, Response), UpdaterError> {
let Ok(url) = self.config.api.join(endpoint) else {
return Err(UpdaterError::InvalidUrl(self.config.api.clone(), endpoint.to_owned()));
};
let response = self.http_client.get(url.clone()).send().await.map_err(|err|
UpdaterError::HttpRequestFailed(url.clone(), err.to_string())
)?;
let status = response.status();
if !status.is_success() {
return Err(UpdaterError::HttpError(url, status));
}
Ok((url, response))
}
async fn decode_and_verify<T: DeserializeOwned>(&self, url: Url, response: Response) -> Result<T, UpdaterError> {
let content_type: String = get_header(&response, CONTENT_TYPE);
let content_length: usize = get_header(&response, CONTENT_LENGTH);
let signature: String = get_header(&response, "x-signature");
let Ok(body) = response.text().await else {
return Err(UpdaterError::InvalidBody(url))
};
if content_type != "application/json" {
return Err(UpdaterError::UnexpectedContentType(url, content_type, s!("application/json")));
}
if body.len() < content_length {
return Err(UpdaterError::MissingData(url, body.len(), content_length));
}
if body.len() > content_length {
return Err(UpdaterError::TooMuchData(url, body.len(), content_length));
}
if signature.is_empty() {
return Err(UpdaterError::MissingSignature(url));
}
let Ok(signature_bytes) = hex::decode(&signature) else {
return Err(UpdaterError::InvalidSignature(url, signature))
};
let signature_array: &[u8; 64] = signature_bytes.as_slice().try_into().map_err(|_err|
UpdaterError::InvalidSignature(url.clone(), signature)
)?;
if self.config.key.verify_strict(body.as_bytes(), &Signature::from_bytes(signature_array)).is_err() {
return Err(UpdaterError::FailedSignatureVerification(url));
}
let Ok(parsed) = serde_json::from_str::<T>(&body) else {
return Err(UpdaterError::InvalidPayload(url));
};
Ok(parsed)
}
async fn replace_executable(&self, update_path: &PathBuf) -> Result<(), UpdaterError> {
let current_path = self.exe_path.clone();
let backup_path = current_path.with_extension("old");
let move_error = |err: IoError| -> UpdaterError {
UpdaterError::UnableToMoveNewExe(update_path.clone(), err.to_string())
};
fs::rename(¤t_path, &backup_path).await.map_err(|err|
UpdaterError::UnableToRenameCurrentExe(current_path.clone(), err.to_string())
)?;
if let Err(err) = fs::rename(&update_path, ¤t_path).await {
if err.raw_os_error() != Some(18_i32) {
return Err(move_error(err));
}
let _size = fs::copy(&update_path, ¤t_path).await.map_err(move_error)?;
if let Err(err2) = fs::remove_file(&update_path).await {
warn!("Failed to delete temporary update file {update_path:?}: {err2}");
}
}
let mut permissions = fs::metadata(¤t_path).await.map_err(|err|
UpdaterError::UnableToGetFileMetadata(current_path.clone(), err.to_string())
)?.permissions();
permissions.set_mode(permissions.mode() | 0o111);
fs::set_permissions(¤t_path, permissions).await.map_err(|err|
UpdaterError::UnableToSetFilePermissions(current_path.clone(), err.to_string())
)?;
Ok(())
}
fn restart(&self) {
let args = args().skip(1).collect::<Vec<_>>();
let err = Command::new(self.exe_path.clone())
.args(args)
.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.exec()
;
error!("Failed to restart application: {err}");
exit(1);
}
}
impl Drop for Updater {
fn drop(&mut self) {
let _ignored = self.queue.send(());
}
}
fn get_header<K, T>(response: &Response, header: K) -> T
where
K: AsHeaderName,
T: Default + FromStr
{
response.headers()
.get(header)
.and_then(|h| h.to_str().ok())
.and_then(|s| T::from_str(s).ok())
.unwrap_or_default()
}