use std::path::{Path, PathBuf};
use std::time::Duration;
use super::error::MediaError;
const HASH_PREFIX: &str = "blake3:";
const VERIFY_THRESHOLD: u64 = 1024 * 1024;
const MAX_STORE_SIZE: usize = 100 * 1024 * 1024;
#[cfg(feature = "media-compression")]
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
#[cfg(feature = "media-compression")]
const CAS_MAGIC: &[u8; 2] = b"NK";
#[cfg(feature = "media-compression")]
const CAS_FLAG_ZSTD: u8 = 0x01;
#[cfg(feature = "media-compression")]
const CAS_FLAG_RAW: u8 = 0x00;
#[cfg(feature = "media-compression")]
const CAS_FRAMING_VERSION: u8 = 0x00;
#[cfg(feature = "media-compression")]
const CAS_HEADER_LEN: usize = 4;
#[cfg(feature = "media-compression")]
const ZSTD_LEVEL: i32 = 3;
#[cfg(feature = "media-compression")]
const MAX_DECOMPRESS_SIZE: u64 = 200 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct StoreResult {
pub hash: String,
pub path: PathBuf,
pub size: u64,
pub deduplicated: bool,
pub verified: bool,
pub pipeline_ms: u64,
}
#[derive(Debug, Clone)]
pub struct CasEntry {
pub hash: String,
pub path: PathBuf,
pub size: u64,
}
#[derive(Debug, Clone)]
pub struct CleanResult {
pub removed: u64,
pub bytes_freed: u64,
}
pub struct CasStore {
root: PathBuf,
}
#[cfg(feature = "media-compression")]
fn should_compress(data: &[u8]) -> bool {
if data.len() < 64 {
return false; }
if data.len() >= 4 && data[..4] == ZSTD_MAGIC {
return false;
}
let mime = infer::get(data).map(|t| t.mime_type());
!matches!(mime, Some(m) if m.starts_with("image/")
|| m.starts_with("audio/")
|| m.starts_with("video/")
|| m == "application/zip"
|| m == "application/gzip"
|| m == "application/x-bzip2"
|| m == "application/x-xz"
)
}
#[cfg(feature = "media-compression")]
fn compress_if_beneficial(data: &[u8]) -> Vec<u8> {
match zstd::encode_all(std::io::Cursor::new(data), ZSTD_LEVEL) {
Ok(compressed) if compressed.len() + CAS_HEADER_LEN < data.len() => {
let mut framed = Vec::with_capacity(CAS_HEADER_LEN + compressed.len());
framed.extend_from_slice(CAS_MAGIC);
framed.push(CAS_FLAG_ZSTD);
framed.push(CAS_FRAMING_VERSION);
framed.extend_from_slice(&compressed);
framed
}
_ => {
let mut framed = Vec::with_capacity(CAS_HEADER_LEN + data.len());
framed.extend_from_slice(CAS_MAGIC);
framed.push(CAS_FLAG_RAW);
framed.push(CAS_FRAMING_VERSION);
framed.extend_from_slice(data);
framed
}
}
}
#[cfg(feature = "media-compression")]
fn frame_uncompressed(data: &[u8]) -> Vec<u8> {
let mut framed = Vec::with_capacity(CAS_HEADER_LEN + data.len());
framed.extend_from_slice(CAS_MAGIC);
framed.push(CAS_FLAG_RAW);
framed.push(CAS_FRAMING_VERSION);
framed.extend_from_slice(data);
framed
}
#[cfg(feature = "media-compression")]
fn transparent_decompress(data: Vec<u8>) -> Result<Vec<u8>, MediaError> {
if data.len() < CAS_HEADER_LEN || &data[..2] != CAS_MAGIC {
return Ok(data);
}
let flag = data[2];
match flag {
CAS_FLAG_ZSTD => {
let zstd_data = &data[CAS_HEADER_LEN..];
let cursor = std::io::Cursor::new(zstd_data);
let mut decoder = zstd::Decoder::new(cursor).map_err(|e| MediaError::MediaStoreIo {
path: PathBuf::from("<zstd-decompress>"),
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e),
})?;
let mut output = Vec::new();
let mut limited = std::io::Read::take(&mut decoder, MAX_DECOMPRESS_SIZE);
std::io::Read::read_to_end(&mut limited, &mut output).map_err(|e| {
MediaError::MediaStoreIo {
path: PathBuf::from("<zstd-decompress>"),
source: e,
}
})?;
let mut probe = [0u8; 1];
if std::io::Read::read(&mut decoder, &mut probe).unwrap_or(0) > 0 {
return Err(MediaError::Base64InputTooLarge {
size: MAX_DECOMPRESS_SIZE as usize + 1,
max: MAX_DECOMPRESS_SIZE as usize,
});
}
Ok(output)
}
CAS_FLAG_RAW => {
Ok(data[CAS_HEADER_LEN..].to_vec())
}
_ => {
Ok(data)
}
}
}
impl CasStore {
pub fn new(root: impl Into<PathBuf>) -> Self {
Self { root: root.into() }
}
pub fn workspace_default(workspace_root: &Path) -> Self {
if let Ok(override_path) = std::env::var("NIKA_MEDIA_STORE") {
let path = PathBuf::from(&override_path);
let resolved = path.canonicalize().unwrap_or(path);
return Self::new(resolved);
}
Self::new(workspace_root.join(".nika").join("media").join("store"))
}
pub fn root(&self) -> &Path {
&self.root
}
pub async fn store(&self, data: &[u8]) -> Result<StoreResult, MediaError> {
if data.is_empty() {
return Err(MediaError::EmptyMediaContent {
task_id: "(cas-direct)".to_string(),
});
}
if data.len() > MAX_STORE_SIZE {
return Err(MediaError::Base64InputTooLarge {
size: data.len(),
max: MAX_STORE_SIZE,
});
}
let process_start = std::time::Instant::now();
let raw_hash = blake3::hash(data).to_hex().to_string();
let size = data.len() as u64;
let prefixed_hash = format!("{HASH_PREFIX}{raw_hash}");
let dir = self.root.join(&raw_hash[..2]);
let final_path = dir.join(&raw_hash[2..]);
#[cfg(feature = "media-compression")]
let framed;
#[cfg(feature = "media-compression")]
let write_data: &[u8] = if should_compress(data) {
framed = compress_if_beneficial(data);
&framed
} else {
framed = frame_uncompressed(data);
&framed
};
#[cfg(not(feature = "media-compression"))]
let write_data: &[u8] = data;
tokio::fs::create_dir_all(&dir)
.await
.map_err(|e| MediaError::MediaStoreIo {
path: dir.clone(),
source: e,
})?;
match crate::io::atomic::write_fail(&final_path, write_data).await {
Ok(()) => {
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
return Ok(StoreResult {
hash: prefixed_hash,
path: final_path,
size,
deduplicated: true,
verified: false,
pipeline_ms: process_start.elapsed().as_millis() as u64,
});
}
Err(e) => {
let _ = tokio::fs::remove_file(&final_path).await;
return Err(MediaError::MediaStoreIo {
path: final_path,
source: e,
});
}
}
let verified = if size >= VERIFY_THRESHOLD {
let stored =
tokio::fs::read(&final_path)
.await
.map_err(|e| MediaError::MediaStoreIo {
path: final_path.clone(),
source: e,
})?;
#[cfg(feature = "media-compression")]
let stored = transparent_decompress(stored)?;
let verify_hash = blake3::hash(&stored).to_hex().to_string();
if verify_hash != raw_hash {
let _ = tokio::fs::remove_file(&final_path).await;
return Err(MediaError::HashMismatch {
expected: prefixed_hash,
actual: format!("{HASH_PREFIX}{verify_hash}"),
});
}
true
} else {
false };
Ok(StoreResult {
hash: prefixed_hash,
path: final_path,
size,
deduplicated: false,
verified,
pipeline_ms: process_start.elapsed().as_millis() as u64,
})
}
pub fn exists(&self, hash: &str) -> bool {
let raw = strip_hash_prefix(hash);
if raw.len() < 3 || validate_hash_hex(raw).is_err() {
return false;
}
let path = self.root.join(&raw[..2]).join(&raw[2..]);
path.exists()
}
pub async fn read(&self, hash: &str) -> Result<Vec<u8>, MediaError> {
let raw = strip_hash_prefix(hash);
if raw.len() < 3 {
return Err(MediaError::MediaNotFound {
hash: hash.to_string(),
});
}
validate_hash_hex(raw)?;
let path = self.root.join(&raw[..2]).join(&raw[2..]);
let data = tokio::fs::read(&path).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
MediaError::MediaNotFound {
hash: hash.to_string(),
}
} else {
MediaError::MediaStoreIo { path, source: e }
}
})?;
#[cfg(feature = "media-compression")]
let data = transparent_decompress(data)?;
Ok(data)
}
pub async fn read_raw(path: &Path) -> Result<Vec<u8>, MediaError> {
let data = tokio::fs::read(path)
.await
.map_err(|e| MediaError::MediaStoreIo {
path: path.to_path_buf(),
source: e,
})?;
#[cfg(feature = "media-compression")]
let data = transparent_decompress(data)?;
Ok(data)
}
pub fn list(&self) -> Vec<CasEntry> {
let mut entries = Vec::new();
let Ok(shards) = std::fs::read_dir(&self.root) else {
return entries;
};
for shard in shards.flatten() {
let shard_name = shard.file_name().to_string_lossy().to_string();
if shard_name.len() != 2 || !shard.path().is_dir() {
continue;
}
let Ok(files) = std::fs::read_dir(shard.path()) else {
continue;
};
for file in files.flatten() {
let file_name = file.file_name().to_string_lossy().to_string();
let raw_hash = format!("{}{}", shard_name, file_name);
let size = file.metadata().map(|m| m.len()).unwrap_or(0);
entries.push(CasEntry {
hash: format!("{HASH_PREFIX}{raw_hash}"),
path: file.path(),
size,
});
}
}
entries
}
pub fn clean_all(&self) -> CleanResult {
let mut removed = 0u64;
let mut bytes_freed = 0u64;
for entry in self.list() {
if let Ok(meta) = std::fs::metadata(&entry.path) {
bytes_freed += meta.len();
}
if std::fs::remove_file(&entry.path).is_ok() {
removed += 1;
} else {
tracing::warn!(path = %entry.path.display(), "failed to remove CAS file");
}
}
CleanResult {
removed,
bytes_freed,
}
}
pub fn clean_older_than(&self, duration: Duration) -> CleanResult {
let mut removed = 0u64;
let mut bytes_freed = 0u64;
let now = std::time::SystemTime::now();
for entry in self.list() {
let Ok(meta) = std::fs::metadata(&entry.path) else {
continue;
};
let Ok(modified) = meta.modified() else {
continue;
};
let Some(age) = now.duration_since(modified).ok() else {
continue;
};
if age > duration {
bytes_freed += meta.len();
if std::fs::remove_file(&entry.path).is_ok() {
removed += 1;
} else {
tracing::warn!(path = %entry.path.display(), "failed to remove CAS file");
}
}
}
CleanResult {
removed,
bytes_freed,
}
}
}
fn strip_hash_prefix(hash: &str) -> &str {
hash.strip_prefix(HASH_PREFIX).unwrap_or(hash)
}
fn validate_hash_hex(raw: &str) -> Result<(), MediaError> {
if !raw.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(MediaError::MediaNotFound {
hash: format!("{HASH_PREFIX}{raw}"),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn store_and_read_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"hello media pipeline";
let result = store.store(data).await.unwrap();
assert!(result.hash.starts_with("blake3:"));
assert!(!result.deduplicated);
assert!(!result.verified);
assert_eq!(result.size, data.len() as u64);
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(read_back, data);
}
#[tokio::test]
async fn store_dedup_same_content() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"identical content";
let r1 = store.store(data).await.unwrap();
let r2 = store.store(data).await.unwrap();
assert_eq!(r1.hash, r2.hash);
assert!(!r1.deduplicated);
assert!(r2.deduplicated);
}
#[tokio::test]
async fn exists_after_store() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"existence check";
let result = store.store(data).await.unwrap();
assert!(store.exists(&result.hash));
assert!(!store.exists("blake3:nonexistent"));
}
#[tokio::test]
async fn read_nonexistent_hash_returns_error() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let result = store
.read("blake3:abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn hash_only_filename_no_extension() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"no extension in path";
let result = store.store(data).await.unwrap();
assert!(
result.path.extension().is_none(),
"CAS path should have no extension: {:?}",
result.path
);
}
#[tokio::test]
async fn hash_has_blake3_prefix() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let result = store.store(b"prefix test").await.unwrap();
assert!(
result.hash.starts_with("blake3:"),
"hash should have blake3: prefix, got: {}",
result.hash
);
}
#[tokio::test]
async fn list_returns_stored_entries() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
store.store(b"file one").await.unwrap();
store.store(b"file two").await.unwrap();
let entries = store.list();
assert_eq!(entries.len(), 2);
assert!(entries.iter().all(|e| e.hash.starts_with("blake3:")));
}
#[tokio::test]
async fn clean_all_removes_files() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
store.store(b"data one").await.unwrap();
store.store(b"data two").await.unwrap();
let clean = store.clean_all();
assert_eq!(clean.removed, 2);
assert_eq!(store.list().len(), 0);
}
#[test]
fn workspace_default_uses_workspace_root() {
let root = std::path::PathBuf::from("/tmp/test-workspace");
let store = CasStore::workspace_default(&root);
assert!(!store.exists("blake3:nonexistent"));
}
#[tokio::test]
async fn store_rejects_empty_data() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let result = store.store(b"").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), "NIKA-258");
}
#[tokio::test]
async fn store_rejects_oversized_data() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let big = vec![0u8; MAX_STORE_SIZE + 1];
let result = store.store(&big).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), "NIKA-257");
}
#[tokio::test]
async fn store_accepts_exactly_max_store_size() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = vec![0xAB_u8; MAX_STORE_SIZE];
let result = store.store(&data).await;
assert!(
result.is_ok(),
"exactly MAX_STORE_SIZE should be accepted, got: {:?}",
result.err()
);
let sr = result.unwrap();
assert_eq!(sr.size, MAX_STORE_SIZE as u64);
assert!(!sr.deduplicated);
assert!(
sr.verified,
"100MB file should trigger read-back verification"
);
let read_back = store.read(&sr.hash).await.unwrap();
assert_eq!(read_back.len(), MAX_STORE_SIZE);
assert!(
read_back.iter().all(|&b| b == 0xAB),
"data corruption: not all bytes are 0xAB"
);
}
#[tokio::test]
async fn concurrent_cas_writes_dedup_correctly() {
let dir = tempfile::tempdir().unwrap();
let store = std::sync::Arc::new(CasStore::new(dir.path()));
let data: Vec<u8> = b"identical content for all tasks".to_vec();
let handles: Vec<_> = (0..10)
.map(|_| {
let store = std::sync::Arc::clone(&store);
let data = data.clone();
tokio::spawn(async move { store.store(&data).await })
})
.collect();
let results: Vec<StoreResult> = futures::future::join_all(handles)
.await
.into_iter()
.map(|h| h.unwrap().unwrap())
.collect();
let hash = &results[0].hash;
assert!(hash.starts_with("blake3:"));
assert!(results.iter().all(|r| &r.hash == hash));
let non_dedup_count = results.iter().filter(|r| !r.deduplicated).count();
assert_eq!(non_dedup_count, 1, "exactly one writer should be non-dedup");
}
#[test]
fn root_accessor_returns_store_root() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
assert_eq!(store.root(), dir.path());
}
#[test]
fn workspace_default_without_env_uses_nika_media_store_path() {
let saved = std::env::var("NIKA_MEDIA_STORE").ok();
std::env::remove_var("NIKA_MEDIA_STORE");
let workspace = PathBuf::from("/tmp/test-workspace");
let store = CasStore::workspace_default(&workspace);
let expected = workspace.join(".nika").join("media").join("store");
assert_eq!(store.root(), expected.as_path());
if let Some(val) = saved {
std::env::set_var("NIKA_MEDIA_STORE", val);
}
}
#[test]
fn workspace_default_respects_nika_media_store_env() {
let dir = tempfile::tempdir().unwrap();
let override_path = dir.path().join("custom-store");
let saved = std::env::var("NIKA_MEDIA_STORE").ok();
std::env::set_var("NIKA_MEDIA_STORE", override_path.to_str().unwrap());
let store = CasStore::workspace_default(Path::new("/ignored/workspace"));
assert_eq!(store.root(), override_path.as_path());
match saved {
Some(val) => std::env::set_var("NIKA_MEDIA_STORE", val),
None => std::env::remove_var("NIKA_MEDIA_STORE"),
}
}
#[test]
fn workspace_default_canonicalizes_existing_override_path() {
let dir = tempfile::tempdir().unwrap();
let actual = dir.path().join("store");
std::fs::create_dir_all(&actual).unwrap();
let dotdot_path = dir.path().join("store").join("..").join("store");
let saved = std::env::var("NIKA_MEDIA_STORE").ok();
std::env::set_var("NIKA_MEDIA_STORE", dotdot_path.to_str().unwrap());
let store = CasStore::workspace_default(Path::new("/ignored"));
let resolved = store.root().to_path_buf();
assert!(
!resolved.to_str().unwrap().contains(".."),
"path should be canonicalized, got: {}",
resolved.display()
);
assert_eq!(
resolved.canonicalize().unwrap(),
actual.canonicalize().unwrap(),
"resolved path should point to the same directory"
);
match saved {
Some(val) => std::env::set_var("NIKA_MEDIA_STORE", val),
None => std::env::remove_var("NIKA_MEDIA_STORE"),
}
}
#[tokio::test]
async fn env_override_store_is_fully_functional() {
let dir = tempfile::tempdir().unwrap();
let override_path = dir.path().join("custom-cas");
let saved = std::env::var("NIKA_MEDIA_STORE").ok();
std::env::set_var("NIKA_MEDIA_STORE", override_path.to_str().unwrap());
let store = CasStore::workspace_default(Path::new("/ignored/workspace"));
let data = b"env override test data";
let result = store.store(data).await.unwrap();
assert!(result.hash.starts_with("blake3:"));
assert!(
result.path.starts_with(&override_path),
"stored file should be inside override path, got: {}",
result.path.display()
);
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(read_back, data);
let entries = store.list();
assert_eq!(entries.len(), 1);
assert!(entries[0].path.starts_with(&override_path));
assert!(store.exists(&result.hash));
let clean = store.clean_all();
assert_eq!(clean.removed, 1);
assert_eq!(store.list().len(), 0);
match saved {
Some(val) => std::env::set_var("NIKA_MEDIA_STORE", val),
None => std::env::remove_var("NIKA_MEDIA_STORE"),
}
}
#[tokio::test]
async fn cas_path_is_always_within_root() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let payloads: Vec<&[u8]> = vec![
b"payload one",
b"payload two",
b"\x00\x01\x02\xff\xfe\xfd",
b"../../../etc/passwd",
];
let canonical_root = dir.path().canonicalize().unwrap();
for payload in payloads {
let result = store.store(payload).await.unwrap();
let canonical_path = result.path.canonicalize().unwrap();
assert!(
canonical_path.starts_with(&canonical_root),
"CAS file {:?} escapes root {:?}",
canonical_path,
canonical_root,
);
let shard = result
.path
.parent()
.unwrap()
.file_name()
.unwrap()
.to_string_lossy();
let filename = result.path.file_name().unwrap().to_string_lossy();
assert_eq!(shard.len(), 2, "Shard directory must be 2 hex chars");
assert!(
shard.chars().all(|c| c.is_ascii_hexdigit()),
"Shard '{}' contains non-hex chars",
shard
);
assert!(
filename.chars().all(|c| c.is_ascii_hexdigit()),
"Filename '{}' contains non-hex chars",
filename
);
}
}
#[test]
fn cas_hash_prefix_strip_safety() {
assert_eq!(strip_hash_prefix("blake3:abcdef"), "abcdef");
assert_eq!(strip_hash_prefix("abcdef"), "abcdef");
assert_eq!(strip_hash_prefix("blake3:../../etc"), "../../etc");
}
#[test]
fn cas_exists_rejects_short_hash() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
assert!(!store.exists("ab"));
assert!(!store.exists("a"));
assert!(!store.exists(""));
assert!(!store.exists("blake3:ab"));
assert!(!store.exists("blake3:a"));
}
#[tokio::test]
async fn cas_read_rejects_short_hash() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let result = store.read("ab").await;
assert!(result.is_err());
let result = store.read("blake3:ab").await;
assert!(result.is_err());
}
#[cfg(feature = "media-compression")]
mod compression_tests {
use super::*;
#[tokio::test]
async fn store_read_json_roundtrip_with_compression() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let json = br#"{"name":"test","items":[1,2,3,4,5],"nested":{"a":"b"}}"#;
let result = store.store(json).await.unwrap();
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(
read_back, json,
"JSON round-trip must preserve data exactly"
);
}
#[tokio::test]
async fn store_png_passes_through_uncompressed() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let mut png_data = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
png_data.extend_from_slice(&[0u8; 100]);
let result = store.store(&png_data).await.unwrap();
let raw = strip_hash_prefix(&result.hash);
let path = dir.path().join(&raw[..2]).join(&raw[2..]);
let on_disk = tokio::fs::read(&path).await.unwrap();
assert_eq!(&on_disk[..2], CAS_MAGIC, "PNG should have NK magic prefix");
assert_eq!(
on_disk[2], CAS_FLAG_RAW,
"PNG should have raw flag, not zstd"
);
assert_eq!(
on_disk[3], CAS_FRAMING_VERSION,
"PNG should have version 0x00"
);
assert_eq!(
&on_disk[CAS_HEADER_LEN..],
&png_data[..],
"PNG data should follow the 4-byte framing header"
);
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(read_back, png_data);
}
#[tokio::test]
async fn store_text_is_compressed_on_disk() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let text: Vec<u8> = "hello world! ".repeat(100).into_bytes();
let result = store.store(&text).await.unwrap();
let raw = strip_hash_prefix(&result.hash);
let path = dir.path().join(&raw[..2]).join(&raw[2..]);
let on_disk = tokio::fs::read(&path).await.unwrap();
assert_eq!(&on_disk[..2], CAS_MAGIC, "should have NK magic prefix");
assert_eq!(on_disk[2], CAS_FLAG_ZSTD, "should have zstd flag");
assert_eq!(on_disk[3], CAS_FRAMING_VERSION, "should have version 0x00");
assert_eq!(
&on_disk[CAS_HEADER_LEN..CAS_HEADER_LEN + 4],
&ZSTD_MAGIC,
"text should be zstd-compressed after 4-byte header"
);
assert!(on_disk.len() < text.len(), "compressed should be smaller");
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(read_back, text);
}
#[tokio::test]
async fn dedup_works_with_compression() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"deduplicate me please".repeat(10);
let r1 = store.store(&data).await.unwrap();
let r2 = store.store(&data).await.unwrap();
assert_eq!(r1.hash, r2.hash, "same content must produce same hash");
assert!(!r1.deduplicated);
assert!(r2.deduplicated, "second store should detect dedup");
}
#[tokio::test]
async fn budget_charged_on_original_size() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let text: Vec<u8> = "budget test ".repeat(100).into_bytes();
let original_size = text.len() as u64;
let result = store.store(&text).await.unwrap();
assert_eq!(
result.size, original_size,
"size should be original data length"
);
}
#[tokio::test]
async fn already_zstd_data_roundtrips_correctly() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let original = b"pre-compressed data content here that is long enough to be over sixty-four bytes for threshold!";
let pre_compressed =
zstd::encode_all(std::io::Cursor::new(original.as_slice()), ZSTD_LEVEL).unwrap();
assert!(pre_compressed.len() >= 4 && pre_compressed[..4] == ZSTD_MAGIC);
assert!(
!should_compress(&pre_compressed),
"zstd data should not be re-compressed"
);
let result = store.store(&pre_compressed).await.unwrap();
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(
read_back, pre_compressed,
"user-stored zstd data should round-trip exactly"
);
}
#[tokio::test]
async fn concurrent_compressed_writes() {
let dir = tempfile::tempdir().unwrap();
let store = std::sync::Arc::new(CasStore::new(dir.path()));
let data: Vec<u8> = "concurrent compression test ".repeat(50).into_bytes();
let handles: Vec<_> = (0..5)
.map(|_| {
let store = std::sync::Arc::clone(&store);
let data = data.clone();
tokio::spawn(async move { store.store(&data).await })
})
.collect();
let results: Vec<StoreResult> = futures::future::join_all(handles)
.await
.into_iter()
.map(|h| h.unwrap().unwrap())
.collect();
let hash = &results[0].hash;
assert!(results.iter().all(|r| &r.hash == hash));
let read_back = store.read(hash).await.unwrap();
assert_eq!(read_back, data);
}
#[test]
fn should_compress_text_yes() {
let text = b"hello world this is some text that should compress, adding more to be over 64 bytes for the threshold";
assert!(text.len() >= 64, "fixture must be >= 64 bytes");
assert!(should_compress(text));
}
#[test]
fn should_compress_json_yes() {
let json =
br#"{"key":"value","list":[1,2,3],"nested":{"a":"b","c":"d","e":"f","g":"h"}}"#;
assert!(json.len() >= 64, "fixture must be >= 64 bytes");
assert!(should_compress(json));
}
#[test]
fn should_compress_png_no() {
let mut png = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
png.extend_from_slice(&[0u8; 100]);
assert!(!should_compress(&png));
}
#[test]
fn should_compress_jpeg_no() {
let mut jpeg = vec![0xFF, 0xD8, 0xFF, 0xE0];
jpeg.extend_from_slice(&[0u8; 100]);
assert!(!should_compress(&jpeg));
}
#[test]
fn should_compress_small_data_no() {
assert!(
!should_compress(b"tiny"),
"data < 64 bytes should skip compression"
);
}
#[test]
fn cas_marker_framing_compressed() {
let data = b"test data that is long enough to be over sixty-four bytes threshold for compression!";
let framed = compress_if_beneficial(data);
assert_eq!(&framed[..2], CAS_MAGIC, "must start with NK magic");
if framed[2] == CAS_FLAG_ZSTD {
assert_eq!(
&framed[CAS_HEADER_LEN..CAS_HEADER_LEN + 4],
&ZSTD_MAGIC,
"zstd magic after header"
);
assert!(framed.len() < data.len(), "compressed should be smaller");
} else {
assert_eq!(framed[2], CAS_FLAG_RAW);
assert_eq!(&framed[CAS_HEADER_LEN..], data.as_slice());
}
}
#[test]
fn cas_marker_framing_uncompressed() {
let data: Vec<u8> = (0..=255).cycle().take(100).collect();
let framed = compress_if_beneficial(&data);
assert_eq!(
&framed[..2],
CAS_MAGIC,
"framed data must start with NK magic"
);
assert!(
framed[2] == CAS_FLAG_ZSTD || framed[2] == CAS_FLAG_RAW,
"framed data must have a valid compression flag"
);
}
#[test]
fn frame_uncompressed_roundtrips() {
let data = b"hello world";
let framed = frame_uncompressed(data);
assert_eq!(&framed[..2], CAS_MAGIC);
assert_eq!(framed[2], CAS_FLAG_RAW);
assert_eq!(framed[3], CAS_FRAMING_VERSION);
assert_eq!(&framed[CAS_HEADER_LEN..], data.as_slice());
let decompressed = transparent_decompress(framed).unwrap();
assert_eq!(decompressed, data);
}
#[tokio::test]
async fn bug6_small_data_starting_with_cas_marker_and_zstd_magic() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let mut evil_data: Vec<u8> = vec![0x01, 0x28, 0xB5, 0x2F, 0xFD];
evil_data.extend_from_slice(b"this is user data, not compressed!");
assert!(evil_data.len() < 64, "must be below compression threshold");
let result = store.store(&evil_data).await.unwrap();
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(
read_back, evil_data,
"Bug 6 regression: small data starting with [0x01][zstd-magic] \
must NOT be falsely decompressed"
);
}
#[tokio::test]
async fn bug6_large_data_starting_with_cas_marker_and_zstd_magic() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let mut evil_data: Vec<u8> = vec![0x01, 0x28, 0xB5, 0x2F, 0xFD];
evil_data.extend_from_slice(&[0xAB; 200]); assert!(evil_data.len() >= 64, "must be above compression threshold");
let result = store.store(&evil_data).await.unwrap();
let read_back = store.read(&result.hash).await.unwrap();
assert_eq!(
read_back, evil_data,
"Bug 6 regression: large data starting with [0x01][zstd-magic] \
must NOT be falsely decompressed"
);
}
#[test]
fn bug6_transparent_decompress_three_cases() {
let original = b"hello world! ".repeat(20);
let compressed =
zstd::encode_all(std::io::Cursor::new(original.as_slice()), ZSTD_LEVEL).unwrap();
let mut framed_compressed = Vec::new();
framed_compressed.extend_from_slice(CAS_MAGIC);
framed_compressed.push(CAS_FLAG_ZSTD);
framed_compressed.push(CAS_FRAMING_VERSION);
framed_compressed.extend_from_slice(&compressed);
let result = transparent_decompress(framed_compressed).unwrap();
assert_eq!(result, original, "case 1: compressed must decompress");
let raw_data = b"raw user data bytes";
let mut framed_raw = Vec::new();
framed_raw.extend_from_slice(CAS_MAGIC);
framed_raw.push(CAS_FLAG_RAW);
framed_raw.push(CAS_FRAMING_VERSION);
framed_raw.extend_from_slice(raw_data);
let result = transparent_decompress(framed_raw).unwrap();
assert_eq!(result, raw_data, "case 2: uncompressed must strip header");
let legacy = vec![0x89, 0x50, 0x4E, 0x47]; let result = transparent_decompress(legacy.clone()).unwrap();
assert_eq!(result, legacy, "case 3: legacy must return as-is");
let legacy_false_positive = vec![0x01, 0x28, 0xB5, 0x2F, 0xFD, 0xAA, 0xBB];
let result = transparent_decompress(legacy_false_positive.clone()).unwrap();
assert_eq!(
result, legacy_false_positive,
"case 3: data starting with 0x01 but no NK prefix — must be legacy"
);
let legacy_null = vec![0x00, 0x50, 0x4E, 0x47, 0xAA, 0xBB];
let result = transparent_decompress(legacy_null.clone()).unwrap();
assert_eq!(
result, legacy_null,
"case 3: data starting with 0x00 but no NK prefix — must be legacy"
);
}
#[test]
fn bug6_empty_data_decompress() {
let result = transparent_decompress(vec![]).unwrap();
assert!(result.is_empty(), "empty data must return empty");
}
#[test]
fn bug4_legacy_blobs_with_old_single_byte_markers_not_corrupted() {
let data_with_null = vec![0x00, 0xFF, 0xFE, 0xFD];
let result = transparent_decompress(data_with_null.clone()).unwrap();
assert_eq!(
result, data_with_null,
"0x00-prefixed legacy blob must not be stripped"
);
let data_with_one = vec![0x01, 0x28, 0xB5, 0x2F, 0xFD];
let result = transparent_decompress(data_with_one.clone()).unwrap();
assert_eq!(
result, data_with_one,
"0x01-prefixed legacy blob must not be decompressed"
);
let short = vec![0x4E, 0x4B]; let result = transparent_decompress(short.clone()).unwrap();
assert_eq!(result, short, "data shorter than header must be legacy");
}
#[tokio::test]
async fn bug6_read_raw_decompresses_correctly() {
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let mut user_data: Vec<u8> = vec![0x01, 0x28, 0xB5, 0x2F, 0xFD];
user_data.extend_from_slice(b"not actually compressed!!!");
assert!(user_data.len() < 64);
let result = store.store(&user_data).await.unwrap();
let via_read_raw = CasStore::read_raw(&result.path).await.unwrap();
assert_eq!(
via_read_raw, user_data,
"read_raw must strip framing and return original data"
);
}
}
#[cfg(unix)]
#[tokio::test]
async fn cas_store_uses_o_excl_prevents_symlink_attack() {
use std::os::unix::fs::symlink;
let dir = tempfile::tempdir().unwrap();
let store = CasStore::new(dir.path());
let data = b"symlink attack test data";
let raw_hash = blake3::hash(data).to_hex().to_string();
let shard_dir = dir.path().join(&raw_hash[..2]);
std::fs::create_dir_all(&shard_dir).unwrap();
let final_path = shard_dir.join(&raw_hash[2..]);
let decoy = dir.path().join("decoy");
std::fs::write(&decoy, b"decoy content").unwrap();
symlink(&decoy, &final_path).unwrap();
let result = store.store(data).await.unwrap();
assert!(
result.deduplicated,
"Symlink at CAS path must be treated as existing file (O_EXCL semantics)"
);
}
}