mcp_exec/
store.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6use sha2::{Digest, Sha256};
7
8#[derive(Clone, Debug)]
9pub enum ToolStore {
10    /// Local directory populated with `.wasm` tool components.
11    LocalDir(PathBuf),
12    /// Single remote component downloaded and cached locally.
13    HttpSingleFile {
14        name: String,
15        url: String,
16        cache_dir: PathBuf,
17    },
18    // Additional registries (OCI/Warg) will be supported in future revisions.
19}
20
21#[derive(Clone, Debug)]
22pub struct ToolInfo {
23    pub name: String,
24    pub path: PathBuf,
25    pub sha256: Option<String>,
26}
27
28#[derive(Debug)]
29pub struct ToolNotFound {
30    name: String,
31}
32
33impl ToolNotFound {
34    pub fn new(name: impl Into<String>) -> Self {
35        Self { name: name.into() }
36    }
37}
38
39impl std::fmt::Display for ToolNotFound {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "tool `{}` not found", self.name)
42    }
43}
44
45impl std::error::Error for ToolNotFound {}
46
47pub fn is_not_found(err: &anyhow::Error) -> bool {
48    err.downcast_ref::<ToolNotFound>().is_some()
49}
50
51impl ToolStore {
52    pub fn list(&self) -> Result<Vec<ToolInfo>> {
53        match self {
54            ToolStore::LocalDir(root) => list_local(root),
55            ToolStore::HttpSingleFile { name, .. } => {
56                let info = self.fetch(name)?;
57                Ok(vec![info])
58            }
59        }
60    }
61
62    pub fn fetch(&self, name: &str) -> Result<ToolInfo> {
63        match self {
64            ToolStore::LocalDir(root) => fetch_local(root, name),
65            ToolStore::HttpSingleFile {
66                name: expected,
67                url,
68                cache_dir,
69            } => fetch_http(expected, url, cache_dir, name),
70        }
71    }
72}
73
74fn list_local(root: &Path) -> Result<Vec<ToolInfo>> {
75    let mut items = Vec::new();
76    if !root.exists() {
77        return Ok(items);
78    }
79
80    for entry in fs::read_dir(root).with_context(|| format!("listing {}", root.display()))? {
81        let entry = entry?;
82        let path = entry.path();
83
84        if !path.is_file() {
85            continue;
86        }
87
88        if !matches!(
89            path.extension().and_then(|ext| ext.to_str()),
90            Some(ext) if ext.eq_ignore_ascii_case("wasm")
91        ) {
92            continue;
93        }
94
95        let Some(name) = path
96            .file_stem()
97            .and_then(|os| os.to_str())
98            .map(|s| s.to_string())
99        else {
100            continue;
101        };
102
103        let sha = compute_sha256(&path).ok();
104        items.push(ToolInfo {
105            name,
106            path: path.clone(),
107            sha256: sha,
108        });
109    }
110
111    items.sort_by(|a, b| a.name.cmp(&b.name));
112    Ok(items)
113}
114
115fn fetch_local(root: &Path, name: &str) -> Result<ToolInfo> {
116    let tools = list_local(root)?;
117    tools
118        .into_iter()
119        .find(|info| info.name == name)
120        .ok_or_else(|| anyhow!(ToolNotFound::new(name)))
121}
122
123fn fetch_http(expected: &str, url: &str, cache_dir: &Path, name: &str) -> Result<ToolInfo> {
124    if name != expected {
125        return Err(anyhow!(ToolNotFound::new(name)));
126    }
127
128    fs::create_dir_all(cache_dir)
129        .with_context(|| format!("creating cache dir {}", cache_dir.display()))?;
130
131    let filename = format!("{expected}.wasm");
132    let dest_path = cache_dir.join(filename);
133
134    if !dest_path.exists() {
135        download_with_retry(url, &dest_path)?;
136    }
137
138    let sha = compute_sha256(&dest_path).ok();
139    Ok(ToolInfo {
140        name: expected.to_string(),
141        path: dest_path,
142        sha256: sha,
143    })
144}
145
146fn compute_sha256(path: &Path) -> Result<String> {
147    use std::io::Read;
148
149    let mut hasher = Sha256::new();
150    let mut file = fs::File::open(path).with_context(|| format!("opening {}", path.display()))?;
151    let mut buf = [0u8; 8192];
152    loop {
153        let read = file.read(&mut buf)?;
154        if read == 0 {
155            break;
156        }
157        hasher.update(&buf[..read]);
158    }
159    Ok(hex::encode(hasher.finalize()))
160}
161
162fn download_with_retry(url: &str, dest: &Path) -> Result<()> {
163    use std::thread::sleep;
164
165    let client = reqwest::blocking::Client::builder()
166        .use_rustls_tls()
167        .timeout(Duration::from_secs(30))
168        .build()
169        .context("building HTTP client")?;
170
171    let mut last_err = None;
172    for attempt in 1..=3 {
173        match download_once(&client, url, dest) {
174            Ok(()) => return Ok(()),
175            Err(err) => {
176                last_err = Some(err);
177                let backoff = Duration::from_secs(attempt * 2);
178                sleep(backoff);
179            }
180        }
181    }
182
183    Err(last_err.unwrap_or_else(|| anyhow!("download failed without specific error")))
184}
185
186fn download_once(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> Result<()> {
187    let response = client
188        .get(url)
189        .send()
190        .with_context(|| format!("requesting {}", url))?
191        .error_for_status()
192        .with_context(|| format!("non-success status from {}", url))?;
193
194    let bytes = response
195        .bytes()
196        .with_context(|| format!("reading bytes from {}", url))?;
197
198    let tmp = dest.with_extension("download");
199    fs::write(&tmp, &bytes).with_context(|| format!("writing {}", tmp.display()))?;
200    fs::rename(&tmp, dest).with_context(|| format!("moving into {}", dest.display()))?;
201    Ok(())
202}