use std::path::{Path, PathBuf};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use tokio::fs;
use tracing::{debug, warn};
use crate::error::{RegistryError, RegistryResult};
use crate::types::{DsseEnvelope, FetchResult, PackHeaders};
use crate::verify::compute_digest;
const DEFAULT_TTL_SECS: i64 = 24 * 60 * 60;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheMeta {
pub fetched_at: DateTime<Utc>,
pub digest: String,
#[serde(default)]
pub etag: Option<String>,
pub expires_at: DateTime<Utc>,
#[serde(default)]
pub key_id: Option<String>,
#[serde(default)]
pub registry_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PackCache {
cache_dir: PathBuf,
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub content: String,
pub metadata: CacheMeta,
pub signature: Option<DsseEnvelope>,
}
impl PackCache {
pub fn new() -> RegistryResult<Self> {
let cache_dir = default_cache_dir()?;
Ok(Self { cache_dir })
}
pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
Self {
cache_dir: cache_dir.into(),
}
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
fn pack_dir(&self, name: &str, version: &str) -> PathBuf {
self.cache_dir.join(name).join(version)
}
pub async fn get(&self, name: &str, version: &str) -> RegistryResult<Option<CacheEntry>> {
let pack_dir = self.pack_dir(name, version);
let pack_path = pack_dir.join("pack.yaml");
let meta_path = pack_dir.join("metadata.json");
if !pack_path.exists() || !meta_path.exists() {
debug!(name, version, "pack not in cache");
return Ok(None);
}
let meta_content =
fs::read_to_string(&meta_path)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to read cache metadata: {}", e),
})?;
let metadata: CacheMeta =
serde_json::from_str(&meta_content).map_err(|e| RegistryError::Cache {
message: format!("failed to parse cache metadata: {}", e),
})?;
if metadata.expires_at < Utc::now() {
debug!(
name,
version,
expires_at = %metadata.expires_at,
"cache entry expired"
);
return Ok(None);
}
let content = fs::read_to_string(&pack_path)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to read cached pack: {}", e),
})?;
let computed_digest = compute_digest(&content);
if computed_digest != metadata.digest {
warn!(
name,
version,
expected = %metadata.digest,
actual = %computed_digest,
"cache integrity check failed"
);
return Err(RegistryError::DigestMismatch {
name: name.to_string(),
version: version.to_string(),
expected: metadata.digest,
actual: computed_digest,
});
}
let sig_path = pack_dir.join("signature.json");
let signature = if sig_path.exists() {
let sig_content = fs::read_to_string(&sig_path).await.ok();
sig_content.and_then(|s| serde_json::from_str(&s).ok())
} else {
None
};
debug!(name, version, "cache hit");
Ok(Some(CacheEntry {
content,
metadata,
signature,
}))
}
pub async fn put(
&self,
name: &str,
version: &str,
result: &FetchResult,
registry_url: Option<&str>,
) -> RegistryResult<()> {
let pack_dir = self.pack_dir(name, version);
fs::create_dir_all(&pack_dir)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to create cache directory: {}", e),
})?;
let expires_at = parse_cache_control_expiry(&result.headers);
let metadata = CacheMeta {
fetched_at: Utc::now(),
digest: result.computed_digest.clone(),
etag: result.headers.etag.clone(),
expires_at,
key_id: result.headers.key_id.clone(),
registry_url: registry_url.map(String::from),
};
let pack_path = pack_dir.join("pack.yaml");
let meta_path = pack_dir.join("metadata.json");
write_atomic(&pack_path, &result.content).await?;
let meta_json =
serde_json::to_string_pretty(&metadata).map_err(|e| RegistryError::Cache {
message: format!("failed to serialize metadata: {}", e),
})?;
write_atomic(&meta_path, &meta_json).await?;
if let Some(sig_b64) = &result.headers.signature {
if let Ok(envelope) = parse_signature(sig_b64) {
let sig_path = pack_dir.join("signature.json");
let sig_json =
serde_json::to_string_pretty(&envelope).map_err(|e| RegistryError::Cache {
message: format!("failed to serialize signature: {}", e),
})?;
write_atomic(&sig_path, &sig_json).await?;
}
}
debug!(name, version, "cached pack");
Ok(())
}
pub async fn get_metadata(&self, name: &str, version: &str) -> Option<CacheMeta> {
let meta_path = self.pack_dir(name, version).join("metadata.json");
let content = fs::read_to_string(&meta_path).await.ok()?;
serde_json::from_str(&content).ok()
}
pub async fn get_etag(&self, name: &str, version: &str) -> Option<String> {
self.get_metadata(name, version).await.and_then(|m| m.etag)
}
pub async fn is_cached(&self, name: &str, version: &str) -> bool {
match self.get_metadata(name, version).await {
Some(meta) => meta.expires_at >= Utc::now(),
None => false,
}
}
pub async fn evict(&self, name: &str, version: &str) -> RegistryResult<()> {
let pack_dir = self.pack_dir(name, version);
if pack_dir.exists() {
fs::remove_dir_all(&pack_dir)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to evict cache entry: {}", e),
})?;
debug!(name, version, "evicted from cache");
}
Ok(())
}
pub async fn clear(&self) -> RegistryResult<()> {
if self.cache_dir.exists() {
fs::remove_dir_all(&self.cache_dir)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to clear cache: {}", e),
})?;
debug!("cleared pack cache");
}
Ok(())
}
pub async fn list(&self) -> RegistryResult<Vec<(String, String, CacheMeta)>> {
let mut result = Vec::new();
if !self.cache_dir.exists() {
return Ok(result);
}
let mut names = fs::read_dir(&self.cache_dir)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to read cache directory: {}", e),
})?;
while let Some(name_entry) = names.next_entry().await.map_err(|e| RegistryError::Cache {
message: format!("failed to read directory entry: {}", e),
})? {
let name_path = name_entry.path();
if !name_path.is_dir() {
continue;
}
let name = name_entry.file_name().to_string_lossy().to_string();
let mut versions =
fs::read_dir(&name_path)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to read version directory: {}", e),
})?;
while let Some(version_entry) =
versions
.next_entry()
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to read directory entry: {}", e),
})?
{
let version_path = version_entry.path();
if !version_path.is_dir() {
continue;
}
let version = version_entry.file_name().to_string_lossy().to_string();
if let Some(meta) = self.get_metadata(&name, &version).await {
result.push((name.clone(), version, meta));
}
}
}
Ok(result)
}
}
impl Default for PackCache {
fn default() -> Self {
Self::new().unwrap_or_else(|_| Self::with_dir("/tmp/assay-cache/packs"))
}
}
fn default_cache_dir() -> RegistryResult<PathBuf> {
let base = dirs::cache_dir()
.or_else(dirs::home_dir)
.ok_or_else(|| RegistryError::Cache {
message: "could not determine cache directory".to_string(),
})?;
Ok(base.join("assay").join("cache").join("packs"))
}
fn parse_cache_control_expiry(headers: &PackHeaders) -> DateTime<Utc> {
let now = Utc::now();
let default_ttl = Duration::seconds(DEFAULT_TTL_SECS);
let ttl = headers
.cache_control
.as_ref()
.and_then(|cc| {
cc.split(',')
.find(|part| part.trim().starts_with("max-age="))
.and_then(|part| {
part.trim()
.strip_prefix("max-age=")
.and_then(|v| v.parse::<i64>().ok())
})
})
.map(Duration::seconds)
.unwrap_or(default_ttl);
now + ttl
}
fn parse_signature(b64: &str) -> RegistryResult<DsseEnvelope> {
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
let bytes = BASE64.decode(b64).map_err(|e| RegistryError::Cache {
message: format!("invalid base64 signature: {}", e),
})?;
serde_json::from_slice(&bytes).map_err(|e| RegistryError::Cache {
message: format!("invalid DSSE envelope: {}", e),
})
}
async fn write_atomic(path: &Path, content: &str) -> RegistryResult<()> {
let temp_path = path.with_extension("tmp");
fs::write(&temp_path, content)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to write temp file: {}", e),
})?;
fs::rename(&temp_path, path)
.await
.map_err(|e| RegistryError::Cache {
message: format!("failed to rename temp file: {}", e),
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
use tempfile::TempDir;
fn create_test_cache() -> (PackCache, TempDir) {
let temp_dir = TempDir::new().unwrap();
let cache = PackCache::with_dir(temp_dir.path().join("cache"));
(cache, temp_dir)
}
fn create_fetch_result(content: &str) -> FetchResult {
FetchResult {
content: content.to_string(),
headers: PackHeaders {
digest: Some(compute_digest(content)),
signature: None,
key_id: None,
etag: Some("\"abc123\"".to_string()),
cache_control: Some("max-age=3600".to_string()),
content_length: Some(content.len() as u64),
},
computed_digest: compute_digest(content),
}
}
#[tokio::test]
async fn test_cache_roundtrip() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap().unwrap();
assert_eq!(entry.content, content);
assert_eq!(entry.metadata.digest, compute_digest(content));
}
#[tokio::test]
async fn test_cache_miss() {
let (cache, _temp_dir) = create_test_cache();
let result = cache.get("nonexistent", "1.0.0").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_integrity_failure() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let pack_path = cache.pack_dir("test-pack", "1.0.0").join("pack.yaml");
fs::write(&pack_path, "corrupted content").await.unwrap();
let err = cache.get("test-pack", "1.0.0").await.unwrap_err();
assert!(matches!(err, RegistryError::DigestMismatch { .. }));
}
#[tokio::test]
async fn test_cache_expiry() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = FetchResult {
content: content.to_string(),
headers: PackHeaders {
digest: Some(compute_digest(content)),
signature: None,
key_id: None,
etag: None,
cache_control: Some("max-age=0".to_string()), content_length: None,
},
computed_digest: compute_digest(content),
};
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let entry = cache.get("test-pack", "1.0.0").await.unwrap();
assert!(entry.is_none());
}
#[tokio::test]
async fn test_cache_evict() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
assert!(cache.is_cached("test-pack", "1.0.0").await);
cache.evict("test-pack", "1.0.0").await.unwrap();
assert!(!cache.is_cached("test-pack", "1.0.0").await);
}
#[tokio::test]
async fn test_cache_clear() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache.put("pack1", "1.0.0", &result, None).await.unwrap();
cache.put("pack2", "1.0.0", &result, None).await.unwrap();
cache.clear().await.unwrap();
assert!(!cache.is_cached("pack1", "1.0.0").await);
assert!(!cache.is_cached("pack2", "1.0.0").await);
}
#[tokio::test]
async fn test_cache_list() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache.put("pack1", "1.0.0", &result, None).await.unwrap();
cache.put("pack1", "2.0.0", &result, None).await.unwrap();
cache.put("pack2", "1.0.0", &result, None).await.unwrap();
let entries = cache.list().await.unwrap();
assert_eq!(entries.len(), 3);
}
#[tokio::test]
async fn test_get_etag() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let etag = cache.get_etag("test-pack", "1.0.0").await;
assert_eq!(etag, Some("\"abc123\"".to_string()));
}
#[tokio::test]
async fn test_parse_cache_control() {
let headers = PackHeaders {
digest: None,
signature: None,
key_id: None,
etag: None,
cache_control: Some("max-age=7200, public".to_string()),
content_length: None,
};
let expires = parse_cache_control_expiry(&headers);
let now = Utc::now();
let diff = expires - now;
assert!(diff.num_seconds() >= 7190 && diff.num_seconds() <= 7210);
}
#[tokio::test]
async fn test_default_ttl() {
let headers = PackHeaders {
digest: None,
signature: None,
key_id: None,
etag: None,
cache_control: None, content_length: None,
};
let expires = parse_cache_control_expiry(&headers);
let now = Utc::now();
let diff = expires - now;
assert!(diff.num_hours() >= 23 && diff.num_hours() <= 25);
}
#[tokio::test]
async fn test_cache_with_signature() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: 1.0.0";
let envelope = DsseEnvelope {
payload_type: "application/vnd.assay.pack+yaml;v=1".to_string(),
payload: base64::engine::general_purpose::STANDARD.encode(content),
signatures: vec![],
};
let envelope_json = serde_json::to_vec(&envelope).unwrap();
let envelope_b64 = base64::engine::general_purpose::STANDARD.encode(&envelope_json);
let result = FetchResult {
content: content.to_string(),
headers: PackHeaders {
digest: Some(compute_digest(content)),
signature: Some(envelope_b64),
key_id: Some("sha256:test-key".to_string()),
etag: None,
cache_control: Some("max-age=3600".to_string()),
content_length: None,
},
computed_digest: compute_digest(content),
};
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap().unwrap();
assert!(entry.signature.is_some());
assert_eq!(
entry.signature.unwrap().payload_type,
"application/vnd.assay.pack+yaml;v=1"
);
}
#[tokio::test]
async fn test_pack_yaml_corrupt_evict_refetch() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: \"1.0.0\"";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap();
assert!(entry.is_some());
let pack_path = cache.pack_dir("test-pack", "1.0.0").join("pack.yaml");
fs::write(&pack_path, "corrupted: content\nmalicious: true")
.await
.unwrap();
let err = cache.get("test-pack", "1.0.0").await.unwrap_err();
assert!(
matches!(err, RegistryError::DigestMismatch { .. }),
"Should detect corruption: {:?}",
err
);
cache.evict("test-pack", "1.0.0").await.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap();
assert!(entry.is_none(), "Cache should be empty after evict");
}
#[tokio::test]
async fn test_signature_json_corrupt_handling() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: \"1.0.0\"";
let envelope = DsseEnvelope {
payload_type: "application/vnd.assay.pack+yaml;v=1".to_string(),
payload: base64::engine::general_purpose::STANDARD.encode(content),
signatures: vec![],
};
let envelope_json = serde_json::to_vec(&envelope).unwrap();
let envelope_b64 = base64::engine::general_purpose::STANDARD.encode(&envelope_json);
let result = FetchResult {
content: content.to_string(),
headers: PackHeaders {
digest: Some(compute_digest(content)),
signature: Some(envelope_b64),
key_id: Some("sha256:test-key".to_string()),
etag: None,
cache_control: Some("max-age=3600".to_string()),
content_length: None,
},
computed_digest: compute_digest(content),
};
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap().unwrap();
assert!(entry.signature.is_some());
let sig_path = cache.pack_dir("test-pack", "1.0.0").join("signature.json");
fs::write(&sig_path, "this is not valid json{{{")
.await
.unwrap();
let entry = cache.get("test-pack", "1.0.0").await.unwrap().unwrap();
assert!(
entry.signature.is_none(),
"Corrupt signature should be None, not error"
);
assert_eq!(entry.content, content);
}
#[tokio::test]
async fn test_metadata_json_corrupt_handling() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: \"1.0.0\"";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let meta_path = cache.pack_dir("test-pack", "1.0.0").join("metadata.json");
fs::write(&meta_path, "invalid json content").await.unwrap();
let result = cache.get("test-pack", "1.0.0").await;
assert!(
matches!(result, Err(RegistryError::Cache { .. })),
"Should return cache error for corrupt metadata: {:?}",
result
);
}
#[tokio::test]
async fn test_atomic_write_prevents_partial_cache() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: \"1.0.0\"";
let result = create_fetch_result(content);
cache
.put("test-pack", "1.0.0", &result, None)
.await
.unwrap();
let pack_dir = cache.pack_dir("test-pack", "1.0.0");
let mut entries = fs::read_dir(&pack_dir).await.unwrap();
while let Some(entry) = entries.next_entry().await.unwrap() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
assert!(
!name_str.ends_with(".tmp"),
"Temp file should not remain: {}",
name_str
);
}
}
#[tokio::test]
async fn test_cache_registry_url_tracking() {
let (cache, _temp_dir) = create_test_cache();
let content = "name: test\nversion: \"1.0.0\"";
let result = create_fetch_result(content);
cache
.put(
"test-pack",
"1.0.0",
&result,
Some("https://registry.example.com/v1"),
)
.await
.unwrap();
let meta = cache.get_metadata("test-pack", "1.0.0").await.unwrap();
assert_eq!(
meta.registry_url,
Some("https://registry.example.com/v1".to_string())
);
}
}