use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use futures::StreamExt;
pub async fn pull(image: &str, dest: &Path, cache_dir: &Path) -> anyhow::Result<()> {
let ref_ = ImageRef::parse(image)?;
let token = fetch_token(&ref_).await?;
let layers = fetch_manifest(&ref_, &token).await?;
std::fs::create_dir_all(dest)?;
for digest in &layers {
extract_or_cache_layer(&ref_, digest, dest, cache_dir, &token).await?;
}
Ok(())
}
struct ImageRef {
repo: String,
tag: String,
registry_api: String,
registry_auth: String,
}
impl ImageRef {
fn parse(image: &str) -> anyhow::Result<Self> {
let image = image.strip_prefix("docker.io/").unwrap_or(image);
let (rest, tag) = if let Some(last_slash) = image.rfind('/') {
let after_slash = &image[last_slash..];
if let Some(tag_pos) = after_slash.rfind(':') {
let abs_pos = last_slash + tag_pos;
(&image[..abs_pos], &image[abs_pos + 1..])
} else {
(image, "latest")
}
} else {
image.split_once(':').unwrap_or((image, "latest"))
};
let (registry, repo) = if let Some((first, remainder)) = rest.split_once('/') {
if is_registry_host(first) {
(first, remainder.to_string())
} else {
("docker.io", rest.to_string())
}
} else {
("docker.io", format!("library/{rest}"))
};
let (registry_api, registry_auth) = registry_endpoints(registry);
Ok(Self {
repo,
tag: tag.to_string(),
registry_api,
registry_auth,
})
}
}
fn is_registry_host(host: &str) -> bool {
host.contains('.') || host.contains(':') || host == "localhost"
}
fn registry_endpoints(host: &str) -> (String, String) {
if host == "docker.io" {
(
"registry-1.docker.io".to_string(),
"auth.docker.io".to_string(),
)
} else {
(host.to_string(), host.to_string())
}
}
fn resolve_platform_manifest(manifests: &[serde_json::Value]) -> anyhow::Result<String> {
#[cfg(target_arch = "x86_64")]
let target_arch = "amd64";
#[cfg(target_arch = "aarch64")]
let target_arch = "arm64";
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
let target_arch = "amd64";
for entry in manifests {
let plat = &entry["platform"];
if plat["os"].as_str() == Some("linux")
&& plat["architecture"].as_str() == Some(target_arch)
{
if let Some(d) = entry["digest"].as_str() {
return Ok(d.to_string());
}
}
}
if let Some(entry) = manifests.first()
&& let Some(d) = entry["digest"].as_str()
{
return Ok(d.to_string());
}
anyhow::bail!("no platform entries in manifest list")
}
async fn fetch_manifest_by_digest(
ref_: &ImageRef,
digest: &str,
token: &str,
) -> anyhow::Result<Vec<String>> {
let url = format!(
"https://{}/v2/{}/manifests/{}",
ref_.registry_api, ref_.repo, digest
);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header(
"Accept",
"application/vnd.oci.image.manifest.v1+json, \
application/vnd.docker.distribution.manifest.v2+json",
)
.bearer_auth(token)
.send()
.await
.map_err(|e| anyhow::anyhow!("child manifest request failed: {e}"))?;
let body = resp
.text()
.await
.map_err(|e| anyhow::anyhow!("reading child manifest: {e}"))?;
let parsed: serde_json::Value =
serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("child manifest JSON: {e}"))?;
let mut layers = Vec::new();
if let Some(arr) = parsed["layers"].as_array() {
for layer in arr {
if let Some(d) = layer["digest"].as_str() {
layers.push(d.to_string());
}
}
}
if layers.is_empty() {
anyhow::bail!("no layers in child manifest {digest}");
}
Ok(layers)
}
async fn fetch_token(ref_: &ImageRef) -> anyhow::Result<String> {
if ref_.registry_auth == "auth.docker.io" {
let url = format!(
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:{}:pull",
ref_.repo
);
let resp = reqwest::get(&url)
.await
.map_err(|e| anyhow::anyhow!("auth request failed: {e}"))?;
let body = resp
.text()
.await
.map_err(|e| anyhow::anyhow!("reading auth response: {e}"))?;
let parsed: serde_json::Value =
serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("auth JSON: {e}"))?;
parsed["token"]
.as_str()
.map(String::from)
.ok_or_else(|| anyhow::anyhow!("no token in auth response: {body}"))
} else {
let client = reqwest::Client::new();
let v2_url = format!("https://{}/v2/", ref_.registry_auth);
let resp = client
.get(&v2_url)
.send()
.await
.map_err(|e| anyhow::anyhow!("auth challenge request failed: {e}"))?;
let www_auth = resp
.headers()
.get("www-authenticate")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
anyhow::anyhow!(
"no WWW-Authenticate header from registry {}. \
Private registries may require authentication — \
set the REGISTRY_TOKEN env var for bearer auth.",
ref_.registry_auth
)
})?;
let realm = parse_www_authenticate_param(www_auth, "realm")
.ok_or_else(|| anyhow::anyhow!("no realm in WWW-Authenticate: {www_auth}"))?;
let service = parse_www_authenticate_param(www_auth, "service")
.unwrap_or(ref_.registry_auth.as_str());
let token_url = format!(
"{realm}?service={service}&scope=repository:{}:pull",
ref_.repo
);
let mut req = client.get(&token_url);
if let Ok(token) = std::env::var("REGISTRY_TOKEN") {
req = req.bearer_auth(token);
}
let resp = req
.send()
.await
.map_err(|e| anyhow::anyhow!("token request failed: {e}"))?;
let body = resp
.text()
.await
.map_err(|e| anyhow::anyhow!("reading token response: {e}"))?;
let parsed: serde_json::Value =
serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("token JSON: {e}"))?;
parsed["token"]
.as_str()
.or_else(|| parsed["access_token"].as_str())
.map(String::from)
.ok_or_else(|| anyhow::anyhow!("no token in auth response: {body}"))
}
}
fn parse_www_authenticate_param<'a>(header: &'a str, param: &str) -> Option<&'a str> {
let prefix = format!("{param}=\"");
let start = header.find(&prefix)? + prefix.len();
let end = header[start..].find('"')?;
Some(&header[start..start + end])
}
async fn fetch_manifest(ref_: &ImageRef, token: &str) -> anyhow::Result<Vec<String>> {
let url = format!(
"https://{}/v2/{}/manifests/{}",
ref_.registry_api, ref_.repo, ref_.tag
);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header(
"Accept",
"application/vnd.oci.image.manifest.v1+json, \
application/vnd.docker.distribution.manifest.v2+json",
)
.bearer_auth(token)
.send()
.await
.map_err(|e| anyhow::anyhow!("manifest request failed: {e}"))?;
let body = resp
.text()
.await
.map_err(|e| anyhow::anyhow!("reading manifest: {e}"))?;
let parsed: serde_json::Value =
serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("manifest JSON: {e}"))?;
if let Some(manifests) = parsed["manifests"].as_array() {
let child_digest = resolve_platform_manifest(manifests)?;
return fetch_manifest_by_digest(ref_, &child_digest, token).await;
}
let mut layers = Vec::new();
if let Some(arr) = parsed["layers"].as_array() {
for layer in arr {
if let Some(d) = layer["digest"].as_str() {
layers.push(d.to_string());
}
}
}
if layers.is_empty() {
anyhow::bail!("no layers found in manifest for {}:{}", ref_.repo, ref_.tag);
}
Ok(layers)
}
async fn extract_or_cache_layer(
ref_: &ImageRef,
digest: &str,
dest: &Path,
cache_dir: &Path,
token: &str,
) -> anyhow::Result<()> {
let cache_path = blob_cache_path(cache_dir, digest);
let blob_bytes = if cache_path.exists() {
let bytes = std::fs::read(&cache_path)?;
verify_blob_digest(&bytes, digest)?;
bytes
} else {
let bytes = download_blob(ref_, digest, token).await?;
verify_blob_digest(&bytes, digest)?;
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&cache_path, &bytes)?;
bytes
};
let dest = dest.to_path_buf();
let digest_owned = digest.to_string();
tokio::task::spawn_blocking(move || {
validate_tar_entries(&blob_bytes)?;
let mut child = std::process::Command::new("tar")
.args([
"-x",
"--no-same-owner",
"--no-same-permissions",
"--no-absolute-filenames",
"-C",
])
.arg(&dest)
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow::anyhow!("spawning tar: {e}"))?;
let mut stdin = child.stdin.take().unwrap();
stdin
.write_all(&blob_bytes)
.map_err(|e| anyhow::anyhow!("writing to tar stdin: {e}"))?;
drop(stdin);
let output = child
.wait_with_output()
.map_err(|e| anyhow::anyhow!("waiting for tar: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("tar extraction failed for {digest_owned}: {stderr}");
}
process_whiteouts(&dest)
.map_err(|e| anyhow::anyhow!("whiteout processing failed for {digest_owned}: {e}"))?;
Ok(())
})
.await
.map_err(|e| anyhow::anyhow!("spawn_blocking join: {e}"))?
}
async fn download_blob(ref_: &ImageRef, digest: &str, token: &str) -> anyhow::Result<Vec<u8>> {
let url = format!(
"https://{}/v2/{}/blobs/{}",
ref_.registry_api, ref_.repo, digest
);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.bearer_auth(token)
.send()
.await
.map_err(|e| anyhow::anyhow!("blob download failed: {e}"))?;
const MAX_BLOB_BYTES: u64 = 2 * 1024 * 1024 * 1024;
if let Some(len) = resp.content_length() {
if len > MAX_BLOB_BYTES {
anyhow::bail!("blob {digest} Content-Length {len} exceeds {MAX_BLOB_BYTES} byte cap");
}
}
stream_blob_with_cap(resp.bytes_stream(), digest, MAX_BLOB_BYTES).await
}
async fn stream_blob_with_cap<S, T, E>(
mut stream: S,
digest: &str,
max_bytes: u64,
) -> anyhow::Result<Vec<u8>>
where
S: futures::Stream<Item = Result<T, E>> + Unpin,
T: AsRef<[u8]>,
E: std::error::Error + Send + Sync + 'static,
{
let mut buf = Vec::new();
let mut total: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| anyhow::anyhow!("reading blob chunk: {e}"))?;
let bytes = chunk.as_ref();
total += bytes.len() as u64;
if total > max_bytes {
anyhow::bail!("blob {digest} exceeds {max_bytes} byte cap");
}
buf.extend_from_slice(bytes);
}
Ok(buf)
}
fn validate_tar_entries(bytes: &[u8]) -> anyhow::Result<()> {
let mut offset = 0;
while offset + 512 <= bytes.len() {
let header = &bytes[offset..offset + 512];
if header.iter().take(100).all(|&b| b == 0) {
break;
}
let name_len = header[0..100].iter().position(|&b| b == 0).unwrap_or(100);
let name = std::str::from_utf8(&header[0..name_len]).unwrap_or("");
let is_ustar = header[257..263] == *b"ustar\x00" || header[257..263] == *b"ustar ";
let prefix = if is_ustar {
let prefix_end = header[345..500].iter().position(|&b| b == 0).unwrap_or(155);
std::str::from_utf8(&header[345..345 + prefix_end]).unwrap_or("")
} else {
""
};
let full_name = if prefix.is_empty() {
name.to_string()
} else {
format!("{prefix}/{name}")
};
if contains_dotdot(&full_name) {
anyhow::bail!("tar entry {full_name:?} contains '..' path component");
}
let size_str = std::str::from_utf8(&header[124..136]).unwrap_or("");
let size_str = size_str.trim_end_matches('\0').trim();
let size = u64::from_str_radix(size_str, 8).unwrap_or(0);
let data_blocks = ((size + 511) / 512) as usize;
offset += 512 + data_blocks * 512;
if offset > bytes.len() {
break;
}
}
Ok(())
}
fn contains_dotdot(path: &str) -> bool {
path == ".." || path.starts_with("../") || path.contains("/../") || path.ends_with("/..")
}
fn process_whiteouts(dest: &Path) -> std::io::Result<()> {
let mut to_remove: Vec<PathBuf> = Vec::new();
let mut dirs_to_clear: Vec<PathBuf> = Vec::new();
walk_whiteouts(dest, &mut to_remove, &mut dirs_to_clear)?;
for path in &to_remove {
if path.is_dir() {
let _ = std::fs::remove_dir_all(path);
} else {
let _ = std::fs::remove_file(path);
}
}
for dir in &dirs_to_clear {
if dir.is_dir() {
for child in std::fs::read_dir(dir)? {
let child = child?;
let _ = std::fs::remove_file(child.path());
}
}
}
Ok(())
}
fn walk_whiteouts(
dir: &Path,
to_remove: &mut Vec<PathBuf>,
dirs_to_clear: &mut Vec<PathBuf>,
) -> std::io::Result<()> {
let entries = match std::fs::read_dir(dir) {
Ok(entries) => entries,
Err(_) => return Ok(()),
};
for entry in entries {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
let file_name = entry.file_name().to_string_lossy().into_owned();
if file_name == ".wh..wh..opq" {
if let Some(parent) = path.parent() {
dirs_to_clear.push(parent.to_path_buf());
}
to_remove.push(path);
} else if let Some(target) = file_name.strip_prefix(".wh.") {
if let Some(parent) = path.parent() {
to_remove.push(parent.join(target));
}
to_remove.push(path);
} else if path.is_dir() {
walk_whiteouts(&path, to_remove, dirs_to_clear)?;
}
}
Ok(())
}
fn blob_cache_path(cache_dir: &Path, digest: &str) -> PathBuf {
if let Some((algo, hex)) = digest.split_once(':') {
cache_dir.join("blobs").join(algo).join(hex)
} else {
cache_dir.join("blobs").join(digest)
}
}
fn verify_blob_digest(bytes: &[u8], expected_digest: &str) -> anyhow::Result<()> {
let (algo, expected_hex) = expected_digest
.split_once(':')
.ok_or_else(|| anyhow::anyhow!("invalid digest format: {expected_digest}"))?;
if algo != "sha256" {
anyhow::bail!("unsupported digest algorithm: {algo}");
}
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(bytes);
let actual_hex = hex::encode(hasher.finalize());
if actual_hex != expected_hex {
anyhow::bail!(
"blob digest mismatch: expected sha256:{expected_hex}, got sha256:{actual_hex}"
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_blob_digest_valid() {
let data = b"hello world";
let digest = "sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
assert!(verify_blob_digest(data, digest).is_ok());
}
#[test]
fn verify_blob_digest_mismatch() {
let data = b"hello world";
let digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
let err = verify_blob_digest(data, digest).unwrap_err();
assert!(err.to_string().contains("blob digest mismatch"));
}
#[test]
fn verify_blob_digest_invalid_format() {
let err = verify_blob_digest(b"data", "noseparator").unwrap_err();
assert!(err.to_string().contains("invalid digest format"));
}
#[test]
fn verify_blob_digest_unsupported_algo() {
let err = verify_blob_digest(b"data", "sha512:abcdef").unwrap_err();
assert!(err.to_string().contains("unsupported digest algorithm"));
}
#[test]
fn verify_blob_digest_empty_bytes() {
let data: &[u8] = &[];
let digest = "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
assert!(verify_blob_digest(data, digest).is_ok());
}
#[test]
fn image_ref_docker_hub_official() {
let r = ImageRef::parse("alpine:3.21").unwrap();
assert_eq!(r.registry_api, "registry-1.docker.io");
assert_eq!(r.registry_auth, "auth.docker.io");
assert_eq!(r.repo, "library/alpine");
assert_eq!(r.tag, "3.21");
}
#[test]
fn image_ref_docker_hub_official_default_tag() {
let r = ImageRef::parse("alpine").unwrap();
assert_eq!(r.repo, "library/alpine");
assert_eq!(r.tag, "latest");
}
#[test]
fn image_ref_docker_hub_explicit_registry() {
let r = ImageRef::parse("docker.io/library/alpine:latest").unwrap();
assert_eq!(r.registry_api, "registry-1.docker.io");
assert_eq!(r.repo, "library/alpine");
assert_eq!(r.tag, "latest");
}
#[test]
fn image_ref_docker_hub_user_image() {
let r = ImageRef::parse("myuser/myimage:v1.0").unwrap();
assert_eq!(r.registry_api, "registry-1.docker.io");
assert_eq!(r.repo, "myuser/myimage");
assert_eq!(r.tag, "v1.0");
}
#[test]
fn image_ref_docker_hub_user_image_default_tag() {
let r = ImageRef::parse("myuser/myimage").unwrap();
assert_eq!(r.repo, "myuser/myimage");
assert_eq!(r.tag, "latest");
}
#[test]
fn image_ref_ghcr() {
let r = ImageRef::parse("ghcr.io/owner/repo:v1").unwrap();
assert_eq!(r.registry_api, "ghcr.io");
assert_eq!(r.registry_auth, "ghcr.io");
assert_eq!(r.repo, "owner/repo");
assert_eq!(r.tag, "v1");
}
#[test]
fn image_ref_quay() {
let r = ImageRef::parse("quay.io/org/image:latest").unwrap();
assert_eq!(r.registry_api, "quay.io");
assert_eq!(r.repo, "org/image");
assert_eq!(r.tag, "latest");
}
#[test]
fn image_ref_custom_registry_with_port() {
let r = ImageRef::parse("localhost:5000/team/image:v1").unwrap();
assert_eq!(r.registry_api, "localhost:5000");
assert_eq!(r.repo, "team/image");
assert_eq!(r.tag, "v1");
}
#[test]
fn image_ref_custom_registry_port_no_tag() {
let r = ImageRef::parse("localhost:5000/team/image").unwrap();
assert_eq!(r.registry_api, "localhost:5000");
assert_eq!(r.repo, "team/image");
assert_eq!(r.tag, "latest");
}
#[test]
fn www_auth_realm() {
let header = r#"Bearer realm="https://ghcr.io/token",service="ghcr.io""#;
assert_eq!(
parse_www_authenticate_param(header, "realm"),
Some("https://ghcr.io/token")
);
}
#[test]
fn www_auth_service() {
let header = r#"Bearer realm="https://ghcr.io/token",service="ghcr.io""#;
assert_eq!(
parse_www_authenticate_param(header, "service"),
Some("ghcr.io")
);
}
#[test]
fn www_auth_missing_param() {
let header = r#"Bearer realm="https://ghcr.io/token""#;
assert_eq!(parse_www_authenticate_param(header, "service"), None);
}
#[test]
fn resolve_platform_matches_linux_amd64() {
let manifests = serde_json::json!([
{"platform": {"os": "windows", "architecture": "amd64"}, "digest": "sha256:win"},
{"platform": {"os": "linux", "architecture": "amd64"}, "digest": "sha256:linux"},
]);
let arr = manifests.as_array().unwrap();
let digest = resolve_platform_manifest(arr).unwrap();
assert_eq!(digest, "sha256:linux");
}
#[test]
fn resolve_platform_falls_back_to_first() {
let manifests = serde_json::json!([
{"platform": {"os": "linux", "architecture": "arm64"}, "digest": "sha256:arm"},
{"platform": {"os": "windows", "architecture": "amd64"}, "digest": "sha256:win"},
]);
let arr = manifests.as_array().unwrap();
let digest = resolve_platform_manifest(arr).unwrap();
assert_eq!(digest, "sha256:arm");
}
#[test]
fn resolve_platform_empty_is_error() {
let manifests: Vec<serde_json::Value> = vec![];
let err = resolve_platform_manifest(&manifests).unwrap_err();
assert!(
err.to_string().contains("no platform entries"),
"expected 'no platform entries', got: {err}"
);
}
#[test]
fn resolve_platform_missing_digest_is_error() {
let manifests = serde_json::json!([
{"platform": {"os": "linux", "architecture": "amd64"}, "no_digest": "nope"},
]);
let arr = manifests.as_array().unwrap();
let err = resolve_platform_manifest(arr).unwrap_err();
assert!(
err.to_string().contains("no platform entries"),
"expected 'no platform entries' when all entries lack digest, got: {err}"
);
}
#[test]
fn blob_cache_path_with_algo_and_hex() {
let cache = std::path::Path::new("/tmp/cache");
let path = blob_cache_path(cache, "sha256:abcdef123456");
assert_eq!(
path,
std::path::Path::new("/tmp/cache/blobs/sha256/abcdef123456")
);
}
#[test]
fn blob_cache_path_without_separator() {
let cache = std::path::Path::new("/tmp/cache");
let path = blob_cache_path(cache, "noseparator");
assert_eq!(path, std::path::Path::new("/tmp/cache/blobs/noseparator"));
}
#[test]
fn is_registry_host_true_for_dotted() {
assert!(is_registry_host("docker.io"));
assert!(is_registry_host("ghcr.io"));
assert!(is_registry_host("registry.internal.corp.com"));
}
#[test]
fn is_registry_host_true_for_port() {
assert!(is_registry_host("localhost:5000"));
assert!(is_registry_host("10.0.0.1:8080"));
}
#[test]
fn is_registry_host_true_for_bare_localhost() {
assert!(is_registry_host("localhost"));
}
#[test]
fn is_registry_host_false_for_plain_name() {
assert!(!is_registry_host("alpine"));
assert!(!is_registry_host("myuser"));
}
#[test]
fn registry_endpoints_docker_hub_uses_split_hosts() {
let (api, auth) = registry_endpoints("docker.io");
assert_eq!(api, "registry-1.docker.io");
assert_eq!(auth, "auth.docker.io");
}
#[test]
fn registry_endpoints_other_uses_same_host() {
let (api, auth) = registry_endpoints("ghcr.io");
assert_eq!(api, "ghcr.io");
assert_eq!(auth, "ghcr.io");
}
#[test]
fn registry_endpoints_localhost_uses_same_host() {
let (api, auth) = registry_endpoints("localhost:5000");
assert_eq!(api, "localhost:5000");
assert_eq!(auth, "localhost:5000");
}
#[test]
fn contains_dotdot_bare_dotdot() {
assert!(contains_dotdot(".."));
}
#[test]
fn contains_dotdot_starts_with_dotdot_slash() {
assert!(contains_dotdot("../etc/passwd"));
}
#[test]
fn contains_dotdot_middle_dotdot() {
assert!(contains_dotdot("foo/../bar"));
}
#[test]
fn contains_dotdot_ends_with_slash_dotdot() {
assert!(contains_dotdot("foo/bar/.."));
}
#[test]
fn contains_dotdot_safe_path() {
assert!(!contains_dotdot("foo/bar/baz"));
assert!(!contains_dotdot("foo.bar"));
assert!(!contains_dotdot("..."));
}
fn tar_header(name: &str, size: u64) -> Vec<u8> {
let mut header = [0u8; 512];
let name_bytes = name.as_bytes();
let copy_len = name_bytes.len().min(99);
header[..copy_len].copy_from_slice(&name_bytes[..copy_len]);
header[100..108].copy_from_slice(b"0000644\0");
header[108..116].copy_from_slice(b"0000000\0");
header[116..124].copy_from_slice(b"0000000\0");
let size_str = format!("{size:011o}\0");
header[124..136].copy_from_slice(size_str.as_bytes());
header[136..148].copy_from_slice(b"00000000000\0");
header[156] = b'0';
header[257..263].copy_from_slice(b"ustar\x00");
let mut cksum: u32 = 0;
for b in header.iter() {
cksum += *b as u32;
}
for i in 148..156 {
cksum -= header[i] as u32;
cksum += b' ' as u32;
}
let cksum_str = format!("{cksum:06o}\0 ");
header[148..156].copy_from_slice(cksum_str.as_bytes());
let mut tar = header.to_vec();
let data_blocks = ((size + 511) / 512) as usize;
tar.resize(512 + data_blocks * 512, 0);
tar
}
#[test]
fn validate_tar_entries_safe_paths() {
let mut tar = Vec::new();
tar.extend(tar_header("usr/bin/foo", 0));
tar.extend(tar_header("etc/config", 0));
tar.extend(vec![0u8; 1024]);
validate_tar_entries(&tar).expect("safe paths should pass");
}
#[test]
fn validate_tar_entries_rejects_dotdot() {
let mut tar = Vec::new();
tar.extend(tar_header("usr/bin/foo", 0));
tar.extend(tar_header("../etc/passwd", 0));
tar.extend(vec![0u8; 1024]);
let err = validate_tar_entries(&tar).unwrap_err();
assert!(err.to_string().contains(".."));
}
#[test]
fn validate_tar_entries_empty_tar() {
let tar = vec![0u8; 1024];
validate_tar_entries(&tar).expect("empty tar should pass");
}
#[test]
fn process_whiteouts_wh_file_removes_target() {
let dir = std::env::temp_dir().join(format!(
"dirge-oci-test-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
));
std::fs::create_dir_all(&dir).unwrap();
let target = dir.join("removed.txt");
let marker = dir.join(".wh.removed.txt");
std::fs::write(&target, b"lower layer content").unwrap();
std::fs::write(&marker, b"").unwrap();
process_whiteouts(&dir).unwrap();
assert!(!target.exists(), "whited-out file should be removed");
assert!(!marker.exists(), "whiteout marker should be removed");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn process_whiteouts_opaque_clears_directory() {
let dir = std::env::temp_dir().join(format!(
"dirge-oci-test-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(dir.join("keep.me"), b"keep").unwrap();
std::fs::write(dir.join(".wh..wh..opq"), b"").unwrap();
process_whiteouts(&dir).unwrap();
assert!(
!dir.join("keep.me").exists(),
"opaque dir should clear siblings"
);
assert!(
!dir.join(".wh..wh..opq").exists(),
"opaque marker should be removed"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn process_whiteouts_no_whiteouts_leaves_files() {
let dir = std::env::temp_dir().join(format!(
"dirge-oci-test-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(dir.join("stay.txt"), b"data").unwrap();
std::fs::create_dir(dir.join("subdir")).unwrap();
std::fs::write(dir.join("subdir/nested.txt"), b"nested").unwrap();
process_whiteouts(&dir).unwrap();
assert!(dir.join("stay.txt").exists());
assert!(dir.join("subdir/nested.txt").exists());
let _ = std::fs::remove_dir_all(&dir);
}
#[tokio::test]
async fn stream_blob_with_cap_below_limit() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> =
vec![Ok(b"hello ".to_vec()), Ok(b"world".to_vec())];
let stream = futures::stream::iter(chunks);
let result = stream_blob_with_cap(stream, "sha256:test", 100)
.await
.unwrap();
assert_eq!(result, b"hello world");
}
#[tokio::test]
async fn stream_blob_with_cap_empty_stream() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> = vec![];
let stream = futures::stream::iter(chunks);
let result = stream_blob_with_cap(stream, "sha256:test", 100)
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn stream_blob_with_cap_exceeds_limit_in_single_chunk() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> =
vec![Ok(b"12345678901234567890".to_vec())];
let stream = futures::stream::iter(chunks);
let err = stream_blob_with_cap(stream, "sha256:big", 10)
.await
.unwrap_err();
assert!(
err.to_string().contains("exceeds"),
"expected 'exceeds', got: {err}"
);
}
#[tokio::test]
async fn stream_blob_with_cap_exceeds_limit_across_chunks() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> =
vec![Ok(b"ab".to_vec()), Ok(b"cd".to_vec()), Ok(b"ef".to_vec())];
let stream = futures::stream::iter(chunks);
let err = stream_blob_with_cap(stream, "sha256:big", 5)
.await
.unwrap_err();
assert!(
err.to_string().contains("exceeds"),
"expected 'exceeds', got: {err}"
);
}
#[tokio::test]
async fn stream_blob_with_cap_exactly_at_limit() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> = vec![Ok(b"12345".to_vec())];
let stream = futures::stream::iter(chunks);
let result = stream_blob_with_cap(stream, "sha256:exact", 5)
.await
.unwrap();
assert_eq!(result, b"12345");
}
#[tokio::test]
async fn stream_blob_with_cap_propagates_error() {
let chunks: Vec<Result<Vec<u8>, std::io::Error>> = vec![
Ok(b"first ".to_vec()),
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"connection lost",
)),
];
let stream = futures::stream::iter(chunks);
let err = stream_blob_with_cap(stream, "sha256:test", 100)
.await
.unwrap_err();
assert!(
err.to_string().contains("reading blob chunk"),
"expected 'reading blob chunk', got: {err}"
);
assert!(
err.to_string().contains("connection lost"),
"expected 'connection lost', got: {err}"
);
}
}