Skip to main content

construct/commands/
update.rs

1//! `construct update` — self-update pipeline with rollback.
2
3use anyhow::{Context, Result, bail};
4use std::path::Path;
5use tracing::{info, warn};
6
7const GITHUB_RELEASES_LATEST_URL: &str =
8    "https://api.github.com/repos/KumihoIO/construct/releases/latest";
9const GITHUB_RELEASES_TAG_URL: &str =
10    "https://api.github.com/repos/KumihoIO/construct/releases/tags";
11
12#[derive(Debug)]
13pub struct UpdateInfo {
14    pub current_version: String,
15    pub latest_version: String,
16    pub download_url: Option<String>,
17    pub is_newer: bool,
18}
19
20/// Check for available updates without downloading.
21///
22/// If `target_version` is `Some`, fetch that specific release tag instead of latest.
23pub async fn check(target_version: Option<&str>) -> Result<UpdateInfo> {
24    let current = env!("CARGO_PKG_VERSION").to_string();
25
26    let client = reqwest::Client::builder()
27        .user_agent(format!("construct/{current}"))
28        .timeout(std::time::Duration::from_secs(15))
29        .build()?;
30
31    let url = match target_version {
32        Some(v) => {
33            let tag = if v.starts_with('v') {
34                v.to_string()
35            } else {
36                format!("v{v}")
37            };
38            format!("{GITHUB_RELEASES_TAG_URL}/{tag}")
39        }
40        None => GITHUB_RELEASES_LATEST_URL.to_string(),
41    };
42
43    let resp = client
44        .get(&url)
45        .send()
46        .await
47        .context("failed to reach GitHub releases API")?;
48
49    if !resp.status().is_success() {
50        bail!("GitHub API returned {}", resp.status());
51    }
52
53    let release: serde_json::Value = resp.json().await?;
54    let tag = release["tag_name"]
55        .as_str()
56        .unwrap_or("unknown")
57        .trim_start_matches('v')
58        .to_string();
59
60    let download_url = find_asset_url(&release);
61    let is_newer = version_is_newer(&current, &tag);
62
63    Ok(UpdateInfo {
64        current_version: current,
65        latest_version: tag,
66        download_url,
67        is_newer,
68    })
69}
70
71/// Run the full 6-phase update pipeline.
72///
73/// If `target_version` is `Some`, fetch that specific version instead of latest.
74pub async fn run(target_version: Option<&str>) -> Result<()> {
75    // Phase 1: Preflight
76    info!("Phase 1/6: Preflight checks...");
77    let update_info = check(target_version).await?;
78
79    if !update_info.is_newer {
80        println!("Already up to date (v{}).", update_info.current_version);
81        return Ok(());
82    }
83
84    println!(
85        "Update available: v{} -> v{}",
86        update_info.current_version, update_info.latest_version
87    );
88
89    let download_url = update_info
90        .download_url
91        .context("no suitable binary found for this platform")?;
92
93    let current_exe =
94        std::env::current_exe().context("cannot determine current executable path")?;
95
96    // Phase 2: Download
97    info!("Phase 2/6: Downloading...");
98    let temp_dir = tempfile::tempdir().context("failed to create temp dir")?;
99    let download_path = temp_dir.path().join("construct_new");
100    download_binary(&download_url, &download_path).await?;
101
102    // Phase 3: Backup
103    info!("Phase 3/6: Creating backup...");
104    let backup_path = current_exe.with_extension("bak");
105    tokio::fs::copy(&current_exe, &backup_path)
106        .await
107        .context("failed to backup current binary")?;
108
109    // Phase 4: Validate
110    info!("Phase 4/6: Validating download...");
111    validate_binary(&download_path).await?;
112
113    // Phase 5: Swap
114    info!("Phase 5/6: Swapping binary...");
115    if let Err(e) = swap_binary(&download_path, &current_exe).await {
116        // Rollback
117        warn!("Swap failed, rolling back: {e}");
118        if let Err(rollback_err) = rollback_binary(&backup_path, &current_exe).await {
119            eprintln!("CRITICAL: Rollback also failed: {rollback_err}");
120            eprintln!(
121                "Manual recovery: cp {} {}",
122                backup_path.display(),
123                current_exe.display()
124            );
125        }
126        bail!("Update failed during swap: {e}");
127    }
128
129    // Phase 6: Smoke test
130    info!("Phase 6/6: Smoke test...");
131    match smoke_test(&current_exe).await {
132        Ok(()) => {
133            // Cleanup backup on success
134            let _ = tokio::fs::remove_file(&backup_path).await;
135            println!("Successfully updated to v{}!", update_info.latest_version);
136            Ok(())
137        }
138        Err(e) => {
139            warn!("Smoke test failed, rolling back: {e}");
140            rollback_binary(&backup_path, &current_exe)
141                .await
142                .context("rollback after smoke test failure")?;
143            bail!("Update rolled back — smoke test failed: {e}");
144        }
145    }
146}
147
148fn find_asset_url(release: &serde_json::Value) -> Option<String> {
149    let target = current_target_triple();
150
151    release["assets"]
152        .as_array()?
153        .iter()
154        .find(|asset| {
155            asset["name"]
156                .as_str()
157                .map(|name| name.contains(target))
158                .unwrap_or(false)
159        })
160        .and_then(|asset| asset["browser_download_url"].as_str().map(String::from))
161}
162
163/// Return the exact Rust target triple for the current platform.
164///
165/// Using full triples (e.g. `aarch64-unknown-linux-gnu` instead of the
166/// shorter `aarch64-unknown-linux`) prevents substring matches from
167/// selecting the wrong asset (e.g. an Android binary on a GNU/Linux host).
168fn current_target_triple() -> &'static str {
169    if cfg!(target_os = "macos") {
170        if cfg!(target_arch = "aarch64") {
171            "aarch64-apple-darwin"
172        } else {
173            "x86_64-apple-darwin"
174        }
175    } else if cfg!(target_os = "linux") {
176        if cfg!(target_arch = "aarch64") {
177            "aarch64-unknown-linux-gnu"
178        } else {
179            "x86_64-unknown-linux-gnu"
180        }
181    } else {
182        "unknown"
183    }
184}
185
186fn version_is_newer(current: &str, candidate: &str) -> bool {
187    let parse = |v: &str| -> Vec<u32> { v.split('.').filter_map(|p| p.parse().ok()).collect() };
188    let cur = parse(current);
189    let cand = parse(candidate);
190    cand > cur
191}
192
193async fn download_binary(url: &str, dest: &Path) -> Result<()> {
194    let client = reqwest::Client::builder()
195        .user_agent(format!("construct/{}", env!("CARGO_PKG_VERSION")))
196        .timeout(std::time::Duration::from_secs(300))
197        .build()?;
198
199    let resp = client
200        .get(url)
201        .send()
202        .await
203        .context("download request failed")?;
204    if !resp.status().is_success() {
205        bail!("download returned {}", resp.status());
206    }
207
208    let bytes = resp.bytes().await.context("failed to read download body")?;
209
210    // Release assets are .tar.gz archives containing a single `construct` binary.
211    // Extract the binary from the archive instead of writing the raw tarball.
212    if url.ends_with(".tar.gz") || url.ends_with(".tgz") {
213        extract_tar_gz(&bytes, dest).context("failed to extract binary from tar.gz archive")?;
214    } else {
215        tokio::fs::write(dest, &bytes)
216            .await
217            .context("failed to write downloaded binary")?;
218    }
219
220    // Make executable on Unix
221    #[cfg(unix)]
222    {
223        use std::os::unix::fs::PermissionsExt;
224        let perms = std::fs::Permissions::from_mode(0o755);
225        tokio::fs::set_permissions(dest, perms).await?;
226    }
227
228    Ok(())
229}
230
231/// Extract the `construct` binary from a `.tar.gz` archive.
232fn extract_tar_gz(archive_bytes: &[u8], dest: &Path) -> Result<()> {
233    use flate2::read::GzDecoder;
234    use std::io::Read;
235    use tar::Archive;
236
237    let gz = GzDecoder::new(archive_bytes);
238    let mut archive = Archive::new(gz);
239
240    for entry in archive.entries().context("failed to read tar entries")? {
241        let mut entry = entry.context("failed to read tar entry")?;
242        let path = entry.path().context("failed to read entry path")?;
243
244        // The archive contains a single binary named "construct" (or "construct.exe" on Windows).
245        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
246
247        if file_name == "construct" || file_name == "construct.exe" {
248            let mut buf = Vec::new();
249            entry
250                .read_to_end(&mut buf)
251                .context("failed to read binary from archive")?;
252            std::fs::write(dest, &buf).context("failed to write extracted binary")?;
253            return Ok(());
254        }
255    }
256
257    bail!("archive does not contain a 'construct' binary")
258}
259
260async fn validate_binary(path: &Path) -> Result<()> {
261    let meta = tokio::fs::metadata(path).await?;
262    if meta.len() < 1_000_000 {
263        bail!(
264            "downloaded binary too small ({} bytes), likely corrupt",
265            meta.len()
266        );
267    }
268
269    // Check binary architecture before attempting execution so we can give
270    // a clear diagnostic instead of the opaque "Exec format error (os error 8)".
271    check_binary_arch(path).await?;
272
273    // Quick check: try running --version
274    let output = tokio::process::Command::new(path)
275        .arg("--version")
276        .output()
277        .await
278        .context("cannot execute downloaded binary")?;
279
280    if !output.status.success() {
281        bail!("downloaded binary --version check failed");
282    }
283
284    let stdout = String::from_utf8_lossy(&output.stdout);
285    if !stdout.contains("construct") {
286        bail!("downloaded binary does not appear to be construct");
287    }
288
289    Ok(())
290}
291
292/// Read the binary header and verify its architecture matches the host.
293///
294/// On Linux/FreeBSD this reads the ELF header; on macOS the Mach-O header.
295/// If the binary is for a different architecture, returns a descriptive error
296/// instead of the opaque "Exec format error (os error 8)".
297async fn check_binary_arch(path: &Path) -> Result<()> {
298    let header = tokio::fs::read(path)
299        .await
300        .map(|bytes| bytes.into_iter().take(32).collect::<Vec<u8>>())
301        .context("failed to read binary header")?;
302
303    if header.len() < 20 {
304        bail!("downloaded file too small to be a valid binary");
305    }
306
307    let binary_arch = detect_arch_from_header(&header);
308    let host_arch = host_architecture();
309
310    if let (Some(bin), Some(host)) = (binary_arch, host_arch) {
311        if bin != host {
312            bail!(
313                "architecture mismatch: downloaded binary is {bin} but this host is {host} — \
314                 the release asset may be mispackaged"
315            );
316        }
317    }
318
319    Ok(())
320}
321
322/// Detect the CPU architecture from an ELF or Mach-O binary header.
323fn detect_arch_from_header(header: &[u8]) -> Option<&'static str> {
324    // ELF magic: 0x7f 'E' 'L' 'F'
325    if header.len() >= 20 && header[0..4] == [0x7f, b'E', b'L', b'F'] {
326        // e_machine is at offset 18 (2 bytes, little-endian for LE binaries)
327        let e_machine = u16::from_le_bytes([header[18], header[19]]);
328        return Some(match e_machine {
329            0x3E => "x86_64",
330            0xB7 => "aarch64",
331            0x03 => "x86",
332            0x28 => "arm",
333            0xF3 => "riscv",
334            _ => "unknown-elf",
335        });
336    }
337
338    // Mach-O magic (64-bit little-endian): 0xFEEDFACF
339    if header.len() >= 8 && header[0..4] == [0xCF, 0xFA, 0xED, 0xFE] {
340        let cputype = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
341        return Some(match cputype {
342            0x0100_0007 => "x86_64",
343            0x0100_000C => "aarch64",
344            _ => "unknown-macho",
345        });
346    }
347
348    None
349}
350
351/// Return the host CPU architecture as a human-readable string.
352fn host_architecture() -> Option<&'static str> {
353    if cfg!(target_arch = "x86_64") {
354        Some("x86_64")
355    } else if cfg!(target_arch = "aarch64") {
356        Some("aarch64")
357    } else if cfg!(target_arch = "x86") {
358        Some("x86")
359    } else if cfg!(target_arch = "arm") {
360        Some("arm")
361    } else {
362        None
363    }
364}
365
366async fn swap_binary(new: &Path, target: &Path) -> Result<()> {
367    // On Linux, a running binary cannot be overwritten in place (ETXTBSY).
368    // Remove the old file first, then copy the new one into the now-free path.
369    // This works because the kernel keeps the inode alive until the process exits.
370    tokio::fs::remove_file(target)
371        .await
372        .context("failed to remove old binary")?;
373    tokio::fs::copy(new, target)
374        .await
375        .context("failed to write new binary")?;
376    Ok(())
377}
378
379async fn rollback_binary(backup: &Path, target: &Path) -> Result<()> {
380    // Remove-then-copy to avoid ETXTBSY if the target is somehow still mapped.
381    let _ = tokio::fs::remove_file(target).await;
382    tokio::fs::copy(backup, target)
383        .await
384        .context("failed to restore backup binary")?;
385    Ok(())
386}
387
388async fn smoke_test(binary: &Path) -> Result<()> {
389    let output = tokio::process::Command::new(binary)
390        .arg("--version")
391        .output()
392        .await
393        .context("smoke test: cannot execute updated binary")?;
394
395    if !output.status.success() {
396        bail!("smoke test: updated binary returned non-zero exit code");
397    }
398
399    Ok(())
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_version_comparison() {
408        assert!(version_is_newer("0.4.3", "0.5.0"));
409        assert!(version_is_newer("0.4.3", "0.4.4"));
410        assert!(!version_is_newer("0.5.0", "0.4.3"));
411        assert!(!version_is_newer("0.4.3", "0.4.3"));
412        assert!(version_is_newer("1.0.0", "2.0.0"));
413    }
414
415    #[test]
416    fn current_target_triple_is_not_empty() {
417        let triple = current_target_triple();
418        assert_ne!(triple, "unknown", "unsupported platform");
419        // The triple must contain at least two hyphens (arch-vendor-os or arch-vendor-os-env)
420        assert!(
421            triple.matches('-').count() >= 2,
422            "triple should have at least two hyphens: {triple}"
423        );
424    }
425
426    fn make_release(assets: &[&str]) -> serde_json::Value {
427        let assets: Vec<serde_json::Value> = assets
428            .iter()
429            .map(|name| {
430                serde_json::json!({
431                    "name": name,
432                    "browser_download_url": format!("https://example.com/{name}")
433                })
434            })
435            .collect();
436        serde_json::json!({ "assets": assets })
437    }
438
439    #[test]
440    fn find_asset_url_picks_correct_gnu_over_android() {
441        let release = make_release(&[
442            "construct-aarch64-linux-android.tar.gz",
443            "construct-aarch64-unknown-linux-gnu.tar.gz",
444            "construct-x86_64-unknown-linux-gnu.tar.gz",
445            "construct-x86_64-apple-darwin.tar.gz",
446            "construct-aarch64-apple-darwin.tar.gz",
447        ]);
448
449        let url = find_asset_url(&release);
450        assert!(url.is_some(), "should find an asset");
451        let url = url.unwrap();
452        // Must NOT match the android binary
453        assert!(
454            !url.contains("android"),
455            "should not select android binary, got: {url}"
456        );
457    }
458
459    #[test]
460    fn find_asset_url_returns_none_for_empty_assets() {
461        let release = serde_json::json!({ "assets": [] });
462        assert!(find_asset_url(&release).is_none());
463    }
464
465    #[test]
466    fn find_asset_url_returns_none_for_missing_assets() {
467        let release = serde_json::json!({});
468        assert!(find_asset_url(&release).is_none());
469    }
470
471    #[test]
472    fn detect_arch_elf_x86_64() {
473        // Minimal ELF header with e_machine = 0x3E (x86_64)
474        let mut header = vec![0u8; 20];
475        header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
476        header[18] = 0x3E;
477        header[19] = 0x00;
478        assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
479    }
480
481    #[test]
482    fn detect_arch_elf_aarch64() {
483        let mut header = vec![0u8; 20];
484        header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
485        header[18] = 0xB7;
486        header[19] = 0x00;
487        assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
488    }
489
490    #[test]
491    fn detect_arch_macho_x86_64() {
492        // Mach-O 64-bit LE magic + cputype 0x01000007 (x86_64)
493        let mut header = vec![0u8; 8];
494        header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
495        header[4..8].copy_from_slice(&0x0100_0007u32.to_le_bytes());
496        assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
497    }
498
499    #[test]
500    fn detect_arch_macho_aarch64() {
501        let mut header = vec![0u8; 8];
502        header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
503        header[4..8].copy_from_slice(&0x0100_000Cu32.to_le_bytes());
504        assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
505    }
506
507    #[test]
508    fn detect_arch_unknown_format() {
509        let header = vec![0u8; 20]; // all zeros — not ELF or Mach-O
510        assert_eq!(detect_arch_from_header(&header), None);
511    }
512
513    #[test]
514    fn detect_arch_too_short() {
515        let header = vec![0x7f, b'E', b'L', b'F']; // only 4 bytes
516        assert_eq!(detect_arch_from_header(&header), None);
517    }
518
519    #[test]
520    fn host_architecture_is_known() {
521        assert!(
522            host_architecture().is_some(),
523            "host architecture should be detected on CI platforms"
524        );
525    }
526
527    #[test]
528    fn extract_tar_gz_finds_binary() {
529        use flate2::Compression;
530        use flate2::write::GzEncoder;
531        use std::io::Write;
532
533        // Build a tar.gz in memory containing a fake "construct" binary.
534        let fake_binary = b"#!/bin/sh\necho construct";
535        let mut tar_buf = Vec::new();
536        {
537            let mut builder = tar::Builder::new(&mut tar_buf);
538            let mut header = tar::Header::new_gnu();
539            header.set_size(fake_binary.len() as u64);
540            header.set_mode(0o755);
541            header.set_cksum();
542            builder
543                .append_data(&mut header, "construct", &fake_binary[..])
544                .unwrap();
545            builder.finish().unwrap();
546        }
547
548        let mut gz_buf = Vec::new();
549        {
550            let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
551            encoder.write_all(&tar_buf).unwrap();
552            encoder.finish().unwrap();
553        }
554
555        let tmp = tempfile::tempdir().unwrap();
556        let dest = tmp.path().join("construct_extracted");
557        extract_tar_gz(&gz_buf, &dest).unwrap();
558
559        let content = std::fs::read(&dest).unwrap();
560        assert_eq!(content, fake_binary);
561    }
562
563    #[test]
564    fn extract_tar_gz_errors_on_missing_binary() {
565        use flate2::Compression;
566        use flate2::write::GzEncoder;
567        use std::io::Write;
568
569        // Build a tar.gz with a file that is NOT named "construct".
570        let mut tar_buf = Vec::new();
571        {
572            let mut builder = tar::Builder::new(&mut tar_buf);
573            let mut header = tar::Header::new_gnu();
574            header.set_size(5);
575            header.set_mode(0o644);
576            header.set_cksum();
577            builder
578                .append_data(&mut header, "README.md", &b"hello"[..])
579                .unwrap();
580            builder.finish().unwrap();
581        }
582
583        let mut gz_buf = Vec::new();
584        {
585            let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
586            encoder.write_all(&tar_buf).unwrap();
587            encoder.finish().unwrap();
588        }
589
590        let tmp = tempfile::tempdir().unwrap();
591        let dest = tmp.path().join("construct_extracted");
592        let result = extract_tar_gz(&gz_buf, &dest);
593        assert!(result.is_err());
594        assert!(
595            result.unwrap_err().to_string().contains("does not contain"),
596            "should report missing binary"
597        );
598    }
599}