1use std::path::{Path, PathBuf};
2
3use futures_util::StreamExt;
4
5use crate::error::{Error, Result};
6use crate::node::types::BinarySource;
7
8const GITHUB_REPO: &str = "WithAutonomi/ant-node";
9pub const BINARY_NAME: &str = "ant-node";
10pub const BOOTSTRAP_PEERS_FILE: &str = "bootstrap_peers.toml";
11
12#[derive(Debug, Clone)]
14pub struct ResolvedBinary {
15 pub path: PathBuf,
17 pub version: String,
19 pub bootstrap_peers_path: Option<PathBuf>,
21}
22
23pub trait ProgressReporter: Send + Sync {
25 fn report_started(&self, message: &str);
26 fn report_progress(&self, bytes: u64, total: u64);
27 fn report_complete(&self, message: &str);
28}
29
30pub struct NoopProgress;
32
33impl ProgressReporter for NoopProgress {
34 fn report_started(&self, _message: &str) {}
35 fn report_progress(&self, _bytes: u64, _total: u64) {}
36 fn report_complete(&self, _message: &str) {}
37}
38
39pub async fn resolve_binary(
48 source: &BinarySource,
49 install_dir: &Path,
50 progress: &dyn ProgressReporter,
51) -> Result<ResolvedBinary> {
52 match source {
53 BinarySource::LocalPath(path) => resolve_local(path).await,
54 BinarySource::Latest => resolve_latest(install_dir, progress).await,
55 BinarySource::Version(version) => resolve_version(version, install_dir, progress).await,
56 BinarySource::Url(url) => resolve_url(url, install_dir, progress).await,
57 }
58}
59
60async fn resolve_local(path: &Path) -> Result<ResolvedBinary> {
64 if !path.exists() {
65 return Err(Error::BinaryNotFound(path.to_path_buf()));
66 }
67
68 let version = extract_version(path).await?;
69
70 let bootstrap_peers_path = path
72 .parent()
73 .map(|dir| dir.join(BOOTSTRAP_PEERS_FILE))
74 .filter(|p| p.exists());
75
76 Ok(ResolvedBinary {
77 path: path.to_path_buf(),
78 version,
79 bootstrap_peers_path,
80 })
81}
82
83async fn resolve_latest(
85 install_dir: &Path,
86 progress: &dyn ProgressReporter,
87) -> Result<ResolvedBinary> {
88 let version = fetch_latest_version().await?;
89 resolve_version(&version, install_dir, progress).await
90}
91
92async fn resolve_version(
94 version: &str,
95 install_dir: &Path,
96 progress: &dyn ProgressReporter,
97) -> Result<ResolvedBinary> {
98 let version = version.strip_prefix('v').unwrap_or(version);
99
100 let cached_path = install_dir.join(format!("{BINARY_NAME}-{version}"));
102 if cached_path.exists() {
103 progress.report_complete(&format!("Using cached {BINARY_NAME} v{version}"));
104 let bootstrap_peers_path =
105 install_dir.join(format!("{BINARY_NAME}-{version}.{BOOTSTRAP_PEERS_FILE}"));
106 let bootstrap_peers_path = Some(bootstrap_peers_path).filter(|p| p.exists());
107 return Ok(ResolvedBinary {
108 path: cached_path,
109 version: version.to_string(),
110 bootstrap_peers_path,
111 });
112 }
113
114 let asset_name = platform_asset_name()?;
115 let url = format!("https://github.com/{GITHUB_REPO}/releases/download/v{version}/{asset_name}");
116
117 download_and_extract(&url, install_dir, version, progress).await
118}
119
120async fn resolve_url(
122 url: &str,
123 install_dir: &Path,
124 progress: &dyn ProgressReporter,
125) -> Result<ResolvedBinary> {
126 download_and_extract(url, install_dir, "unknown", progress).await
128}
129
130async fn fetch_latest_version() -> Result<String> {
132 let url = format!("https://api.github.com/repos/{GITHUB_REPO}/releases/latest");
133 let client = reqwest::Client::new();
134 let resp = client
135 .get(&url)
136 .header("User-Agent", "ant-cli")
137 .header("Accept", "application/vnd.github+json")
138 .send()
139 .await
140 .map_err(|e| Error::BinaryResolution(format!("failed to fetch latest release: {e}")))?;
141
142 if !resp.status().is_success() {
143 return Err(Error::BinaryResolution(format!(
144 "GitHub API returned status {} when fetching latest release",
145 resp.status()
146 )));
147 }
148
149 let body: serde_json::Value = resp
150 .json()
151 .await
152 .map_err(|e| Error::BinaryResolution(format!("failed to parse release JSON: {e}")))?;
153
154 let tag = body["tag_name"]
155 .as_str()
156 .ok_or_else(|| Error::BinaryResolution("no tag_name in release response".to_string()))?;
157
158 Ok(tag.strip_prefix('v').unwrap_or(tag).to_string())
159}
160
161async fn download_and_extract(
165 url: &str,
166 install_dir: &Path,
167 version: &str,
168 progress: &dyn ProgressReporter,
169) -> Result<ResolvedBinary> {
170 progress.report_started(&format!("Downloading {BINARY_NAME} from {url}"));
171
172 let client = reqwest::Client::new();
173 let resp = client
174 .get(url)
175 .header("User-Agent", "ant-cli")
176 .send()
177 .await
178 .map_err(|e| Error::BinaryResolution(format!("download request failed: {e}")))?;
179
180 if !resp.status().is_success() {
181 return Err(Error::BinaryResolution(format!(
182 "download returned status {}",
183 resp.status()
184 )));
185 }
186
187 let total_size = resp.content_length().unwrap_or(0);
188 let mut downloaded: u64 = 0;
189
190 std::fs::create_dir_all(install_dir)?;
192 let tmp_path = install_dir.join(".download.tmp");
193 let mut tmp_file = std::fs::File::create(&tmp_path)
194 .map_err(|e| Error::BinaryResolution(format!("failed to create temp file: {e}")))?;
195
196 let mut stream = resp.bytes_stream();
197 while let Some(chunk) = stream.next().await {
198 let chunk =
199 chunk.map_err(|e| Error::BinaryResolution(format!("download stream error: {e}")))?;
200 downloaded += chunk.len() as u64;
201 std::io::Write::write_all(&mut tmp_file, &chunk)
202 .map_err(|e| Error::BinaryResolution(format!("failed to write temp file: {e}")))?;
203 progress.report_progress(downloaded, total_size);
204 }
205 drop(tmp_file);
206
207 progress.report_started("Extracting archive...");
208
209 let bytes = std::fs::read(&tmp_path)
211 .map_err(|e| Error::BinaryResolution(format!("failed to read temp file: {e}")))?;
212 let _ = std::fs::remove_file(&tmp_path);
213
214 let extracted = if url.ends_with(".zip") {
216 extract_zip(&bytes, install_dir, BINARY_NAME)?
217 } else {
218 extract_tar_gz(&bytes, install_dir, BINARY_NAME)?
220 };
221
222 let actual_version = match extract_version(&extracted.binary_path).await {
224 Ok(v) => v,
225 Err(_) => version.to_string(),
226 };
227
228 let cached_path = install_dir.join(format!("{BINARY_NAME}-{actual_version}"));
230 if extracted.binary_path != cached_path {
231 if !cached_path.exists() {
232 std::fs::rename(&extracted.binary_path, &cached_path)?;
233 } else {
234 let _ = std::fs::remove_file(&extracted.binary_path);
235 }
236 }
237
238 let bootstrap_peers_path = if let Some(bp_path) = extracted.bootstrap_peers_path {
240 let cached_bp = install_dir.join(format!(
241 "{BINARY_NAME}-{actual_version}.{BOOTSTRAP_PEERS_FILE}"
242 ));
243 if bp_path != cached_bp {
244 if !cached_bp.exists() {
245 std::fs::rename(&bp_path, &cached_bp)?;
246 } else {
247 let _ = std::fs::remove_file(&bp_path);
248 }
249 }
250 Some(cached_bp)
251 } else {
252 None
253 };
254
255 progress.report_complete(&format!(
256 "Downloaded {BINARY_NAME} v{actual_version} to {}",
257 cached_path.display()
258 ));
259
260 Ok(ResolvedBinary {
261 path: cached_path,
262 version: actual_version,
263 bootstrap_peers_path,
264 })
265}
266
267#[derive(Debug)]
269pub struct ExtractionResult {
270 pub binary_path: PathBuf,
272 pub bootstrap_peers_path: Option<PathBuf>,
274}
275
276pub fn extract_tar_gz(
282 data: &[u8],
283 install_dir: &Path,
284 binary_name: &str,
285) -> Result<ExtractionResult> {
286 let decoder = flate2::read::GzDecoder::new(data);
287 let mut archive = tar::Archive::new(decoder);
288
289 let mut binary_path = None;
290 let mut bootstrap_peers_path = None;
291
292 for entry in archive
293 .entries()
294 .map_err(|e| Error::BinaryResolution(format!("failed to read tar entries: {e}")))?
295 {
296 let mut entry =
297 entry.map_err(|e| Error::BinaryResolution(format!("failed to read tar entry: {e}")))?;
298
299 let path = entry
300 .path()
301 .map_err(|e| Error::BinaryResolution(format!("invalid path in archive: {e}")))?;
302
303 for component in path.components() {
305 if matches!(component, std::path::Component::ParentDir) {
306 return Err(Error::BinaryResolution(format!(
307 "path traversal detected in archive: {}",
308 path.display()
309 )));
310 }
311 }
312
313 let file_name = path
314 .file_name()
315 .and_then(|n| n.to_str())
316 .unwrap_or_default();
317
318 if file_name == binary_name {
319 let dest = install_dir.join(binary_name);
320 let mut file = std::fs::File::create(&dest)?;
321 std::io::copy(&mut entry, &mut file)?;
322
323 #[cfg(unix)]
324 {
325 use std::os::unix::fs::PermissionsExt;
326 std::fs::set_permissions(&dest, std::fs::Permissions::from_mode(0o755))?;
327 }
328
329 binary_path = Some(dest);
330 } else if file_name == BOOTSTRAP_PEERS_FILE {
331 let dest = install_dir.join(BOOTSTRAP_PEERS_FILE);
332 let mut file = std::fs::File::create(&dest)?;
333 std::io::copy(&mut entry, &mut file)?;
334
335 bootstrap_peers_path = Some(dest);
336 }
337 }
338
339 let binary_path = binary_path
340 .ok_or_else(|| Error::BinaryResolution(format!("'{binary_name}' not found in archive")))?;
341
342 Ok(ExtractionResult {
343 binary_path,
344 bootstrap_peers_path,
345 })
346}
347
348pub fn extract_zip(data: &[u8], install_dir: &Path, binary_name: &str) -> Result<ExtractionResult> {
354 let cursor = std::io::Cursor::new(data);
355 let mut archive = zip::ZipArchive::new(cursor)
356 .map_err(|e| Error::BinaryResolution(format!("failed to open zip archive: {e}")))?;
357
358 let mut binary_path = None;
359 let mut bootstrap_peers_path = None;
360
361 for i in 0..archive.len() {
362 let mut file = archive
363 .by_index(i)
364 .map_err(|e| Error::BinaryResolution(format!("failed to read zip entry: {e}")))?;
365
366 let file_name = file
367 .enclosed_name()
368 .and_then(|p| p.file_name().map(|n| n.to_string_lossy().to_string()))
369 .unwrap_or_default();
370
371 if file_name == binary_name || file_name == format!("{binary_name}.exe") {
372 let dest = install_dir.join(&file_name);
373 let mut out = std::fs::File::create(&dest)?;
374 std::io::copy(&mut file, &mut out)?;
375
376 #[cfg(unix)]
377 {
378 use std::os::unix::fs::PermissionsExt;
379 std::fs::set_permissions(&dest, std::fs::Permissions::from_mode(0o755))?;
380 }
381
382 binary_path = Some(dest);
383 } else if file_name == BOOTSTRAP_PEERS_FILE {
384 let dest = install_dir.join(BOOTSTRAP_PEERS_FILE);
385 let mut out = std::fs::File::create(&dest)?;
386 std::io::copy(&mut file, &mut out)?;
387
388 bootstrap_peers_path = Some(dest);
389 }
390 }
391
392 let binary_path = binary_path
393 .ok_or_else(|| Error::BinaryResolution(format!("'{binary_name}' not found in archive")))?;
394
395 Ok(ExtractionResult {
396 binary_path,
397 bootstrap_peers_path,
398 })
399}
400
401pub(crate) async fn extract_version(binary_path: &Path) -> Result<String> {
406 let output = tokio::process::Command::new(binary_path)
407 .arg("--version")
408 .output()
409 .await
410 .map_err(|e| {
411 Error::BinaryResolution(format!(
412 "failed to run {} --version: {e}",
413 binary_path.display()
414 ))
415 })?;
416
417 if !output.status.success() {
418 return Err(Error::BinaryResolution(format!(
419 "{} --version exited with status {}",
420 binary_path.display(),
421 output.status
422 )));
423 }
424
425 let stdout = String::from_utf8_lossy(&output.stdout);
426 let version = stdout
428 .split_whitespace()
429 .last()
430 .unwrap_or("unknown")
431 .to_string();
432
433 Ok(version)
434}
435
436fn platform_asset_name() -> Result<String> {
438 let os = if cfg!(target_os = "linux") {
439 "linux"
440 } else if cfg!(target_os = "macos") {
441 "macos"
442 } else if cfg!(target_os = "windows") {
443 "windows"
444 } else {
445 return Err(Error::BinaryResolution(format!(
446 "unsupported platform: {}",
447 std::env::consts::OS
448 )));
449 };
450
451 let arch = if cfg!(target_arch = "aarch64") {
452 "arm64"
453 } else if cfg!(target_arch = "x86_64") {
454 "x64"
455 } else {
456 return Err(Error::BinaryResolution(format!(
457 "unsupported architecture: {}",
458 std::env::consts::ARCH
459 )));
460 };
461
462 let ext = if cfg!(target_os = "windows") {
463 "zip"
464 } else {
465 "tar.gz"
466 };
467
468 Ok(format!("ant-node-cli-{os}-{arch}.{ext}"))
469}
470
471pub fn binary_install_dir() -> crate::error::Result<PathBuf> {
473 Ok(crate::config::data_dir()?.join("bin"))
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[tokio::test]
481 async fn local_path_not_found() {
482 let result = resolve_binary(
483 &BinarySource::LocalPath("/nonexistent/binary".into()),
484 Path::new("/tmp"),
485 &NoopProgress,
486 )
487 .await;
488 assert!(result.is_err());
489 let err = result.unwrap_err();
490 assert!(matches!(err, Error::BinaryNotFound(_)));
491 }
492
493 #[test]
494 fn platform_asset_name_has_correct_format() {
495 let name = platform_asset_name().unwrap();
496 assert!(name.starts_with("ant-node-cli-"));
497 assert!(
498 name.ends_with(".tar.gz") || name.ends_with(".zip"),
499 "unexpected extension: {name}"
500 );
501 }
502
503 #[test]
504 fn extract_tar_gz_finds_binary() {
505 let tmp = tempfile::tempdir().unwrap();
507 let mut builder = tar::Builder::new(Vec::new());
508
509 let data = b"#!/bin/sh\necho test\n";
510 let mut header = tar::Header::new_gnu();
511 header.set_path(BINARY_NAME).unwrap();
512 header.set_size(data.len() as u64);
513 header.set_mode(0o755);
514 header.set_cksum();
515 builder.append(&header, &data[..]).unwrap();
516 let tar_data = builder.into_inner().unwrap();
517
518 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
519 std::io::Write::write_all(&mut encoder, &tar_data).unwrap();
520 let gz_data = encoder.finish().unwrap();
521
522 let result = extract_tar_gz(&gz_data, tmp.path(), BINARY_NAME);
523 assert!(result.is_ok());
524 let extracted = result.unwrap();
525 assert!(extracted.binary_path.exists());
526 assert_eq!(extracted.binary_path.file_name().unwrap(), BINARY_NAME);
527 assert!(extracted.bootstrap_peers_path.is_none());
528 }
529
530 #[test]
531 fn extract_tar_gz_finds_bootstrap_peers() {
532 let tmp = tempfile::tempdir().unwrap();
533 let mut builder = tar::Builder::new(Vec::new());
534
535 let bin_data = b"#!/bin/sh\necho test\n";
537 let mut header = tar::Header::new_gnu();
538 header.set_path(BINARY_NAME).unwrap();
539 header.set_size(bin_data.len() as u64);
540 header.set_mode(0o755);
541 header.set_cksum();
542 builder.append(&header, &bin_data[..]).unwrap();
543
544 let bp_data = b"[peers]\naddrs = [\"1.2.3.4:5000\"]\n";
546 let mut header = tar::Header::new_gnu();
547 header.set_path(BOOTSTRAP_PEERS_FILE).unwrap();
548 header.set_size(bp_data.len() as u64);
549 header.set_mode(0o644);
550 header.set_cksum();
551 builder.append(&header, &bp_data[..]).unwrap();
552
553 let tar_data = builder.into_inner().unwrap();
554
555 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
556 std::io::Write::write_all(&mut encoder, &tar_data).unwrap();
557 let gz_data = encoder.finish().unwrap();
558
559 let result = extract_tar_gz(&gz_data, tmp.path(), BINARY_NAME).unwrap();
560 assert!(result.binary_path.exists());
561 assert!(result.bootstrap_peers_path.is_some());
562 let bp_path = result.bootstrap_peers_path.unwrap();
563 assert!(bp_path.exists());
564 assert_eq!(bp_path.file_name().unwrap(), BOOTSTRAP_PEERS_FILE);
565 }
566
567 #[test]
568 fn extract_tar_gz_missing_binary_errors() {
569 let tmp = tempfile::tempdir().unwrap();
570 let builder = tar::Builder::new(Vec::new());
571 let tar_data = builder.into_inner().unwrap();
572
573 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
574 std::io::Write::write_all(&mut encoder, &tar_data).unwrap();
575 let gz_data = encoder.finish().unwrap();
576
577 let result = extract_tar_gz(&gz_data, tmp.path(), BINARY_NAME);
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn extract_tar_gz_rejects_path_traversal() {
583 let tmp = tempfile::tempdir().unwrap();
584
585 let data = b"malicious content";
588 let mut header = tar::Header::new_gnu();
589 header.set_path("placeholder").unwrap();
591 header.set_size(data.len() as u64);
592 header.set_mode(0o755);
593
594 let traversal = b"../../../etc/evil";
596 let raw = header.as_mut_bytes();
597 raw[..traversal.len()].copy_from_slice(traversal);
598 raw[traversal.len()] = 0;
599 header.set_cksum();
600
601 let mut builder = tar::Builder::new(Vec::new());
602 builder.append(&header, &data[..]).unwrap();
603 let tar_data = builder.into_inner().unwrap();
604
605 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
606 std::io::Write::write_all(&mut encoder, &tar_data).unwrap();
607 let gz_data = encoder.finish().unwrap();
608
609 let result = extract_tar_gz(&gz_data, tmp.path(), BINARY_NAME);
610 assert!(result.is_err());
611 let err = result.unwrap_err().to_string();
612 assert!(
613 err.contains("path traversal"),
614 "expected path traversal error, got: {err}"
615 );
616 }
617
618 #[tokio::test]
619 async fn resolve_version_uses_cache() {
620 let tmp = tempfile::tempdir().unwrap();
621 let cached = tmp.path().join(format!("{BINARY_NAME}-1.2.3"));
622 std::fs::write(&cached, "fake binary").unwrap();
623
624 let result = resolve_version("1.2.3", tmp.path(), &NoopProgress).await;
625 assert!(result.is_ok());
626 let resolved = result.unwrap();
627 assert_eq!(resolved.path, cached);
628 assert_eq!(resolved.version, "1.2.3");
629 assert!(resolved.bootstrap_peers_path.is_none());
630 }
631
632 #[tokio::test]
633 async fn resolve_version_uses_cached_bootstrap_peers() {
634 let tmp = tempfile::tempdir().unwrap();
635 let cached = tmp.path().join(format!("{BINARY_NAME}-1.2.3"));
636 std::fs::write(&cached, "fake binary").unwrap();
637 let cached_bp = tmp
638 .path()
639 .join(format!("{BINARY_NAME}-1.2.3.{BOOTSTRAP_PEERS_FILE}"));
640 std::fs::write(&cached_bp, "[peers]").unwrap();
641
642 let resolved = resolve_version("1.2.3", tmp.path(), &NoopProgress)
643 .await
644 .unwrap();
645 assert_eq!(resolved.path, cached);
646 assert_eq!(resolved.bootstrap_peers_path, Some(cached_bp));
647 }
648
649 #[tokio::test]
650 async fn resolve_version_strips_v_prefix() {
651 let tmp = tempfile::tempdir().unwrap();
652 let cached = tmp.path().join(format!("{BINARY_NAME}-0.3.4"));
653 std::fs::write(&cached, "fake binary").unwrap();
654
655 let result = resolve_version("v0.3.4", tmp.path(), &NoopProgress).await;
656 assert!(result.is_ok());
657 let resolved = result.unwrap();
658 assert_eq!(resolved.version, "0.3.4");
659 }
660}