mcp_exec/
store.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6use sha2::{Digest, Sha256};
7
8use crate::path_safety::normalize_under_root;
9
10#[derive(Clone, Debug)]
11pub enum ToolStore {
12    /// Local directory populated with `.wasm` tool components.
13    LocalDir(PathBuf),
14    /// Single remote component downloaded and cached locally.
15    HttpSingleFile {
16        name: String,
17        url: String,
18        cache_dir: PathBuf,
19    },
20    // Additional registries (OCI/Warg) will be supported in future revisions.
21}
22
23#[derive(Clone, Debug)]
24pub struct ToolInfo {
25    pub name: String,
26    pub path: PathBuf,
27    pub sha256: Option<String>,
28}
29
30#[derive(Debug)]
31pub struct ToolNotFound {
32    name: String,
33}
34
35impl ToolNotFound {
36    pub fn new(name: impl Into<String>) -> Self {
37        Self { name: name.into() }
38    }
39}
40
41impl std::fmt::Display for ToolNotFound {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "tool `{}` not found", self.name)
44    }
45}
46
47impl std::error::Error for ToolNotFound {}
48
49pub fn is_not_found(err: &anyhow::Error) -> bool {
50    err.downcast_ref::<ToolNotFound>().is_some()
51}
52
53impl ToolStore {
54    pub fn list(&self) -> Result<Vec<ToolInfo>> {
55        match self {
56            ToolStore::LocalDir(root) => list_local(root),
57            ToolStore::HttpSingleFile { name, .. } => {
58                let info = self.fetch(name)?;
59                Ok(vec![info])
60            }
61        }
62    }
63
64    pub fn fetch(&self, name: &str) -> Result<ToolInfo> {
65        match self {
66            ToolStore::LocalDir(root) => fetch_local(root, name),
67            ToolStore::HttpSingleFile {
68                name: expected,
69                url,
70                cache_dir,
71            } => fetch_http(expected, url, cache_dir, name),
72        }
73    }
74}
75
76fn list_local(root: &Path) -> Result<Vec<ToolInfo>> {
77    let mut items = Vec::new();
78    if !root.exists() {
79        return Ok(items);
80    }
81
82    let canonical_root = root
83        .canonicalize()
84        .with_context(|| format!("canonicalizing tool store root {}", root.display()))?;
85    let display_root = root.to_path_buf();
86
87    for entry in fs::read_dir(&canonical_root)
88        .with_context(|| format!("listing {}", canonical_root.display()))?
89    {
90        let entry = entry?;
91        let path = entry.path();
92
93        if !path.is_file() {
94            continue;
95        }
96
97        if !matches!(
98            path.extension().and_then(|ext| ext.to_str()),
99            Some(ext) if ext.eq_ignore_ascii_case("wasm")
100        ) {
101            continue;
102        }
103
104        let Some(name) = path
105            .file_stem()
106            .and_then(|os| os.to_str())
107            .map(|s| s.to_string())
108        else {
109            continue;
110        };
111
112        let relative = path.strip_prefix(&canonical_root).with_context(|| {
113            format!(
114                "entry {} not under {}",
115                path.display(),
116                canonical_root.display()
117            )
118        })?;
119        let safe_path = normalize_under_root(&canonical_root, relative)?;
120        let stable_path = display_root.join(relative);
121
122        let sha = compute_sha256(&safe_path).ok();
123        items.push(ToolInfo {
124            name,
125            path: stable_path,
126            sha256: sha,
127        });
128    }
129
130    items.sort_by(|a, b| a.name.cmp(&b.name));
131    Ok(items)
132}
133
134fn fetch_local(root: &Path, name: &str) -> Result<ToolInfo> {
135    let tools = list_local(root)?;
136    tools
137        .into_iter()
138        .find(|info| info.name == name)
139        .ok_or_else(|| anyhow!(ToolNotFound::new(name)))
140}
141
142fn fetch_http(expected: &str, url: &str, cache_dir: &Path, name: &str) -> Result<ToolInfo> {
143    if name != expected {
144        return Err(anyhow!(ToolNotFound::new(name)));
145    }
146
147    fs::create_dir_all(cache_dir)
148        .with_context(|| format!("creating cache dir {}", cache_dir.display()))?;
149
150    let cache_dir = cache_dir
151        .canonicalize()
152        .with_context(|| format!("canonicalizing cache dir {}", cache_dir.display()))?;
153
154    let filename = format!("{expected}.wasm");
155    let dest_path = normalize_under_root(&cache_dir, Path::new(&filename))?;
156
157    if !dest_path.exists() {
158        download_with_retry(url, &dest_path)?;
159    }
160
161    let sha = compute_sha256(&dest_path).ok();
162    Ok(ToolInfo {
163        name: expected.to_string(),
164        path: dest_path,
165        sha256: sha,
166    })
167}
168
169fn compute_sha256(path: &Path) -> Result<String> {
170    use std::io::Read;
171
172    let mut hasher = Sha256::new();
173    let mut file = fs::File::open(path).with_context(|| format!("opening {}", path.display()))?;
174    let mut buf = [0u8; 8192];
175    loop {
176        let read = file.read(&mut buf)?;
177        if read == 0 {
178            break;
179        }
180        hasher.update(&buf[..read]);
181    }
182    Ok(hex::encode(hasher.finalize()))
183}
184
185fn download_with_retry(url: &str, dest: &Path) -> Result<()> {
186    use std::thread::sleep;
187
188    let client = reqwest::blocking::Client::builder()
189        .use_rustls_tls()
190        .timeout(Duration::from_secs(30))
191        .build()
192        .context("building HTTP client")?;
193
194    let mut last_err = None;
195    for attempt in 1..=3 {
196        match download_once(&client, url, dest) {
197            Ok(()) => return Ok(()),
198            Err(err) => {
199                last_err = Some(err);
200                let backoff = Duration::from_secs(attempt * 2);
201                sleep(backoff);
202            }
203        }
204    }
205
206    Err(last_err.unwrap_or_else(|| anyhow!("download failed without specific error")))
207}
208
209fn download_once(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> Result<()> {
210    let response = client
211        .get(url)
212        .send()
213        .with_context(|| format!("requesting {}", url))?
214        .error_for_status()
215        .with_context(|| format!("non-success status from {}", url))?;
216
217    let bytes = response
218        .bytes()
219        .with_context(|| format!("reading bytes from {}", url))?;
220
221    let tmp = dest.with_extension("download");
222    fs::write(&tmp, &bytes).with_context(|| format!("writing {}", tmp.display()))?;
223    fs::rename(&tmp, dest).with_context(|| format!("moving into {}", dest.display()))?;
224    Ok(())
225}