mod apply;
mod binary_cache;
mod cache_dir;
mod monitor;
mod release_cache;
mod rollout;
mod signature;
pub use apply::{AutoApplyUpgrader, RESTART_EXIT_CODE};
pub use binary_cache::BinaryCache;
pub use cache_dir::upgrade_cache_dir;
pub use monitor::{find_platform_asset, version_from_tag, Asset, GitHubRelease, UpgradeMonitor};
pub use release_cache::ReleaseCache;
pub use rollout::StagedRollout;
pub use signature::{
verify_binary_signature, verify_binary_signature_with_key, verify_from_file,
verify_from_file_with_key, PUBLIC_KEY_SIZE, SIGNATURE_SIZE, SIGNING_CONTEXT,
};
use crate::error::{Error, Result};
use crate::logging::{debug, info, warn};
use semver::Version;
use std::fs;
use std::path::Path;
const MAX_BINARY_SIZE_BYTES: usize = 200 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct UpgradeInfo {
pub version: Version,
pub download_url: String,
pub signature_url: String,
pub release_notes: String,
}
#[derive(Debug)]
pub enum UpgradeResult {
Success {
version: Version,
exit_code: i32,
},
RolledBack {
reason: String,
},
NoUpgrade,
}
pub struct Upgrader {
current_version: Version,
client: reqwest::Client,
}
impl Upgrader {
#[must_use]
pub fn new() -> Self {
let current_version =
Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or_else(|_| Version::new(0, 0, 0));
Self {
current_version,
client: reqwest::Client::new(),
}
}
#[cfg(test)]
#[must_use]
pub fn with_version(version: Version) -> Self {
Self {
current_version: version,
client: reqwest::Client::new(),
}
}
#[must_use]
pub fn current_version(&self) -> &Version {
&self.current_version
}
pub fn validate_upgrade(&self, info: &UpgradeInfo) -> Result<()> {
if info.version <= self.current_version {
return Err(Error::Upgrade(format!(
"Cannot downgrade from {} to {}",
self.current_version, info.version
)));
}
Ok(())
}
pub fn create_backup(&self, current: &Path, rollback_dir: &Path) -> Result<()> {
let filename = current
.file_name()
.ok_or_else(|| Error::Upgrade("Invalid binary path".to_string()))?;
let backup_path = rollback_dir.join(format!("{}.backup", filename.to_string_lossy()));
debug!("Creating backup at: {}", backup_path.display());
fs::copy(current, &backup_path)?;
Ok(())
}
pub fn restore_from_backup(&self, current: &Path, rollback_dir: &Path) -> Result<()> {
let filename = current
.file_name()
.ok_or_else(|| Error::Upgrade("Invalid binary path".to_string()))?;
let backup_path = rollback_dir.join(format!("{}.backup", filename.to_string_lossy()));
if !backup_path.exists() {
return Err(Error::Upgrade("No backup found for rollback".to_string()));
}
info!("Restoring from backup: {}", backup_path.display());
fs::copy(&backup_path, current)?;
Ok(())
}
pub fn atomic_replace(&self, new_binary: &Path, target: &Path) -> Result<()> {
#[cfg(unix)]
{
if let Ok(meta) = fs::metadata(target) {
let perms = meta.permissions();
fs::set_permissions(new_binary, perms)?;
}
}
fs::rename(new_binary, target)?;
debug!("Atomic replacement complete");
Ok(())
}
async fn download(&self, url: &str, dest: &Path) -> Result<()> {
debug!("Downloading: {}", url);
let response = self
.client
.get(url)
.send()
.await
.map_err(|e| Error::Network(format!("Download failed: {e}")))?;
if !response.status().is_success() {
return Err(Error::Network(format!(
"Download returned status: {}",
response.status()
)));
}
let bytes = response
.bytes()
.await
.map_err(|e| Error::Network(format!("Failed to read response: {e}")))?;
Self::enforce_max_binary_size(bytes.len())?;
fs::write(dest, &bytes)?;
debug!("Downloaded {} bytes to {}", bytes.len(), dest.display());
Ok(())
}
fn enforce_max_binary_size(len: usize) -> Result<()> {
if len > MAX_BINARY_SIZE_BYTES {
return Err(Error::Upgrade(format!(
"Downloaded binary too large: {len} bytes (max {MAX_BINARY_SIZE_BYTES})"
)));
}
Ok(())
}
fn create_tempdir_in_target_dir(current_binary: &Path) -> Result<tempfile::TempDir> {
let target_dir = current_binary
.parent()
.ok_or_else(|| Error::Upgrade("Current binary has no parent directory".to_string()))?;
tempfile::Builder::new()
.prefix("ant-upgrade-")
.tempdir_in(target_dir)
.map_err(|e| Error::Upgrade(format!("Failed to create temp dir: {e}")))
}
pub async fn perform_upgrade(
&self,
info: &UpgradeInfo,
current_binary: &Path,
rollback_dir: &Path,
) -> Result<UpgradeResult> {
if !Self::auto_upgrade_supported() {
warn!(
"Auto-upgrade is not supported on this platform; refusing upgrade to {}",
info.version
);
return Ok(UpgradeResult::RolledBack {
reason: "Auto-upgrade not supported on this platform".to_string(),
});
}
self.validate_upgrade(info)?;
self.create_backup(current_binary, rollback_dir)?;
let temp_dir = Self::create_tempdir_in_target_dir(current_binary)?;
let new_binary = temp_dir.path().join("new_binary");
let sig_path = temp_dir.path().join("signature");
if let Err(e) = self.download(&info.download_url, &new_binary).await {
warn!("Download failed: {e}");
return Ok(UpgradeResult::RolledBack {
reason: format!("Download failed: {e}"),
});
}
if let Err(e) = self.download(&info.signature_url, &sig_path).await {
warn!("Signature download failed: {e}");
return Ok(UpgradeResult::RolledBack {
reason: format!("Signature download failed: {e}"),
});
}
if let Err(e) = signature::verify_from_file(&new_binary, &sig_path) {
warn!("Signature verification failed: {e}");
return Ok(UpgradeResult::RolledBack {
reason: format!("Signature verification failed: {e}"),
});
}
if let Err(e) = self.atomic_replace(&new_binary, current_binary) {
warn!("Replacement failed, rolling back: {e}");
if let Err(restore_err) = self.restore_from_backup(current_binary, rollback_dir) {
return Err(Error::Upgrade(format!(
"Critical: replacement failed ({e}) AND rollback failed ({restore_err})"
)));
}
return Ok(UpgradeResult::RolledBack {
reason: format!("Replacement failed: {e}"),
});
}
info!("Successfully upgraded to version {}", info.version);
Ok(UpgradeResult::Success {
version: info.version.clone(),
exit_code: 0,
})
}
const fn auto_upgrade_supported() -> bool {
true
}
}
impl Default for Upgrader {
fn default() -> Self {
Self::new()
}
}
pub async fn perform_upgrade(
info: &UpgradeInfo,
current_binary: &Path,
rollback_dir: &Path,
) -> Result<UpgradeResult> {
Upgrader::new()
.perform_upgrade(info, current_binary, rollback_dir)
.await
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::doc_markdown,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::case_sensitive_file_extension_comparisons
)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_backup_created() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("current");
let rollback_dir = temp.path().join("rollback");
fs::create_dir(&rollback_dir).unwrap();
let original_content = b"old binary content";
fs::write(¤t, original_content).unwrap();
let upgrader = Upgrader::new();
upgrader.create_backup(¤t, &rollback_dir).unwrap();
let backup_path = rollback_dir.join("current.backup");
assert!(backup_path.exists(), "Backup file should exist");
assert_eq!(
fs::read(&backup_path).unwrap(),
original_content,
"Backup content should match"
);
}
#[test]
fn test_restore_from_backup() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
let rollback_dir = temp.path().join("rollback");
fs::create_dir(&rollback_dir).unwrap();
let original = b"original content";
fs::write(¤t, original).unwrap();
let upgrader = Upgrader::new();
upgrader.create_backup(¤t, &rollback_dir).unwrap();
fs::write(¤t, b"corrupted content").unwrap();
upgrader
.restore_from_backup(¤t, &rollback_dir)
.unwrap();
assert_eq!(fs::read(¤t).unwrap(), original);
}
#[test]
fn test_atomic_replacement() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
let new_binary = temp.path().join("new_binary");
fs::write(¤t, b"old").unwrap();
fs::write(&new_binary, b"new").unwrap();
let upgrader = Upgrader::new();
upgrader.atomic_replace(&new_binary, ¤t).unwrap();
assert_eq!(fs::read(¤t).unwrap(), b"new");
assert!(!new_binary.exists(), "Source should be moved, not copied");
}
#[test]
fn test_downgrade_prevention() {
let current_version = Version::new(1, 1, 0);
let older_version = Version::new(1, 0, 0);
let upgrader = Upgrader::with_version(current_version);
let info = UpgradeInfo {
version: older_version,
download_url: "test".to_string(),
signature_url: "test.sig".to_string(),
release_notes: "Old".to_string(),
};
let result = upgrader.validate_upgrade(&info);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("downgrade") || err_msg.contains("Cannot"),
"Error should mention downgrade prevention: {err_msg}"
);
}
#[test]
fn test_same_version_prevention() {
let version = Version::new(1, 0, 0);
let upgrader = Upgrader::with_version(version.clone());
let info = UpgradeInfo {
version,
download_url: "test".to_string(),
signature_url: "test.sig".to_string(),
release_notes: "Same".to_string(),
};
let result = upgrader.validate_upgrade(&info);
assert!(result.is_err(), "Same version should be rejected");
}
#[test]
fn test_upgrade_validation_passes() {
let upgrader = Upgrader::with_version(Version::new(1, 0, 0));
let info = UpgradeInfo {
version: Version::new(1, 1, 0),
download_url: "test".to_string(),
signature_url: "test.sig".to_string(),
release_notes: "New".to_string(),
};
let result = upgrader.validate_upgrade(&info);
assert!(result.is_ok(), "Newer version should be accepted");
}
#[test]
fn test_restore_fails_without_backup() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
let rollback_dir = temp.path().join("rollback");
fs::create_dir(&rollback_dir).unwrap();
fs::write(¤t, b"content").unwrap();
let upgrader = Upgrader::new();
let result = upgrader.restore_from_backup(¤t, &rollback_dir);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("No backup"));
}
#[cfg(unix)]
#[test]
fn test_permissions_preserved() {
use std::os::unix::fs::PermissionsExt;
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
let new_binary = temp.path().join("new");
fs::write(¤t, b"old").unwrap();
fs::write(&new_binary, b"new").unwrap();
let mut perms = fs::metadata(¤t).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(¤t, perms).unwrap();
let upgrader = Upgrader::new();
upgrader.atomic_replace(&new_binary, ¤t).unwrap();
let new_perms = fs::metadata(¤t).unwrap().permissions();
assert_eq!(
new_perms.mode() & 0o777,
0o755,
"Permissions should be preserved"
);
}
#[test]
fn test_current_version_getter() {
let version = Version::new(2, 3, 4);
let upgrader = Upgrader::with_version(version.clone());
assert_eq!(*upgrader.current_version(), version);
}
#[test]
fn test_default_impl() {
let upgrader = Upgrader::default();
assert!(!upgrader.current_version().to_string().is_empty());
}
#[test]
fn test_backup_special_filename() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("ant-node-v1.0.0");
let rollback_dir = temp.path().join("rollback");
fs::create_dir(&rollback_dir).unwrap();
fs::write(¤t, b"content").unwrap();
let upgrader = Upgrader::new();
let result = upgrader.create_backup(¤t, &rollback_dir);
assert!(result.is_ok());
let backup_path = rollback_dir.join("ant-node-v1.0.0.backup");
assert!(backup_path.exists());
}
#[test]
fn test_upgrade_info() {
let info = UpgradeInfo {
version: Version::new(1, 2, 3),
download_url: "https://example.com/binary".to_string(),
signature_url: "https://example.com/binary.sig".to_string(),
release_notes: "Bug fixes and improvements".to_string(),
};
assert_eq!(info.version, Version::new(1, 2, 3));
assert!(info.download_url.contains("example.com"));
assert!(info.signature_url.ends_with(".sig"));
}
#[test]
fn test_upgrade_result_variants() {
let success = UpgradeResult::Success {
version: Version::new(1, 0, 0),
exit_code: 0,
};
assert!(matches!(success, UpgradeResult::Success { .. }));
let rolled_back = UpgradeResult::RolledBack {
reason: "Test failure".to_string(),
};
assert!(matches!(rolled_back, UpgradeResult::RolledBack { .. }));
let no_upgrade = UpgradeResult::NoUpgrade;
assert!(matches!(no_upgrade, UpgradeResult::NoUpgrade));
}
#[test]
fn test_large_file_backup() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("large_binary");
let rollback_dir = temp.path().join("rollback");
fs::create_dir(&rollback_dir).unwrap();
let large_content: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
fs::write(¤t, &large_content).unwrap();
let upgrader = Upgrader::new();
upgrader.create_backup(¤t, &rollback_dir).unwrap();
let backup_path = rollback_dir.join("large_binary.backup");
assert_eq!(fs::read(&backup_path).unwrap(), large_content);
}
#[test]
fn test_backup_nonexistent_rollback_dir() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
let rollback_dir = temp.path().join("nonexistent");
fs::write(¤t, b"content").unwrap();
let upgrader = Upgrader::new();
let result = upgrader.create_backup(¤t, &rollback_dir);
assert!(result.is_err(), "Should fail if rollback dir doesn't exist");
}
#[test]
fn test_tempdir_in_target_dir() {
let temp = TempDir::new().unwrap();
let current = temp.path().join("binary");
fs::write(¤t, b"content").unwrap();
let tempdir = Upgrader::create_tempdir_in_target_dir(¤t).unwrap();
assert_eq!(
tempdir.path().parent().unwrap(),
temp.path(),
"Upgrade tempdir should be in same dir as target"
);
}
#[test]
fn test_enforce_max_binary_size_rejects_large() {
let too_large = MAX_BINARY_SIZE_BYTES + 1;
let result = Upgrader::enforce_max_binary_size(too_large);
assert!(result.is_err());
}
#[test]
fn test_enforce_max_binary_size_accepts_small() {
let result = Upgrader::enforce_max_binary_size(1024);
assert!(result.is_ok());
}
#[test]
fn test_auto_upgrade_supported_on_all_platforms() {
assert!(Upgrader::auto_upgrade_supported());
}
}