1use crate::error::{Error, Result};
11use crate::upgrade::binary_cache::BinaryCache;
12use crate::upgrade::{signature, UpgradeInfo, UpgradeResult};
13use flate2::read::GzDecoder;
14use semver::Version;
15use std::env;
16use std::fs::{self, File};
17use std::io::Read;
18use std::path::{Path, PathBuf};
19use tar::Archive;
20use tracing::{debug, error, info, warn};
21
22const MAX_ARCHIVE_SIZE_BYTES: usize = 200 * 1024 * 1024;
24
25pub const RESTART_EXIT_CODE: i32 = 100;
31
32pub struct AutoApplyUpgrader {
34 current_version: Version,
36 client: reqwest::Client,
38 binary_cache: Option<BinaryCache>,
40 stop_on_upgrade: bool,
42}
43
44impl AutoApplyUpgrader {
45 #[must_use]
47 pub fn new() -> Self {
48 let current_version =
49 Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or_else(|_| Version::new(0, 0, 0));
50
51 Self {
52 current_version,
53 client: reqwest::Client::builder()
54 .user_agent(concat!("ant-node/", env!("CARGO_PKG_VERSION")))
55 .timeout(std::time::Duration::from_secs(300))
56 .build()
57 .unwrap_or_else(|_| reqwest::Client::new()),
58 binary_cache: None,
59 stop_on_upgrade: false,
60 }
61 }
62
63 #[must_use]
68 pub fn with_binary_cache(mut self, cache: BinaryCache) -> Self {
69 self.binary_cache = Some(cache);
70 self
71 }
72
73 #[must_use]
78 pub fn with_stop_on_upgrade(mut self, stop: bool) -> Self {
79 self.stop_on_upgrade = stop;
80 self
81 }
82
83 #[must_use]
85 pub fn current_version(&self) -> &Version {
86 &self.current_version
87 }
88
89 pub fn current_binary_path() -> Result<PathBuf> {
99 let invoked_path = env::args().next().map(PathBuf::from);
102
103 if let Some(ref invoked) = invoked_path {
104 let path_str = invoked.to_string_lossy();
109 let cleaned = if path_str.ends_with(" (deleted)") {
110 let stripped = path_str.trim_end_matches(" (deleted)");
111 debug!("Stripped '(deleted)' suffix from invoked path: {stripped}");
112 PathBuf::from(stripped)
113 } else {
114 invoked.clone()
115 };
116
117 if cleaned.exists() {
118 if let Ok(canonical) = cleaned.canonicalize() {
121 return Ok(canonical);
122 }
123 return Ok(cleaned);
124 }
125 }
126
127 let path = env::current_exe()
129 .map_err(|e| Error::Upgrade(format!("Cannot determine binary path: {e}")))?;
130
131 #[cfg(unix)]
132 {
133 let path_str = path.to_string_lossy();
134 if path_str.ends_with(" (deleted)") {
135 let cleaned = path_str.trim_end_matches(" (deleted)");
136 debug!("Stripped '(deleted)' suffix from binary path: {cleaned}");
137 return Ok(PathBuf::from(cleaned));
138 }
139 }
140
141 Ok(path)
142 }
143
144 pub async fn apply_upgrade(&self, info: &UpgradeInfo) -> Result<UpgradeResult> {
159 info!(
160 "Starting auto-apply upgrade from {} to {}",
161 self.current_version, info.version
162 );
163
164 if info.version <= self.current_version {
166 warn!(
167 "Ignoring downgrade attempt: {} -> {}",
168 self.current_version, info.version
169 );
170 return Ok(UpgradeResult::NoUpgrade);
171 }
172
173 let current_binary = Self::current_binary_path()?;
175 let binary_dir = current_binary
176 .parent()
177 .ok_or_else(|| Error::Upgrade("Cannot determine binary directory".to_string()))?;
178
179 let temp_dir = tempfile::Builder::new()
181 .prefix("ant-upgrade-")
182 .tempdir_in(binary_dir)
183 .map_err(|e| Error::Upgrade(format!("Failed to create temp dir: {e}")))?;
184
185 let version_str = info.version.to_string();
186
187 let extracted_binary = match self
190 .resolve_upgrade_binary(info, temp_dir.path(), &version_str)
191 .await
192 {
193 Ok(path) => path,
194 Err(e) => {
195 warn!("Download/verify/extract failed: {e}");
196 return Ok(UpgradeResult::RolledBack {
197 reason: format!("{e}"),
198 });
199 }
200 };
201
202 if let Some(disk_version) = on_disk_version(¤t_binary).await {
205 if disk_version == info.version {
206 info!(
207 "Binary already upgraded to {} by another service, skipping replacement",
208 info.version
209 );
210 let exit_code = self.prepare_restart(¤t_binary)?;
211 return Ok(UpgradeResult::Success {
212 version: info.version.clone(),
213 exit_code,
214 });
215 }
216 }
217
218 let backup_path = binary_dir.join(format!(
220 "{}.backup",
221 current_binary
222 .file_name()
223 .map_or_else(|| "ant-node".into(), |s| s.to_string_lossy())
224 ));
225 info!("Creating backup at {}...", backup_path.display());
226 if let Err(e) = fs::copy(¤t_binary, &backup_path) {
227 warn!("Backup creation failed: {e}");
228 return Ok(UpgradeResult::RolledBack {
229 reason: format!("Backup failed: {e}"),
230 });
231 }
232
233 info!("Replacing binary...");
236 let new_bin = extracted_binary.clone();
237 let target_bin = current_binary.clone();
238 let replace_result =
239 tokio::task::spawn_blocking(move || Self::replace_binary(&new_bin, &target_bin))
240 .await
241 .map_err(|e| Error::Upgrade(format!("Binary replacement task panicked: {e}")))?;
242 if let Err(e) = replace_result {
243 warn!("Binary replacement failed: {e}");
244 if let Err(restore_err) = fs::copy(&backup_path, ¤t_binary) {
246 error!("CRITICAL: Replacement failed ({e}) AND rollback failed ({restore_err})");
247 return Err(Error::Upgrade(format!(
248 "Critical: replacement failed ({e}) AND rollback failed ({restore_err})"
249 )));
250 }
251 return Ok(UpgradeResult::RolledBack {
252 reason: format!("Replacement failed: {e}"),
253 });
254 }
255
256 info!(
257 "Successfully upgraded to version {}! Restarting...",
258 info.version
259 );
260
261 let exit_code = self.prepare_restart(¤t_binary)?;
263
264 Ok(UpgradeResult::Success {
265 version: info.version.clone(),
266 exit_code,
267 })
268 }
269
270 async fn download(&self, url: &str, dest: &Path) -> Result<()> {
272 debug!("Downloading: {}", url);
273
274 let response = self
275 .client
276 .get(url)
277 .send()
278 .await
279 .map_err(|e| Error::Network(format!("Download failed: {e}")))?;
280
281 if !response.status().is_success() {
282 return Err(Error::Network(format!(
283 "Download returned status: {}",
284 response.status()
285 )));
286 }
287
288 let bytes = response
289 .bytes()
290 .await
291 .map_err(|e| Error::Network(format!("Failed to read response: {e}")))?;
292
293 if bytes.len() > MAX_ARCHIVE_SIZE_BYTES {
294 return Err(Error::Upgrade(format!(
295 "Downloaded file too large: {} bytes (max {})",
296 bytes.len(),
297 MAX_ARCHIVE_SIZE_BYTES
298 )));
299 }
300
301 fs::write(dest, &bytes)?;
302 debug!("Downloaded {} bytes to {}", bytes.len(), dest.display());
303 Ok(())
304 }
305
306 async fn resolve_upgrade_binary(
314 &self,
315 info: &UpgradeInfo,
316 dest_dir: &Path,
317 version_str: &str,
318 ) -> Result<PathBuf> {
319 if let Some(ref cache) = self.binary_cache {
320 if let Some(cached_path) = cache.get_verified(version_str) {
322 info!("Cached binary verified for version {}", version_str);
323 let dest = dest_dir.join(
324 cached_path
325 .file_name()
326 .unwrap_or_else(|| std::ffi::OsStr::new("ant-node")),
327 );
328 if let Err(e) = fs::copy(&cached_path, &dest) {
329 warn!("Failed to copy from cache, will re-download: {e}");
330 return self
331 .download_verify_extract(info, dest_dir, Some(cache))
332 .await;
333 }
334 return Ok(dest);
335 }
336
337 let cache_clone = cache.clone();
340 let lock_guard =
344 tokio::task::spawn_blocking(move || cache_clone.acquire_download_lock())
345 .await
346 .map_err(|e| Error::Upgrade(format!("Lock task failed: {e}")))??;
347
348 if let Some(cached_path) = cache.get_verified(version_str) {
350 info!(
351 "Cached binary became available under lock for version {}",
352 version_str
353 );
354 let dest = dest_dir.join(
355 cached_path
356 .file_name()
357 .unwrap_or_else(|| std::ffi::OsStr::new("ant-node")),
358 );
359 fs::copy(&cached_path, &dest)?;
360 return Ok(dest);
361 }
362
363 let result = self
365 .download_verify_extract(info, dest_dir, Some(cache))
366 .await;
367 drop(lock_guard);
368 result
369 } else {
370 self.download_verify_extract(info, dest_dir, None).await
371 }
372 }
373
374 async fn download_verify_extract(
382 &self,
383 info: &UpgradeInfo,
384 dest_dir: &Path,
385 cache: Option<&BinaryCache>,
386 ) -> Result<PathBuf> {
387 let archive_path = dest_dir.join("archive");
388 let sig_path = dest_dir.join("signature");
389
390 info!("Downloading ant-node binary...");
392 self.download(&info.download_url, &archive_path).await?;
393
394 info!("Downloading signature...");
396 self.download(&info.signature_url, &sig_path).await?;
397
398 info!("Verifying ML-DSA signature on archive...");
400 signature::verify_from_file(&archive_path, &sig_path)?;
401 info!("Archive signature verified successfully");
402
403 info!("Extracting binary from archive...");
405 let extracted_binary = Self::extract_binary(&archive_path, dest_dir)?;
406
407 if let Some(c) = cache {
409 let version_str = info.version.to_string();
410 if let Err(e) = c.store(&version_str, &extracted_binary) {
411 warn!("Failed to store binary in cache: {e}");
412 }
413 }
414
415 Ok(extracted_binary)
416 }
417
418 fn extract_binary(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
424 let mut file = File::open(archive_path)?;
425 let mut magic = [0u8; 2];
426 file.read_exact(&mut magic)
427 .map_err(|e| Error::Upgrade(format!("Failed to read archive header: {e}")))?;
428 drop(file);
429
430 match magic {
431 [0x1f, 0x8b] => Self::extract_from_tar_gz(archive_path, dest_dir),
432 [0x50, 0x4b] => Self::extract_from_zip(archive_path, dest_dir),
433 _ => Err(Error::Upgrade(format!(
434 "Unknown archive format (magic bytes: {:02x} {:02x})",
435 magic[0], magic[1]
436 ))),
437 }
438 }
439
440 fn extract_from_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
442 let file = File::open(archive_path)?;
443 let decoder = GzDecoder::new(file);
444 let mut archive = Archive::new(decoder);
445
446 let binary_name = if cfg!(windows) {
447 "ant-node.exe"
448 } else {
449 "ant-node"
450 };
451 let extracted_binary = dest_dir.join(binary_name);
452
453 for entry in archive
454 .entries()
455 .map_err(|e| Error::Upgrade(format!("Failed to read archive: {e}")))?
456 {
457 let mut entry =
458 entry.map_err(|e| Error::Upgrade(format!("Failed to read entry: {e}")))?;
459 let path = entry
460 .path()
461 .map_err(|e| Error::Upgrade(format!("Invalid path in archive: {e}")))?;
462
463 if let Some(name) = path.file_name() {
465 let name_str = name.to_string_lossy();
466 if name_str == "ant-node" || name_str == "ant-node.exe" {
467 debug!("Found binary in tar.gz archive: {}", path.display());
468
469 let mut out = File::create(&extracted_binary)?;
471 std::io::copy(&mut entry, &mut out)
472 .map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
473
474 #[cfg(unix)]
476 {
477 use std::os::unix::fs::PermissionsExt;
478 let mut perms = fs::metadata(&extracted_binary)?.permissions();
479 perms.set_mode(0o755);
480 fs::set_permissions(&extracted_binary, perms)?;
481 }
482
483 return Ok(extracted_binary);
484 }
485 }
486 }
487
488 Err(Error::Upgrade(
489 "ant-node binary not found in tar.gz archive".to_string(),
490 ))
491 }
492
493 fn extract_from_zip(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
495 let file = File::open(archive_path)?;
496 let mut archive = zip::ZipArchive::new(file)
497 .map_err(|e| Error::Upgrade(format!("Failed to open zip archive: {e}")))?;
498
499 let binary_name = if cfg!(windows) {
500 "ant-node.exe"
501 } else {
502 "ant-node"
503 };
504 let extracted_binary = dest_dir.join(binary_name);
505
506 for i in 0..archive.len() {
507 let mut entry = archive
508 .by_index(i)
509 .map_err(|e| Error::Upgrade(format!("Failed to read zip entry: {e}")))?;
510
511 let path = match entry.enclosed_name() {
512 Some(p) => p.clone(),
513 None => continue,
514 };
515
516 if let Some(name) = path.file_name() {
517 let name_str = name.to_string_lossy();
518 if name_str == "ant-node" || name_str == "ant-node.exe" {
519 debug!("Found binary in zip archive: {}", path.display());
520
521 let mut out = File::create(&extracted_binary)?;
523 std::io::copy(&mut entry, &mut out)
524 .map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
525
526 #[cfg(unix)]
528 {
529 use std::os::unix::fs::PermissionsExt;
530 let mut perms = fs::metadata(&extracted_binary)?.permissions();
531 perms.set_mode(0o755);
532 fs::set_permissions(&extracted_binary, perms)?;
533 }
534
535 return Ok(extracted_binary);
536 }
537 }
538 }
539
540 Err(Error::Upgrade(
541 "ant-node binary not found in zip archive".to_string(),
542 ))
543 }
544
545 fn replace_binary(new_binary: &Path, target: &Path) -> Result<()> {
550 #[cfg(unix)]
551 {
552 if let Ok(meta) = fs::metadata(target) {
554 let perms = meta.permissions();
555 fs::set_permissions(new_binary, perms)?;
556 }
557 fs::rename(new_binary, target)?;
559 }
560
561 #[cfg(windows)]
562 {
563 let _ = target; let delays = [500u64, 1000, 2000];
566 let mut last_err = None;
567 for (attempt, delay_ms) in delays.iter().enumerate() {
568 match self_replace::self_replace(new_binary) {
569 Ok(()) => {
570 last_err = None;
571 break;
572 }
573 Err(e) => {
574 warn!(
575 "self_replace attempt {} failed: {e}, retrying in {delay_ms}ms",
576 attempt + 1
577 );
578 last_err = Some(e);
579 std::thread::sleep(std::time::Duration::from_millis(*delay_ms));
580 }
581 }
582 }
583 if let Some(e) = last_err {
584 return Err(Error::Upgrade(format!(
585 "self_replace failed after retries: {e}"
586 )));
587 }
588 }
589
590 debug!("Binary replacement complete");
591 Ok(())
592 }
593
594 fn prepare_restart(&self, binary_path: &Path) -> Result<i32> {
609 if self.stop_on_upgrade {
610 let exit_code;
611
612 #[cfg(unix)]
613 {
614 info!("Service manager mode: will exit with code 0 after graceful shutdown");
615 exit_code = 0;
616 }
617
618 #[cfg(windows)]
619 {
620 let _ = binary_path;
621 info!(
622 "Service manager mode: will exit with code {} after graceful shutdown",
623 RESTART_EXIT_CODE
624 );
625 exit_code = RESTART_EXIT_CODE;
626 }
627
628 #[cfg(not(any(unix, windows)))]
629 {
630 let _ = binary_path;
631 warn!("Auto-restart not supported on this platform. Please restart manually.");
632 exit_code = 0;
633 }
634
635 Ok(exit_code)
636 } else {
637 let args: Vec<String> = env::args().skip(1).collect();
639
640 info!("Spawning new process: {} {:?}", binary_path.display(), args);
641
642 std::process::Command::new(binary_path)
643 .args(&args)
644 .stdin(std::process::Stdio::null())
645 .stdout(std::process::Stdio::inherit())
646 .stderr(std::process::Stdio::inherit())
647 .spawn()
648 .map_err(|e| Error::Upgrade(format!("Failed to spawn new binary: {e}")))?;
649
650 info!("New process spawned, will exit after graceful shutdown");
651 Ok(0)
652 }
653 }
654}
655
656async fn on_disk_version(binary_path: &Path) -> Option<Version> {
664 let output = tokio::time::timeout(
665 std::time::Duration::from_secs(5),
666 tokio::process::Command::new(binary_path)
667 .arg("--version")
668 .output(),
669 )
670 .await
671 .ok()?
672 .ok()?;
673 let stdout = String::from_utf8_lossy(&output.stdout);
674 let version_str = stdout.trim().strip_prefix("ant-node ")?;
675 Version::parse(version_str).ok()
676}
677
678impl Default for AutoApplyUpgrader {
679 fn default() -> Self {
680 Self::new()
681 }
682}
683
684#[cfg(test)]
685#[allow(clippy::unwrap_used, clippy::expect_used)]
686mod tests {
687 use super::*;
688
689 #[test]
690 fn test_auto_apply_upgrader_creation() {
691 let upgrader = AutoApplyUpgrader::new();
692 assert!(!upgrader.current_version().to_string().is_empty());
693 }
694
695 #[test]
696 fn test_current_binary_path() {
697 let result = AutoApplyUpgrader::current_binary_path();
698 assert!(result.is_ok());
699 let path = result.unwrap();
700 assert!(path.exists() || path.to_string_lossy().contains("test"));
701 }
702
703 #[test]
704 fn test_default_impl() {
705 let upgrader = AutoApplyUpgrader::default();
706 assert!(!upgrader.current_version().to_string().is_empty());
707 }
708
709 fn create_tar_gz_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
711 use flate2::write::GzEncoder;
712 use flate2::Compression;
713
714 let archive_path = dir.join("test.tar.gz");
715 let file = File::create(&archive_path).unwrap();
716 let encoder = GzEncoder::new(file, Compression::default());
717 let mut builder = tar::Builder::new(encoder);
718
719 let mut header = tar::Header::new_gnu();
720 header.set_size(content.len() as u64);
721 header.set_mode(0o755);
722 header.set_cksum();
723 builder
724 .append_data(&mut header, binary_name, content)
725 .unwrap();
726 builder.finish().unwrap();
727
728 archive_path
729 }
730
731 fn create_zip_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
733 use std::io::Write;
734
735 let archive_path = dir.join("test.zip");
736 let file = File::create(&archive_path).unwrap();
737 let mut zip_writer = zip::ZipWriter::new(file);
738 let options = zip::write::SimpleFileOptions::default()
739 .compression_method(zip::CompressionMethod::Stored);
740 zip_writer.start_file(binary_name, options).unwrap();
741 zip_writer.write_all(content).unwrap();
742 zip_writer.finish().unwrap();
743
744 archive_path
745 }
746
747 #[test]
748 fn test_extract_binary_from_tar_gz() {
749 let dir = tempfile::tempdir().unwrap();
750 let content = b"fake-binary-content";
751 let archive = create_tar_gz_archive(dir.path(), "ant-node", content);
752
753 let dest = tempfile::tempdir().unwrap();
754 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
755 assert!(result.is_ok());
756
757 let extracted = result.unwrap();
758 assert!(extracted.exists());
759 assert_eq!(fs::read(&extracted).unwrap(), content);
760 }
761
762 #[test]
763 fn test_extract_binary_from_zip() {
764 let dir = tempfile::tempdir().unwrap();
765 let content = b"fake-binary-content";
766 let archive = create_zip_archive(dir.path(), "ant-node", content);
767
768 let dest = tempfile::tempdir().unwrap();
769 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
770 assert!(result.is_ok());
771
772 let extracted = result.unwrap();
773 assert!(extracted.exists());
774 assert_eq!(fs::read(&extracted).unwrap(), content);
775 }
776
777 #[test]
778 fn test_extract_binary_from_zip_with_exe() {
779 let dir = tempfile::tempdir().unwrap();
780 let content = b"fake-windows-binary";
781 let archive = create_zip_archive(dir.path(), "ant-node.exe", content);
782
783 let dest = tempfile::tempdir().unwrap();
784 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
785 assert!(result.is_ok());
786
787 let extracted = result.unwrap();
788 assert!(extracted.exists());
789 assert_eq!(fs::read(&extracted).unwrap(), content);
790 }
791
792 #[test]
793 fn test_extract_binary_from_tar_gz_nested_path() {
794 let dir = tempfile::tempdir().unwrap();
795 let content = b"nested-binary";
796 let archive = create_tar_gz_archive(dir.path(), "some/nested/path/ant-node", content);
797
798 let dest = tempfile::tempdir().unwrap();
799 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
800 assert!(result.is_ok());
801
802 let extracted = result.unwrap();
803 assert!(extracted.exists());
804 assert_eq!(fs::read(&extracted).unwrap(), content);
805 }
806
807 #[test]
808 fn test_extract_binary_unknown_format() {
809 let dir = tempfile::tempdir().unwrap();
810 let archive_path = dir.path().join("bad_archive");
811 fs::write(&archive_path, b"XX not a real archive").unwrap();
812
813 let dest = tempfile::tempdir().unwrap();
814 let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
815 assert!(result.is_err());
816
817 let err = result.unwrap_err().to_string();
818 assert!(err.contains("Unknown archive format"));
819 }
820
821 #[test]
822 fn test_extract_binary_missing_binary_in_tar_gz() {
823 let dir = tempfile::tempdir().unwrap();
824 let content = b"not-the-binary";
825 let archive = create_tar_gz_archive(dir.path(), "other-file", content);
826
827 let dest = tempfile::tempdir().unwrap();
828 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
829 assert!(result.is_err());
830
831 let err = result.unwrap_err().to_string();
832 assert!(err.contains("not found in tar.gz archive"));
833 }
834
835 #[test]
836 fn test_extract_binary_missing_binary_in_zip() {
837 let dir = tempfile::tempdir().unwrap();
838 let content = b"not-the-binary";
839 let archive = create_zip_archive(dir.path(), "other-file", content);
840
841 let dest = tempfile::tempdir().unwrap();
842 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
843 assert!(result.is_err());
844
845 let err = result.unwrap_err().to_string();
846 assert!(err.contains("not found in zip archive"));
847 }
848
849 #[test]
850 fn test_extract_binary_empty_file() {
851 let dir = tempfile::tempdir().unwrap();
852 let archive_path = dir.path().join("empty");
853 fs::write(&archive_path, b"").unwrap();
854
855 let dest = tempfile::tempdir().unwrap();
856 let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
857 assert!(result.is_err());
858 }
859}