1use std::fs;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6use sha2::{Digest, Sha256};
7
8#[derive(Clone, Debug)]
9pub enum ToolStore {
10 LocalDir(PathBuf),
12 HttpSingleFile {
14 name: String,
15 url: String,
16 cache_dir: PathBuf,
17 },
18 }
20
21#[derive(Clone, Debug)]
22pub struct ToolInfo {
23 pub name: String,
24 pub path: PathBuf,
25 pub sha256: Option<String>,
26}
27
28#[derive(Debug)]
29pub struct ToolNotFound {
30 name: String,
31}
32
33impl ToolNotFound {
34 pub fn new(name: impl Into<String>) -> Self {
35 Self { name: name.into() }
36 }
37}
38
39impl std::fmt::Display for ToolNotFound {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "tool `{}` not found", self.name)
42 }
43}
44
45impl std::error::Error for ToolNotFound {}
46
47pub fn is_not_found(err: &anyhow::Error) -> bool {
48 err.downcast_ref::<ToolNotFound>().is_some()
49}
50
51impl ToolStore {
52 pub fn list(&self) -> Result<Vec<ToolInfo>> {
53 match self {
54 ToolStore::LocalDir(root) => list_local(root),
55 ToolStore::HttpSingleFile { name, .. } => {
56 let info = self.fetch(name)?;
57 Ok(vec![info])
58 }
59 }
60 }
61
62 pub fn fetch(&self, name: &str) -> Result<ToolInfo> {
63 match self {
64 ToolStore::LocalDir(root) => fetch_local(root, name),
65 ToolStore::HttpSingleFile {
66 name: expected,
67 url,
68 cache_dir,
69 } => fetch_http(expected, url, cache_dir, name),
70 }
71 }
72}
73
74fn list_local(root: &Path) -> Result<Vec<ToolInfo>> {
75 let mut items = Vec::new();
76 if !root.exists() {
77 return Ok(items);
78 }
79
80 for entry in fs::read_dir(root).with_context(|| format!("listing {}", root.display()))? {
81 let entry = entry?;
82 let path = entry.path();
83
84 if !path.is_file() {
85 continue;
86 }
87
88 if matches!(
89 path.extension().and_then(|ext| ext.to_str()),
90 Some(ext) if ext.eq_ignore_ascii_case("wasm")
91 ) && let Some(name) = path
92 .file_stem()
93 .and_then(|os| os.to_str())
94 .map(|s| s.to_string())
95 {
96 let sha = compute_sha256(&path).ok();
97 items.push(ToolInfo {
98 name,
99 path: path.clone(),
100 sha256: sha,
101 });
102 }
103 }
104
105 items.sort_by(|a, b| a.name.cmp(&b.name));
106 Ok(items)
107}
108
109fn fetch_local(root: &Path, name: &str) -> Result<ToolInfo> {
110 let tools = list_local(root)?;
111 tools
112 .into_iter()
113 .find(|info| info.name == name)
114 .ok_or_else(|| anyhow!(ToolNotFound::new(name)))
115}
116
117fn fetch_http(expected: &str, url: &str, cache_dir: &Path, name: &str) -> Result<ToolInfo> {
118 if name != expected {
119 return Err(anyhow!(ToolNotFound::new(name)));
120 }
121
122 fs::create_dir_all(cache_dir)
123 .with_context(|| format!("creating cache dir {}", cache_dir.display()))?;
124
125 let filename = format!("{expected}.wasm");
126 let dest_path = cache_dir.join(filename);
127
128 if !dest_path.exists() {
129 download_with_retry(url, &dest_path)?;
130 }
131
132 let sha = compute_sha256(&dest_path).ok();
133 Ok(ToolInfo {
134 name: expected.to_string(),
135 path: dest_path,
136 sha256: sha,
137 })
138}
139
140fn compute_sha256(path: &Path) -> Result<String> {
141 use std::io::Read;
142
143 let mut hasher = Sha256::new();
144 let mut file = fs::File::open(path).with_context(|| format!("opening {}", path.display()))?;
145 let mut buf = [0u8; 8192];
146 loop {
147 let read = file.read(&mut buf)?;
148 if read == 0 {
149 break;
150 }
151 hasher.update(&buf[..read]);
152 }
153 Ok(hex::encode(hasher.finalize()))
154}
155
156fn download_with_retry(url: &str, dest: &Path) -> Result<()> {
157 use std::thread::sleep;
158
159 let client = reqwest::blocking::Client::builder()
160 .use_rustls_tls()
161 .timeout(Duration::from_secs(30))
162 .build()
163 .context("building HTTP client")?;
164
165 let mut last_err = None;
166 for attempt in 1..=3 {
167 match download_once(&client, url, dest) {
168 Ok(()) => return Ok(()),
169 Err(err) => {
170 last_err = Some(err);
171 let backoff = Duration::from_secs(attempt * 2);
172 sleep(backoff);
173 }
174 }
175 }
176
177 Err(last_err.unwrap_or_else(|| anyhow!("download failed without specific error")))
178}
179
180fn download_once(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> Result<()> {
181 let response = client
182 .get(url)
183 .send()
184 .with_context(|| format!("requesting {}", url))?
185 .error_for_status()
186 .with_context(|| format!("non-success status from {}", url))?;
187
188 let bytes = response
189 .bytes()
190 .with_context(|| format!("reading bytes from {}", url))?;
191
192 let tmp = dest.with_extension("download");
193 fs::write(&tmp, &bytes).with_context(|| format!("writing {}", tmp.display()))?;
194 fs::rename(&tmp, dest).with_context(|| format!("moving into {}", dest.display()))?;
195 Ok(())
196}