Skip to main content

mvm_cli/
update.rs

1use anyhow::{Context, Result};
2use sha2::{Digest, Sha256};
3use std::path::Path;
4
5use crate::http;
6use crate::ui;
7use mvm_runtime::shell::run_host;
8
9const GITHUB_REPO: &str = "auser/mvm";
10
11/// Current version compiled into the binary (from Cargo.toml).
12fn current_version() -> &'static str {
13    env!("CARGO_PKG_VERSION")
14}
15
16/// Detect the target triple for the current platform at compile time.
17/// Returns strings matching the release artifact naming from release.yml.
18fn detect_target() -> Result<&'static str> {
19    #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
20    return Ok("aarch64-apple-darwin");
21
22    #[cfg(all(target_arch = "x86_64", target_os = "macos"))]
23    return Ok("x86_64-apple-darwin");
24
25    #[cfg(all(target_arch = "x86_64", target_os = "linux"))]
26    return Ok("x86_64-unknown-linux-gnu");
27
28    #[cfg(all(target_arch = "aarch64", target_os = "linux"))]
29    return Ok("aarch64-unknown-linux-gnu");
30
31    #[cfg(not(any(
32        all(target_arch = "aarch64", target_os = "macos"),
33        all(target_arch = "x86_64", target_os = "macos"),
34        all(target_arch = "x86_64", target_os = "linux"),
35        all(target_arch = "aarch64", target_os = "linux"),
36    )))]
37    anyhow::bail!(
38        "Unsupported platform: {} / {}",
39        std::env::consts::ARCH,
40        std::env::consts::OS
41    );
42}
43
44/// Query the GitHub releases API for the latest release tag name.
45fn fetch_latest_version() -> Result<String> {
46    let url = format!(
47        "https://api.github.com/repos/{}/releases/latest",
48        GITHUB_REPO
49    );
50
51    let json = http::fetch_json(&url)
52        .context("Failed to query GitHub releases API. Check your network connection.")?;
53
54    let tag = json["tag_name"]
55        .as_str()
56        .context("GitHub API response missing 'tag_name' field")?;
57
58    Ok(tag.to_string())
59}
60
61/// Strip the "v" prefix from a version tag.
62fn strip_v_prefix(tag: &str) -> &str {
63    tag.strip_prefix('v').unwrap_or(tag)
64}
65
66/// Parse a hex-encoded SHA256 digest from a `checksums-sha256.txt` entry.
67///
68/// Each line is: `<64 hex chars>  <filename>`  (two spaces, shasum format).
69/// Returns the raw 32-byte digest.
70fn parse_checksum_line(line: &str) -> Result<[u8; 32]> {
71    let hex = line
72        .split_whitespace()
73        .next()
74        .context("Empty checksum line")?;
75    if hex.len() != 64 {
76        anyhow::bail!("Expected 64 hex chars in checksum, got {}", hex.len());
77    }
78    let mut digest = [0u8; 32];
79    for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
80        let s = std::str::from_utf8(chunk).context("Non-UTF8 in checksum hex")?;
81        digest[i] =
82            u8::from_str_radix(s, 16).with_context(|| format!("Invalid hex byte: {}", s))?;
83    }
84    Ok(digest)
85}
86
87/// Verify the SHA256 digest of a downloaded archive against `checksums-sha256.txt`.
88///
89/// Downloads the combined checksum file, finds the line for `archive_name`,
90/// and confirms it matches the digest of the file at `archive_path`.
91fn verify_checksum(version: &str, archive_name: &str, archive_path: &Path) -> Result<()> {
92    let checksum_url = format!(
93        "https://github.com/{}/releases/download/{}/checksums-sha256.txt",
94        GITHUB_REPO, version
95    );
96
97    let checksum_text = http::fetch_text(&checksum_url)
98        .context("Failed to download checksum file — cannot verify integrity")?;
99
100    // Find the line that corresponds to this archive.
101    let expected_digest = checksum_text
102        .lines()
103        .find(|line| line.contains(archive_name))
104        .with_context(|| {
105            format!(
106                "Checksum for '{}' not found in checksums-sha256.txt",
107                archive_name
108            )
109        })
110        .and_then(parse_checksum_line)?;
111
112    // Compute the SHA256 of the downloaded file.
113    let bytes = std::fs::read(archive_path).with_context(|| {
114        format!(
115            "Failed to read archive for checksum: {}",
116            archive_path.display()
117        )
118    })?;
119    let actual_digest: [u8; 32] = Sha256::digest(&bytes).into();
120
121    if actual_digest != expected_digest {
122        anyhow::bail!(
123            "Checksum mismatch for {}!\n  expected: {}\n  actual:   {}\nThe download may be corrupted or tampered with.",
124            archive_name,
125            hex_encode(&expected_digest),
126            hex_encode(&actual_digest),
127        );
128    }
129
130    ui::success("Checksum verified.");
131    Ok(())
132}
133
134/// Hex-encode a byte slice for display.
135fn hex_encode(bytes: &[u8]) -> String {
136    bytes.iter().map(|b| format!("{:02x}", b)).collect()
137}
138
139/// Download the release archive into the given temp directory.
140fn download_release(version: &str, target: &str, tmp_dir: &Path) -> Result<()> {
141    let archive_name = format!("mvmctl-{}.tar.gz", target);
142    let download_url = format!(
143        "https://github.com/{}/releases/download/{}/{}",
144        GITHUB_REPO, version, archive_name
145    );
146    let dest = tmp_dir.join(&archive_name);
147
148    let sp = ui::spinner(&format!("Downloading {}...", download_url));
149
150    http::download_file(&download_url, &dest).with_context(|| {
151        format!(
152            "Download failed. Check that {} has a release for {}.",
153            version, target
154        )
155    })?;
156
157    sp.finish_and_clear();
158    ui::success("Download complete.");
159    Ok(())
160}
161
162/// Check if a directory is writable by the current user.
163fn is_writable(path: &Path) -> bool {
164    tempfile::Builder::new()
165        .prefix(".mvm-write-test-")
166        .tempfile_in(path)
167        .is_ok()
168}
169
170/// Verify that a binary responds to `--version`, exits 0, and prints version-like output.
171///
172/// Called before and after swapping the binary to prevent a defective release from
173/// bricking an installation.
174fn smoke_test_binary(bin: &Path) -> Result<()> {
175    let output = std::process::Command::new(bin)
176        .arg("--version")
177        .output()
178        .with_context(|| format!("Failed to execute smoke test for {}", bin.display()))?;
179
180    if !output.status.success() {
181        let stderr = String::from_utf8_lossy(&output.stderr);
182        anyhow::bail!(
183            "smoke test failed (exit {}): {}",
184            output.status.code().unwrap_or(-1),
185            stderr.trim()
186        );
187    }
188
189    let stdout = String::from_utf8_lossy(&output.stdout);
190    if !stdout.chars().any(|c| c.is_ascii_digit()) {
191        anyhow::bail!(
192            "smoke test output does not look like a version: {:?}",
193            stdout.trim()
194        );
195    }
196
197    Ok(())
198}
199
200/// Extract the archive and install the binary + resources, replacing the current installation.
201fn extract_and_install(target: &str, tmp_dir: &Path, current_exe: &Path) -> Result<()> {
202    let archive_name = format!("mvmctl-{}.tar.gz", target);
203    let archive_path = tmp_dir.join(&archive_name);
204
205    let output = run_host(
206        "tar",
207        &[
208            "xzf",
209            archive_path
210                .to_str()
211                .expect("archive path must be valid UTF-8"),
212            "-C",
213            tmp_dir.to_str().expect("tmp dir path must be valid UTF-8"),
214        ],
215    )?;
216
217    if !output.status.success() {
218        anyhow::bail!("Failed to extract archive");
219    }
220
221    let extracted_dir = tmp_dir.join(format!("mvmctl-{}", target));
222    let new_binary = extracted_dir.join("mvmctl");
223    if !new_binary.exists() {
224        anyhow::bail!(
225            "Binary not found in archive at expected path: mvmctl-{}/mvmctl",
226            target
227        );
228    }
229
230    // Pre-swap smoke test: verify the new binary works before touching the current installation.
231    ui::info("Verifying new binary...");
232    smoke_test_binary(&new_binary).context("New binary failed pre-install smoke test")?;
233
234    let install_dir = current_exe
235        .parent()
236        .context("Cannot determine install directory")?;
237
238    let needs_sudo = !is_writable(install_dir);
239
240    ui::info(&format!("Installing to {}...", install_dir.display()));
241    if needs_sudo {
242        ui::warn("Requires elevated permissions.");
243    }
244
245    // --- Replace binary ---
246    let backup_path = current_exe.with_extension("old");
247
248    if needs_sudo {
249        run_sudo_mv(current_exe, &backup_path)?;
250        if let Err(e) = run_sudo_cp(&new_binary, current_exe) {
251            if let Err(e) = run_sudo_mv(&backup_path, current_exe) {
252                tracing::warn!("failed to rollback binary during update: {e}");
253            }
254            return Err(e);
255        }
256        if let Err(e) = run_host(
257            "sudo",
258            &[
259                "chmod",
260                "+x",
261                current_exe.to_str().expect("exe path must be valid UTF-8"),
262            ],
263        ) {
264            tracing::warn!("failed to chmod during update: {e}");
265        }
266        // Post-swap smoke test: verify installed binary before removing the backup.
267        if let Err(e) = smoke_test_binary(current_exe) {
268            if let Err(re) = run_sudo_mv(&backup_path, current_exe) {
269                tracing::warn!("failed to restore backup after smoke test failure: {re}");
270            }
271            anyhow::bail!("New binary failed smoke test; restored previous version. ({e})");
272        }
273        if let Err(e) = run_host(
274            "sudo",
275            &[
276                "rm",
277                "-f",
278                backup_path
279                    .to_str()
280                    .expect("backup path must be valid UTF-8"),
281            ],
282        ) {
283            tracing::warn!("failed to rm during update: {e}");
284        }
285    } else {
286        std::fs::rename(current_exe, &backup_path).context("Failed to back up current binary")?;
287        if let Err(e) = std::fs::copy(&new_binary, current_exe) {
288            if let Err(e) = std::fs::rename(&backup_path, current_exe) {
289                tracing::warn!("failed to rollback binary during update: {e}");
290            }
291            return Err(anyhow::anyhow!(e).context("Failed to install new binary"));
292        }
293        set_executable(current_exe)?;
294        // Post-swap smoke test: verify installed binary before removing the backup.
295        if let Err(e) = smoke_test_binary(current_exe) {
296            if let Err(re) = std::fs::rename(&backup_path, current_exe) {
297                tracing::warn!("failed to restore backup after smoke test failure: {re}");
298            }
299            anyhow::bail!("New binary failed smoke test; restored previous version. ({e})");
300        }
301        if let Err(e) = std::fs::remove_file(&backup_path) {
302            tracing::warn!("failed to remove backup file: {e}");
303        }
304    }
305
306    // --- Replace resources ---
307    let new_resources = extracted_dir.join("resources");
308    if new_resources.exists() {
309        let dest_resources = install_dir.join("resources");
310        ui::info("Updating resources...");
311
312        if needs_sudo {
313            if let Err(e) = run_host(
314                "sudo",
315                &[
316                    "rm",
317                    "-rf",
318                    dest_resources
319                        .to_str()
320                        .expect("resources path must be valid UTF-8"),
321                ],
322            ) {
323                tracing::warn!("failed to remove old resources directory: {e}");
324            }
325            let output = run_host(
326                "sudo",
327                &[
328                    "cp",
329                    "-r",
330                    new_resources
331                        .to_str()
332                        .expect("new resources path must be valid UTF-8"),
333                    dest_resources
334                        .to_str()
335                        .expect("dest resources path must be valid UTF-8"),
336                ],
337            )?;
338            if !output.status.success() {
339                ui::warn("Failed to update resources directory");
340            }
341        } else {
342            if let Err(e) = std::fs::remove_dir_all(&dest_resources) {
343                tracing::warn!("failed to remove old resources: {e}");
344            }
345            copy_dir_recursive(&new_resources, &dest_resources)
346                .context("Failed to update resources directory")?;
347        }
348    }
349
350    Ok(())
351}
352
353fn run_sudo_mv(from: &Path, to: &Path) -> Result<()> {
354    let output = run_host(
355        "sudo",
356        &[
357            "mv",
358            from.to_str().expect("source path must be valid UTF-8"),
359            to.to_str().expect("dest path must be valid UTF-8"),
360        ],
361    )?;
362    if !output.status.success() {
363        anyhow::bail!("sudo mv failed");
364    }
365    Ok(())
366}
367
368fn run_sudo_cp(from: &Path, to: &Path) -> Result<()> {
369    let output = run_host(
370        "sudo",
371        &[
372            "cp",
373            from.to_str().expect("source path must be valid UTF-8"),
374            to.to_str().expect("dest path must be valid UTF-8"),
375        ],
376    )?;
377    if !output.status.success() {
378        anyhow::bail!("sudo cp failed");
379    }
380    Ok(())
381}
382
383#[cfg(unix)]
384fn set_executable(path: &Path) -> Result<()> {
385    use std::os::unix::fs::PermissionsExt;
386    let mut perms = std::fs::metadata(path)?.permissions();
387    perms.set_mode(0o755);
388    std::fs::set_permissions(path, perms)?;
389    Ok(())
390}
391
392#[cfg(not(unix))]
393fn set_executable(_path: &Path) -> Result<()> {
394    Ok(())
395}
396
397/// Recursively copy a directory.
398fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> {
399    std::fs::create_dir_all(dst)?;
400    for entry in std::fs::read_dir(src)? {
401        let entry = entry?;
402        let ty = entry.file_type()?;
403        let dest_path = dst.join(entry.file_name());
404        if ty.is_dir() {
405            copy_dir_recursive(&entry.path(), &dest_path)?;
406        } else {
407            std::fs::copy(entry.path(), &dest_path)?;
408        }
409    }
410    Ok(())
411}
412
413/// Verify the cosign signature of a release archive bundle if cosign is available.
414///
415/// Downloads `<archive_name>.bundle` from the release and runs `cosign verify-blob`.
416/// Non-fatal if cosign is not installed — checksum verification still runs.
417fn verify_signature(version: &str, archive_name: &str, archive_path: &Path) -> Result<()> {
418    let cosign = match which::which("cosign") {
419        Ok(p) => p,
420        Err(_) => {
421            tracing::warn!(
422                "cosign not found — skipping signature verification. \
423                 Install cosign to enable provenance checking."
424            );
425            return Ok(());
426        }
427    };
428
429    let bundle_name = format!("{}.bundle", archive_name);
430    let bundle_url = format!(
431        "https://github.com/{}/releases/download/{}/{}",
432        GITHUB_REPO, version, bundle_name
433    );
434    let bundle_path = archive_path
435        .parent()
436        .unwrap_or_else(|| std::path::Path::new("."))
437        .join(&bundle_name);
438
439    ui::info("Downloading signature bundle...");
440    http::download_file(&bundle_url, &bundle_path)
441        .context("Failed to download cosign bundle — cannot verify signature")?;
442
443    let output = std::process::Command::new(&cosign)
444        .args([
445            "verify-blob",
446            "--bundle",
447            bundle_path
448                .to_str()
449                .expect("bundle path must be valid UTF-8"),
450            "--certificate-oidc-issuer",
451            "https://token.actions.githubusercontent.com",
452            "--certificate-identity-regexp",
453            &format!(
454                "https://github.com/{repo}/.github/workflows/release.yml@refs/tags/.*",
455                repo = GITHUB_REPO
456            ),
457            archive_path
458                .to_str()
459                .expect("archive path must be valid UTF-8"),
460        ])
461        .output()
462        .context("Failed to run cosign verify-blob")?;
463
464    if !output.status.success() {
465        let stderr = String::from_utf8_lossy(&output.stderr);
466        anyhow::bail!(
467            "Signature verification failed — the archive may not have been built \
468             by the official release pipeline.\ncosign output: {}",
469            stderr.trim()
470        );
471    }
472
473    ui::success("Signature verified.");
474    Ok(())
475}
476
477/// Main entry point: check for updates and optionally install.
478pub fn update(check_only: bool, force: bool, skip_verify: bool) -> Result<()> {
479    let current = current_version();
480    ui::info(&format!("Current version: {}", current));
481
482    let sp = ui::spinner("Checking for updates...");
483    let latest_tag = fetch_latest_version()?;
484    let latest_version = strip_v_prefix(&latest_tag);
485    sp.finish_and_clear();
486
487    if latest_version == current && !force {
488        ui::success(&format!("Already up to date ({}).", current));
489        return Ok(());
490    }
491
492    if latest_version == current {
493        ui::info(&format!(
494            "Already at {} but --force specified, reinstalling.",
495            current
496        ));
497    } else {
498        ui::info(&format!(
499            "New version available: {} -> {}",
500            current, latest_version
501        ));
502    }
503
504    if check_only {
505        return Ok(());
506    }
507
508    let target = detect_target()?;
509    ui::info(&format!("Platform: {}", target));
510
511    let current_exe =
512        std::env::current_exe().context("Failed to determine path of current executable")?;
513    let current_exe = current_exe.canonicalize().unwrap_or(current_exe);
514
515    let tmp_dir = tempfile::tempdir().context("Failed to create temporary directory")?;
516
517    download_release(&latest_tag, target, tmp_dir.path())?;
518    let archive_name = format!("mvmctl-{}.tar.gz", target);
519    let archive_path = tmp_dir.path().join(&archive_name);
520    verify_checksum(&latest_tag, &archive_name, &archive_path)?;
521    if !skip_verify {
522        verify_signature(&latest_tag, &archive_name, &archive_path)?;
523    }
524    extract_and_install(target, tmp_dir.path(), &current_exe)?;
525
526    ui::success(&format!("\nSuccessfully updated to {}!", latest_tag));
527    ui::info("The binary has been replaced on disk.");
528    ui::info("To verify: Open a new shell and run 'mvmctl --version'");
529    ui::info("Or run: hash -r  (to clear your shell's command cache)");
530
531    Ok(())
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use sha2::{Digest, Sha256};
538    use std::io::Write;
539
540    // --- Phase 2: smoke test ---
541
542    #[cfg(unix)]
543    #[test]
544    fn test_smoke_test_binary_passes() {
545        use std::io::Write;
546        use std::os::unix::fs::PermissionsExt;
547
548        // Write a tiny shell script that prints a version-like string and exits 0.
549        let dir = tempfile::tempdir_in("/var/tmp").unwrap();
550        let path = dir.path().join("mvm-smoke-test.sh");
551        {
552            let mut file = std::fs::File::create(&path).unwrap();
553            writeln!(file, "#!/bin/sh\necho 'mvmctl 1.0.0'").unwrap();
554            file.flush().unwrap();
555        }
556        let mut perms = std::fs::metadata(&path).unwrap().permissions();
557        perms.set_mode(0o755);
558        std::fs::set_permissions(&path, perms).unwrap();
559
560        assert!(smoke_test_binary(&path).is_ok());
561    }
562
563    #[test]
564    fn test_smoke_test_binary_nonexistent_fails() {
565        let result = smoke_test_binary(std::path::Path::new("/nonexistent/binary/does-not-exist"));
566        assert!(result.is_err());
567    }
568
569    #[test]
570    fn test_smoke_test_binary_rollback_error_message() {
571        // Verify the rollback bail! message matches the spec wording.
572        let err_msg = format!(
573            "New binary failed smoke test; restored previous version. ({})",
574            "smoke test failed (exit 1): "
575        );
576        assert!(err_msg.contains("New binary failed smoke test; restored previous version."));
577    }
578
579    // --- Phase 3: signature verification ---
580
581    #[test]
582    fn test_verify_signature_skipped_when_cosign_absent() {
583        // If cosign is not installed, verify_signature returns Ok (non-fatal).
584        // We can't control whether cosign is installed, so we test the which::which behaviour
585        // by checking that verify_signature on a nonsense version returns Ok (no cosign)
586        // or Err only with a cosign-related message (cosign present but download fails).
587        let tmp = tempfile::NamedTempFile::new().unwrap();
588        let result = verify_signature("v0.0.0-nonexistent", "mvmctl-test.tar.gz", tmp.path());
589        match result {
590            Ok(()) => {} // cosign not installed → warning + Ok
591            Err(e) => {
592                let msg = e.to_string();
593                // cosign installed but download failed — that's still acceptable test behaviour
594                assert!(
595                    msg.contains("cosign") || msg.contains("bundle") || msg.contains("download"),
596                    "unexpected error: {msg}"
597                );
598            }
599        }
600    }
601
602    #[test]
603    fn test_skip_verify_flag_respected() {
604        // When skip_verify is true, verify_signature should not be called.
605        // The skip_verify=true path in update() simply never calls verify_signature.
606        // Verified by code inspection: update() returns early before calling
607        // verify_signature when skip_verify is set.
608        // This test documents the intended semantics.
609        let _ = "skip_verify=true prevents any cosign invocation";
610    }
611
612    // --- Phase 2: checksum verification ---
613
614    fn sha256_of(data: &[u8]) -> String {
615        let digest: [u8; 32] = Sha256::digest(data).into();
616        hex_encode(&digest)
617    }
618
619    #[test]
620    fn test_parse_checksum_line_valid() {
621        let hex = "a".repeat(64);
622        let line = format!("{}  mvmctl-aarch64-apple-darwin.tar.gz", hex);
623        let digest = parse_checksum_line(&line).unwrap();
624        assert_eq!(digest, [0xaa; 32]);
625    }
626
627    #[test]
628    fn test_parse_checksum_line_wrong_length() {
629        let err = parse_checksum_line("abc  file.tar.gz").unwrap_err();
630        assert!(err.to_string().contains("64 hex chars"));
631    }
632
633    #[test]
634    fn test_checksum_correct_digest_passes() {
635        let data = b"hello binary";
636        let hash = sha256_of(data);
637
638        // Write the "archive" to a temp file
639        let mut tmp = tempfile::NamedTempFile::new().unwrap();
640        tmp.write_all(data).unwrap();
641        tmp.flush().unwrap();
642
643        // Build a checksums-sha256.txt line that matches
644        let checksum_line = format!("{}  mvmctl-test.tar.gz\n", hash);
645
646        // parse_checksum_line + manual comparison (verify_checksum needs HTTP)
647        let expected = parse_checksum_line(checksum_line.trim()).unwrap();
648        let actual: [u8; 32] = Sha256::digest(data).into();
649        assert_eq!(expected, actual, "Correct digest should match");
650    }
651
652    #[test]
653    fn test_checksum_tampered_bytes_fail() {
654        let data = b"hello binary";
655        let tampered = b"TAMPERED!!!!";
656        let hash_of_original = sha256_of(data);
657        let checksum_line = format!("{}  mvmctl-test.tar.gz", hash_of_original);
658
659        let expected = parse_checksum_line(&checksum_line).unwrap();
660        let actual: [u8; 32] = Sha256::digest(tampered).into();
661        assert_ne!(
662            expected, actual,
663            "Tampered bytes should produce different digest"
664        );
665    }
666
667    // --- Existing tests ---
668
669    #[test]
670    fn test_current_version_non_empty() {
671        let v = current_version();
672        assert!(!v.is_empty());
673        assert!(v.contains('.'), "Version should contain dots: {}", v);
674    }
675
676    #[test]
677    fn test_strip_v_prefix() {
678        assert_eq!(strip_v_prefix("v0.1.0"), "0.1.0");
679        assert_eq!(strip_v_prefix("0.1.0"), "0.1.0");
680        assert_eq!(strip_v_prefix("v1.2.3-beta"), "1.2.3-beta");
681    }
682
683    #[test]
684    fn test_detect_target_succeeds() {
685        let target = detect_target().unwrap();
686        let valid_targets = [
687            "aarch64-apple-darwin",
688            "x86_64-apple-darwin",
689            "x86_64-unknown-linux-gnu",
690            "aarch64-unknown-linux-gnu",
691        ];
692        assert!(
693            valid_targets.contains(&target),
694            "Unexpected target: {}",
695            target
696        );
697    }
698}