1use anyhow::{Context, Result, bail};
2use flate2::read::GzDecoder;
3use reqwest::blocking::Client;
4use std::fs::{self, File};
5use std::io::{Read, Write};
6use std::path::{Component as PathComponent, Path, PathBuf};
7use tar::Archive;
8use tempfile::TempDir;
9use url::Url;
10use zip::ZipArchive;
11
12const MAX_DOWNLOAD_BYTES: u64 = 100 * 1024 * 1024;
13const MAX_EXPANDED_BYTES: u64 = 500 * 1024 * 1024;
14const MAX_ENTRIES: usize = 20_000;
15
16pub struct PreparedSource {
17 pub root: PathBuf,
18 pub display: String,
19 pub revision: Option<String>,
20 _temp: Option<TempDir>,
21}
22
23impl PreparedSource {
24 fn temporary(root: PathBuf, display: String, revision: Option<String>, temp: TempDir) -> Self {
25 Self {
26 root,
27 display,
28 revision,
29 _temp: Some(temp),
30 }
31 }
32
33 fn local(root: PathBuf, display: String) -> Self {
34 Self {
35 root,
36 display,
37 revision: None,
38 _temp: None,
39 }
40 }
41}
42
43pub trait SourceProvider {
44 fn prepare(&self, source: &str) -> Result<PreparedSource>;
45}
46
47pub struct DefaultSourceProvider {
48 client: Client,
49}
50
51impl Default for DefaultSourceProvider {
52 fn default() -> Self {
53 Self {
54 client: Client::builder()
55 .user_agent(concat!("agentport/", env!("CARGO_PKG_VERSION")))
56 .build()
57 .expect("valid HTTP client"),
58 }
59 }
60}
61
62impl SourceProvider for DefaultSourceProvider {
63 fn prepare(&self, source: &str) -> Result<PreparedSource> {
64 if source.starts_with("https://github.com/") || source.starts_with("http://github.com/") {
65 return self.prepare_github(source);
66 }
67
68 let path = PathBuf::from(source);
69 if path.is_dir() {
70 return Ok(PreparedSource::local(
71 path.canonicalize().context("resolve local source")?,
72 source.to_owned(),
73 ));
74 }
75 if !path.is_file() {
76 bail!("source is neither a supported GitHub URL nor a local file/directory");
77 }
78
79 let temp = tempfile::tempdir().context("create extraction directory")?;
80 let lower = path
81 .file_name()
82 .and_then(|name| name.to_str())
83 .unwrap_or_default()
84 .to_lowercase();
85 if lower.ends_with(".zip") {
86 extract_zip(&path, temp.path())?;
87 } else if lower.ends_with(".tar.gz") || lower.ends_with(".tgz") {
88 extract_tar_gz(&path, temp.path())?;
89 } else {
90 bail!("unsupported archive; expected .zip, .tar.gz, or .tgz");
91 }
92 let root = single_root(temp.path())?;
93 Ok(PreparedSource::temporary(
94 root,
95 source.to_owned(),
96 None,
97 temp,
98 ))
99 }
100}
101
102impl DefaultSourceProvider {
103 fn prepare_github(&self, source: &str) -> Result<PreparedSource> {
104 let parsed = Url::parse(source).context("parse GitHub URL")?;
105 let segments: Vec<_> = parsed
106 .path_segments()
107 .into_iter()
108 .flatten()
109 .filter(|part| !part.is_empty())
110 .collect();
111 if segments.len() < 2 {
112 bail!("GitHub URL must include owner and repository");
113 }
114 let owner = segments[0];
115 let repo = segments[1].trim_end_matches(".git");
116 let (revision, subpath) = if segments.get(2) == Some(&"tree") {
117 let revision = segments
118 .get(3)
119 .context("GitHub tree URL is missing a ref")?
120 .to_string();
121 let subpath = segments.get(4..).unwrap_or_default().join("/");
122 (Some(revision), subpath)
123 } else {
124 (None, String::new())
125 };
126 let archive_ref = revision.as_deref().unwrap_or("HEAD");
127 let archive_url = format!("https://github.com/{owner}/{repo}/archive/{archive_ref}.zip");
128 let temp = tempfile::tempdir().context("create download directory")?;
129 let archive_path = temp.path().join("source.zip");
130 self.download(&archive_url, &archive_path)?;
131 let unpacked = temp.path().join("unpacked");
132 fs::create_dir(&unpacked)?;
133 extract_zip(&archive_path, &unpacked)?;
134 let mut root = single_root(&unpacked)?;
135 if !subpath.is_empty() {
136 root = root.join(&subpath);
137 if !root.is_dir() {
138 bail!("path '{subpath}' was not found in the downloaded repository");
139 }
140 }
141 Ok(PreparedSource::temporary(
142 root,
143 source.to_owned(),
144 revision,
145 temp,
146 ))
147 }
148
149 fn download(&self, url: &str, path: &Path) -> Result<()> {
150 let mut response = self
151 .client
152 .get(url)
153 .send()
154 .context("download GitHub archive")?;
155 if !response.status().is_success() {
156 bail!("GitHub returned {} for {url}", response.status());
157 }
158 if response
159 .content_length()
160 .is_some_and(|length| length > MAX_DOWNLOAD_BYTES)
161 {
162 bail!("archive exceeds the 100 MiB download limit");
163 }
164 let mut output = File::create(path)?;
165 let mut limited = response.by_ref().take(MAX_DOWNLOAD_BYTES + 1);
166 let copied = std::io::copy(&mut limited, &mut output)?;
167 if copied > MAX_DOWNLOAD_BYTES {
168 bail!("archive exceeds the 100 MiB download limit");
169 }
170 output.flush()?;
171 Ok(())
172 }
173}
174
175fn safe_relative(path: &Path) -> bool {
176 !path.as_os_str().is_empty()
177 && path
178 .components()
179 .all(|component| matches!(component, PathComponent::Normal(_) | PathComponent::CurDir))
180}
181
182fn extract_zip(path: &Path, destination: &Path) -> Result<()> {
183 let file = File::open(path).context("open ZIP archive")?;
184 let mut archive = ZipArchive::new(file).context("read ZIP archive")?;
185 if archive.len() > MAX_ENTRIES {
186 bail!("archive contains too many entries");
187 }
188 let mut expanded = 0_u64;
189 for index in 0..archive.len() {
190 let mut entry = archive.by_index(index)?;
191 let enclosed = entry
192 .enclosed_name()
193 .context("archive contains an unsafe path")?
194 .to_owned();
195 if !safe_relative(&enclosed) {
196 bail!("archive contains an unsafe path");
197 }
198 let unix_mode = entry.unix_mode().unwrap_or_default();
199 if unix_mode & 0o170000 == 0o120000 {
200 bail!("archive contains a symbolic link");
201 }
202 expanded = expanded.saturating_add(entry.size());
203 if expanded > MAX_EXPANDED_BYTES {
204 bail!("expanded archive exceeds the 500 MiB limit");
205 }
206 let output = destination.join(enclosed);
207 if entry.is_dir() {
208 fs::create_dir_all(&output)?;
209 } else {
210 if let Some(parent) = output.parent() {
211 fs::create_dir_all(parent)?;
212 }
213 let mut file = File::create(&output)?;
214 std::io::copy(&mut entry, &mut file)?;
215 }
216 }
217 Ok(())
218}
219
220fn extract_tar_gz(path: &Path, destination: &Path) -> Result<()> {
221 let file = File::open(path).context("open tar.gz archive")?;
222 let decoder = GzDecoder::new(file);
223 let mut archive = Archive::new(decoder);
224 let mut entries = 0_usize;
225 let mut expanded = 0_u64;
226 for item in archive.entries().context("read tar.gz archive")? {
227 let mut entry = item?;
228 entries += 1;
229 if entries > MAX_ENTRIES {
230 bail!("archive contains too many entries");
231 }
232 let path = entry.path()?.into_owned();
233 if !safe_relative(&path) {
234 bail!("archive contains an unsafe path");
235 }
236 let kind = entry.header().entry_type();
237 if kind.is_symlink() || kind.is_hard_link() {
238 bail!("archive contains a link");
239 }
240 if !(kind.is_dir() || kind.is_file()) {
241 bail!("archive contains an unsupported entry type");
242 }
243 expanded = expanded.saturating_add(entry.size());
244 if expanded > MAX_EXPANDED_BYTES {
245 bail!("expanded archive exceeds the 500 MiB limit");
246 }
247 let output = destination.join(path);
248 if kind.is_dir() {
249 fs::create_dir_all(&output)?;
250 } else {
251 if let Some(parent) = output.parent() {
252 fs::create_dir_all(parent)?;
253 }
254 entry.unpack(&output)?;
255 }
256 }
257 Ok(())
258}
259
260fn single_root(path: &Path) -> Result<PathBuf> {
261 let entries: Vec<_> = fs::read_dir(path)?.filter_map(Result::ok).collect();
262 if entries.len() == 1 && entries[0].path().is_dir() {
263 Ok(entries[0].path())
264 } else {
265 Ok(path.to_path_buf())
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use std::io::Write;
273
274 #[test]
275 fn rejects_zip_traversal() {
276 let temp = tempfile::tempdir().unwrap();
277 let zip_path = temp.path().join("bad.zip");
278 let file = File::create(&zip_path).unwrap();
279 let mut writer = zip::ZipWriter::new(file);
280 writer
281 .start_file("/absolute", zip::write::SimpleFileOptions::default())
282 .unwrap();
283 writer.write_all(b"unsafe").unwrap();
284 writer.finish().unwrap();
285 let out = temp.path().join("out");
286 fs::create_dir(&out).unwrap();
287 assert!(extract_zip(&zip_path, &out).is_err());
288 }
289
290 #[test]
291 fn extracts_normal_zip() {
292 let temp = tempfile::tempdir().unwrap();
293 let zip_path = temp.path().join("ok.zip");
294 let file = File::create(&zip_path).unwrap();
295 let mut writer = zip::ZipWriter::new(file);
296 writer
297 .start_file("skill/SKILL.md", zip::write::SimpleFileOptions::default())
298 .unwrap();
299 writer.write_all(b"---\nname: skill\n---\n").unwrap();
300 writer.finish().unwrap();
301 let out = temp.path().join("out");
302 fs::create_dir(&out).unwrap();
303 extract_zip(&zip_path, &out).unwrap();
304 assert!(out.join("skill/SKILL.md").is_file());
305 }
306}