use std::fs;
use std::io::{Read, Write};
use std::path::{Component, Path, PathBuf};
use anyhow::{Context, Result, bail};
use flate2::read::GzDecoder;
use crate::network_policy::NetworkPolicy;
use zagens_runtime_adapters::network_policy::{Decision, host_from_url};
use zagens_runtime_adapters::tools::host_policy_decision;
use super::registry::fetch_registry;
use super::types::{
DownloadAttempt, DownloadOutcome, InstallError, InstallSource, RegistryFetchResult,
UrlResolution,
};
pub(super) async fn candidate_urls(
source: &InstallSource,
network: &NetworkPolicy,
registry_url: &str,
) -> Result<UrlResolution> {
match source {
InstallSource::GitHubRepo(repo) => {
Ok(UrlResolution::Resolved(vec![
format!("https://github.com/{repo}/archive/refs/heads/main.tar.gz"),
format!("https://github.com/{repo}/archive/refs/heads/master.tar.gz"),
]))
}
InstallSource::DirectUrl(url) => Ok(UrlResolution::Resolved(vec![url.clone()])),
InstallSource::Registry(name) => {
match fetch_registry(network, registry_url).await? {
RegistryFetchResult::Loaded(doc) => {
let entry = doc
.skills
.get(name)
.with_context(|| format!("skill '{name}' not found in registry"))?
.clone();
let inner = InstallSource::parse(&entry.source).with_context(|| {
format!(
"registry entry for '{name}' has invalid source: {}",
entry.source
)
})?;
if matches!(inner, InstallSource::Registry(_)) {
bail!("registry entry for '{name}' must not point to another registry");
}
Box::pin(candidate_urls(&inner, network, registry_url)).await
}
RegistryFetchResult::NeedsApproval(host) => Ok(UrlResolution::NeedsApproval(host)),
RegistryFetchResult::Denied(host) => Ok(UrlResolution::Denied(host)),
}
}
}
}
pub(super) async fn download_first_success(
urls: &[String],
network: &NetworkPolicy,
max_size: u64,
) -> Result<DownloadOutcome> {
let mut last_status: Option<reqwest::StatusCode> = None;
let mut prompt_host: Option<String> = None;
let mut denied_host: Option<String> = None;
for url in urls {
let host = match host_from_url(url) {
Some(h) => h,
None => bail!("invalid download url: {url}"),
};
match host_policy_decision(network, &host) {
Decision::Allow => {}
Decision::Deny => {
denied_host.get_or_insert(host);
continue;
}
Decision::Prompt => {
prompt_host.get_or_insert(host);
continue;
}
}
match download_with_cap(url, max_size).await? {
DownloadAttempt::Bytes(bytes) => {
return Ok(DownloadOutcome::Bytes {
bytes,
url: url.clone(),
});
}
DownloadAttempt::NotFound(status) => {
last_status = Some(status);
continue;
}
}
}
if let Some(host) = denied_host {
return Ok(DownloadOutcome::Denied(host));
}
if let Some(host) = prompt_host {
return Ok(DownloadOutcome::NeedsApproval(host));
}
bail!(
"failed to download skill (last status: {})",
last_status
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string())
);
}
pub(super) async fn download_with_cap(url: &str, max_size: u64) -> Result<DownloadAttempt> {
let resp = reqwest::get(url)
.await
.with_context(|| format!("failed to GET {url}"))?;
let status = resp.status();
if !status.is_success() {
if status == reqwest::StatusCode::NOT_FOUND {
return Ok(DownloadAttempt::NotFound(status));
}
bail!("download {url} returned {status}");
}
let compressed_cap = max_size.saturating_mul(4);
let bytes = resp
.bytes()
.await
.with_context(|| format!("failed to read body of {url}"))?;
if (bytes.len() as u64) > compressed_cap {
bail!("download {url} exceeds compressed size cap of {compressed_cap} bytes");
}
Ok(DownloadAttempt::Bytes(bytes.to_vec()))
}
pub(super) struct StagedSkill {
pub(super) skill_name: String,
pub(super) staged_path: PathBuf,
}
pub(super) fn stage_tarball(bytes: &[u8], skills_dir: &Path, max_size: u64) -> Result<StagedSkill> {
fs::create_dir_all(skills_dir)
.with_context(|| format!("failed to create skills directory {}", skills_dir.display()))?;
let scan = scan_tarball(bytes, max_size)?;
let staged_path = skills_dir.join(format!("{}.tmp", scan.skill_name));
if staged_path.exists() {
fs::remove_dir_all(&staged_path).with_context(|| {
format!(
"failed to clean stale staging dir {}",
staged_path.display()
)
})?;
}
fs::create_dir_all(&staged_path)
.with_context(|| format!("failed to create staging dir {}", staged_path.display()))?;
let result = extract_into(&scan, bytes, &staged_path, max_size);
if let Err(err) = result {
let _ = fs::remove_dir_all(&staged_path);
return Err(err);
}
Ok(StagedSkill {
skill_name: scan.skill_name,
staged_path,
})
}
pub(super) struct TarballScan {
skill_name: String,
prefix: String,
skill_root: String,
}
pub(super) fn scan_tarball(bytes: &[u8], max_size: u64) -> Result<TarballScan> {
let cursor = std::io::Cursor::new(bytes);
let gz = GzDecoder::new(cursor);
let mut archive = tar::Archive::new(gz);
let mut total_size: u64 = 0;
let mut prefix: Option<String> = None;
let mut skill_md_relative: Option<(String, Vec<u8>)> = None;
let mut link_paths: Vec<String> = Vec::new();
for entry in archive
.entries()
.context("failed to read tar entries (corrupt archive?)")?
{
let mut entry = entry.context("failed to read tar entry")?;
let header = entry.header().clone();
let entry_type = header.entry_type();
let path = entry
.path()
.context("tar entry has invalid path")?
.to_path_buf();
let path_str = path.to_string_lossy().into_owned();
if !is_safe_path(&path) {
return Err(InstallError::PathTraversal(path_str).into());
}
if let Ok(size) = header.size() {
total_size = total_size.saturating_add(size);
if total_size > max_size {
return Err(InstallError::OversizedTarball { limit: max_size }.into());
}
}
if prefix.is_none() {
if let Some(Component::Normal(first)) = path.components().next() {
let candidate = first.to_string_lossy().into_owned();
if path.components().count() > 1 {
prefix = Some(candidate);
} else {
prefix = Some(String::new());
}
} else {
prefix = Some(String::new());
}
}
if entry_type.is_symlink() || entry_type.is_hard_link() {
link_paths.push(path_str);
continue;
}
if entry_type.is_file() {
let stripped = strip_prefix(&path_str, prefix.as_deref().unwrap_or(""));
if stripped.eq_ignore_ascii_case("SKILL.md")
|| stripped.starts_with("skills/")
&& stripped.ends_with("/SKILL.md")
&& stripped.matches('/').count() == 2
{
let mut buf = Vec::new();
entry
.read_to_end(&mut buf)
.context("failed to read SKILL.md from archive")?;
if skill_md_relative.is_none() {
skill_md_relative = Some((stripped.to_string(), buf));
}
}
}
}
let prefix = prefix.unwrap_or_default();
let (skill_md_path, skill_md_bytes) = skill_md_relative
.ok_or(InstallError::MissingSkillMd)
.map_err(anyhow::Error::from)?;
let skill_root = if skill_md_path == "SKILL.md" {
String::new()
} else {
skill_md_path
.strip_suffix("/SKILL.md")
.unwrap_or("")
.to_string()
};
for link_path in link_paths {
if is_within_selected_root(&link_path, &prefix, &skill_root) {
return Err(InstallError::SymlinkRejected.into());
}
}
let name = parse_frontmatter_name(&skill_md_bytes)?;
Ok(TarballScan {
skill_name: name,
prefix,
skill_root,
})
}
pub(super) fn extract_into(
scan: &TarballScan,
bytes: &[u8],
dest: &Path,
max_size: u64,
) -> Result<()> {
let cursor = std::io::Cursor::new(bytes);
let gz = GzDecoder::new(cursor);
let mut archive = tar::Archive::new(gz);
let mut total_size: u64 = 0;
let prefix_with_root = if scan.skill_root.is_empty() {
scan.prefix.clone()
} else if scan.prefix.is_empty() {
scan.skill_root.clone()
} else {
format!("{}/{}", scan.prefix, scan.skill_root)
};
for entry in archive
.entries()
.context("failed to read tar entries (corrupt archive?)")?
{
let mut entry = entry.context("failed to read tar entry")?;
let header = entry.header().clone();
let entry_type = header.entry_type();
let path = entry
.path()
.context("tar entry has invalid path")?
.to_path_buf();
let path_str = path.to_string_lossy().into_owned();
if !is_safe_path(&path) {
return Err(InstallError::PathTraversal(path_str).into());
}
let stripped = strip_prefix(&path_str, &prefix_with_root).into_owned();
if stripped.is_empty() && entry_type.is_dir() {
continue;
}
if stripped == path_str && !prefix_with_root.is_empty() {
continue;
}
let stripped_path = Path::new(&stripped);
if !is_safe_path(stripped_path) {
return Err(InstallError::PathTraversal(stripped).into());
}
if entry_type.is_symlink() || entry_type.is_hard_link() {
return Err(InstallError::SymlinkRejected.into());
}
let target = dest.join(stripped_path);
let target_components: Vec<_> = target.components().collect();
let dest_components: Vec<_> = dest.components().collect();
if !target_components.starts_with(dest_components.as_slice()) {
return Err(InstallError::PathTraversal(stripped).into());
}
if entry_type.is_dir() {
fs::create_dir_all(&target)
.with_context(|| format!("failed to create dir {}", target.display()))?;
continue;
}
if entry_type.is_file() {
if let Some(parent) = target.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create dir {}", parent.display()))?;
}
let mut buf = Vec::new();
entry
.read_to_end(&mut buf)
.with_context(|| format!("failed to read {}", path.display()))?;
total_size = total_size.saturating_add(buf.len() as u64);
if total_size > max_size {
return Err(InstallError::OversizedTarball { limit: max_size }.into());
}
let mut out = fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(&target)
.with_context(|| format!("failed to create {}", target.display()))?;
out.write_all(&buf)
.with_context(|| format!("failed to write {}", target.display()))?;
}
}
Ok(())
}
pub(super) fn selected_root(prefix: &str, skill_root: &str) -> String {
if skill_root.is_empty() {
prefix.to_string()
} else if prefix.is_empty() {
skill_root.to_string()
} else {
format!("{prefix}/{skill_root}")
}
}
pub(super) fn is_within_selected_root(path: &str, prefix: &str, skill_root: &str) -> bool {
let root = selected_root(prefix, skill_root);
if root.is_empty() {
return true;
}
path == root || path.starts_with(&format!("{root}/"))
}
pub(super) fn is_safe_path(path: &Path) -> bool {
if path.is_absolute() {
return false;
}
for component in path.components() {
match component {
Component::ParentDir => return false,
Component::Prefix(_) | Component::RootDir => return false,
_ => {}
}
}
true
}
pub(super) fn strip_prefix<'a>(path: &'a str, prefix: &str) -> std::borrow::Cow<'a, str> {
if prefix.is_empty() {
return std::borrow::Cow::Borrowed(path);
}
let with_slash = format!("{prefix}/");
if let Some(rest) = path.strip_prefix(&with_slash) {
std::borrow::Cow::Owned(rest.to_string())
} else if path == prefix {
std::borrow::Cow::Borrowed("")
} else {
std::borrow::Cow::Borrowed(path)
}
}
pub(super) fn parse_frontmatter_name(bytes: &[u8]) -> Result<String> {
let content = std::str::from_utf8(bytes).context("SKILL.md is not valid UTF-8")?;
let trimmed = content.trim_start();
if !trimmed.starts_with("---") {
bail!("SKILL.md is missing the leading '---' frontmatter fence");
}
let after_open = &trimmed[3..];
let close = after_open.find("---").ok_or_else(|| {
anyhow::anyhow!("SKILL.md is missing the closing '---' frontmatter fence")
})?;
let frontmatter = &after_open[..close];
let mut name: Option<String> = None;
let mut has_description = false;
for raw in frontmatter.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once(':') {
let key = key.trim().to_ascii_lowercase();
let value = value.trim().to_string();
match key.as_str() {
"name" if !value.is_empty() => name = Some(value),
"description" if !value.is_empty() => has_description = true,
_ => {}
}
}
}
let name = name.ok_or(InstallError::MissingFrontmatterField("name"))?;
if !has_description {
return Err(InstallError::MissingFrontmatterField("description").into());
}
if name.contains('/')
|| name.contains('\\')
|| name == "."
|| name == ".."
|| name.contains(' ')
{
bail!("SKILL.md `name` must be a single path-safe segment (got '{name}')");
}
Ok(name)
}