Skip to main content

agentport/
source.rs

1use anyhow::{Context, Result, bail};
2use flate2::read::GzDecoder;
3use reqwest::blocking::Client;
4use std::fs::{self, File};
5use std::io::{Read, Write};
6use std::path::{Component as PathComponent, Path, PathBuf};
7use tar::Archive;
8use tempfile::TempDir;
9use url::Url;
10use zip::ZipArchive;
11
12const MAX_DOWNLOAD_BYTES: u64 = 100 * 1024 * 1024;
13const MAX_EXPANDED_BYTES: u64 = 500 * 1024 * 1024;
14const MAX_ENTRIES: usize = 20_000;
15
16pub struct PreparedSource {
17    pub root: PathBuf,
18    pub display: String,
19    pub revision: Option<String>,
20    _temp: Option<TempDir>,
21}
22
23impl PreparedSource {
24    fn temporary(root: PathBuf, display: String, revision: Option<String>, temp: TempDir) -> Self {
25        Self {
26            root,
27            display,
28            revision,
29            _temp: Some(temp),
30        }
31    }
32
33    fn local(root: PathBuf, display: String) -> Self {
34        Self {
35            root,
36            display,
37            revision: None,
38            _temp: None,
39        }
40    }
41}
42
43pub trait SourceProvider {
44    fn prepare(&self, source: &str) -> Result<PreparedSource>;
45}
46
47pub struct DefaultSourceProvider {
48    client: Client,
49}
50
51impl Default for DefaultSourceProvider {
52    fn default() -> Self {
53        Self {
54            client: Client::builder()
55                .user_agent(concat!("agentport/", env!("CARGO_PKG_VERSION")))
56                .build()
57                .expect("valid HTTP client"),
58        }
59    }
60}
61
62impl SourceProvider for DefaultSourceProvider {
63    fn prepare(&self, source: &str) -> Result<PreparedSource> {
64        if source.starts_with("https://github.com/") || source.starts_with("http://github.com/") {
65            return self.prepare_github(source);
66        }
67
68        let path = PathBuf::from(source);
69        if path.is_dir() {
70            return Ok(PreparedSource::local(
71                path.canonicalize().context("resolve local source")?,
72                source.to_owned(),
73            ));
74        }
75        if !path.is_file() {
76            bail!("source is neither a supported GitHub URL nor a local file/directory");
77        }
78
79        let temp = tempfile::tempdir().context("create extraction directory")?;
80        let lower = path
81            .file_name()
82            .and_then(|name| name.to_str())
83            .unwrap_or_default()
84            .to_lowercase();
85        if lower.ends_with(".zip") {
86            extract_zip(&path, temp.path())?;
87        } else if lower.ends_with(".tar.gz") || lower.ends_with(".tgz") {
88            extract_tar_gz(&path, temp.path())?;
89        } else {
90            bail!("unsupported archive; expected .zip, .tar.gz, or .tgz");
91        }
92        let root = single_root(temp.path())?;
93        Ok(PreparedSource::temporary(
94            root,
95            source.to_owned(),
96            None,
97            temp,
98        ))
99    }
100}
101
102impl DefaultSourceProvider {
103    fn prepare_github(&self, source: &str) -> Result<PreparedSource> {
104        let parsed = Url::parse(source).context("parse GitHub URL")?;
105        let segments: Vec<_> = parsed
106            .path_segments()
107            .into_iter()
108            .flatten()
109            .filter(|part| !part.is_empty())
110            .collect();
111        if segments.len() < 2 {
112            bail!("GitHub URL must include owner and repository");
113        }
114        let owner = segments[0];
115        let repo = segments[1].trim_end_matches(".git");
116        let (revision, subpath) = if segments.get(2) == Some(&"tree") {
117            let revision = segments
118                .get(3)
119                .context("GitHub tree URL is missing a ref")?
120                .to_string();
121            let subpath = segments.get(4..).unwrap_or_default().join("/");
122            (Some(revision), subpath)
123        } else {
124            (None, String::new())
125        };
126        let archive_ref = revision.as_deref().unwrap_or("HEAD");
127        let archive_url = format!("https://github.com/{owner}/{repo}/archive/{archive_ref}.zip");
128        let temp = tempfile::tempdir().context("create download directory")?;
129        let archive_path = temp.path().join("source.zip");
130        self.download(&archive_url, &archive_path)?;
131        let unpacked = temp.path().join("unpacked");
132        fs::create_dir(&unpacked)?;
133        extract_zip(&archive_path, &unpacked)?;
134        let mut root = single_root(&unpacked)?;
135        if !subpath.is_empty() {
136            root = root.join(&subpath);
137            if !root.is_dir() {
138                bail!("path '{subpath}' was not found in the downloaded repository");
139            }
140        }
141        Ok(PreparedSource::temporary(
142            root,
143            source.to_owned(),
144            revision,
145            temp,
146        ))
147    }
148
149    fn download(&self, url: &str, path: &Path) -> Result<()> {
150        let mut response = self
151            .client
152            .get(url)
153            .send()
154            .context("download GitHub archive")?;
155        if !response.status().is_success() {
156            bail!("GitHub returned {} for {url}", response.status());
157        }
158        if response
159            .content_length()
160            .is_some_and(|length| length > MAX_DOWNLOAD_BYTES)
161        {
162            bail!("archive exceeds the 100 MiB download limit");
163        }
164        let mut output = File::create(path)?;
165        let mut limited = response.by_ref().take(MAX_DOWNLOAD_BYTES + 1);
166        let copied = std::io::copy(&mut limited, &mut output)?;
167        if copied > MAX_DOWNLOAD_BYTES {
168            bail!("archive exceeds the 100 MiB download limit");
169        }
170        output.flush()?;
171        Ok(())
172    }
173}
174
175fn safe_relative(path: &Path) -> bool {
176    !path.as_os_str().is_empty()
177        && path
178            .components()
179            .all(|component| matches!(component, PathComponent::Normal(_) | PathComponent::CurDir))
180}
181
182fn extract_zip(path: &Path, destination: &Path) -> Result<()> {
183    let file = File::open(path).context("open ZIP archive")?;
184    let mut archive = ZipArchive::new(file).context("read ZIP archive")?;
185    if archive.len() > MAX_ENTRIES {
186        bail!("archive contains too many entries");
187    }
188    let mut expanded = 0_u64;
189    for index in 0..archive.len() {
190        let mut entry = archive.by_index(index)?;
191        let enclosed = entry
192            .enclosed_name()
193            .context("archive contains an unsafe path")?
194            .to_owned();
195        if !safe_relative(&enclosed) {
196            bail!("archive contains an unsafe path");
197        }
198        let unix_mode = entry.unix_mode().unwrap_or_default();
199        if unix_mode & 0o170000 == 0o120000 {
200            bail!("archive contains a symbolic link");
201        }
202        expanded = expanded.saturating_add(entry.size());
203        if expanded > MAX_EXPANDED_BYTES {
204            bail!("expanded archive exceeds the 500 MiB limit");
205        }
206        let output = destination.join(enclosed);
207        if entry.is_dir() {
208            fs::create_dir_all(&output)?;
209        } else {
210            if let Some(parent) = output.parent() {
211                fs::create_dir_all(parent)?;
212            }
213            let mut file = File::create(&output)?;
214            std::io::copy(&mut entry, &mut file)?;
215        }
216    }
217    Ok(())
218}
219
220fn extract_tar_gz(path: &Path, destination: &Path) -> Result<()> {
221    let file = File::open(path).context("open tar.gz archive")?;
222    let decoder = GzDecoder::new(file);
223    let mut archive = Archive::new(decoder);
224    let mut entries = 0_usize;
225    let mut expanded = 0_u64;
226    for item in archive.entries().context("read tar.gz archive")? {
227        let mut entry = item?;
228        entries += 1;
229        if entries > MAX_ENTRIES {
230            bail!("archive contains too many entries");
231        }
232        let path = entry.path()?.into_owned();
233        if !safe_relative(&path) {
234            bail!("archive contains an unsafe path");
235        }
236        let kind = entry.header().entry_type();
237        if kind.is_symlink() || kind.is_hard_link() {
238            bail!("archive contains a link");
239        }
240        if !(kind.is_dir() || kind.is_file()) {
241            bail!("archive contains an unsupported entry type");
242        }
243        expanded = expanded.saturating_add(entry.size());
244        if expanded > MAX_EXPANDED_BYTES {
245            bail!("expanded archive exceeds the 500 MiB limit");
246        }
247        let output = destination.join(path);
248        if kind.is_dir() {
249            fs::create_dir_all(&output)?;
250        } else {
251            if let Some(parent) = output.parent() {
252                fs::create_dir_all(parent)?;
253            }
254            entry.unpack(&output)?;
255        }
256    }
257    Ok(())
258}
259
260fn single_root(path: &Path) -> Result<PathBuf> {
261    let entries: Vec<_> = fs::read_dir(path)?.filter_map(Result::ok).collect();
262    if entries.len() == 1 && entries[0].path().is_dir() {
263        Ok(entries[0].path())
264    } else {
265        Ok(path.to_path_buf())
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use std::io::Write;
273
274    #[test]
275    fn rejects_zip_traversal() {
276        let temp = tempfile::tempdir().unwrap();
277        let zip_path = temp.path().join("bad.zip");
278        let file = File::create(&zip_path).unwrap();
279        let mut writer = zip::ZipWriter::new(file);
280        writer
281            .start_file("/absolute", zip::write::SimpleFileOptions::default())
282            .unwrap();
283        writer.write_all(b"unsafe").unwrap();
284        writer.finish().unwrap();
285        let out = temp.path().join("out");
286        fs::create_dir(&out).unwrap();
287        assert!(extract_zip(&zip_path, &out).is_err());
288    }
289
290    #[test]
291    fn extracts_normal_zip() {
292        let temp = tempfile::tempdir().unwrap();
293        let zip_path = temp.path().join("ok.zip");
294        let file = File::create(&zip_path).unwrap();
295        let mut writer = zip::ZipWriter::new(file);
296        writer
297            .start_file("skill/SKILL.md", zip::write::SimpleFileOptions::default())
298            .unwrap();
299        writer.write_all(b"---\nname: skill\n---\n").unwrap();
300        writer.finish().unwrap();
301        let out = temp.path().join("out");
302        fs::create_dir(&out).unwrap();
303        extract_zip(&zip_path, &out).unwrap();
304        assert!(out.join("skill/SKILL.md").is_file());
305    }
306}