1use anyhow::{Context, Result, bail};
4use std::path::Path;
5use tracing::{info, warn};
6
7const GITHUB_RELEASES_LATEST_URL: &str =
8 "https://api.github.com/repos/KumihoIO/construct/releases/latest";
9const GITHUB_RELEASES_TAG_URL: &str =
10 "https://api.github.com/repos/KumihoIO/construct/releases/tags";
11
12#[derive(Debug)]
13pub struct UpdateInfo {
14 pub current_version: String,
15 pub latest_version: String,
16 pub download_url: Option<String>,
17 pub is_newer: bool,
18}
19
20pub async fn check(target_version: Option<&str>) -> Result<UpdateInfo> {
24 let current = env!("CARGO_PKG_VERSION").to_string();
25
26 let client = reqwest::Client::builder()
27 .user_agent(format!("construct/{current}"))
28 .timeout(std::time::Duration::from_secs(15))
29 .build()?;
30
31 let url = match target_version {
32 Some(v) => {
33 let tag = if v.starts_with('v') {
34 v.to_string()
35 } else {
36 format!("v{v}")
37 };
38 format!("{GITHUB_RELEASES_TAG_URL}/{tag}")
39 }
40 None => GITHUB_RELEASES_LATEST_URL.to_string(),
41 };
42
43 let resp = client
44 .get(&url)
45 .send()
46 .await
47 .context("failed to reach GitHub releases API")?;
48
49 if !resp.status().is_success() {
50 bail!("GitHub API returned {}", resp.status());
51 }
52
53 let release: serde_json::Value = resp.json().await?;
54 let tag = release["tag_name"]
55 .as_str()
56 .unwrap_or("unknown")
57 .trim_start_matches('v')
58 .to_string();
59
60 let download_url = find_asset_url(&release);
61 let is_newer = version_is_newer(¤t, &tag);
62
63 Ok(UpdateInfo {
64 current_version: current,
65 latest_version: tag,
66 download_url,
67 is_newer,
68 })
69}
70
71pub async fn run(target_version: Option<&str>) -> Result<()> {
75 info!("Phase 1/6: Preflight checks...");
77 let update_info = check(target_version).await?;
78
79 if !update_info.is_newer {
80 println!("Already up to date (v{}).", update_info.current_version);
81 return Ok(());
82 }
83
84 println!(
85 "Update available: v{} -> v{}",
86 update_info.current_version, update_info.latest_version
87 );
88
89 let download_url = update_info
90 .download_url
91 .context("no suitable binary found for this platform")?;
92
93 let current_exe =
94 std::env::current_exe().context("cannot determine current executable path")?;
95
96 info!("Phase 2/6: Downloading...");
98 let temp_dir = tempfile::tempdir().context("failed to create temp dir")?;
99 let download_path = temp_dir.path().join("construct_new");
100 download_binary(&download_url, &download_path).await?;
101
102 info!("Phase 3/6: Creating backup...");
104 let backup_path = current_exe.with_extension("bak");
105 tokio::fs::copy(¤t_exe, &backup_path)
106 .await
107 .context("failed to backup current binary")?;
108
109 info!("Phase 4/6: Validating download...");
111 validate_binary(&download_path).await?;
112
113 info!("Phase 5/6: Swapping binary...");
115 if let Err(e) = swap_binary(&download_path, ¤t_exe).await {
116 warn!("Swap failed, rolling back: {e}");
118 if let Err(rollback_err) = rollback_binary(&backup_path, ¤t_exe).await {
119 eprintln!("CRITICAL: Rollback also failed: {rollback_err}");
120 eprintln!(
121 "Manual recovery: cp {} {}",
122 backup_path.display(),
123 current_exe.display()
124 );
125 }
126 bail!("Update failed during swap: {e}");
127 }
128
129 info!("Phase 6/6: Smoke test...");
131 match smoke_test(¤t_exe).await {
132 Ok(()) => {
133 let _ = tokio::fs::remove_file(&backup_path).await;
135 println!("Successfully updated to v{}!", update_info.latest_version);
136 Ok(())
137 }
138 Err(e) => {
139 warn!("Smoke test failed, rolling back: {e}");
140 rollback_binary(&backup_path, ¤t_exe)
141 .await
142 .context("rollback after smoke test failure")?;
143 bail!("Update rolled back — smoke test failed: {e}");
144 }
145 }
146}
147
148fn find_asset_url(release: &serde_json::Value) -> Option<String> {
149 let target = current_target_triple();
150
151 release["assets"]
152 .as_array()?
153 .iter()
154 .find(|asset| {
155 asset["name"]
156 .as_str()
157 .map(|name| name.contains(target))
158 .unwrap_or(false)
159 })
160 .and_then(|asset| asset["browser_download_url"].as_str().map(String::from))
161}
162
163fn current_target_triple() -> &'static str {
169 if cfg!(target_os = "macos") {
170 if cfg!(target_arch = "aarch64") {
171 "aarch64-apple-darwin"
172 } else {
173 "x86_64-apple-darwin"
174 }
175 } else if cfg!(target_os = "linux") {
176 if cfg!(target_arch = "aarch64") {
177 "aarch64-unknown-linux-gnu"
178 } else {
179 "x86_64-unknown-linux-gnu"
180 }
181 } else {
182 "unknown"
183 }
184}
185
186fn version_is_newer(current: &str, candidate: &str) -> bool {
187 let parse = |v: &str| -> Vec<u32> { v.split('.').filter_map(|p| p.parse().ok()).collect() };
188 let cur = parse(current);
189 let cand = parse(candidate);
190 cand > cur
191}
192
193async fn download_binary(url: &str, dest: &Path) -> Result<()> {
194 let client = reqwest::Client::builder()
195 .user_agent(format!("construct/{}", env!("CARGO_PKG_VERSION")))
196 .timeout(std::time::Duration::from_secs(300))
197 .build()?;
198
199 let resp = client
200 .get(url)
201 .send()
202 .await
203 .context("download request failed")?;
204 if !resp.status().is_success() {
205 bail!("download returned {}", resp.status());
206 }
207
208 let bytes = resp.bytes().await.context("failed to read download body")?;
209
210 if url.ends_with(".tar.gz") || url.ends_with(".tgz") {
213 extract_tar_gz(&bytes, dest).context("failed to extract binary from tar.gz archive")?;
214 } else {
215 tokio::fs::write(dest, &bytes)
216 .await
217 .context("failed to write downloaded binary")?;
218 }
219
220 #[cfg(unix)]
222 {
223 use std::os::unix::fs::PermissionsExt;
224 let perms = std::fs::Permissions::from_mode(0o755);
225 tokio::fs::set_permissions(dest, perms).await?;
226 }
227
228 Ok(())
229}
230
231fn extract_tar_gz(archive_bytes: &[u8], dest: &Path) -> Result<()> {
233 use flate2::read::GzDecoder;
234 use std::io::Read;
235 use tar::Archive;
236
237 let gz = GzDecoder::new(archive_bytes);
238 let mut archive = Archive::new(gz);
239
240 for entry in archive.entries().context("failed to read tar entries")? {
241 let mut entry = entry.context("failed to read tar entry")?;
242 let path = entry.path().context("failed to read entry path")?;
243
244 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
246
247 if file_name == "construct" || file_name == "construct.exe" {
248 let mut buf = Vec::new();
249 entry
250 .read_to_end(&mut buf)
251 .context("failed to read binary from archive")?;
252 std::fs::write(dest, &buf).context("failed to write extracted binary")?;
253 return Ok(());
254 }
255 }
256
257 bail!("archive does not contain a 'construct' binary")
258}
259
260async fn validate_binary(path: &Path) -> Result<()> {
261 let meta = tokio::fs::metadata(path).await?;
262 if meta.len() < 1_000_000 {
263 bail!(
264 "downloaded binary too small ({} bytes), likely corrupt",
265 meta.len()
266 );
267 }
268
269 check_binary_arch(path).await?;
272
273 let output = tokio::process::Command::new(path)
275 .arg("--version")
276 .output()
277 .await
278 .context("cannot execute downloaded binary")?;
279
280 if !output.status.success() {
281 bail!("downloaded binary --version check failed");
282 }
283
284 let stdout = String::from_utf8_lossy(&output.stdout);
285 if !stdout.contains("construct") {
286 bail!("downloaded binary does not appear to be construct");
287 }
288
289 Ok(())
290}
291
292async fn check_binary_arch(path: &Path) -> Result<()> {
298 let header = tokio::fs::read(path)
299 .await
300 .map(|bytes| bytes.into_iter().take(32).collect::<Vec<u8>>())
301 .context("failed to read binary header")?;
302
303 if header.len() < 20 {
304 bail!("downloaded file too small to be a valid binary");
305 }
306
307 let binary_arch = detect_arch_from_header(&header);
308 let host_arch = host_architecture();
309
310 if let (Some(bin), Some(host)) = (binary_arch, host_arch) {
311 if bin != host {
312 bail!(
313 "architecture mismatch: downloaded binary is {bin} but this host is {host} — \
314 the release asset may be mispackaged"
315 );
316 }
317 }
318
319 Ok(())
320}
321
322fn detect_arch_from_header(header: &[u8]) -> Option<&'static str> {
324 if header.len() >= 20 && header[0..4] == [0x7f, b'E', b'L', b'F'] {
326 let e_machine = u16::from_le_bytes([header[18], header[19]]);
328 return Some(match e_machine {
329 0x3E => "x86_64",
330 0xB7 => "aarch64",
331 0x03 => "x86",
332 0x28 => "arm",
333 0xF3 => "riscv",
334 _ => "unknown-elf",
335 });
336 }
337
338 if header.len() >= 8 && header[0..4] == [0xCF, 0xFA, 0xED, 0xFE] {
340 let cputype = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
341 return Some(match cputype {
342 0x0100_0007 => "x86_64",
343 0x0100_000C => "aarch64",
344 _ => "unknown-macho",
345 });
346 }
347
348 None
349}
350
351fn host_architecture() -> Option<&'static str> {
353 if cfg!(target_arch = "x86_64") {
354 Some("x86_64")
355 } else if cfg!(target_arch = "aarch64") {
356 Some("aarch64")
357 } else if cfg!(target_arch = "x86") {
358 Some("x86")
359 } else if cfg!(target_arch = "arm") {
360 Some("arm")
361 } else {
362 None
363 }
364}
365
366async fn swap_binary(new: &Path, target: &Path) -> Result<()> {
367 tokio::fs::remove_file(target)
371 .await
372 .context("failed to remove old binary")?;
373 tokio::fs::copy(new, target)
374 .await
375 .context("failed to write new binary")?;
376 Ok(())
377}
378
379async fn rollback_binary(backup: &Path, target: &Path) -> Result<()> {
380 let _ = tokio::fs::remove_file(target).await;
382 tokio::fs::copy(backup, target)
383 .await
384 .context("failed to restore backup binary")?;
385 Ok(())
386}
387
388async fn smoke_test(binary: &Path) -> Result<()> {
389 let output = tokio::process::Command::new(binary)
390 .arg("--version")
391 .output()
392 .await
393 .context("smoke test: cannot execute updated binary")?;
394
395 if !output.status.success() {
396 bail!("smoke test: updated binary returned non-zero exit code");
397 }
398
399 Ok(())
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_version_comparison() {
408 assert!(version_is_newer("0.4.3", "0.5.0"));
409 assert!(version_is_newer("0.4.3", "0.4.4"));
410 assert!(!version_is_newer("0.5.0", "0.4.3"));
411 assert!(!version_is_newer("0.4.3", "0.4.3"));
412 assert!(version_is_newer("1.0.0", "2.0.0"));
413 }
414
415 #[test]
416 fn current_target_triple_is_not_empty() {
417 let triple = current_target_triple();
418 assert_ne!(triple, "unknown", "unsupported platform");
419 assert!(
421 triple.matches('-').count() >= 2,
422 "triple should have at least two hyphens: {triple}"
423 );
424 }
425
426 fn make_release(assets: &[&str]) -> serde_json::Value {
427 let assets: Vec<serde_json::Value> = assets
428 .iter()
429 .map(|name| {
430 serde_json::json!({
431 "name": name,
432 "browser_download_url": format!("https://example.com/{name}")
433 })
434 })
435 .collect();
436 serde_json::json!({ "assets": assets })
437 }
438
439 #[test]
440 fn find_asset_url_picks_correct_gnu_over_android() {
441 let release = make_release(&[
442 "construct-aarch64-linux-android.tar.gz",
443 "construct-aarch64-unknown-linux-gnu.tar.gz",
444 "construct-x86_64-unknown-linux-gnu.tar.gz",
445 "construct-x86_64-apple-darwin.tar.gz",
446 "construct-aarch64-apple-darwin.tar.gz",
447 ]);
448
449 let url = find_asset_url(&release);
450 assert!(url.is_some(), "should find an asset");
451 let url = url.unwrap();
452 assert!(
454 !url.contains("android"),
455 "should not select android binary, got: {url}"
456 );
457 }
458
459 #[test]
460 fn find_asset_url_returns_none_for_empty_assets() {
461 let release = serde_json::json!({ "assets": [] });
462 assert!(find_asset_url(&release).is_none());
463 }
464
465 #[test]
466 fn find_asset_url_returns_none_for_missing_assets() {
467 let release = serde_json::json!({});
468 assert!(find_asset_url(&release).is_none());
469 }
470
471 #[test]
472 fn detect_arch_elf_x86_64() {
473 let mut header = vec![0u8; 20];
475 header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
476 header[18] = 0x3E;
477 header[19] = 0x00;
478 assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
479 }
480
481 #[test]
482 fn detect_arch_elf_aarch64() {
483 let mut header = vec![0u8; 20];
484 header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
485 header[18] = 0xB7;
486 header[19] = 0x00;
487 assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
488 }
489
490 #[test]
491 fn detect_arch_macho_x86_64() {
492 let mut header = vec![0u8; 8];
494 header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
495 header[4..8].copy_from_slice(&0x0100_0007u32.to_le_bytes());
496 assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
497 }
498
499 #[test]
500 fn detect_arch_macho_aarch64() {
501 let mut header = vec![0u8; 8];
502 header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
503 header[4..8].copy_from_slice(&0x0100_000Cu32.to_le_bytes());
504 assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
505 }
506
507 #[test]
508 fn detect_arch_unknown_format() {
509 let header = vec![0u8; 20]; assert_eq!(detect_arch_from_header(&header), None);
511 }
512
513 #[test]
514 fn detect_arch_too_short() {
515 let header = vec![0x7f, b'E', b'L', b'F']; assert_eq!(detect_arch_from_header(&header), None);
517 }
518
519 #[test]
520 fn host_architecture_is_known() {
521 assert!(
522 host_architecture().is_some(),
523 "host architecture should be detected on CI platforms"
524 );
525 }
526
527 #[test]
528 fn extract_tar_gz_finds_binary() {
529 use flate2::Compression;
530 use flate2::write::GzEncoder;
531 use std::io::Write;
532
533 let fake_binary = b"#!/bin/sh\necho construct";
535 let mut tar_buf = Vec::new();
536 {
537 let mut builder = tar::Builder::new(&mut tar_buf);
538 let mut header = tar::Header::new_gnu();
539 header.set_size(fake_binary.len() as u64);
540 header.set_mode(0o755);
541 header.set_cksum();
542 builder
543 .append_data(&mut header, "construct", &fake_binary[..])
544 .unwrap();
545 builder.finish().unwrap();
546 }
547
548 let mut gz_buf = Vec::new();
549 {
550 let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
551 encoder.write_all(&tar_buf).unwrap();
552 encoder.finish().unwrap();
553 }
554
555 let tmp = tempfile::tempdir().unwrap();
556 let dest = tmp.path().join("construct_extracted");
557 extract_tar_gz(&gz_buf, &dest).unwrap();
558
559 let content = std::fs::read(&dest).unwrap();
560 assert_eq!(content, fake_binary);
561 }
562
563 #[test]
564 fn extract_tar_gz_errors_on_missing_binary() {
565 use flate2::Compression;
566 use flate2::write::GzEncoder;
567 use std::io::Write;
568
569 let mut tar_buf = Vec::new();
571 {
572 let mut builder = tar::Builder::new(&mut tar_buf);
573 let mut header = tar::Header::new_gnu();
574 header.set_size(5);
575 header.set_mode(0o644);
576 header.set_cksum();
577 builder
578 .append_data(&mut header, "README.md", &b"hello"[..])
579 .unwrap();
580 builder.finish().unwrap();
581 }
582
583 let mut gz_buf = Vec::new();
584 {
585 let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
586 encoder.write_all(&tar_buf).unwrap();
587 encoder.finish().unwrap();
588 }
589
590 let tmp = tempfile::tempdir().unwrap();
591 let dest = tmp.path().join("construct_extracted");
592 let result = extract_tar_gz(&gz_buf, &dest);
593 assert!(result.is_err());
594 assert!(
595 result.unwrap_err().to_string().contains("does not contain"),
596 "should report missing binary"
597 );
598 }
599}