use crate::error::{Error, Result};
use crate::logging::{debug, error, info, warn};
use crate::upgrade::binary_cache::BinaryCache;
use crate::upgrade::{signature, UpgradeInfo, UpgradeResult};
use flate2::read::GzDecoder;
use semver::Version;
use std::env;
use std::fs::{self, File};
use std::io::Read;
use std::path::{Path, PathBuf};
use tar::Archive;
const MAX_ARCHIVE_SIZE_BYTES: usize = 200 * 1024 * 1024;
pub const RESTART_EXIT_CODE: i32 = 100;
pub struct AutoApplyUpgrader {
current_version: Version,
client: reqwest::Client,
binary_cache: Option<BinaryCache>,
stop_on_upgrade: bool,
}
impl AutoApplyUpgrader {
#[must_use]
pub fn new() -> Self {
let current_version =
Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or_else(|_| Version::new(0, 0, 0));
Self {
current_version,
client: reqwest::Client::builder()
.user_agent(concat!("ant-node/", env!("CARGO_PKG_VERSION")))
.timeout(std::time::Duration::from_secs(300))
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
binary_cache: None,
stop_on_upgrade: false,
}
}
#[must_use]
pub fn with_binary_cache(mut self, cache: BinaryCache) -> Self {
self.binary_cache = Some(cache);
self
}
#[must_use]
pub fn with_stop_on_upgrade(mut self, stop: bool) -> Self {
self.stop_on_upgrade = stop;
self
}
#[must_use]
pub fn current_version(&self) -> &Version {
&self.current_version
}
pub fn current_binary_path() -> Result<PathBuf> {
let invoked_path = env::args().next().map(PathBuf::from);
if let Some(ref invoked) = invoked_path {
let path_str = invoked.to_string_lossy();
let cleaned = if path_str.ends_with(" (deleted)") {
let stripped = path_str.trim_end_matches(" (deleted)");
debug!("Stripped '(deleted)' suffix from invoked path: {stripped}");
PathBuf::from(stripped)
} else {
invoked.clone()
};
if cleaned.exists() {
if let Ok(canonical) = cleaned.canonicalize() {
return Ok(canonical);
}
return Ok(cleaned);
}
}
let path = env::current_exe()
.map_err(|e| Error::Upgrade(format!("Cannot determine binary path: {e}")))?;
#[cfg(unix)]
{
let path_str = path.to_string_lossy();
if path_str.ends_with(" (deleted)") {
let cleaned = path_str.trim_end_matches(" (deleted)");
debug!("Stripped '(deleted)' suffix from binary path: {cleaned}");
return Ok(PathBuf::from(cleaned));
}
}
Ok(path)
}
pub async fn apply_upgrade(&self, info: &UpgradeInfo) -> Result<UpgradeResult> {
info!(
"Starting auto-apply upgrade from {} to {}",
self.current_version, info.version
);
if info.version <= self.current_version {
warn!(
"Ignoring downgrade attempt: {} -> {}",
self.current_version, info.version
);
return Ok(UpgradeResult::NoUpgrade);
}
let current_binary = Self::current_binary_path()?;
let binary_dir = current_binary
.parent()
.ok_or_else(|| Error::Upgrade("Cannot determine binary directory".to_string()))?;
let temp_dir = tempfile::Builder::new()
.prefix("ant-upgrade-")
.tempdir_in(binary_dir)
.map_err(|e| Error::Upgrade(format!("Failed to create temp dir: {e}")))?;
let version_str = info.version.to_string();
let extracted_binary = match self
.resolve_upgrade_binary(info, temp_dir.path(), &version_str)
.await
{
Ok(path) => path,
Err(e) => {
warn!("Download/verify/extract failed: {e}");
return Ok(UpgradeResult::RolledBack {
reason: format!("{e}"),
});
}
};
if let Some(disk_version) = on_disk_version(¤t_binary).await {
if disk_version == info.version {
info!(
"Binary already upgraded to {} by another service, skipping replacement",
info.version
);
let exit_code = self.prepare_restart(¤t_binary)?;
return Ok(UpgradeResult::Success {
version: info.version.clone(),
exit_code,
});
}
}
let backup_path = binary_dir.join(format!(
"{}.backup",
current_binary
.file_name()
.map_or_else(|| "ant-node".into(), |s| s.to_string_lossy())
));
info!("Creating backup at {}...", backup_path.display());
if let Err(e) = fs::copy(¤t_binary, &backup_path) {
warn!("Backup creation failed: {e}");
return Ok(UpgradeResult::RolledBack {
reason: format!("Backup failed: {e}"),
});
}
info!("Replacing binary...");
let new_bin = extracted_binary.clone();
let target_bin = current_binary.clone();
let replace_result =
tokio::task::spawn_blocking(move || Self::replace_binary(&new_bin, &target_bin))
.await
.map_err(|e| Error::Upgrade(format!("Binary replacement task panicked: {e}")))?;
if let Err(e) = replace_result {
warn!("Binary replacement failed: {e}");
if let Err(restore_err) = fs::copy(&backup_path, ¤t_binary) {
error!("CRITICAL: Replacement failed ({e}) AND rollback failed ({restore_err})");
return Err(Error::Upgrade(format!(
"Critical: replacement failed ({e}) AND rollback failed ({restore_err})"
)));
}
return Ok(UpgradeResult::RolledBack {
reason: format!("Replacement failed: {e}"),
});
}
info!(
"Successfully upgraded to version {}! Restarting...",
info.version
);
let exit_code = self.prepare_restart(¤t_binary)?;
Ok(UpgradeResult::Success {
version: info.version.clone(),
exit_code,
})
}
async fn download(&self, url: &str, dest: &Path) -> Result<()> {
debug!("Downloading: {}", url);
let response = self
.client
.get(url)
.send()
.await
.map_err(|e| Error::Network(format!("Download failed: {e}")))?;
if !response.status().is_success() {
return Err(Error::Network(format!(
"Download returned status: {}",
response.status()
)));
}
let bytes = response
.bytes()
.await
.map_err(|e| Error::Network(format!("Failed to read response: {e}")))?;
if bytes.len() > MAX_ARCHIVE_SIZE_BYTES {
return Err(Error::Upgrade(format!(
"Downloaded file too large: {} bytes (max {})",
bytes.len(),
MAX_ARCHIVE_SIZE_BYTES
)));
}
fs::write(dest, &bytes)?;
debug!("Downloaded {} bytes to {}", bytes.len(), dest.display());
Ok(())
}
async fn resolve_upgrade_binary(
&self,
info: &UpgradeInfo,
dest_dir: &Path,
version_str: &str,
) -> Result<PathBuf> {
if let Some(ref cache) = self.binary_cache {
if let Some(cached_path) = cache.get_verified(version_str) {
info!("Cached binary verified for version {}", version_str);
let dest = dest_dir.join(
cached_path
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("ant-node")),
);
if let Err(e) = fs::copy(&cached_path, &dest) {
warn!("Failed to copy from cache, will re-download: {e}");
return self
.download_verify_extract(info, dest_dir, Some(cache))
.await;
}
return Ok(dest);
}
let cache_clone = cache.clone();
let lock_guard =
tokio::task::spawn_blocking(move || cache_clone.acquire_download_lock())
.await
.map_err(|e| Error::Upgrade(format!("Lock task failed: {e}")))??;
if let Some(cached_path) = cache.get_verified(version_str) {
info!(
"Cached binary became available under lock for version {}",
version_str
);
let dest = dest_dir.join(
cached_path
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("ant-node")),
);
fs::copy(&cached_path, &dest)?;
return Ok(dest);
}
let result = self
.download_verify_extract(info, dest_dir, Some(cache))
.await;
drop(lock_guard);
result
} else {
self.download_verify_extract(info, dest_dir, None).await
}
}
async fn download_verify_extract(
&self,
info: &UpgradeInfo,
dest_dir: &Path,
cache: Option<&BinaryCache>,
) -> Result<PathBuf> {
let archive_path = dest_dir.join("archive");
let sig_path = dest_dir.join("signature");
info!("Downloading ant-node binary...");
self.download(&info.download_url, &archive_path).await?;
info!("Downloading signature...");
self.download(&info.signature_url, &sig_path).await?;
info!("Verifying ML-DSA signature on archive...");
signature::verify_from_file(&archive_path, &sig_path)?;
info!("Archive signature verified successfully");
info!("Extracting binary from archive...");
let extracted_binary = Self::extract_binary(&archive_path, dest_dir)?;
if let Some(c) = cache {
let version_str = info.version.to_string();
if let Err(e) = c.store(&version_str, &extracted_binary) {
warn!("Failed to store binary in cache: {e}");
}
}
Ok(extracted_binary)
}
fn extract_binary(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
let mut file = File::open(archive_path)?;
let mut magic = [0u8; 2];
file.read_exact(&mut magic)
.map_err(|e| Error::Upgrade(format!("Failed to read archive header: {e}")))?;
drop(file);
match magic {
[0x1f, 0x8b] => Self::extract_from_tar_gz(archive_path, dest_dir),
[0x50, 0x4b] => Self::extract_from_zip(archive_path, dest_dir),
_ => Err(Error::Upgrade(format!(
"Unknown archive format (magic bytes: {:02x} {:02x})",
magic[0], magic[1]
))),
}
}
fn extract_from_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
let file = File::open(archive_path)?;
let decoder = GzDecoder::new(file);
let mut archive = Archive::new(decoder);
let binary_name = if cfg!(windows) {
"ant-node.exe"
} else {
"ant-node"
};
let extracted_binary = dest_dir.join(binary_name);
for entry in archive
.entries()
.map_err(|e| Error::Upgrade(format!("Failed to read archive: {e}")))?
{
let mut entry =
entry.map_err(|e| Error::Upgrade(format!("Failed to read entry: {e}")))?;
let path = entry
.path()
.map_err(|e| Error::Upgrade(format!("Invalid path in archive: {e}")))?;
if let Some(name) = path.file_name() {
let name_str = name.to_string_lossy();
if name_str == "ant-node" || name_str == "ant-node.exe" {
debug!("Found binary in tar.gz archive: {}", path.display());
let mut out = File::create(&extracted_binary)?;
std::io::copy(&mut entry, &mut out)
.map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&extracted_binary)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&extracted_binary, perms)?;
}
return Ok(extracted_binary);
}
}
}
Err(Error::Upgrade(
"ant-node binary not found in tar.gz archive".to_string(),
))
}
fn extract_from_zip(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
let file = File::open(archive_path)?;
let mut archive = zip::ZipArchive::new(file)
.map_err(|e| Error::Upgrade(format!("Failed to open zip archive: {e}")))?;
let binary_name = if cfg!(windows) {
"ant-node.exe"
} else {
"ant-node"
};
let extracted_binary = dest_dir.join(binary_name);
for i in 0..archive.len() {
let mut entry = archive
.by_index(i)
.map_err(|e| Error::Upgrade(format!("Failed to read zip entry: {e}")))?;
let path = match entry.enclosed_name() {
Some(p) => p.clone(),
None => continue,
};
if let Some(name) = path.file_name() {
let name_str = name.to_string_lossy();
if name_str == "ant-node" || name_str == "ant-node.exe" {
debug!("Found binary in zip archive: {}", path.display());
let mut out = File::create(&extracted_binary)?;
std::io::copy(&mut entry, &mut out)
.map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&extracted_binary)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&extracted_binary, perms)?;
}
return Ok(extracted_binary);
}
}
}
Err(Error::Upgrade(
"ant-node binary not found in zip archive".to_string(),
))
}
fn replace_binary(new_binary: &Path, target: &Path) -> Result<()> {
#[cfg(unix)]
{
if let Ok(meta) = fs::metadata(target) {
let perms = meta.permissions();
fs::set_permissions(new_binary, perms)?;
}
fs::rename(new_binary, target)?;
}
#[cfg(windows)]
{
let _ = target; let delays = [500u64, 1000, 2000];
let mut last_err = None;
for (attempt, delay_ms) in delays.iter().enumerate() {
match self_replace::self_replace(new_binary) {
Ok(()) => {
last_err = None;
break;
}
Err(e) => {
warn!(
"self_replace attempt {} failed: {e}, retrying in {delay_ms}ms",
attempt + 1
);
last_err = Some(e);
std::thread::sleep(std::time::Duration::from_millis(*delay_ms));
}
}
}
if let Some(e) = last_err {
return Err(Error::Upgrade(format!(
"self_replace failed after retries: {e}"
)));
}
}
debug!("Binary replacement complete");
Ok(())
}
fn prepare_restart(&self, binary_path: &Path) -> Result<i32> {
if self.stop_on_upgrade {
let exit_code;
#[cfg(unix)]
{
info!("Service manager mode: will exit with code 0 after graceful shutdown");
exit_code = 0;
}
#[cfg(windows)]
{
let _ = binary_path;
info!(
"Service manager mode: will exit with code {} after graceful shutdown",
RESTART_EXIT_CODE
);
exit_code = RESTART_EXIT_CODE;
}
#[cfg(not(any(unix, windows)))]
{
let _ = binary_path;
warn!("Auto-restart not supported on this platform. Please restart manually.");
exit_code = 0;
}
Ok(exit_code)
} else {
let args: Vec<String> = env::args().skip(1).collect();
info!("Spawning new process: {} {:?}", binary_path.display(), args);
std::process::Command::new(binary_path)
.args(&args)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit())
.spawn()
.map_err(|e| Error::Upgrade(format!("Failed to spawn new binary: {e}")))?;
info!("New process spawned, will exit after graceful shutdown");
Ok(0)
}
}
}
async fn on_disk_version(binary_path: &Path) -> Option<Version> {
let output = tokio::time::timeout(
std::time::Duration::from_secs(5),
tokio::process::Command::new(binary_path)
.arg("--version")
.output(),
)
.await
.ok()?
.ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
let version_str = stdout.trim().strip_prefix("ant-node ")?;
Version::parse(version_str).ok()
}
impl Default for AutoApplyUpgrader {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_auto_apply_upgrader_creation() {
let upgrader = AutoApplyUpgrader::new();
assert!(!upgrader.current_version().to_string().is_empty());
}
#[test]
fn test_current_binary_path() {
let result = AutoApplyUpgrader::current_binary_path();
assert!(result.is_ok());
let path = result.unwrap();
assert!(path.exists() || path.to_string_lossy().contains("test"));
}
#[test]
fn test_default_impl() {
let upgrader = AutoApplyUpgrader::default();
assert!(!upgrader.current_version().to_string().is_empty());
}
fn create_tar_gz_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
use flate2::write::GzEncoder;
use flate2::Compression;
let archive_path = dir.join("test.tar.gz");
let file = File::create(&archive_path).unwrap();
let encoder = GzEncoder::new(file, Compression::default());
let mut builder = tar::Builder::new(encoder);
let mut header = tar::Header::new_gnu();
header.set_size(content.len() as u64);
header.set_mode(0o755);
header.set_cksum();
builder
.append_data(&mut header, binary_name, content)
.unwrap();
builder.finish().unwrap();
archive_path
}
fn create_zip_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
use std::io::Write;
let archive_path = dir.join("test.zip");
let file = File::create(&archive_path).unwrap();
let mut zip_writer = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip_writer.start_file(binary_name, options).unwrap();
zip_writer.write_all(content).unwrap();
zip_writer.finish().unwrap();
archive_path
}
#[test]
fn test_extract_binary_from_tar_gz() {
let dir = tempfile::tempdir().unwrap();
let content = b"fake-binary-content";
let archive = create_tar_gz_archive(dir.path(), "ant-node", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_ok());
let extracted = result.unwrap();
assert!(extracted.exists());
assert_eq!(fs::read(&extracted).unwrap(), content);
}
#[test]
fn test_extract_binary_from_zip() {
let dir = tempfile::tempdir().unwrap();
let content = b"fake-binary-content";
let archive = create_zip_archive(dir.path(), "ant-node", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_ok());
let extracted = result.unwrap();
assert!(extracted.exists());
assert_eq!(fs::read(&extracted).unwrap(), content);
}
#[test]
fn test_extract_binary_from_zip_with_exe() {
let dir = tempfile::tempdir().unwrap();
let content = b"fake-windows-binary";
let archive = create_zip_archive(dir.path(), "ant-node.exe", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_ok());
let extracted = result.unwrap();
assert!(extracted.exists());
assert_eq!(fs::read(&extracted).unwrap(), content);
}
#[test]
fn test_extract_binary_from_tar_gz_nested_path() {
let dir = tempfile::tempdir().unwrap();
let content = b"nested-binary";
let archive = create_tar_gz_archive(dir.path(), "some/nested/path/ant-node", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_ok());
let extracted = result.unwrap();
assert!(extracted.exists());
assert_eq!(fs::read(&extracted).unwrap(), content);
}
#[test]
fn test_extract_binary_unknown_format() {
let dir = tempfile::tempdir().unwrap();
let archive_path = dir.path().join("bad_archive");
fs::write(&archive_path, b"XX not a real archive").unwrap();
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Unknown archive format"));
}
#[test]
fn test_extract_binary_missing_binary_in_tar_gz() {
let dir = tempfile::tempdir().unwrap();
let content = b"not-the-binary";
let archive = create_tar_gz_archive(dir.path(), "other-file", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found in tar.gz archive"));
}
#[test]
fn test_extract_binary_missing_binary_in_zip() {
let dir = tempfile::tempdir().unwrap();
let content = b"not-the-binary";
let archive = create_zip_archive(dir.path(), "other-file", content);
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found in zip archive"));
}
#[test]
fn test_extract_binary_empty_file() {
let dir = tempfile::tempdir().unwrap();
let archive_path = dir.path().join("empty");
fs::write(&archive_path, b"").unwrap();
let dest = tempfile::tempdir().unwrap();
let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
assert!(result.is_err());
}
}