1use anyhow::{bail, Context, Result};
7use serde::{Deserialize, Serialize};
8use std::env;
9use std::fs;
10use std::io::{self, Write};
11use std::path::{Path, PathBuf};
12use std::process::Command;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15const GITHUB_REPO: &str = "8b-is/smart-tree";
17
18const GITHUB_RELEASES_API: &str = "https://api.github.com/repos/8b-is/smart-tree/releases/latest";
20
21const UPDATE_CHECK_INTERVAL_SECS: u64 = 86400;
23
24const BINARIES: &[&str] = &["st", "mq", "m8", "n8x"];
27
28const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
30
31#[derive(Debug, Deserialize)]
33struct GitHubRelease {
34 tag_name: String,
35 assets: Vec<GitHubAsset>,
36}
37
38#[derive(Debug, Deserialize)]
39struct GitHubAsset {
40 name: String,
41 browser_download_url: String,
42}
43
44#[derive(Debug, Default, Serialize, Deserialize)]
46struct UpdateCache {
47 #[serde(default)]
48 last_check: u64,
49 #[serde(default)]
50 latest_version: Option<String>,
51}
52
53fn get_cache_path() -> Result<PathBuf> {
55 let home = dirs::home_dir().context("Could not find home directory")?;
56 let st_dir = home.join(".st");
57 fs::create_dir_all(&st_dir)?;
58 Ok(st_dir.join("update_check.json"))
59}
60
61fn load_cache() -> UpdateCache {
63 let cache_path = match get_cache_path() {
64 Ok(p) => p,
65 Err(_) => return UpdateCache::default(),
66 };
67
68 match fs::read_to_string(&cache_path) {
69 Ok(contents) => serde_json::from_str(&contents).unwrap_or_default(),
70 Err(_) => UpdateCache::default(),
71 }
72}
73
74fn save_cache(cache: &UpdateCache) -> Result<()> {
76 let cache_path = get_cache_path()?;
77 let contents = serde_json::to_string_pretty(cache)?;
78 fs::write(&cache_path, contents)?;
79 Ok(())
80}
81
82fn now_secs() -> u64 {
84 SystemTime::now()
85 .duration_since(UNIX_EPOCH)
86 .map(|d| d.as_secs())
87 .unwrap_or(0)
88}
89
90pub fn should_check_update() -> bool {
92 let cache = load_cache();
93 let now = now_secs();
94 now.saturating_sub(cache.last_check) > UPDATE_CHECK_INTERVAL_SECS
95}
96
97fn is_newer_version(current: &str, latest: &str) -> bool {
99 let current = current.strip_prefix('v').unwrap_or(current);
101 let latest = latest.strip_prefix('v').unwrap_or(latest);
102
103 let parse_version = |v: &str| -> (u32, u32, u32) {
104 let parts: Vec<u32> = v.split('.').filter_map(|p| p.parse().ok()).collect();
105 (
106 parts.first().copied().unwrap_or(0),
107 parts.get(1).copied().unwrap_or(0),
108 parts.get(2).copied().unwrap_or(0),
109 )
110 };
111
112 let (curr_major, curr_minor, curr_patch) = parse_version(current);
113 let (lat_major, lat_minor, lat_patch) = parse_version(latest);
114
115 (lat_major, lat_minor, lat_patch) > (curr_major, curr_minor, curr_patch)
116}
117
118pub async fn check_for_update() -> Result<Option<String>> {
120 let client = reqwest::Client::builder()
121 .user_agent("smart-tree-updater")
122 .timeout(Duration::from_secs(10))
123 .build()?;
124
125 let response: GitHubRelease = client
126 .get(GITHUB_RELEASES_API)
127 .send()
128 .await
129 .context("Failed to connect to GitHub")?
130 .json()
131 .await
132 .context("Failed to parse GitHub response")?;
133
134 let mut cache = load_cache();
136 cache.last_check = now_secs();
137 cache.latest_version = Some(response.tag_name.clone());
138 let _ = save_cache(&cache);
139
140 let latest = response.tag_name;
141 if is_newer_version(CURRENT_VERSION, &latest) {
142 Ok(Some(latest))
143 } else {
144 Ok(None)
145 }
146}
147
148pub async fn check_for_update_cached() -> Option<String> {
150 let cache = load_cache();
151
152 if should_check_update() {
153 match check_for_update().await {
155 Ok(Some(version)) => Some(version),
156 Ok(None) => None,
157 Err(_) => None, }
159 } else {
160 cache
162 .latest_version
163 .filter(|v| is_newer_version(CURRENT_VERSION, v))
164 }
165}
166
167pub fn print_update_banner(latest_version: &str) {
169 let current = format!("v{}", CURRENT_VERSION);
170 eprintln!();
171 eprintln!("\x1b[36m╭─────────────────────────────────────────────────────╮\x1b[0m");
172 eprintln!(
173 "\x1b[36m│\x1b[0m \x1b[32m🌳 Smart Tree {} is available!\x1b[0m (you have {})",
174 latest_version, current
175 );
176 eprintln!("\x1b[36m│\x1b[0m Run '\x1b[1mst --update\x1b[0m' to upgrade");
177 eprintln!("\x1b[36m╰─────────────────────────────────────────────────────╯\x1b[0m");
178 eprintln!();
179}
180
181fn get_platform() -> Result<(&'static str, &'static str)> {
183 let os = if cfg!(target_os = "macos") {
184 "apple-darwin"
185 } else if cfg!(target_os = "linux") {
186 "unknown-linux-gnu"
187 } else if cfg!(target_os = "windows") {
188 "pc-windows-msvc"
189 } else {
190 bail!("Unsupported operating system");
191 };
192
193 let arch = if cfg!(target_arch = "x86_64") {
194 "x86_64"
195 } else if cfg!(target_arch = "aarch64") {
196 "aarch64"
197 } else {
198 bail!("Unsupported architecture");
199 };
200
201 Ok((arch, os))
202}
203
204fn create_temp_dir() -> Result<PathBuf> {
206 let base = env::temp_dir();
207 let unique_name = format!("st-update-{}", now_secs());
208 let temp_dir = base.join(unique_name);
209 fs::create_dir_all(&temp_dir).context("Failed to create temp directory")?;
210 Ok(temp_dir)
211}
212
213fn cleanup_temp_dir(path: &Path) {
215 let _ = fs::remove_dir_all(path);
216}
217
218fn find_install_dir() -> Result<PathBuf> {
220 let current_exe = env::current_exe().context("Could not determine current executable path")?;
222 let install_dir = current_exe
223 .parent()
224 .context("Could not determine installation directory")?
225 .to_path_buf();
226
227 Ok(install_dir)
228}
229
230fn needs_sudo(install_dir: &Path) -> bool {
232 #[cfg(unix)]
233 {
234 use std::os::unix::fs::MetadataExt;
235 if let Ok(meta) = install_dir.metadata() {
236 let uid = unsafe { libc::getuid() };
238 if meta.uid() != uid {
239 return fs::metadata(install_dir)
241 .and_then(|_| {
242 fs::OpenOptions::new()
243 .write(true)
244 .open(install_dir.join(".test_write"))
245 })
246 .is_err();
247 }
248 }
249 false
250 }
251 #[cfg(not(unix))]
252 {
253 false
254 }
255}
256
257pub async fn download_and_install(version: &str, yes: bool) -> Result<()> {
259 let (arch, os) = get_platform()?;
260 let install_dir = find_install_dir()?;
261
262 println!("\x1b[36m🌳 Smart Tree Updater\x1b[0m");
263 println!();
264 println!("Current version: v{}", CURRENT_VERSION);
265 println!("Latest version: {}", version);
266 println!("Install path: {}", install_dir.display());
267 println!("Binaries: {}", BINARIES.join(", "));
268 println!();
269
270 if !yes {
271 print!("Proceed with update? [Y/n] ");
272 io::stdout().flush()?;
273
274 let mut input = String::new();
275 io::stdin().read_line(&mut input)?;
276 let input = input.trim().to_lowercase();
277
278 if !input.is_empty() && input != "y" && input != "yes" {
279 println!("Update cancelled.");
280 return Ok(());
281 }
282 }
283
284 let use_sudo = needs_sudo(&install_dir);
285 if use_sudo {
286 println!("\x1b[33m⚠ Installation directory requires elevated permissions.\x1b[0m");
287 println!(" You may be prompted for your password.\n");
288 }
289
290 let ext = if cfg!(target_os = "windows") {
292 "zip"
293 } else {
294 "tar.gz"
295 };
296 let archive_name = format!("st-{}-{}-{}.{}", version, arch, os, ext);
297 let download_url = format!(
298 "https://github.com/{}/releases/download/{}/{}",
299 GITHUB_REPO, version, archive_name
300 );
301
302 println!("Downloading {}...", archive_name);
303
304 let temp_dir = create_temp_dir()?;
306 let archive_path = temp_dir.join(&archive_name);
307
308 let client = reqwest::Client::builder()
310 .user_agent("smart-tree-updater")
311 .timeout(Duration::from_secs(300))
312 .build()?;
313
314 let response = client
315 .get(&download_url)
316 .send()
317 .await
318 .context("Failed to download release")?;
319
320 if !response.status().is_success() {
321 bail!("Download failed: HTTP {}", response.status());
322 }
323
324 let bytes = response.bytes().await?;
325 fs::write(&archive_path, &bytes)?;
326
327 println!("Extracting...");
328
329 #[cfg(unix)]
331 {
332 let output = Command::new("tar")
333 .args(["-xzf", archive_path.to_str().unwrap()])
334 .current_dir(&temp_dir)
335 .output()
336 .context("Failed to extract archive")?;
337
338 if !output.status.success() {
339 bail!(
340 "Failed to extract archive: {}",
341 String::from_utf8_lossy(&output.stderr)
342 );
343 }
344 }
345
346 #[cfg(windows)]
347 {
348 let output = Command::new("powershell")
350 .args([
351 "-Command",
352 &format!(
353 "Expand-Archive -Path '{}' -DestinationPath '{}' -Force",
354 archive_path.display(),
355 &temp_dir.display()
356 ),
357 ])
358 .output()
359 .context("Failed to extract archive")?;
360
361 if !output.status.success() {
362 bail!(
363 "Failed to extract archive: {}",
364 String::from_utf8_lossy(&output.stderr)
365 );
366 }
367 }
368
369 println!("Installing binaries...");
371
372 let mut installed_count = 0;
373 for binary in BINARIES {
374 let binary_name = if cfg!(windows) {
375 format!("{}.exe", binary)
376 } else {
377 binary.to_string()
378 };
379
380 let src_path = match find_binary_in_dir(&temp_dir, &binary_name) {
382 Ok(path) => path,
383 Err(_) => {
384 println!(" \x1b[33m⚠\x1b[0m {} (not in archive, skipping)", binary);
386 continue;
387 }
388 };
389 let dest_path = install_dir.join(&binary_name);
390
391 #[cfg(unix)]
393 {
394 if use_sudo {
395 let _ = Command::new("sudo")
396 .args(["rm", "-f", dest_path.to_str().unwrap()])
397 .status();
398
399 Command::new("sudo")
400 .args([
401 "cp",
402 src_path.to_str().unwrap(),
403 dest_path.to_str().unwrap(),
404 ])
405 .status()
406 .context(format!("Failed to install {}", binary))?;
407
408 Command::new("sudo")
409 .args(["chmod", "+x", dest_path.to_str().unwrap()])
410 .status()?;
411 } else {
412 let _ = fs::remove_file(&dest_path);
413 fs::copy(&src_path, &dest_path).context(format!("Failed to install {}", binary))?;
414
415 use std::os::unix::fs::PermissionsExt;
417 let mut perms = fs::metadata(&dest_path)?.permissions();
418 perms.set_mode(0o755);
419 fs::set_permissions(&dest_path, perms)?;
420 }
421 }
422
423 #[cfg(windows)]
424 {
425 let old_path = install_dir.join(format!("{}.old", binary_name));
427 let _ = fs::remove_file(&old_path);
428 let _ = fs::rename(&dest_path, &old_path);
429
430 fs::copy(&src_path, &dest_path).context(format!("Failed to install {}", binary))?;
431 }
432
433 println!(" \x1b[32m✓\x1b[0m {}", binary);
434 installed_count += 1;
435 }
436
437 if installed_count == 0 {
439 bail!("No binaries were installed from the archive");
440 }
441
442 let mut cache = load_cache();
444 cache.latest_version = Some(version.to_string());
445 let _ = save_cache(&cache);
446
447 cleanup_temp_dir(&temp_dir);
449
450 println!();
451 println!("\x1b[32m✨ Successfully updated to {}!\x1b[0m", version);
452
453 #[cfg(windows)]
454 {
455 println!();
456 println!(
457 "\x1b[33mNote: Please restart your terminal for the update to take effect.\x1b[0m"
458 );
459 }
460
461 Ok(())
462}
463
464fn find_binary_in_dir(dir: &Path, binary_name: &str) -> Result<PathBuf> {
466 let root_path = dir.join(binary_name);
468 if root_path.exists() {
469 return Ok(root_path);
470 }
471
472 for entry in fs::read_dir(dir)? {
474 let entry = entry?;
475 let path = entry.path();
476 if path.is_dir() {
477 let nested = path.join(binary_name);
478 if nested.exists() {
479 return Ok(nested);
480 }
481 }
482 }
483
484 bail!("Could not find {} in downloaded archive", binary_name)
485}
486
487pub async fn run_update(yes: bool) -> Result<()> {
489 println!("Checking for updates...");
490
491 match check_for_update().await? {
492 Some(version) => {
493 download_and_install(&version, yes).await?;
494 }
495 None => {
496 println!(
497 "\x1b[32m✓\x1b[0m Already up to date! (v{})",
498 CURRENT_VERSION
499 );
500 }
501 }
502
503 Ok(())
504}
505
506pub fn current_version() -> &'static str {
508 CURRENT_VERSION
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_version_comparison() {
517 assert!(is_newer_version("5.5.0", "5.5.1"));
518 assert!(is_newer_version("5.5.1", "5.6.0"));
519 assert!(is_newer_version("5.5.1", "6.0.0"));
520 assert!(is_newer_version("v5.5.0", "v5.5.1"));
521 assert!(!is_newer_version("5.5.1", "5.5.1"));
522 assert!(!is_newer_version("5.5.1", "5.5.0"));
523 assert!(!is_newer_version("6.0.0", "5.5.1"));
524 }
525
526 #[test]
527 fn test_platform_detection() {
528 let result = get_platform();
529 assert!(result.is_ok());
530 }
531}