use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use std::path::Path;
use crate::http;
use crate::ui;
use mvm_runtime::shell::run_host;
const GITHUB_REPO: &str = "auser/mvm";
fn current_version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
fn detect_target() -> Result<&'static str> {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
return Ok("aarch64-apple-darwin");
#[cfg(all(target_arch = "x86_64", target_os = "macos"))]
return Ok("x86_64-apple-darwin");
#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
return Ok("x86_64-unknown-linux-gnu");
#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
return Ok("aarch64-unknown-linux-gnu");
#[cfg(not(any(
all(target_arch = "aarch64", target_os = "macos"),
all(target_arch = "x86_64", target_os = "macos"),
all(target_arch = "x86_64", target_os = "linux"),
all(target_arch = "aarch64", target_os = "linux"),
)))]
anyhow::bail!(
"Unsupported platform: {} / {}",
std::env::consts::ARCH,
std::env::consts::OS
);
}
fn fetch_latest_version() -> Result<String> {
let url = format!(
"https://api.github.com/repos/{}/releases/latest",
GITHUB_REPO
);
let json = http::fetch_json(&url)
.context("Failed to query GitHub releases API. Check your network connection.")?;
let tag = json["tag_name"]
.as_str()
.context("GitHub API response missing 'tag_name' field")?;
Ok(tag.to_string())
}
fn strip_v_prefix(tag: &str) -> &str {
tag.strip_prefix('v').unwrap_or(tag)
}
fn parse_checksum_line(line: &str) -> Result<[u8; 32]> {
let hex = line
.split_whitespace()
.next()
.context("Empty checksum line")?;
if hex.len() != 64 {
anyhow::bail!("Expected 64 hex chars in checksum, got {}", hex.len());
}
let mut digest = [0u8; 32];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let s = std::str::from_utf8(chunk).context("Non-UTF8 in checksum hex")?;
digest[i] =
u8::from_str_radix(s, 16).with_context(|| format!("Invalid hex byte: {}", s))?;
}
Ok(digest)
}
fn verify_checksum(version: &str, archive_name: &str, archive_path: &Path) -> Result<()> {
let checksum_url = format!(
"https://github.com/{}/releases/download/{}/checksums-sha256.txt",
GITHUB_REPO, version
);
let checksum_text = http::fetch_text(&checksum_url)
.context("Failed to download checksum file — cannot verify integrity")?;
let expected_digest = checksum_text
.lines()
.find(|line| line.contains(archive_name))
.with_context(|| {
format!(
"Checksum for '{}' not found in checksums-sha256.txt",
archive_name
)
})
.and_then(parse_checksum_line)?;
let bytes = std::fs::read(archive_path).with_context(|| {
format!(
"Failed to read archive for checksum: {}",
archive_path.display()
)
})?;
let actual_digest: [u8; 32] = Sha256::digest(&bytes).into();
if actual_digest != expected_digest {
anyhow::bail!(
"Checksum mismatch for {}!\n expected: {}\n actual: {}\nThe download may be corrupted or tampered with.",
archive_name,
hex_encode(&expected_digest),
hex_encode(&actual_digest),
);
}
ui::success("Checksum verified.");
Ok(())
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
fn download_release(version: &str, target: &str, tmp_dir: &Path) -> Result<()> {
let archive_name = format!("mvmctl-{}.tar.gz", target);
let download_url = format!(
"https://github.com/{}/releases/download/{}/{}",
GITHUB_REPO, version, archive_name
);
let dest = tmp_dir.join(&archive_name);
let sp = ui::spinner(&format!("Downloading {}...", download_url));
http::download_file(&download_url, &dest).with_context(|| {
format!(
"Download failed. Check that {} has a release for {}.",
version, target
)
})?;
sp.finish_and_clear();
ui::success("Download complete.");
Ok(())
}
fn is_writable(path: &Path) -> bool {
tempfile::Builder::new()
.prefix(".mvm-write-test-")
.tempfile_in(path)
.is_ok()
}
fn smoke_test_binary(bin: &Path) -> Result<()> {
let output = std::process::Command::new(bin)
.arg("--version")
.output()
.with_context(|| format!("Failed to execute smoke test for {}", bin.display()))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!(
"smoke test failed (exit {}): {}",
output.status.code().unwrap_or(-1),
stderr.trim()
);
}
let stdout = String::from_utf8_lossy(&output.stdout);
if !stdout.chars().any(|c| c.is_ascii_digit()) {
anyhow::bail!(
"smoke test output does not look like a version: {:?}",
stdout.trim()
);
}
Ok(())
}
fn extract_and_install(target: &str, tmp_dir: &Path, current_exe: &Path) -> Result<()> {
let archive_name = format!("mvmctl-{}.tar.gz", target);
let archive_path = tmp_dir.join(&archive_name);
let output = run_host(
"tar",
&[
"xzf",
archive_path
.to_str()
.expect("archive path must be valid UTF-8"),
"-C",
tmp_dir.to_str().expect("tmp dir path must be valid UTF-8"),
],
)?;
if !output.status.success() {
anyhow::bail!("Failed to extract archive");
}
let extracted_dir = tmp_dir.join(format!("mvmctl-{}", target));
let new_binary = extracted_dir.join("mvmctl");
if !new_binary.exists() {
anyhow::bail!(
"Binary not found in archive at expected path: mvmctl-{}/mvmctl",
target
);
}
ui::info("Verifying new binary...");
smoke_test_binary(&new_binary).context("New binary failed pre-install smoke test")?;
let install_dir = current_exe
.parent()
.context("Cannot determine install directory")?;
let needs_sudo = !is_writable(install_dir);
ui::info(&format!("Installing to {}...", install_dir.display()));
if needs_sudo {
ui::warn("Requires elevated permissions.");
}
let backup_path = current_exe.with_extension("old");
if needs_sudo {
run_sudo_mv(current_exe, &backup_path)?;
if let Err(e) = run_sudo_cp(&new_binary, current_exe) {
if let Err(e) = run_sudo_mv(&backup_path, current_exe) {
tracing::warn!("failed to rollback binary during update: {e}");
}
return Err(e);
}
if let Err(e) = run_host(
"sudo",
&[
"chmod",
"+x",
current_exe.to_str().expect("exe path must be valid UTF-8"),
],
) {
tracing::warn!("failed to chmod during update: {e}");
}
if let Err(e) = smoke_test_binary(current_exe) {
if let Err(re) = run_sudo_mv(&backup_path, current_exe) {
tracing::warn!("failed to restore backup after smoke test failure: {re}");
}
anyhow::bail!("New binary failed smoke test; restored previous version. ({e})");
}
if let Err(e) = run_host(
"sudo",
&[
"rm",
"-f",
backup_path
.to_str()
.expect("backup path must be valid UTF-8"),
],
) {
tracing::warn!("failed to rm during update: {e}");
}
} else {
std::fs::rename(current_exe, &backup_path).context("Failed to back up current binary")?;
if let Err(e) = std::fs::copy(&new_binary, current_exe) {
if let Err(e) = std::fs::rename(&backup_path, current_exe) {
tracing::warn!("failed to rollback binary during update: {e}");
}
return Err(anyhow::anyhow!(e).context("Failed to install new binary"));
}
set_executable(current_exe)?;
if let Err(e) = smoke_test_binary(current_exe) {
if let Err(re) = std::fs::rename(&backup_path, current_exe) {
tracing::warn!("failed to restore backup after smoke test failure: {re}");
}
anyhow::bail!("New binary failed smoke test; restored previous version. ({e})");
}
if let Err(e) = std::fs::remove_file(&backup_path) {
tracing::warn!("failed to remove backup file: {e}");
}
}
let new_resources = extracted_dir.join("resources");
if new_resources.exists() {
let dest_resources = install_dir.join("resources");
ui::info("Updating resources...");
if needs_sudo {
if let Err(e) = run_host(
"sudo",
&[
"rm",
"-rf",
dest_resources
.to_str()
.expect("resources path must be valid UTF-8"),
],
) {
tracing::warn!("failed to remove old resources directory: {e}");
}
let output = run_host(
"sudo",
&[
"cp",
"-r",
new_resources
.to_str()
.expect("new resources path must be valid UTF-8"),
dest_resources
.to_str()
.expect("dest resources path must be valid UTF-8"),
],
)?;
if !output.status.success() {
ui::warn("Failed to update resources directory");
}
} else {
if let Err(e) = std::fs::remove_dir_all(&dest_resources) {
tracing::warn!("failed to remove old resources: {e}");
}
copy_dir_recursive(&new_resources, &dest_resources)
.context("Failed to update resources directory")?;
}
}
Ok(())
}
fn run_sudo_mv(from: &Path, to: &Path) -> Result<()> {
let output = run_host(
"sudo",
&[
"mv",
from.to_str().expect("source path must be valid UTF-8"),
to.to_str().expect("dest path must be valid UTF-8"),
],
)?;
if !output.status.success() {
anyhow::bail!("sudo mv failed");
}
Ok(())
}
fn run_sudo_cp(from: &Path, to: &Path) -> Result<()> {
let output = run_host(
"sudo",
&[
"cp",
from.to_str().expect("source path must be valid UTF-8"),
to.to_str().expect("dest path must be valid UTF-8"),
],
)?;
if !output.status.success() {
anyhow::bail!("sudo cp failed");
}
Ok(())
}
#[cfg(unix)]
fn set_executable(path: &Path) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
let mut perms = std::fs::metadata(path)?.permissions();
perms.set_mode(0o755);
std::fs::set_permissions(path, perms)?;
Ok(())
}
#[cfg(not(unix))]
fn set_executable(_path: &Path) -> Result<()> {
Ok(())
}
fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<()> {
std::fs::create_dir_all(dst)?;
for entry in std::fs::read_dir(src)? {
let entry = entry?;
let ty = entry.file_type()?;
let dest_path = dst.join(entry.file_name());
if ty.is_dir() {
copy_dir_recursive(&entry.path(), &dest_path)?;
} else {
std::fs::copy(entry.path(), &dest_path)?;
}
}
Ok(())
}
fn verify_signature(version: &str, archive_name: &str, archive_path: &Path) -> Result<()> {
let cosign = match which::which("cosign") {
Ok(p) => p,
Err(_) => {
tracing::warn!(
"cosign not found — skipping signature verification. \
Install cosign to enable provenance checking."
);
return Ok(());
}
};
let bundle_name = format!("{}.bundle", archive_name);
let bundle_url = format!(
"https://github.com/{}/releases/download/{}/{}",
GITHUB_REPO, version, bundle_name
);
let bundle_path = archive_path
.parent()
.unwrap_or_else(|| std::path::Path::new("."))
.join(&bundle_name);
ui::info("Downloading signature bundle...");
http::download_file(&bundle_url, &bundle_path)
.context("Failed to download cosign bundle — cannot verify signature")?;
let output = std::process::Command::new(&cosign)
.args([
"verify-blob",
"--bundle",
bundle_path
.to_str()
.expect("bundle path must be valid UTF-8"),
"--certificate-oidc-issuer",
"https://token.actions.githubusercontent.com",
"--certificate-identity-regexp",
&format!(
"https://github.com/{repo}/.github/workflows/release.yml@refs/tags/.*",
repo = GITHUB_REPO
),
archive_path
.to_str()
.expect("archive path must be valid UTF-8"),
])
.output()
.context("Failed to run cosign verify-blob")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!(
"Signature verification failed — the archive may not have been built \
by the official release pipeline.\ncosign output: {}",
stderr.trim()
);
}
ui::success("Signature verified.");
Ok(())
}
pub fn update(check_only: bool, force: bool, skip_verify: bool) -> Result<()> {
let current = current_version();
ui::info(&format!("Current version: {}", current));
let sp = ui::spinner("Checking for updates...");
let latest_tag = fetch_latest_version()?;
let latest_version = strip_v_prefix(&latest_tag);
sp.finish_and_clear();
if latest_version == current && !force {
ui::success(&format!("Already up to date ({}).", current));
return Ok(());
}
if latest_version == current {
ui::info(&format!(
"Already at {} but --force specified, reinstalling.",
current
));
} else {
ui::info(&format!(
"New version available: {} -> {}",
current, latest_version
));
}
if check_only {
return Ok(());
}
let target = detect_target()?;
ui::info(&format!("Platform: {}", target));
let current_exe =
std::env::current_exe().context("Failed to determine path of current executable")?;
let current_exe = current_exe.canonicalize().unwrap_or(current_exe);
let tmp_dir = tempfile::tempdir().context("Failed to create temporary directory")?;
download_release(&latest_tag, target, tmp_dir.path())?;
let archive_name = format!("mvmctl-{}.tar.gz", target);
let archive_path = tmp_dir.path().join(&archive_name);
verify_checksum(&latest_tag, &archive_name, &archive_path)?;
if !skip_verify {
verify_signature(&latest_tag, &archive_name, &archive_path)?;
}
extract_and_install(target, tmp_dir.path(), ¤t_exe)?;
ui::success(&format!("\nSuccessfully updated to {}!", latest_tag));
ui::info("The binary has been replaced on disk.");
ui::info("To verify: Open a new shell and run 'mvmctl --version'");
ui::info("Or run: hash -r (to clear your shell's command cache)");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use sha2::{Digest, Sha256};
use std::io::Write;
#[cfg(unix)]
#[test]
fn test_smoke_test_binary_passes() {
use std::io::Write;
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir_in("/var/tmp").unwrap();
let path = dir.path().join("mvm-smoke-test.sh");
{
let mut file = std::fs::File::create(&path).unwrap();
writeln!(file, "#!/bin/sh\necho 'mvmctl 1.0.0'").unwrap();
file.flush().unwrap();
}
let mut perms = std::fs::metadata(&path).unwrap().permissions();
perms.set_mode(0o755);
std::fs::set_permissions(&path, perms).unwrap();
assert!(smoke_test_binary(&path).is_ok());
}
#[test]
fn test_smoke_test_binary_nonexistent_fails() {
let result = smoke_test_binary(std::path::Path::new("/nonexistent/binary/does-not-exist"));
assert!(result.is_err());
}
#[test]
fn test_smoke_test_binary_rollback_error_message() {
let err_msg = format!(
"New binary failed smoke test; restored previous version. ({})",
"smoke test failed (exit 1): "
);
assert!(err_msg.contains("New binary failed smoke test; restored previous version."));
}
#[test]
fn test_verify_signature_skipped_when_cosign_absent() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let result = verify_signature("v0.0.0-nonexistent", "mvmctl-test.tar.gz", tmp.path());
match result {
Ok(()) => {} Err(e) => {
let msg = e.to_string();
assert!(
msg.contains("cosign") || msg.contains("bundle") || msg.contains("download"),
"unexpected error: {msg}"
);
}
}
}
#[test]
fn test_skip_verify_flag_respected() {
let _ = "skip_verify=true prevents any cosign invocation";
}
fn sha256_of(data: &[u8]) -> String {
let digest: [u8; 32] = Sha256::digest(data).into();
hex_encode(&digest)
}
#[test]
fn test_parse_checksum_line_valid() {
let hex = "a".repeat(64);
let line = format!("{} mvmctl-aarch64-apple-darwin.tar.gz", hex);
let digest = parse_checksum_line(&line).unwrap();
assert_eq!(digest, [0xaa; 32]);
}
#[test]
fn test_parse_checksum_line_wrong_length() {
let err = parse_checksum_line("abc file.tar.gz").unwrap_err();
assert!(err.to_string().contains("64 hex chars"));
}
#[test]
fn test_checksum_correct_digest_passes() {
let data = b"hello binary";
let hash = sha256_of(data);
let mut tmp = tempfile::NamedTempFile::new().unwrap();
tmp.write_all(data).unwrap();
tmp.flush().unwrap();
let checksum_line = format!("{} mvmctl-test.tar.gz\n", hash);
let expected = parse_checksum_line(checksum_line.trim()).unwrap();
let actual: [u8; 32] = Sha256::digest(data).into();
assert_eq!(expected, actual, "Correct digest should match");
}
#[test]
fn test_checksum_tampered_bytes_fail() {
let data = b"hello binary";
let tampered = b"TAMPERED!!!!";
let hash_of_original = sha256_of(data);
let checksum_line = format!("{} mvmctl-test.tar.gz", hash_of_original);
let expected = parse_checksum_line(&checksum_line).unwrap();
let actual: [u8; 32] = Sha256::digest(tampered).into();
assert_ne!(
expected, actual,
"Tampered bytes should produce different digest"
);
}
#[test]
fn test_current_version_non_empty() {
let v = current_version();
assert!(!v.is_empty());
assert!(v.contains('.'), "Version should contain dots: {}", v);
}
#[test]
fn test_strip_v_prefix() {
assert_eq!(strip_v_prefix("v0.1.0"), "0.1.0");
assert_eq!(strip_v_prefix("0.1.0"), "0.1.0");
assert_eq!(strip_v_prefix("v1.2.3-beta"), "1.2.3-beta");
}
#[test]
fn test_detect_target_succeeds() {
let target = detect_target().unwrap();
let valid_targets = [
"aarch64-apple-darwin",
"x86_64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
];
assert!(
valid_targets.contains(&target),
"Unexpected target: {}",
target
);
}
}