use crate::{NylError, Result};
use sha2::{Digest, Sha256};
use std::path::PathBuf;
use std::process::Command;
pub struct OciChartPuller {
cache_dir: PathBuf,
}
impl OciChartPuller {
pub fn new() -> Result<Self> {
let root = if let Ok(cache_dir) = std::env::var("NYL_CACHE_DIR") {
PathBuf::from(cache_dir)
} else {
std::env::current_dir()?.join(".nyl").join("cache")
};
Ok(Self {
cache_dir: root.join("helm").join("oci"),
})
}
pub fn with_cache_dir(cache_dir: impl Into<PathBuf>) -> Self {
Self {
cache_dir: cache_dir.into().join("helm").join("oci"),
}
}
pub fn pull(&self, repository: &str, version: &str, chart_name: Option<&str>) -> Result<PathBuf> {
let chart_dir = self.chart_cache_path(repository, version);
if chart_dir.join("Chart.yaml").exists() {
tracing::debug!("Using cached Helm chart: {}", chart_dir.display());
return Ok(chart_dir);
}
tracing::debug!("Pulling Helm chart from {}", crate::util::sanitize_url(repository));
std::fs::create_dir_all(&self.cache_dir)
.map_err(|e| NylError::Process(format!("Failed to create chart cache directory: {}", e)))?;
let tmp_dir = tempfile::Builder::new()
.prefix(".pull-tmp-")
.tempdir_in(&self.cache_dir)
.map_err(|e| NylError::Process(format!("Failed to create temp pull directory: {}", e)))?;
let is_oci = repository.starts_with("oci://");
let mut cmd = Command::new("helm");
cmd.arg("pull");
if is_oci {
cmd.arg(repository);
} else {
let name = chart_name
.ok_or_else(|| NylError::Config("Chart name is required for non-OCI Helm repositories".to_string()))?;
cmd.arg("--repo").arg(repository).arg(name);
}
cmd.arg("--version")
.arg(version)
.arg("--untar")
.arg("-d")
.arg(tmp_dir.path());
tracing::debug!("Executing helm command: {:?}", cmd);
let output = cmd
.output()
.map_err(|e| NylError::Process(format!("Failed to execute helm pull: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(NylError::HelmChart(format!(
"helm pull failed for {}@{}: {}",
repository, version, stderr
)));
}
tracing::debug!("Helm chart pulled successfully");
let extracted_name = if is_oci {
extract_chart_name(repository)
} else {
chart_name.unwrap_or("chart").to_string()
};
let extracted = tmp_dir.path().join(&extracted_name);
if !extracted.join("Chart.yaml").exists() {
return Err(NylError::HelmChart(format!(
"Chart.yaml not found after pulling {}@{} (expected at {})",
repository,
version,
extracted.display()
)));
}
match std::fs::rename(&extracted, &chart_dir) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if !chart_dir.join("Chart.yaml").exists() {
return Err(NylError::Process(format!("Failed to move cached chart: {}", e)));
}
}
Err(e) => {
return Err(NylError::Process(format!("Failed to move cached chart: {}", e)));
}
}
Ok(chart_dir)
}
fn chart_cache_path(&self, repository: &str, version: &str) -> PathBuf {
let mut hasher = Sha256::new();
hasher.update(repository.as_bytes());
let repo_hash = hex::encode(hasher.finalize());
let safe_version = Self::sanitize_version(version);
self.cache_dir.join(format!("{}-{}", &repo_hash[..16], safe_version))
}
fn sanitize_version(version: &str) -> String {
let sanitized: String = version
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect();
if sanitized.is_empty() {
"unknown".to_string()
} else {
sanitized
}
}
}
fn extract_chart_name(repository: &str) -> String {
repository
.trim_end_matches('/')
.rsplit('/')
.next()
.unwrap_or("chart")
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_extract_chart_name() {
assert_eq!(extract_chart_name("oci://ghcr.io/owner/repo/mychart"), "mychart");
assert_eq!(extract_chart_name("oci://ghcr.io/niklasrosenstein/nyl/chart"), "chart");
assert_eq!(extract_chart_name("oci://registry.example.com/charts/nginx"), "nginx");
}
#[test]
fn test_extract_chart_name_trailing_slash() {
assert_eq!(extract_chart_name("oci://ghcr.io/owner/repo/mychart/"), "mychart");
}
#[test]
fn test_chart_cache_path_deterministic() {
let temp = TempDir::new().unwrap();
let puller = OciChartPuller::with_cache_dir(temp.path());
let path1 = puller.chart_cache_path("oci://ghcr.io/owner/nyl/chart", "0.1.0-sha-abc1234");
let path2 = puller.chart_cache_path("oci://ghcr.io/owner/nyl/chart", "0.1.0-sha-abc1234");
assert_eq!(path1, path2);
}
#[test]
fn test_chart_cache_path_different_versions() {
let temp = TempDir::new().unwrap();
let puller = OciChartPuller::with_cache_dir(temp.path());
let path1 = puller.chart_cache_path("oci://ghcr.io/owner/nyl/chart", "0.1.0-sha-abc1234");
let path2 = puller.chart_cache_path("oci://ghcr.io/owner/nyl/chart", "0.1.0-sha-def5678");
assert_ne!(path1, path2);
}
#[test]
fn test_chart_cache_path_different_repos() {
let temp = TempDir::new().unwrap();
let puller = OciChartPuller::with_cache_dir(temp.path());
let path1 = puller.chart_cache_path("oci://ghcr.io/owner1/nyl/chart", "0.1.0");
let path2 = puller.chart_cache_path("oci://ghcr.io/owner2/nyl/chart", "0.1.0");
assert_ne!(path1, path2);
}
#[test]
fn test_pull_returns_cached_chart() {
let temp = TempDir::new().unwrap();
let puller = OciChartPuller::with_cache_dir(temp.path());
let repo = "oci://ghcr.io/owner/nyl/chart";
let version = "0.1.0";
let cache_path = puller.chart_cache_path(repo, version);
std::fs::create_dir_all(&cache_path).unwrap();
std::fs::write(
cache_path.join("Chart.yaml"),
"apiVersion: v2\nname: chart\nversion: 0.1.0\n",
)
.unwrap();
let result = puller.pull(repo, version, None).unwrap();
assert_eq!(result, cache_path);
assert!(result.join("Chart.yaml").exists());
}
#[test]
fn test_sanitize_version_safe() {
assert_eq!(OciChartPuller::sanitize_version("1.0.0"), "1.0.0");
assert_eq!(OciChartPuller::sanitize_version("1.0.0-alpha"), "1.0.0-alpha");
assert_eq!(OciChartPuller::sanitize_version("1.0.0_beta"), "1.0.0_beta");
}
#[test]
fn test_sanitize_version_path_traversal() {
assert_eq!(
OciChartPuller::sanitize_version("../../../etc/passwd"),
".._.._.._etc_passwd"
);
assert_eq!(OciChartPuller::sanitize_version("1.0/../../bad"), "1.0_.._.._bad");
assert_eq!(OciChartPuller::sanitize_version(".."), "..");
}
#[test]
fn test_sanitize_version_special_chars() {
assert_eq!(OciChartPuller::sanitize_version("1.0.0+build"), "1.0.0_build");
assert_eq!(OciChartPuller::sanitize_version("v1.0.0!@#$"), "v1.0.0____");
}
#[test]
fn test_sanitize_version_empty() {
assert_eq!(OciChartPuller::sanitize_version(""), "unknown");
}
#[test]
fn test_sanitize_version_only_special_chars() {
assert_eq!(OciChartPuller::sanitize_version("!@#$%"), "_____");
}
}