1use std::fs;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6use sha2::{Digest, Sha256};
7
8use crate::path_safety::normalize_under_root;
9
10#[derive(Clone, Debug)]
11pub enum ToolStore {
12 LocalDir(PathBuf),
14 HttpSingleFile {
16 name: String,
17 url: String,
18 cache_dir: PathBuf,
19 },
20 }
22
23#[derive(Clone, Debug)]
24pub struct ToolInfo {
25 pub name: String,
26 pub path: PathBuf,
27 pub sha256: Option<String>,
28}
29
30#[derive(Debug)]
31pub struct ToolNotFound {
32 name: String,
33}
34
35impl ToolNotFound {
36 pub fn new(name: impl Into<String>) -> Self {
37 Self { name: name.into() }
38 }
39}
40
41impl std::fmt::Display for ToolNotFound {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "tool `{}` not found", self.name)
44 }
45}
46
47impl std::error::Error for ToolNotFound {}
48
49pub fn is_not_found(err: &anyhow::Error) -> bool {
50 err.downcast_ref::<ToolNotFound>().is_some()
51}
52
53impl ToolStore {
54 pub fn list(&self) -> Result<Vec<ToolInfo>> {
55 match self {
56 ToolStore::LocalDir(root) => list_local(root),
57 ToolStore::HttpSingleFile { name, .. } => {
58 let info = self.fetch(name)?;
59 Ok(vec![info])
60 }
61 }
62 }
63
64 pub fn fetch(&self, name: &str) -> Result<ToolInfo> {
65 match self {
66 ToolStore::LocalDir(root) => fetch_local(root, name),
67 ToolStore::HttpSingleFile {
68 name: expected,
69 url,
70 cache_dir,
71 } => fetch_http(expected, url, cache_dir, name),
72 }
73 }
74}
75
76fn list_local(root: &Path) -> Result<Vec<ToolInfo>> {
77 let mut items = Vec::new();
78 if !root.exists() {
79 return Ok(items);
80 }
81
82 let canonical_root = root
83 .canonicalize()
84 .with_context(|| format!("canonicalizing tool store root {}", root.display()))?;
85 let display_root = root.to_path_buf();
86
87 for entry in fs::read_dir(&canonical_root)
88 .with_context(|| format!("listing {}", canonical_root.display()))?
89 {
90 let entry = entry?;
91 let path = entry.path();
92
93 if !path.is_file() {
94 continue;
95 }
96
97 if !matches!(
98 path.extension().and_then(|ext| ext.to_str()),
99 Some(ext) if ext.eq_ignore_ascii_case("wasm")
100 ) {
101 continue;
102 }
103
104 let Some(name) = path
105 .file_stem()
106 .and_then(|os| os.to_str())
107 .map(|s| s.to_string())
108 else {
109 continue;
110 };
111
112 let relative = path.strip_prefix(&canonical_root).with_context(|| {
113 format!(
114 "entry {} not under {}",
115 path.display(),
116 canonical_root.display()
117 )
118 })?;
119 let safe_path = normalize_under_root(&canonical_root, relative)?;
120 let stable_path = display_root.join(relative);
121
122 let sha = compute_sha256(&safe_path).ok();
123 items.push(ToolInfo {
124 name,
125 path: stable_path,
126 sha256: sha,
127 });
128 }
129
130 items.sort_by(|a, b| a.name.cmp(&b.name));
131 Ok(items)
132}
133
134fn fetch_local(root: &Path, name: &str) -> Result<ToolInfo> {
135 let tools = list_local(root)?;
136 tools
137 .into_iter()
138 .find(|info| info.name == name)
139 .ok_or_else(|| anyhow!(ToolNotFound::new(name)))
140}
141
142fn fetch_http(expected: &str, url: &str, cache_dir: &Path, name: &str) -> Result<ToolInfo> {
143 if name != expected {
144 return Err(anyhow!(ToolNotFound::new(name)));
145 }
146
147 fs::create_dir_all(cache_dir)
148 .with_context(|| format!("creating cache dir {}", cache_dir.display()))?;
149
150 let cache_dir = cache_dir
151 .canonicalize()
152 .with_context(|| format!("canonicalizing cache dir {}", cache_dir.display()))?;
153
154 let filename = format!("{expected}.wasm");
155 let dest_path = normalize_under_root(&cache_dir, Path::new(&filename))?;
156
157 if !dest_path.exists() {
158 download_with_retry(url, &dest_path)?;
159 }
160
161 let sha = compute_sha256(&dest_path).ok();
162 Ok(ToolInfo {
163 name: expected.to_string(),
164 path: dest_path,
165 sha256: sha,
166 })
167}
168
169fn compute_sha256(path: &Path) -> Result<String> {
170 use std::io::Read;
171
172 let mut hasher = Sha256::new();
173 let mut file = fs::File::open(path).with_context(|| format!("opening {}", path.display()))?;
174 let mut buf = [0u8; 8192];
175 loop {
176 let read = file.read(&mut buf)?;
177 if read == 0 {
178 break;
179 }
180 hasher.update(&buf[..read]);
181 }
182 Ok(hex::encode(hasher.finalize()))
183}
184
185fn download_with_retry(url: &str, dest: &Path) -> Result<()> {
186 use std::thread::sleep;
187
188 let client = reqwest::blocking::Client::builder()
189 .use_rustls_tls()
190 .timeout(Duration::from_secs(30))
191 .build()
192 .context("building HTTP client")?;
193
194 let mut last_err = None;
195 for attempt in 1..=3 {
196 match download_once(&client, url, dest) {
197 Ok(()) => return Ok(()),
198 Err(err) => {
199 last_err = Some(err);
200 let backoff = Duration::from_secs(attempt * 2);
201 sleep(backoff);
202 }
203 }
204 }
205
206 Err(last_err.unwrap_or_else(|| anyhow!("download failed without specific error")))
207}
208
209fn download_once(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> Result<()> {
210 let response = client
211 .get(url)
212 .send()
213 .with_context(|| format!("requesting {}", url))?
214 .error_for_status()
215 .with_context(|| format!("non-success status from {}", url))?;
216
217 let bytes = response
218 .bytes()
219 .with_context(|| format!("reading bytes from {}", url))?;
220
221 let tmp = dest.with_extension("download");
222 fs::write(&tmp, &bytes).with_context(|| format!("writing {}", tmp.display()))?;
223 fs::rename(&tmp, dest).with_context(|| format!("moving into {}", dest.display()))?;
224 Ok(())
225}