1use crate::error::{Error, Result};
11use crate::logging::{debug, error, info, warn};
12use crate::upgrade::binary_cache::BinaryCache;
13use crate::upgrade::{signature, UpgradeInfo, UpgradeResult};
14use flate2::read::GzDecoder;
15use semver::Version;
16use std::env;
17use std::fs::{self, File};
18use std::io::Read;
19use std::path::{Path, PathBuf};
20use tar::Archive;
21
22pub(super) const MAX_ARCHIVE_SIZE_BYTES: usize = 200 * 1024 * 1024;
24
25pub const RESTART_EXIT_CODE: i32 = 100;
31
32pub struct AutoApplyUpgrader {
34 current_version: Version,
36 client: reqwest::Client,
38 binary_cache: Option<BinaryCache>,
40 stop_on_upgrade: bool,
42}
43
44impl AutoApplyUpgrader {
45 #[must_use]
47 pub fn new() -> Self {
48 let current_version =
49 Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or_else(|_| Version::new(0, 0, 0));
50
51 Self {
52 current_version,
53 client: reqwest::Client::builder()
54 .user_agent(concat!("ant-node/", env!("CARGO_PKG_VERSION")))
55 .timeout(std::time::Duration::from_secs(300))
56 .build()
57 .unwrap_or_else(|_| reqwest::Client::new()),
58 binary_cache: None,
59 stop_on_upgrade: false,
60 }
61 }
62
63 #[must_use]
68 pub fn with_binary_cache(mut self, cache: BinaryCache) -> Self {
69 self.binary_cache = Some(cache);
70 self
71 }
72
73 #[must_use]
78 pub fn with_stop_on_upgrade(mut self, stop: bool) -> Self {
79 self.stop_on_upgrade = stop;
80 self
81 }
82
83 #[must_use]
85 pub fn current_version(&self) -> &Version {
86 &self.current_version
87 }
88
89 pub fn current_binary_path() -> Result<PathBuf> {
99 let invoked_path = env::args().next().map(PathBuf::from);
102
103 if let Some(ref invoked) = invoked_path {
104 let path_str = invoked.to_string_lossy();
109 let cleaned = if path_str.ends_with(" (deleted)") {
110 let stripped = path_str.trim_end_matches(" (deleted)");
111 debug!("Stripped '(deleted)' suffix from invoked path: {stripped}");
112 PathBuf::from(stripped)
113 } else {
114 invoked.clone()
115 };
116
117 if cleaned.exists() {
118 if let Ok(canonical) = cleaned.canonicalize() {
121 return Ok(canonical);
122 }
123 return Ok(cleaned);
124 }
125 }
126
127 let path = env::current_exe()
129 .map_err(|e| Error::Upgrade(format!("Cannot determine binary path: {e}")))?;
130
131 #[cfg(unix)]
132 {
133 let path_str = path.to_string_lossy();
134 if path_str.ends_with(" (deleted)") {
135 let cleaned = path_str.trim_end_matches(" (deleted)");
136 debug!("Stripped '(deleted)' suffix from binary path: {cleaned}");
137 return Ok(PathBuf::from(cleaned));
138 }
139 }
140
141 Ok(path)
142 }
143
144 pub async fn apply_upgrade(&self, info: &UpgradeInfo) -> Result<UpgradeResult> {
159 info!(
160 "Starting auto-apply upgrade from {} to {}",
161 self.current_version, info.version
162 );
163
164 if info.version <= self.current_version {
166 warn!(
167 "Ignoring downgrade attempt: {} -> {}",
168 self.current_version, info.version
169 );
170 return Ok(UpgradeResult::NoUpgrade);
171 }
172
173 let current_binary = Self::current_binary_path()?;
175 let binary_dir = current_binary
176 .parent()
177 .ok_or_else(|| Error::Upgrade("Cannot determine binary directory".to_string()))?;
178
179 let mut tempdir_builder = tempfile::Builder::new();
190 tempdir_builder.prefix("ant-upgrade-");
191 #[cfg(unix)]
192 {
193 use std::os::unix::fs::PermissionsExt;
194 tempdir_builder.permissions(std::fs::Permissions::from_mode(0o700));
195 }
196 let temp_dir = tempdir_builder
197 .tempdir_in(binary_dir)
198 .map_err(|e| Error::Upgrade(format!("Failed to create temp dir: {e}")))?;
199
200 let version_str = info.version.to_string();
201
202 let extracted_binary = match self
205 .resolve_upgrade_binary(info, temp_dir.path(), &version_str)
206 .await
207 {
208 Ok(path) => path,
209 Err(e) => {
210 warn!("Download/verify/extract failed: {e}");
211 return Ok(UpgradeResult::RolledBack {
212 reason: format!("{e}"),
213 });
214 }
215 };
216
217 if let Some(disk_version) = on_disk_version(¤t_binary).await {
220 if disk_version == info.version {
221 info!(
222 "Binary already upgraded to {} by another service, skipping replacement",
223 info.version
224 );
225 let exit_code = self.prepare_restart(¤t_binary)?;
226 return Ok(UpgradeResult::Success {
227 version: info.version.clone(),
228 exit_code,
229 });
230 }
231 }
232
233 let backup_path = binary_dir.join(format!(
235 "{}.backup",
236 current_binary
237 .file_name()
238 .map_or_else(|| "ant-node".into(), |s| s.to_string_lossy())
239 ));
240 info!("Creating backup at {}...", backup_path.display());
241 if let Err(e) = fs::copy(¤t_binary, &backup_path) {
242 warn!("Backup creation failed: {e}");
243 return Ok(UpgradeResult::RolledBack {
244 reason: format!("Backup failed: {e}"),
245 });
246 }
247
248 info!("Replacing binary...");
251 let new_bin = extracted_binary.clone();
252 let target_bin = current_binary.clone();
253 let replace_result =
254 tokio::task::spawn_blocking(move || Self::replace_binary(&new_bin, &target_bin))
255 .await
256 .map_err(|e| Error::Upgrade(format!("Binary replacement task panicked: {e}")))?;
257 if let Err(e) = replace_result {
258 warn!("Binary replacement failed: {e}");
259 if let Err(restore_err) = fs::copy(&backup_path, ¤t_binary) {
261 error!("CRITICAL: Replacement failed ({e}) AND rollback failed ({restore_err})");
262 return Err(Error::Upgrade(format!(
263 "Critical: replacement failed ({e}) AND rollback failed ({restore_err})"
264 )));
265 }
266 return Ok(UpgradeResult::RolledBack {
267 reason: format!("Replacement failed: {e}"),
268 });
269 }
270
271 info!(
272 "Successfully upgraded to version {}! Restarting...",
273 info.version
274 );
275
276 let exit_code = self.prepare_restart(¤t_binary)?;
278
279 Ok(UpgradeResult::Success {
280 version: info.version.clone(),
281 exit_code,
282 })
283 }
284
285 async fn download(&self, url: &str, dest: &Path) -> Result<()> {
287 debug!("Downloading: {}", url);
288
289 let response = self
290 .client
291 .get(url)
292 .send()
293 .await
294 .map_err(|e| Error::Network(format!("Download failed: {e}")))?;
295
296 if !response.status().is_success() {
297 return Err(Error::Network(format!(
298 "Download returned status: {}",
299 response.status()
300 )));
301 }
302
303 let bytes = response
304 .bytes()
305 .await
306 .map_err(|e| Error::Network(format!("Failed to read response: {e}")))?;
307
308 if bytes.len() > MAX_ARCHIVE_SIZE_BYTES {
309 return Err(Error::Upgrade(format!(
310 "Downloaded file too large: {} bytes (max {})",
311 bytes.len(),
312 MAX_ARCHIVE_SIZE_BYTES
313 )));
314 }
315
316 fs::write(dest, &bytes)?;
317 debug!("Downloaded {} bytes to {}", bytes.len(), dest.display());
318 Ok(())
319 }
320
321 async fn resolve_upgrade_binary(
329 &self,
330 info: &UpgradeInfo,
331 dest_dir: &Path,
332 version_str: &str,
333 ) -> Result<PathBuf> {
334 if let Some(ref cache) = self.binary_cache {
335 if let Some(verified_archive) = cache.get_verified_archive(version_str, dest_dir) {
343 match Self::extract_binary(&verified_archive, dest_dir) {
344 Ok(binary) => {
345 info!("Reused signature-verified cached archive for {version_str}");
346 return Ok(binary);
347 }
348 Err(e) => {
349 warn!("Failed to extract from cached archive, will re-download: {e}");
350 return self
351 .download_verify_extract(info, dest_dir, Some(cache))
352 .await;
353 }
354 }
355 }
356
357 let cache_clone = cache.clone();
360 let lock_guard =
364 tokio::task::spawn_blocking(move || cache_clone.acquire_download_lock())
365 .await
366 .map_err(|e| Error::Upgrade(format!("Lock task failed: {e}")))??;
367
368 if let Some(verified_archive) = cache.get_verified_archive(version_str, dest_dir) {
371 if let Ok(binary) = Self::extract_binary(&verified_archive, dest_dir) {
372 info!(
373 "Signature-verified cached archive became available under lock for {version_str}"
374 );
375 return Ok(binary);
376 }
377 }
378
379 let result = self
381 .download_verify_extract(info, dest_dir, Some(cache))
382 .await;
383 drop(lock_guard);
384 result
385 } else {
386 self.download_verify_extract(info, dest_dir, None).await
387 }
388 }
389
390 async fn download_verify_extract(
398 &self,
399 info: &UpgradeInfo,
400 dest_dir: &Path,
401 cache: Option<&BinaryCache>,
402 ) -> Result<PathBuf> {
403 let archive_path = dest_dir.join("archive");
404 let sig_path = dest_dir.join("signature");
405
406 info!("Downloading ant-node binary...");
408 self.download(&info.download_url, &archive_path).await?;
409
410 info!("Downloading signature...");
412 self.download(&info.signature_url, &sig_path).await?;
413
414 info!("Verifying ML-DSA signature on archive...");
416 signature::verify_from_file(&archive_path, &sig_path)?;
417 info!("Archive signature verified successfully");
418
419 info!("Extracting binary from archive...");
421 let extracted_binary = Self::extract_binary(&archive_path, dest_dir)?;
422
423 if let Some(c) = cache {
432 let version_str = info.version.to_string();
433 if let Err(e) = c.store_archive(&version_str, &archive_path, &sig_path) {
434 warn!("Failed to store verified archive in cache: {e}");
435 }
436 }
437
438 Ok(extracted_binary)
439 }
440
441 fn extract_binary(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
447 let mut file = File::open(archive_path)?;
448 let mut magic = [0u8; 2];
449 file.read_exact(&mut magic)
450 .map_err(|e| Error::Upgrade(format!("Failed to read archive header: {e}")))?;
451 drop(file);
452
453 match magic {
454 [0x1f, 0x8b] => Self::extract_from_tar_gz(archive_path, dest_dir),
455 [0x50, 0x4b] => Self::extract_from_zip(archive_path, dest_dir),
456 _ => Err(Error::Upgrade(format!(
457 "Unknown archive format (magic bytes: {:02x} {:02x})",
458 magic[0], magic[1]
459 ))),
460 }
461 }
462
463 fn extract_from_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
465 let file = File::open(archive_path)?;
466 let decoder = GzDecoder::new(file);
467 let mut archive = Archive::new(decoder);
468
469 let binary_name = if cfg!(windows) {
470 "ant-node.exe"
471 } else {
472 "ant-node"
473 };
474 let extracted_binary = dest_dir.join(binary_name);
475
476 for entry in archive
477 .entries()
478 .map_err(|e| Error::Upgrade(format!("Failed to read archive: {e}")))?
479 {
480 let mut entry =
481 entry.map_err(|e| Error::Upgrade(format!("Failed to read entry: {e}")))?;
482 let path = entry
483 .path()
484 .map_err(|e| Error::Upgrade(format!("Invalid path in archive: {e}")))?;
485
486 if let Some(name) = path.file_name() {
488 let name_str = name.to_string_lossy();
489 if name_str == "ant-node" || name_str == "ant-node.exe" {
490 debug!("Found binary in tar.gz archive: {}", path.display());
491
492 let mut out = File::create(&extracted_binary)?;
494 std::io::copy(&mut entry, &mut out)
495 .map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
496
497 #[cfg(unix)]
499 {
500 use std::os::unix::fs::PermissionsExt;
501 let mut perms = fs::metadata(&extracted_binary)?.permissions();
502 perms.set_mode(0o755);
503 fs::set_permissions(&extracted_binary, perms)?;
504 }
505
506 return Ok(extracted_binary);
507 }
508 }
509 }
510
511 Err(Error::Upgrade(
512 "ant-node binary not found in tar.gz archive".to_string(),
513 ))
514 }
515
516 fn extract_from_zip(archive_path: &Path, dest_dir: &Path) -> Result<PathBuf> {
518 let file = File::open(archive_path)?;
519 let mut archive = zip::ZipArchive::new(file)
520 .map_err(|e| Error::Upgrade(format!("Failed to open zip archive: {e}")))?;
521
522 let binary_name = if cfg!(windows) {
523 "ant-node.exe"
524 } else {
525 "ant-node"
526 };
527 let extracted_binary = dest_dir.join(binary_name);
528
529 for i in 0..archive.len() {
530 let mut entry = archive
531 .by_index(i)
532 .map_err(|e| Error::Upgrade(format!("Failed to read zip entry: {e}")))?;
533
534 let path = match entry.enclosed_name() {
535 Some(p) => p.clone(),
536 None => continue,
537 };
538
539 if let Some(name) = path.file_name() {
540 let name_str = name.to_string_lossy();
541 if name_str == "ant-node" || name_str == "ant-node.exe" {
542 debug!("Found binary in zip archive: {}", path.display());
543
544 let mut out = File::create(&extracted_binary)?;
546 std::io::copy(&mut entry, &mut out)
547 .map_err(|e| Error::Upgrade(format!("Failed to write binary: {e}")))?;
548
549 #[cfg(unix)]
551 {
552 use std::os::unix::fs::PermissionsExt;
553 let mut perms = fs::metadata(&extracted_binary)?.permissions();
554 perms.set_mode(0o755);
555 fs::set_permissions(&extracted_binary, perms)?;
556 }
557
558 return Ok(extracted_binary);
559 }
560 }
561 }
562
563 Err(Error::Upgrade(
564 "ant-node binary not found in zip archive".to_string(),
565 ))
566 }
567
568 fn replace_binary(new_binary: &Path, target: &Path) -> Result<()> {
573 #[cfg(unix)]
574 {
575 if let Ok(meta) = fs::metadata(target) {
577 let perms = meta.permissions();
578 fs::set_permissions(new_binary, perms)?;
579 }
580 fs::rename(new_binary, target)?;
582 }
583
584 #[cfg(windows)]
585 {
586 let _ = target; let delays = [500u64, 1000, 2000];
589 let mut last_err = None;
590 for (attempt, delay_ms) in delays.iter().enumerate() {
591 match self_replace::self_replace(new_binary) {
592 Ok(()) => {
593 last_err = None;
594 break;
595 }
596 Err(e) => {
597 warn!(
598 "self_replace attempt {} failed: {e}, retrying in {delay_ms}ms",
599 attempt + 1
600 );
601 last_err = Some(e);
602 std::thread::sleep(std::time::Duration::from_millis(*delay_ms));
603 }
604 }
605 }
606 if let Some(e) = last_err {
607 return Err(Error::Upgrade(format!(
608 "self_replace failed after retries: {e}"
609 )));
610 }
611 }
612
613 debug!("Binary replacement complete");
614 Ok(())
615 }
616
617 fn prepare_restart(&self, binary_path: &Path) -> Result<i32> {
632 if self.stop_on_upgrade {
633 let exit_code;
634
635 #[cfg(unix)]
636 {
637 info!("Service manager mode: will exit with code 0 after graceful shutdown");
638 exit_code = 0;
639 }
640
641 #[cfg(windows)]
642 {
643 let _ = binary_path;
644 info!(
645 "Service manager mode: will exit with code {} after graceful shutdown",
646 RESTART_EXIT_CODE
647 );
648 exit_code = RESTART_EXIT_CODE;
649 }
650
651 #[cfg(not(any(unix, windows)))]
652 {
653 let _ = binary_path;
654 warn!("Auto-restart not supported on this platform. Please restart manually.");
655 exit_code = 0;
656 }
657
658 Ok(exit_code)
659 } else {
660 let args: Vec<String> = env::args().skip(1).collect();
662
663 info!("Spawning new process: {} {:?}", binary_path.display(), args);
664
665 std::process::Command::new(binary_path)
666 .args(&args)
667 .stdin(std::process::Stdio::null())
668 .stdout(std::process::Stdio::inherit())
669 .stderr(std::process::Stdio::inherit())
670 .spawn()
671 .map_err(|e| Error::Upgrade(format!("Failed to spawn new binary: {e}")))?;
672
673 info!("New process spawned, will exit after graceful shutdown");
674 Ok(0)
675 }
676 }
677}
678
679async fn on_disk_version(binary_path: &Path) -> Option<Version> {
687 let output = tokio::time::timeout(
688 std::time::Duration::from_secs(5),
689 tokio::process::Command::new(binary_path)
690 .arg("--version")
691 .output(),
692 )
693 .await
694 .ok()?
695 .ok()?;
696 let stdout = String::from_utf8_lossy(&output.stdout);
697 let version_str = stdout.trim().strip_prefix("ant-node ")?;
698 Version::parse(version_str).ok()
699}
700
701impl Default for AutoApplyUpgrader {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707#[cfg(test)]
708#[allow(clippy::unwrap_used, clippy::expect_used)]
709mod tests {
710 use super::*;
711
712 #[test]
713 fn test_auto_apply_upgrader_creation() {
714 let upgrader = AutoApplyUpgrader::new();
715 assert!(!upgrader.current_version().to_string().is_empty());
716 }
717
718 #[test]
719 fn test_current_binary_path() {
720 let result = AutoApplyUpgrader::current_binary_path();
721 assert!(result.is_ok());
722 let path = result.unwrap();
723 assert!(path.exists() || path.to_string_lossy().contains("test"));
724 }
725
726 #[test]
727 fn test_default_impl() {
728 let upgrader = AutoApplyUpgrader::default();
729 assert!(!upgrader.current_version().to_string().is_empty());
730 }
731
732 fn create_tar_gz_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
734 use flate2::write::GzEncoder;
735 use flate2::Compression;
736
737 let archive_path = dir.join("test.tar.gz");
738 let file = File::create(&archive_path).unwrap();
739 let encoder = GzEncoder::new(file, Compression::default());
740 let mut builder = tar::Builder::new(encoder);
741
742 let mut header = tar::Header::new_gnu();
743 header.set_size(content.len() as u64);
744 header.set_mode(0o755);
745 header.set_cksum();
746 builder
747 .append_data(&mut header, binary_name, content)
748 .unwrap();
749 builder.finish().unwrap();
750
751 archive_path
752 }
753
754 fn create_zip_archive(dir: &Path, binary_name: &str, content: &[u8]) -> PathBuf {
756 use std::io::Write;
757
758 let archive_path = dir.join("test.zip");
759 let file = File::create(&archive_path).unwrap();
760 let mut zip_writer = zip::ZipWriter::new(file);
761 let options = zip::write::SimpleFileOptions::default()
762 .compression_method(zip::CompressionMethod::Stored);
763 zip_writer.start_file(binary_name, options).unwrap();
764 zip_writer.write_all(content).unwrap();
765 zip_writer.finish().unwrap();
766
767 archive_path
768 }
769
770 #[test]
771 fn test_extract_binary_from_tar_gz() {
772 let dir = tempfile::tempdir().unwrap();
773 let content = b"fake-binary-content";
774 let archive = create_tar_gz_archive(dir.path(), "ant-node", content);
775
776 let dest = tempfile::tempdir().unwrap();
777 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
778 assert!(result.is_ok());
779
780 let extracted = result.unwrap();
781 assert!(extracted.exists());
782 assert_eq!(fs::read(&extracted).unwrap(), content);
783 }
784
785 #[test]
786 fn test_extract_binary_from_zip() {
787 let dir = tempfile::tempdir().unwrap();
788 let content = b"fake-binary-content";
789 let archive = create_zip_archive(dir.path(), "ant-node", content);
790
791 let dest = tempfile::tempdir().unwrap();
792 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
793 assert!(result.is_ok());
794
795 let extracted = result.unwrap();
796 assert!(extracted.exists());
797 assert_eq!(fs::read(&extracted).unwrap(), content);
798 }
799
800 #[test]
801 fn test_extract_binary_from_zip_with_exe() {
802 let dir = tempfile::tempdir().unwrap();
803 let content = b"fake-windows-binary";
804 let archive = create_zip_archive(dir.path(), "ant-node.exe", content);
805
806 let dest = tempfile::tempdir().unwrap();
807 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
808 assert!(result.is_ok());
809
810 let extracted = result.unwrap();
811 assert!(extracted.exists());
812 assert_eq!(fs::read(&extracted).unwrap(), content);
813 }
814
815 #[test]
816 fn test_extract_binary_from_tar_gz_nested_path() {
817 let dir = tempfile::tempdir().unwrap();
818 let content = b"nested-binary";
819 let archive = create_tar_gz_archive(dir.path(), "some/nested/path/ant-node", content);
820
821 let dest = tempfile::tempdir().unwrap();
822 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
823 assert!(result.is_ok());
824
825 let extracted = result.unwrap();
826 assert!(extracted.exists());
827 assert_eq!(fs::read(&extracted).unwrap(), content);
828 }
829
830 #[test]
831 fn test_extract_binary_unknown_format() {
832 let dir = tempfile::tempdir().unwrap();
833 let archive_path = dir.path().join("bad_archive");
834 fs::write(&archive_path, b"XX not a real archive").unwrap();
835
836 let dest = tempfile::tempdir().unwrap();
837 let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
838 assert!(result.is_err());
839
840 let err = result.unwrap_err().to_string();
841 assert!(err.contains("Unknown archive format"));
842 }
843
844 #[test]
845 fn test_extract_binary_missing_binary_in_tar_gz() {
846 let dir = tempfile::tempdir().unwrap();
847 let content = b"not-the-binary";
848 let archive = create_tar_gz_archive(dir.path(), "other-file", content);
849
850 let dest = tempfile::tempdir().unwrap();
851 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
852 assert!(result.is_err());
853
854 let err = result.unwrap_err().to_string();
855 assert!(err.contains("not found in tar.gz archive"));
856 }
857
858 #[test]
859 fn test_extract_binary_missing_binary_in_zip() {
860 let dir = tempfile::tempdir().unwrap();
861 let content = b"not-the-binary";
862 let archive = create_zip_archive(dir.path(), "other-file", content);
863
864 let dest = tempfile::tempdir().unwrap();
865 let result = AutoApplyUpgrader::extract_binary(&archive, dest.path());
866 assert!(result.is_err());
867
868 let err = result.unwrap_err().to_string();
869 assert!(err.contains("not found in zip archive"));
870 }
871
872 #[test]
873 fn test_extract_binary_empty_file() {
874 let dir = tempfile::tempdir().unwrap();
875 let archive_path = dir.path().join("empty");
876 fs::write(&archive_path, b"").unwrap();
877
878 let dest = tempfile::tempdir().unwrap();
879 let result = AutoApplyUpgrader::extract_binary(&archive_path, dest.path());
880 assert!(result.is_err());
881 }
882}