Skip to main content

romm_cli/update/
mod.rs

1use anyhow::{anyhow, bail, Context, Result};
2use self_update::cargo_crate_version;
3use self_update::update::ReleaseUpdate;
4use self_update::Extract;
5use serde::Deserialize;
6use sha2::{Digest, Sha256};
7use std::collections::HashMap;
8use std::env::consts::EXE_SUFFIX;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tokio::io::AsyncWriteExt;
12
13use crate::core::interrupt::{cancelled_error, InterruptContext};
14
15const REPO_OWNER: &str = "patricksmill";
16const REPO_NAME: &str = "romm-cli";
17const DEFAULT_BIN_NAME: &str = "romm-cli";
18const SHIPPED_BINARIES: &[&str] = &["romm-cli", "romm-tui"];
19const CHANGELOG_URL: &str = "https://github.com/patricksmill/romm-cli/blob/main/CHANGELOG.md";
20const CHECKSUMS_ASSET_NAME: &str = "checksums.txt";
21
22#[derive(Debug, Clone)]
23pub struct UpdateStatus {
24    pub current_version: String,
25    pub latest_version: String,
26    pub release_tag: String,
27    pub should_update: bool,
28    pub release_url: String,
29    pub changelog_url: String,
30}
31
32#[derive(Debug, Clone)]
33pub struct ApplyUpdateOptions {
34    pub show_progress: bool,
35    pub show_output: bool,
36    pub no_confirm: bool,
37    pub target_version_tag: Option<String>,
38}
39
40impl Default for ApplyUpdateOptions {
41    fn default() -> Self {
42        Self {
43            show_progress: false,
44            show_output: false,
45            no_confirm: true,
46            target_version_tag: None,
47        }
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum ApplyUpdateOutcome {
53    Updated(String),
54    UpToDate(String),
55}
56
57#[derive(Debug, Deserialize)]
58struct GithubLatestRelease {
59    tag_name: String,
60    html_url: String,
61}
62
63#[derive(Debug, Clone)]
64struct ResolvedRelease {
65    version: String,
66    archive_name: String,
67    archive_download_url: String,
68    checksums_download_url: String,
69}
70
71pub fn github_api_base_url() -> String {
72    std::env::var("ROMM_GITHUB_API_BASE").unwrap_or_else(|_| "https://api.github.com".to_string())
73}
74
75fn github_latest_release_api_url() -> String {
76    std::env::var("ROMM_GITHUB_LATEST_RELEASE_API").unwrap_or_else(|_| {
77        format!(
78            "{}/repos/{}/{}/releases/latest",
79            github_api_base_url(),
80            REPO_OWNER,
81            REPO_NAME
82        )
83    })
84}
85
86pub fn github_release_asset_key() -> Result<&'static str> {
87    match (std::env::consts::OS, std::env::consts::ARCH) {
88        ("macos", "x86_64") => Ok("macos-x86_64"),
89        ("macos", "aarch64") => Ok("macos-aarch64"),
90        ("linux", "x86_64") => Ok("linux-x86_64"),
91        ("linux", "aarch64") => Ok("linux-aarch64"),
92        ("windows", "x86_64") => Ok("windows-x86_64"),
93        (os, arch) => Err(anyhow!("unsupported platform for self-update: {os}-{arch}")),
94    }
95}
96
97fn normalize_version_tag(version: &str) -> &str {
98    version.trim().trim_start_matches('v')
99}
100
101fn is_latest_newer(latest: &str, current: &str) -> bool {
102    self_update::version::bump_is_greater(
103        normalize_version_tag(current),
104        normalize_version_tag(latest),
105    )
106    .unwrap_or(false)
107}
108
109pub fn changelog_url() -> &'static str {
110    CHANGELOG_URL
111}
112
113pub fn open_url_in_browser(url: &str) -> Result<()> {
114    #[cfg(target_os = "windows")]
115    {
116        Command::new("cmd")
117            .args(["/C", "start", "", url])
118            .spawn()
119            .context("failed to launch browser via start")?;
120        return Ok(());
121    }
122
123    #[cfg(target_os = "macos")]
124    {
125        Command::new("open")
126            .arg(url)
127            .spawn()
128            .context("failed to launch browser via open")?;
129        return Ok(());
130    }
131
132    #[cfg(all(unix, not(target_os = "macos")))]
133    {
134        Command::new("xdg-open")
135            .arg(url)
136            .spawn()
137            .context("failed to launch browser via xdg-open")?;
138        return Ok(());
139    }
140
141    #[allow(unreachable_code)]
142    Err(anyhow!("unsupported OS for opening browser"))
143}
144
145pub fn open_changelog_in_browser() -> Result<()> {
146    open_url_in_browser(changelog_url())
147}
148
149fn binary_name_from_path(path: &Path) -> Option<String> {
150    let raw = path.as_os_str().to_string_lossy();
151    raw.rsplit(['/', '\\'])
152        .next()
153        .map(|name| {
154            name.strip_suffix(".exe")
155                .or_else(|| name.strip_suffix(".EXE"))
156                .unwrap_or(name)
157                .to_string()
158        })
159        .filter(|name| !name.is_empty())
160}
161
162fn current_binary_name() -> String {
163    std::env::current_exe()
164        .ok()
165        .and_then(|path| binary_name_from_path(&path))
166        .unwrap_or_else(|| DEFAULT_BIN_NAME.to_string())
167}
168
169fn shipped_binary_file_name(stem: &str) -> String {
170    format!("{stem}{EXE_SUFFIX}")
171}
172
173fn build_release_updater(options: &ApplyUpdateOptions) -> Result<Box<dyn ReleaseUpdate>> {
174    let target = github_release_asset_key()?;
175    let bin_name = current_binary_name();
176    let mut builder = self_update::backends::github::Update::configure();
177    builder
178        .repo_owner(REPO_OWNER)
179        .repo_name(REPO_NAME)
180        .bin_name(&bin_name)
181        .target(target)
182        .identifier(DEFAULT_BIN_NAME)
183        .current_version(cargo_crate_version!())
184        .with_url(&github_api_base_url())
185        .show_download_progress(false)
186        .show_output(options.show_output)
187        .no_confirm(options.no_confirm);
188
189    if let Some(ref tag) = options.target_version_tag {
190        builder.target_version_tag(tag);
191    }
192
193    builder
194        .build()
195        .map_err(|e| anyhow!("build self_update config: {e}"))
196}
197
198fn resolve_release(options: &ApplyUpdateOptions) -> Result<Option<ResolvedRelease>> {
199    let current_version = cargo_crate_version!().to_string();
200    let target = github_release_asset_key()?;
201    let updater = build_release_updater(options)?;
202
203    let release = if let Some(ref tag) = options.target_version_tag {
204        updater.get_release_version(tag)?
205    } else {
206        let latest = updater.get_latest_release()?;
207        if !is_latest_newer(&latest.version, &current_version) {
208            return Ok(None);
209        }
210        latest
211    };
212
213    let archive = release
214        .asset_for(target, Some(DEFAULT_BIN_NAME))
215        .ok_or_else(|| anyhow!("no release asset found for target `{target}`"))?;
216
217    let checksums_download_url = release
218        .assets
219        .iter()
220        .find(|asset| asset.name == CHECKSUMS_ASSET_NAME)
221        .ok_or_else(|| anyhow!("release is missing `{CHECKSUMS_ASSET_NAME}` asset"))?
222        .download_url
223        .clone();
224
225    Ok(Some(ResolvedRelease {
226        version: release.version,
227        archive_name: archive.name,
228        archive_download_url: archive.download_url,
229        checksums_download_url,
230    }))
231}
232
233fn parse_checksums(content: &str) -> HashMap<String, String> {
234    let mut out = HashMap::new();
235    for line in content.lines() {
236        let line = line.trim();
237        if line.is_empty() {
238            continue;
239        }
240        let Some((hash, name)) = line.split_once("  ") else {
241            continue;
242        };
243        let name = name.trim_start_matches('*').trim();
244        out.insert(name.to_string(), hash.to_lowercase());
245    }
246    out
247}
248
249fn sha256_hex_file(path: &Path) -> Result<String> {
250    use std::io::Read;
251    let mut file = std::fs::File::open(path).with_context(|| format!("open {}", path.display()))?;
252    let mut hasher = Sha256::new();
253    let mut buffer = [0u8; 8192];
254    loop {
255        let read = file.read(&mut buffer).context("read file for sha256")?;
256        if read == 0 {
257            break;
258        }
259        hasher.update(&buffer[..read]);
260    }
261    Ok(hasher
262        .finalize()
263        .iter()
264        .map(|byte| format!("{byte:02x}"))
265        .collect())
266}
267
268fn verify_archive_checksum(
269    archive_path: &Path,
270    archive_name: &str,
271    checksums_content: &str,
272) -> Result<()> {
273    let checksums = parse_checksums(checksums_content);
274    let expected = checksums
275        .get(archive_name)
276        .ok_or_else(|| anyhow!("checksums.txt has no entry for `{archive_name}`"))?;
277    let actual = sha256_hex_file(archive_path)?;
278    if &actual != expected {
279        bail!("checksum mismatch for `{archive_name}`: expected {expected}, got {actual}");
280    }
281    Ok(())
282}
283
284fn github_asset_headers(user_agent: &str) -> reqwest::header::HeaderMap {
285    let mut headers = reqwest::header::HeaderMap::new();
286    headers.insert(
287        reqwest::header::USER_AGENT,
288        reqwest::header::HeaderValue::from_str(user_agent)
289            .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("romm-cli")),
290    );
291    headers.insert(
292        reqwest::header::ACCEPT,
293        reqwest::header::HeaderValue::from_static("application/octet-stream"),
294    );
295    headers
296}
297
298async fn download_url_to_file(
299    client: &reqwest::Client,
300    url: &str,
301    dest: &Path,
302    user_agent: &str,
303    interrupt: &InterruptContext,
304    show_progress: bool,
305) -> Result<()> {
306    if interrupt.is_cancelled() {
307        return Err(cancelled_error().into());
308    }
309
310    let response = client
311        .get(url)
312        .headers(github_asset_headers(user_agent))
313        .send()
314        .await
315        .with_context(|| format!("download request failed for {url}"))?
316        .error_for_status()
317        .with_context(|| format!("download returned error status for {url}"))?;
318
319    let total = response.content_length();
320    let mut file = tokio::fs::File::create(dest)
321        .await
322        .with_context(|| format!("create {}", dest.display()))?;
323
324    let progress = if show_progress {
325        total.map(|len| {
326            let pb = indicatif::ProgressBar::new(len);
327            pb.set_style(
328                indicatif::ProgressStyle::default_bar()
329                    .template("{wide_bar} {bytes}/{total_bytes}")
330                    .expect("progress template"),
331            );
332            pb
333        })
334    } else {
335        None
336    };
337
338    let mut downloaded = 0u64;
339    let mut response = response;
340    while let Some(chunk) = response.chunk().await.context("read download chunk")? {
341        if interrupt.is_cancelled() {
342            return Err(cancelled_error().into());
343        }
344        file.write_all(&chunk)
345            .await
346            .context("write download chunk")?;
347        downloaded += chunk.len() as u64;
348        if let Some(ref pb) = progress {
349            pb.set_position(downloaded);
350        }
351    }
352
353    if let Some(pb) = progress {
354        pb.finish_and_clear();
355    }
356
357    Ok(())
358}
359
360fn install_extracted_binaries(extract_dir: &Path, running_bin_stem: &str) -> Result<()> {
361    let current_exe = std::env::current_exe().context("resolve current executable path")?;
362    let install_dir = current_exe
363        .parent()
364        .ok_or_else(|| anyhow!("current executable has no parent directory"))?;
365
366    let mut running_source = None;
367
368    for stem in SHIPPED_BINARIES {
369        let file_name = shipped_binary_file_name(stem);
370        let source = extract_dir.join(&file_name);
371        if !source.is_file() {
372            continue;
373        }
374
375        let dest = install_dir.join(&file_name);
376        if stem == &running_bin_stem {
377            running_source = Some(source);
378            continue;
379        }
380
381        std::fs::copy(&source, &dest).with_context(|| {
382            format!(
383                "copy sibling binary `{}` to `{}`",
384                source.display(),
385                dest.display()
386            )
387        })?;
388        if let Ok(meta) = std::fs::metadata(&source) {
389            let _ = std::fs::set_permissions(&dest, meta.permissions());
390        }
391    }
392
393    let Some(new_running) = running_source else {
394        bail!("extracted archive did not contain `{running_bin_stem}`");
395    };
396
397    self_update::self_replace::self_replace(new_running).context("replace running executable")?;
398
399    Ok(())
400}
401
402fn install_from_archive(
403    archive_path: &Path,
404    archive_name: &str,
405    checksums_content: &str,
406) -> Result<()> {
407    verify_archive_checksum(archive_path, archive_name, checksums_content)?;
408
409    let extract_dir = self_update::TempDir::new().context("create temp extract dir")?;
410    Extract::from_source(archive_path)
411        .extract_into(extract_dir.path())
412        .with_context(|| format!("extract `{archive_name}`"))?;
413
414    install_extracted_binaries(extract_dir.path(), &current_binary_name())?;
415    Ok(())
416}
417
418pub async fn check_for_update() -> Result<UpdateStatus> {
419    let current_version = cargo_crate_version!().to_string();
420    let response = reqwest::Client::new()
421        .get(github_latest_release_api_url())
422        .header(
423            reqwest::header::USER_AGENT,
424            format!("romm-cli/{current_version}"),
425        )
426        .send()
427        .await
428        .context("failed to query latest release")?
429        .error_for_status()
430        .context("latest release endpoint returned an error status")?;
431
432    let latest_release: GithubLatestRelease = response
433        .json()
434        .await
435        .context("failed to parse latest release response")?;
436
437    let release_tag = latest_release.tag_name.clone();
438    let latest_version = release_tag.trim_start_matches('v').to_string();
439    Ok(UpdateStatus {
440        should_update: is_latest_newer(&latest_version, &current_version),
441        current_version,
442        latest_version,
443        release_tag,
444        release_url: latest_release.html_url,
445        changelog_url: changelog_url().to_string(),
446    })
447}
448
449pub async fn apply_update(
450    interrupt: Option<InterruptContext>,
451    options: ApplyUpdateOptions,
452) -> Result<ApplyUpdateOutcome> {
453    let interrupt = interrupt.unwrap_or_default();
454    let current_version = cargo_crate_version!().to_string();
455    let user_agent = format!("romm-cli/{current_version}");
456
457    let resolved = tokio::task::spawn_blocking({
458        let options = options.clone();
459        move || resolve_release(&options)
460    })
461    .await
462    .map_err(|e| anyhow!("update resolve task failed: {e}"))??;
463
464    let Some(resolved) = resolved else {
465        return Ok(ApplyUpdateOutcome::UpToDate(current_version));
466    };
467
468    let archive_dir = self_update::TempDir::new().context("create temp download dir")?;
469    let archive_path: PathBuf = archive_dir.path().join(&resolved.archive_name);
470
471    let client = reqwest::Client::new();
472
473    if interrupt.is_cancelled() {
474        return Err(cancelled_error().into());
475    }
476    let checksums_content = client
477        .get(&resolved.checksums_download_url)
478        .headers(github_asset_headers(&user_agent))
479        .send()
480        .await
481        .context("download checksums.txt")?
482        .error_for_status()
483        .context("checksums.txt request failed")?
484        .text()
485        .await
486        .context("read checksums.txt")?;
487
488    download_url_to_file(
489        &client,
490        &resolved.archive_download_url,
491        &archive_path,
492        &user_agent,
493        &interrupt,
494        options.show_progress,
495    )
496    .await?;
497
498    let version = resolved.version.clone();
499    let archive_name = resolved.archive_name.clone();
500    let install_task = tokio::task::spawn_blocking(move || {
501        install_from_archive(&archive_path, &archive_name, &checksums_content).map(|_| version)
502    });
503
504    let installed_version = tokio::select! {
505        out = install_task => out
506            .map_err(|e| anyhow!("update install task failed: {e}"))??,
507        _ = interrupt.cancelled() => return Err(cancelled_error().into()),
508    };
509
510    Ok(ApplyUpdateOutcome::Updated(installed_version))
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn version_compare_handles_patch_and_minor() {
519        assert!(is_latest_newer("0.25.1", "0.25.0"));
520        assert!(is_latest_newer("0.26.0", "0.25.9"));
521        assert!(!is_latest_newer("0.25.0", "0.25.0"));
522        assert!(!is_latest_newer("0.24.9", "0.25.0"));
523    }
524
525    #[test]
526    fn version_compare_handles_v_prefix() {
527        assert!(is_latest_newer("v1.2.4", "1.2.3"));
528    }
529
530    #[test]
531    fn version_compare_handles_prerelease_to_stable() {
532        assert!(is_latest_newer("0.25.0", "0.25.0-alpha"));
533    }
534
535    #[test]
536    fn parse_checksums_reads_sha256sum_format() {
537        let parsed = parse_checksums("abc123  romm-cli-linux-x86_64.tar.gz\n");
538        assert_eq!(
539            parsed.get("romm-cli-linux-x86_64.tar.gz"),
540            Some(&"abc123".to_string())
541        );
542    }
543
544    #[test]
545    fn verify_archive_checksum_matches() {
546        let dir = self_update::TempDir::new().expect("tempdir");
547        let path = dir.path().join("sample.tar.gz");
548        std::fs::write(&path, b"hello").expect("write sample");
549        let digest = sha256_hex_file(&path).expect("hash");
550        let checksums = format!("{digest}  sample.tar.gz\n");
551        verify_archive_checksum(&path, "sample.tar.gz", &checksums).expect("verify");
552    }
553
554    #[test]
555    fn verify_archive_checksum_rejects_mismatch() {
556        let dir = self_update::TempDir::new().expect("tempdir");
557        let path = dir.path().join("sample.tar.gz");
558        std::fs::write(&path, b"hello").expect("write sample");
559        let checksums = "deadbeef  sample.tar.gz\n";
560        assert!(verify_archive_checksum(&path, "sample.tar.gz", checksums).is_err());
561    }
562
563    #[test]
564    fn binary_name_from_path_strips_windows_exe_extension() {
565        assert_eq!(
566            binary_name_from_path(Path::new(r"C:\tools\romm-tui.exe")).as_deref(),
567            Some("romm-tui")
568        );
569    }
570
571    #[test]
572    fn current_binary_name_is_available() {
573        assert!(!current_binary_name().is_empty());
574    }
575
576    #[test]
577    fn github_release_asset_key_supports_windows() {
578        if std::env::consts::OS == "windows" && std::env::consts::ARCH == "x86_64" {
579            assert_eq!(
580                github_release_asset_key().expect("target"),
581                "windows-x86_64"
582            );
583        }
584    }
585}