mcp_exec/
store.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6use sha2::{Digest, Sha256};
7
8#[derive(Clone, Debug)]
9pub enum ToolStore {
10    /// Local directory populated with `.wasm` tool components.
11    LocalDir(PathBuf),
12    /// Single remote component downloaded and cached locally.
13    HttpSingleFile {
14        name: String,
15        url: String,
16        cache_dir: PathBuf,
17    },
18    // Additional registries (OCI/Warg) will be supported in future revisions.
19}
20
21#[derive(Clone, Debug)]
22pub struct ToolInfo {
23    pub name: String,
24    pub path: PathBuf,
25    pub sha256: Option<String>,
26}
27
28#[derive(Debug)]
29pub struct ToolNotFound {
30    name: String,
31}
32
33impl ToolNotFound {
34    pub fn new(name: impl Into<String>) -> Self {
35        Self { name: name.into() }
36    }
37}
38
39impl std::fmt::Display for ToolNotFound {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "tool `{}` not found", self.name)
42    }
43}
44
45impl std::error::Error for ToolNotFound {}
46
47pub fn is_not_found(err: &anyhow::Error) -> bool {
48    err.downcast_ref::<ToolNotFound>().is_some()
49}
50
51impl ToolStore {
52    pub fn list(&self) -> Result<Vec<ToolInfo>> {
53        match self {
54            ToolStore::LocalDir(root) => list_local(root),
55            ToolStore::HttpSingleFile { name, .. } => {
56                let info = self.fetch(name)?;
57                Ok(vec![info])
58            }
59        }
60    }
61
62    pub fn fetch(&self, name: &str) -> Result<ToolInfo> {
63        match self {
64            ToolStore::LocalDir(root) => fetch_local(root, name),
65            ToolStore::HttpSingleFile {
66                name: expected,
67                url,
68                cache_dir,
69            } => fetch_http(expected, url, cache_dir, name),
70        }
71    }
72}
73
74fn list_local(root: &Path) -> Result<Vec<ToolInfo>> {
75    let mut items = Vec::new();
76    if !root.exists() {
77        return Ok(items);
78    }
79
80    for entry in fs::read_dir(root).with_context(|| format!("listing {}", root.display()))? {
81        let entry = entry?;
82        let path = entry.path();
83
84        if !path.is_file() {
85            continue;
86        }
87
88        if matches!(
89            path.extension().and_then(|ext| ext.to_str()),
90            Some(ext) if ext.eq_ignore_ascii_case("wasm")
91        ) && let Some(name) = path
92            .file_stem()
93            .and_then(|os| os.to_str())
94            .map(|s| s.to_string())
95        {
96            let sha = compute_sha256(&path).ok();
97            items.push(ToolInfo {
98                name,
99                path: path.clone(),
100                sha256: sha,
101            });
102        }
103    }
104
105    items.sort_by(|a, b| a.name.cmp(&b.name));
106    Ok(items)
107}
108
109fn fetch_local(root: &Path, name: &str) -> Result<ToolInfo> {
110    let tools = list_local(root)?;
111    tools
112        .into_iter()
113        .find(|info| info.name == name)
114        .ok_or_else(|| anyhow!(ToolNotFound::new(name)))
115}
116
117fn fetch_http(expected: &str, url: &str, cache_dir: &Path, name: &str) -> Result<ToolInfo> {
118    if name != expected {
119        return Err(anyhow!(ToolNotFound::new(name)));
120    }
121
122    fs::create_dir_all(cache_dir)
123        .with_context(|| format!("creating cache dir {}", cache_dir.display()))?;
124
125    let filename = format!("{expected}.wasm");
126    let dest_path = cache_dir.join(filename);
127
128    if !dest_path.exists() {
129        download_with_retry(url, &dest_path)?;
130    }
131
132    let sha = compute_sha256(&dest_path).ok();
133    Ok(ToolInfo {
134        name: expected.to_string(),
135        path: dest_path,
136        sha256: sha,
137    })
138}
139
140fn compute_sha256(path: &Path) -> Result<String> {
141    use std::io::Read;
142
143    let mut hasher = Sha256::new();
144    let mut file = fs::File::open(path).with_context(|| format!("opening {}", path.display()))?;
145    let mut buf = [0u8; 8192];
146    loop {
147        let read = file.read(&mut buf)?;
148        if read == 0 {
149            break;
150        }
151        hasher.update(&buf[..read]);
152    }
153    Ok(hex::encode(hasher.finalize()))
154}
155
156fn download_with_retry(url: &str, dest: &Path) -> Result<()> {
157    use std::thread::sleep;
158
159    let client = reqwest::blocking::Client::builder()
160        .use_rustls_tls()
161        .timeout(Duration::from_secs(30))
162        .build()
163        .context("building HTTP client")?;
164
165    let mut last_err = None;
166    for attempt in 1..=3 {
167        match download_once(&client, url, dest) {
168            Ok(()) => return Ok(()),
169            Err(err) => {
170                last_err = Some(err);
171                let backoff = Duration::from_secs(attempt * 2);
172                sleep(backoff);
173            }
174        }
175    }
176
177    Err(last_err.unwrap_or_else(|| anyhow!("download failed without specific error")))
178}
179
180fn download_once(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> Result<()> {
181    let response = client
182        .get(url)
183        .send()
184        .with_context(|| format!("requesting {}", url))?
185        .error_for_status()
186        .with_context(|| format!("non-success status from {}", url))?;
187
188    let bytes = response
189        .bytes()
190        .with_context(|| format!("reading bytes from {}", url))?;
191
192    let tmp = dest.with_extension("download");
193    fs::write(&tmp, &bytes).with_context(|| format!("writing {}", tmp.display()))?;
194    fs::rename(&tmp, dest).with_context(|| format!("moving into {}", dest.display()))?;
195    Ok(())
196}