use crate::{
body::StreamingBody,
error::{Result, StreamingError},
runtime, StreamingCacheManager,
};
use async_trait::async_trait;
use bytes::Bytes;
use http::{Response, Version};
use http_body_util::{BodyExt, Empty};
use http_cache_semantics::CachePolicy;
use log::warn;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use url::Url;
use {
blake3,
dashmap::DashMap,
lru::LruCache,
std::num::NonZeroUsize,
std::sync::atomic::{AtomicU64, AtomicUsize, Ordering},
};
use std::collections::HashMap;
cfg_if::cfg_if! {
if #[cfg(feature = "streaming-tokio")] {
use tokio::sync::Mutex;
} else if #[cfg(feature = "streaming-smol")] {
use async_lock::Mutex;
} else {
use std::sync::Mutex;
}
}
const CACHE_VERSION: &str = "cache-v2";
#[derive(Debug, Clone, Copy)]
pub struct StreamingCacheConfig {
pub max_cache_size: Option<u64>,
pub max_entries: Option<usize>,
pub streaming_buffer_size: usize,
}
impl Default for StreamingCacheConfig {
fn default() -> Self {
Self {
max_cache_size: None,
max_entries: None,
streaming_buffer_size: 8192, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LruEntry {
cache_key: String,
content_digest: String,
last_accessed: u64,
file_size: u64,
}
#[derive(Debug)]
struct ContentRefCounter {
refs: DashMap<String, Arc<AtomicUsize>>,
lru_cache: Arc<Mutex<LruCache<String, LruEntry>>>,
current_cache_size: AtomicU64,
current_entries: AtomicUsize,
}
impl ContentRefCounter {
fn new() -> Self {
Self {
refs: DashMap::new(),
lru_cache: Arc::new(Mutex::new(LruCache::new(
NonZeroUsize::new(10000).unwrap(),
))),
current_cache_size: AtomicU64::new(0),
current_entries: AtomicUsize::new(0),
}
}
async fn current_cache_size(&self) -> Result<u64> {
Ok(self.current_cache_size.load(Ordering::Relaxed))
}
async fn current_entries(&self) -> Result<usize> {
Ok(self.current_entries.load(Ordering::Relaxed))
}
async fn add_cache_entry(
&self,
cache_key: String,
content_digest: String,
file_size: u64,
) -> Result<()> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let entry = LruEntry {
cache_key: cache_key.clone(),
content_digest,
last_accessed: now,
file_size,
};
cfg_if::cfg_if! {
if #[cfg(any(feature = "streaming-tokio", feature = "streaming-smol"))] {
let mut lru = self.lru_cache.lock().await;
lru.put(cache_key, entry);
} else {
let mut lru = self.lru_cache.lock().map_err(|e| {
StreamingError::concurrency(format!(
"Failed to acquire lock for lru_cache: {e}"
))
})?;
lru.put(cache_key, entry);
}
}
self.current_cache_size.fetch_add(file_size, Ordering::Relaxed);
self.current_entries.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn update_access_time(&self, cache_key: &str) -> Result<()> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
cfg_if::cfg_if! {
if #[cfg(any(feature = "streaming-tokio", feature = "streaming-smol"))] {
let mut lru = self.lru_cache.lock().await;
if let Some(entry) = lru.get_mut(cache_key) {
entry.last_accessed = now;
}
} else {
let mut lru = self.lru_cache.lock().map_err(|e| {
StreamingError::concurrency(format!(
"Failed to acquire lock for lru_cache: {e}"
))
})?;
if let Some(entry) = lru.get_mut(cache_key) {
entry.last_accessed = now;
}
}
}
Ok(())
}
async fn get_lru_entries_for_eviction(
&self,
target_size: u64,
target_count: usize,
) -> Result<Vec<LruEntry>> {
let current_size = self.current_cache_size().await?;
let current_count = self.current_entries().await?;
cfg_if::cfg_if! {
if #[cfg(any(feature = "streaming-tokio", feature = "streaming-smol"))] {
let lru = self.lru_cache.lock().await;
let mut entries_to_evict = Vec::new();
let mut size_to_free = current_size.saturating_sub(target_size);
let mut entries_to_free = current_count.saturating_sub(target_count);
for (_, entry) in lru.iter().rev() {
if size_to_free == 0 && entries_to_free == 0 {
break;
}
entries_to_evict.push(entry.clone());
if size_to_free > 0 {
size_to_free = size_to_free.saturating_sub(entry.file_size);
}
entries_to_free = entries_to_free.saturating_sub(1);
}
Ok(entries_to_evict)
} else {
let lru = self.lru_cache.lock().map_err(|e| {
StreamingError::concurrency(format!(
"Failed to acquire lock for lru_cache: {e}"
))
})?;
let mut entries_to_evict = Vec::new();
let mut size_to_free = current_size.saturating_sub(target_size);
let mut entries_to_free = current_count.saturating_sub(target_count);
for (_, entry) in lru.iter().rev() {
if size_to_free == 0 && entries_to_free == 0 {
break;
}
entries_to_evict.push(entry.clone());
if size_to_free > 0 {
size_to_free = size_to_free.saturating_sub(entry.file_size);
}
entries_to_free = entries_to_free.saturating_sub(1);
}
Ok(entries_to_evict)
}
}
}
async fn remove_cache_entry(
&self,
cache_key: &str,
) -> Result<Option<LruEntry>> {
cfg_if::cfg_if! {
if #[cfg(any(feature = "streaming-tokio", feature = "streaming-smol"))] {
let mut lru = self.lru_cache.lock().await;
if let Some(entry) = lru.pop(cache_key) {
self.current_cache_size.fetch_sub(entry.file_size, Ordering::Relaxed);
self.current_entries.fetch_sub(1, Ordering::Relaxed);
return Ok(Some(entry));
}
} else {
let mut lru = self.lru_cache.lock().map_err(|e| {
StreamingError::concurrency(format!(
"Failed to acquire lock for lru_cache: {e}"
))
})?;
if let Some(entry) = lru.pop(cache_key) {
self.current_cache_size.fetch_sub(entry.file_size, Ordering::Relaxed);
self.current_entries.fetch_sub(1, Ordering::Relaxed);
return Ok(Some(entry));
}
}
}
Ok(None)
}
async fn rollback_cache_entry(
&self,
cache_key: &str,
exact_file_size: u64,
) -> Result<()> {
cfg_if::cfg_if! {
if #[cfg(any(feature = "streaming-tokio", feature = "streaming-smol"))] {
let mut lru = self.lru_cache.lock().await;
if lru.pop(cache_key).is_some() {
self.current_cache_size.fetch_sub(exact_file_size, Ordering::Relaxed);
self.current_entries.fetch_sub(1, Ordering::Relaxed);
}
} else {
let mut lru = self.lru_cache.lock().map_err(|e| {
StreamingError::concurrency(format!(
"Failed to acquire lock for lru_cache: {e}"
))
})?;
if lru.pop(cache_key).is_some() {
self.current_cache_size.fetch_sub(exact_file_size, Ordering::Relaxed);
self.current_entries.fetch_sub(1, Ordering::Relaxed);
}
}
}
Ok(())
}
async fn add_ref(&self, digest: &str) -> Result<usize> {
let counter = self
.refs
.entry(digest.to_string())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
Ok(counter.fetch_add(1, Ordering::Relaxed) + 1)
}
async fn remove_ref(&self, digest: &str) -> Result<bool> {
if let Some(entry) = self.refs.get_mut(digest) {
let current_count = entry.load(Ordering::Relaxed);
if current_count > 0 {
entry.fetch_sub(1, Ordering::Relaxed);
let new_count = entry.load(Ordering::Relaxed);
if new_count == 0 {
drop(entry); self.refs.remove(digest);
return Ok(true);
}
}
}
Ok(false)
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub current_size: u64,
pub current_entries: usize,
pub max_size: Option<u64>,
pub max_entries: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct CorruptedEntry {
pub cache_key: String,
pub digest: String,
pub reason: String,
}
#[derive(Debug, Clone)]
pub struct CacheIntegrityReport {
pub total_entries: usize,
pub valid_entries: usize,
pub corrupted_entries: Vec<CorruptedEntry>,
pub orphaned_content: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheMetadata {
pub status: u16,
pub version: u8,
pub headers: HashMap<String, String>,
pub content_digest: String,
pub policy: CachePolicy,
pub created_at: u64,
}
#[derive(Debug)]
pub struct StreamingManager {
root_path: PathBuf,
ref_counter: ContentRefCounter,
config: StreamingCacheConfig,
}
impl Clone for StreamingManager {
fn clone(&self) -> Self {
Self {
root_path: self.root_path.clone(),
ref_counter: ContentRefCounter {
refs: self.ref_counter.refs.clone(),
lru_cache: self.ref_counter.lru_cache.clone(),
current_cache_size: AtomicU64::new(
self.ref_counter.current_cache_size.load(Ordering::Relaxed),
),
current_entries: AtomicUsize::new(
self.ref_counter.current_entries.load(Ordering::Relaxed),
),
},
config: self.config,
}
}
}
impl StreamingManager {
pub fn new(root_path: PathBuf) -> Self {
Self::new_with_config(root_path, StreamingCacheConfig::default())
}
pub fn new_with_config(
root_path: PathBuf,
config: StreamingCacheConfig,
) -> Self {
Self { root_path, ref_counter: ContentRefCounter::new(), config }
}
pub async fn new_with_existing_cache(root_path: PathBuf) -> Result<Self> {
Self::new_with_existing_cache_and_config(
root_path,
StreamingCacheConfig::default(),
)
.await
}
pub async fn new_with_existing_cache_and_config(
root_path: PathBuf,
config: StreamingCacheConfig,
) -> Result<Self> {
let manager = Self::new_with_config(root_path, config);
let current_entries = manager.ref_counter.current_entries().await?;
if current_entries == 0 {
manager.rebuild_reference_counts().await?;
}
Ok(manager)
}
pub async fn verify_content_integrity(
&self,
digest: &str,
content_path: &Path,
) -> Result<bool> {
if !content_path.exists() {
return Ok(false);
}
let content =
runtime::read(content_path).await.map_err(StreamingError::io)?;
let computed_digest = Self::calculate_digest(&content);
Ok(computed_digest == digest)
}
async fn verify_content_integrity_streaming(
&self,
digest: &str,
content_path: &Path,
) -> Result<bool> {
if !content_path.exists() {
return Ok(false);
}
let file_size = match runtime::metadata(content_path).await {
Ok(meta) => meta.len(),
Err(_) => return Ok(false), };
let computed_digest =
if file_size > self.config.streaming_buffer_size as u64 {
self.compute_digest_streaming(content_path).await?
} else {
let content = runtime::read(content_path)
.await
.map_err(StreamingError::io)?;
Self::calculate_digest(&content)
};
Ok(computed_digest == digest)
}
async fn compute_digest_streaming(
&self,
file_path: &Path,
) -> Result<String> {
let file =
runtime::File::open(file_path).await.map_err(StreamingError::io)?;
let mut hasher = blake3::Hasher::new();
let mut buffer = vec![0u8; self.config.streaming_buffer_size];
cfg_if::cfg_if! {
if #[cfg(feature = "streaming-tokio")] {
use tokio::io::AsyncReadExt;
let mut file = file;
loop {
let bytes_read = file.read(&mut buffer).await.map_err(StreamingError::io)?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
} else if #[cfg(feature = "streaming-smol")] {
use smol::io::AsyncReadExt;
let mut file = file;
loop {
let bytes_read = file.read(&mut buffer).await.map_err(StreamingError::io)?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
}
}
Ok(hasher.finalize().to_hex().to_string())
}
pub async fn verify_cache_integrity(&self) -> Result<CacheIntegrityReport> {
let mut report = CacheIntegrityReport {
total_entries: 0,
valid_entries: 0,
corrupted_entries: Vec::new(),
orphaned_content: Vec::new(),
};
let metadata_dir = self.root_path.join(CACHE_VERSION).join("metadata");
let content_dir = self.root_path.join(CACHE_VERSION).join("content");
if !metadata_dir.exists() {
return Ok(report);
}
let mut referenced_digests = std::collections::HashSet::new();
cfg_if::cfg_if! {
if #[cfg(feature = "streaming-tokio")] {
let mut entries = runtime::read_dir(&metadata_dir).await.map_err(StreamingError::io)?;
while let Some(entry) = entries.next_entry().await.map_err(StreamingError::io)? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
report.total_entries += 1;
let content = runtime::read(&path).await.map_err(StreamingError::io)?;
match serde_json::from_slice::<CacheMetadata>(&content) {
Ok(metadata) => {
referenced_digests.insert(metadata.content_digest.clone());
let content_path = self.content_path(&metadata.content_digest);
match self.verify_content_integrity_streaming(&metadata.content_digest, &content_path).await {
Ok(true) => report.valid_entries += 1,
Ok(false) => {
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
report.corrupted_entries.push(CorruptedEntry {
cache_key,
digest: metadata.content_digest.clone(),
reason: "Content digest mismatch".to_string(),
});
}
Err(e) => {
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
report.corrupted_entries.push(CorruptedEntry {
cache_key,
digest: metadata.content_digest.clone(),
reason: format!("Verification error: {}", e),
});
}
}
}
Err(e) => {
report.corrupted_entries.push(CorruptedEntry {
cache_key: "unknown".to_string(),
digest: "unknown".to_string(),
reason: format!("Invalid metadata: {}", e),
});
}
}
}
}
} else if #[cfg(feature = "streaming-smol")] {
use futures::stream::StreamExt;
let mut entries = runtime::read_dir(&metadata_dir).await.map_err(StreamingError::io)?;
while let Some(entry_result) = entries.next().await {
let entry = entry_result.map_err(StreamingError::io)?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
report.total_entries += 1;
let content = runtime::read(&path).await.map_err(StreamingError::io)?;
match serde_json::from_slice::<CacheMetadata>(&content) {
Ok(metadata) => {
referenced_digests.insert(metadata.content_digest.clone());
let content_path = self.content_path(&metadata.content_digest);
match self.verify_content_integrity_streaming(&metadata.content_digest, &content_path).await {
Ok(true) => report.valid_entries += 1,
Ok(false) => {
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
report.corrupted_entries.push(CorruptedEntry {
cache_key,
digest: metadata.content_digest.clone(),
reason: "Content digest mismatch".to_string(),
});
}
Err(e) => {
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
report.corrupted_entries.push(CorruptedEntry {
cache_key,
digest: metadata.content_digest.clone(),
reason: format!("Verification error: {}", e),
});
}
}
}
Err(e) => {
report.corrupted_entries.push(CorruptedEntry {
cache_key: "unknown".to_string(),
digest: "unknown".to_string(),
reason: format!("Invalid metadata: {}", e),
});
}
}
}
}
}
}
if content_dir.exists() {
cfg_if::cfg_if! {
if #[cfg(feature = "streaming-tokio")] {
let mut content_entries = runtime::read_dir(&content_dir).await.map_err(StreamingError::io)?;
while let Some(entry) = content_entries.next_entry().await.map_err(StreamingError::io)? {
let path = entry.path();
if let Some(filename) = path.file_name().and_then(|s| s.to_str()) {
if !referenced_digests.contains(filename) {
report.orphaned_content.push(filename.to_string());
}
}
}
} else if #[cfg(feature = "streaming-smol")] {
use futures::stream::StreamExt;
let mut content_entries = runtime::read_dir(&content_dir).await.map_err(StreamingError::io)?;
while let Some(entry_result) = content_entries.next().await {
let entry = entry_result.map_err(StreamingError::io)?;
let path = entry.path();
if let Some(filename) = path.file_name().and_then(|s| s.to_str()) {
if !referenced_digests.contains(filename) {
report.orphaned_content.push(filename.to_string());
}
}
}
}
}
}
Ok(report)
}
pub async fn remove_corrupted_entries(
&self,
corrupted_digests: &[String],
) -> Result<usize> {
let mut removed_count = 0;
for digest in corrupted_digests {
let content_path = self.content_path(digest);
if content_path.exists() {
if let Err(e) = runtime::remove_file(&content_path).await {
warn!(
"Failed to remove corrupted content file {}: {}",
digest, e
);
} else {
removed_count += 1;
}
}
}
Ok(removed_count)
}
pub async fn cache_stats(&self) -> Result<CacheStats> {
Ok(CacheStats {
current_size: self.ref_counter.current_cache_size().await?,
current_entries: self.ref_counter.current_entries().await?,
max_size: self.config.max_cache_size,
max_entries: self.config.max_entries,
})
}
async fn enforce_cache_limits(&self) -> Result<()> {
let current_size = self.ref_counter.current_cache_size().await?;
let current_entries = self.ref_counter.current_entries().await?;
let target_size = self.config.max_cache_size.unwrap_or(u64::MAX);
let target_count = self.config.max_entries.unwrap_or(usize::MAX);
if current_size <= target_size && current_entries <= target_count {
return Ok(());
}
let entries_to_evict = self
.ref_counter
.get_lru_entries_for_eviction(target_size, target_count)
.await?;
for entry in entries_to_evict {
if let Err(e) = self.delete(&entry.cache_key).await {
warn!(
"Warning: Failed to evict cache entry '{}': {}",
entry.cache_key, e
);
}
}
Ok(())
}
fn metadata_path(&self, key: &str) -> PathBuf {
let encoded_key = hex::encode(key.as_bytes());
self.root_path
.join(CACHE_VERSION)
.join("metadata")
.join(format!("{encoded_key}.json"))
}
fn content_path(&self, digest: &str) -> PathBuf {
self.root_path.join(CACHE_VERSION).join("content").join(digest)
}
fn calculate_digest(content: &[u8]) -> String {
blake3::hash(content).to_hex().to_string()
}
async fn ensure_dir_exists(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
runtime::create_dir_all(parent)
.await
.map_err(StreamingError::io)?;
}
Ok(())
}
async fn atomic_write(path: &Path, content: &[u8]) -> Result<()> {
use std::ffi::OsString;
let mut temp_path = path.to_path_buf();
let mut temp_name = temp_path
.file_name()
.map(|n| n.to_os_string())
.unwrap_or_else(|| OsString::from("temp"));
temp_name.push(".tmp");
let pid = std::process::id();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
temp_name.push(format!(".{}.{}", pid, timestamp));
temp_path.set_file_name(&temp_name);
Self::ensure_dir_exists(&temp_path).await?;
runtime::write(&temp_path, content)
.await
.map_err(StreamingError::io)?;
if let Err(e) = runtime::rename(&temp_path, path).await {
if runtime::metadata(path).await.is_ok() {
let _ = runtime::remove_file(&temp_path).await;
return Ok(());
}
let _ = runtime::remove_file(&temp_path).await;
return Err(StreamingError::io(format!(
"Failed to atomically write file {:?}: {}",
path, e
))
.into());
}
Ok(())
}
async fn rebuild_reference_counts(&self) -> Result<()> {
let metadata_dir = self.root_path.join(CACHE_VERSION).join("metadata");
if !metadata_dir.exists() {
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(feature = "streaming-tokio")] {
let mut entries = runtime::read_dir(&metadata_dir).await.map_err(StreamingError::io)?;
while let Some(entry) = entries.next_entry().await.map_err(StreamingError::io)? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
let content = runtime::read(&path).await.map_err(StreamingError::io)?;
match serde_json::from_slice::<CacheMetadata>(&content) {
Ok(metadata) => {
if let Err(e) = self.ref_counter.add_ref(&metadata.content_digest).await {
return Err(StreamingError::consistency(format!("Failed to rebuild reference count for {}: {}", metadata.content_digest, e)).into());
}
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
let content_path = self.content_path(&metadata.content_digest);
let file_size = if let Ok(meta) = runtime::metadata(&content_path).await {
meta.len()
} else {
0
};
if let Err(e) = self.ref_counter.add_cache_entry(
cache_key,
metadata.content_digest.clone(),
file_size
).await {
return Err(StreamingError::consistency(format!("Failed to rebuild LRU tracking for {}: {}", metadata.content_digest, e)).into());
}
}
Err(e) => {
return Err(StreamingError::serialization(format!("Failed to parse metadata file {:?}: {}", path, e)).into());
}
}
}
}
} else if #[cfg(feature = "streaming-smol")] {
use futures::stream::StreamExt;
let mut entries = runtime::read_dir(&metadata_dir).await.map_err(StreamingError::io)?;
while let Some(entry_result) = entries.next().await {
let entry = entry_result.map_err(StreamingError::io)?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
let content = runtime::read(&path).await.map_err(StreamingError::io)?;
match serde_json::from_slice::<CacheMetadata>(&content) {
Ok(metadata) => {
if let Err(e) = self.ref_counter.add_ref(&metadata.content_digest).await {
return Err(StreamingError::consistency(format!("Failed to rebuild reference count for {}: {}", metadata.content_digest, e)).into());
}
let cache_key = path.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| hex::decode(s).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_else(|| format!("unknown-{}", metadata.content_digest));
let content_path = self.content_path(&metadata.content_digest);
let file_size = if let Ok(meta) = runtime::metadata(&content_path).await {
meta.len()
} else {
0
};
if let Err(e) = self.ref_counter.add_cache_entry(
cache_key,
metadata.content_digest.clone(),
file_size
).await {
return Err(StreamingError::consistency(format!("Failed to rebuild LRU tracking for {}: {}", metadata.content_digest, e)).into());
}
}
Err(e) => {
return Err(StreamingError::serialization(format!("Failed to parse metadata file {:?}: {}", path, e)).into());
}
}
}
}
}
}
Ok(())
}
async fn process_body_streaming<B>(
&self,
body: B,
) -> Result<(String, Bytes, u64)>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
{
use http_body_util::BodyExt;
let collected =
body.collect().await.map_err(|e| StreamingError::new(e.into()))?;
let body_bytes = collected.to_bytes();
let content_digest = Self::calculate_digest(&body_bytes);
let file_size = body_bytes.len() as u64;
Ok((content_digest, body_bytes, file_size))
}
}
#[async_trait]
impl StreamingCacheManager for StreamingManager {
type Body = StreamingBody<Empty<Bytes>>;
async fn get(
&self,
cache_key: &str,
) -> Result<Option<(Response<Self::Body>, CachePolicy)>>
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
let metadata_path = self.metadata_path(cache_key);
if !metadata_path.exists() {
return Ok(None);
}
let _ = self.ref_counter.update_access_time(cache_key).await;
let metadata_content =
runtime::read(&metadata_path).await.map_err(StreamingError::io)?;
let metadata: CacheMetadata = serde_json::from_slice(&metadata_content)
.map_err(StreamingError::serialization)?;
let content_path = self.content_path(&metadata.content_digest);
let file = match runtime::File::open(&content_path).await {
Ok(file) => file,
Err(_) => return Ok(None), };
let mut response_builder = Response::builder()
.status(metadata.status)
.version(match metadata.version {
9 => Version::HTTP_09,
10 => Version::HTTP_10,
11 => Version::HTTP_11,
2 => Version::HTTP_2,
3 => Version::HTTP_3,
_ => Version::HTTP_11,
});
for (name, value) in &metadata.headers {
if let (Ok(header_name), Ok(header_value)) = (
name.parse::<http::HeaderName>(),
value.parse::<http::HeaderValue>(),
) {
response_builder =
response_builder.header(header_name, header_value);
}
}
let body = StreamingBody::from_file(file);
let response =
response_builder.body(body).map_err(StreamingError::new)?;
Ok(Some((response, metadata.policy)))
}
async fn put<B>(
&self,
cache_key: String,
response: Response<B>,
policy: CachePolicy,
_request_url: Url,
) -> Result<Response<Self::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
let (parts, body) = response.into_parts();
let (content_digest, body_bytes, file_size) =
self.process_body_streaming(body).await?;
let content_path = self.content_path(&content_digest);
if runtime::metadata(&content_path).await.is_err() {
Self::atomic_write(&content_path, &body_bytes).await?;
}
let _ref_count = self.ref_counter.add_ref(&content_digest).await?;
if runtime::metadata(&content_path).await.is_err() {
self.ref_counter.remove_ref(&content_digest).await?;
return Err(StreamingError::consistency(
"Content file was deleted during cache operation - possible race condition".to_string()
).into());
}
self.ref_counter
.add_cache_entry(
cache_key.clone(),
content_digest.clone(),
file_size,
)
.await?;
self.enforce_cache_limits().await?;
let metadata = CacheMetadata {
status: parts.status.as_u16(),
version: match parts.version {
Version::HTTP_09 => 9,
Version::HTTP_10 => 10,
Version::HTTP_11 => 11,
Version::HTTP_2 => 2,
Version::HTTP_3 => 3,
_ => 11,
},
headers: crate::HttpCacheOptions::headers_to_hashmap(
&parts.headers,
),
content_digest: content_digest.clone(),
policy,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let metadata_path = self.metadata_path(&cache_key);
let metadata_json = serde_json::to_vec(&metadata)
.map_err(StreamingError::serialization)?;
if let Err(e) = Self::atomic_write(&metadata_path, &metadata_json).await
{
let content_removed = self
.ref_counter
.remove_ref(&content_digest)
.await
.unwrap_or(false);
let _ = self
.ref_counter
.rollback_cache_entry(&cache_key, file_size)
.await;
if content_removed {
let _ = runtime::remove_file(&content_path).await;
}
return Err(e);
}
let response =
Response::from_parts(parts, StreamingBody::buffered(body_bytes));
Ok(response)
}
async fn convert_body<B>(
&self,
response: Response<B>,
) -> Result<Response<Self::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
let (parts, body) = response.into_parts();
let collected =
body.collect().await.map_err(|e| StreamingError::new(e.into()))?;
let body_bytes = collected.to_bytes();
let streaming_body = StreamingBody::buffered(body_bytes);
Ok(Response::from_parts(parts, streaming_body))
}
async fn delete(&self, cache_key: &str) -> Result<()> {
let metadata_path = self.metadata_path(cache_key);
let metadata_content =
runtime::read(&metadata_path).await.map_err(StreamingError::io)?;
let metadata: CacheMetadata = serde_json::from_slice(&metadata_content)
.map_err(StreamingError::serialization)?;
let can_delete_content = {
let would_be_orphaned =
self.ref_counter.remove_ref(&metadata.content_digest).await?;
if would_be_orphaned {
self.ref_counter.add_ref(&metadata.content_digest).await?;
true
} else {
false
}
};
if can_delete_content {
let content_path = self.content_path(&metadata.content_digest);
if content_path.exists() {
match runtime::File::open(&content_path).await {
Ok(_) => {} Err(e) => {
return Err(StreamingError::io(format!(
"Cannot delete content file {:?} (may be locked): {}",
content_path, e
)).into());
}
}
}
}
self.ref_counter.remove_cache_entry(cache_key).await?;
let should_delete_content = if can_delete_content {
self.ref_counter.remove_ref(&metadata.content_digest).await?
} else {
false
};
if should_delete_content {
let content_path = self.content_path(&metadata.content_digest);
if let Err(e) = runtime::remove_file(&content_path).await {
self.ref_counter.add_ref(&metadata.content_digest).await?;
self.ref_counter
.add_cache_entry(
cache_key.to_string(),
metadata.content_digest.clone(),
0,
)
.await?;
return Err(StreamingError::io(format!(
"Failed to remove content file {:?}: {}",
content_path, e
))
.into());
}
}
if let Err(e) = runtime::remove_file(&metadata_path).await {
return Err(StreamingError::io(format!(
"Warning: content deleted but metadata removal failed for {:?}: {}",
metadata_path, e
)).into());
}
Ok(())
}
#[cfg(feature = "streaming")]
fn body_to_bytes_stream(
body: Self::Body,
) -> impl futures_util::Stream<
Item = std::result::Result<
Bytes,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error: Send + Sync + 'static,
{
body.into_bytes_stream()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::StreamingCacheManager as StreamingCacheManagerTrait;
use http_body_util::Full;
use tempfile::TempDir;
#[tokio::test]
async fn test_streaming_cache_put_get() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let original_body = Full::new(Bytes::from("test response body"));
let response = Response::builder()
.status(200)
.header("content-type", "text/plain")
.body(original_body)
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/test").unwrap();
let cached_response = cache
.put("test-key".to_string(), response, policy.clone(), request_url)
.await
.unwrap();
assert_eq!(cached_response.status(), 200);
let retrieved = cache.get("test-key").await.unwrap();
assert!(retrieved.is_some());
let (cached_response, cached_policy) = retrieved.unwrap();
assert_eq!(cached_response.status(), 200);
assert_eq!(
cached_response.headers().get("content-type").unwrap(),
"text/plain"
);
let now = std::time::SystemTime::now();
assert_eq!(cached_policy.time_to_live(now), policy.time_to_live(now));
}
#[tokio::test]
async fn test_streaming_cache_delete() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let original_body = Full::new(Bytes::from("test response body"));
let response = Response::builder()
.status(200)
.header("content-type", "text/plain")
.body(original_body)
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/test").unwrap();
let cache_key = "test-key-delete";
cache
.put(cache_key.to_string(), response, policy, request_url)
.await
.unwrap();
let retrieved = cache.get(cache_key).await.unwrap();
assert!(retrieved.is_some());
cache.delete(cache_key).await.unwrap();
let retrieved = cache.get(cache_key).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_content_deduplication() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let identical_content = Bytes::from("identical response body content");
let request_url = Url::parse("http://example.com/test").unwrap();
let response1 = Response::builder()
.status(200)
.header("cache-control", "max-age=3600")
.body(Full::new(identical_content.clone()))
.unwrap();
let response2 = Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Full::new(identical_content.clone()))
.unwrap();
let policy1 = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test1")
.body(())
.unwrap()
.into_parts()
.0,
&response1.clone().map(|_| ()),
);
let policy2 = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test2")
.body(())
.unwrap()
.into_parts()
.0,
&response2.clone().map(|_| ()),
);
cache
.put("key1".to_string(), response1, policy1, request_url.clone())
.await
.unwrap();
cache
.put("key2".to_string(), response2, policy2, request_url)
.await
.unwrap();
let retrieved1 = cache.get("key1").await.unwrap().unwrap();
let retrieved2 = cache.get("key2").await.unwrap().unwrap();
assert_eq!(retrieved1.0.status(), 200);
assert_eq!(retrieved2.0.status(), 200);
let content_digest1 =
StreamingManager::calculate_digest(&identical_content);
let content_path1 = cache.content_path(&content_digest1);
assert!(content_path1.exists());
let content_dir = temp_dir.path().join(CACHE_VERSION).join("content");
let mut content_file_count = 0;
if content_dir.exists() {
for entry in std::fs::read_dir(&content_dir).unwrap() {
let entry = entry.unwrap();
if entry.file_type().unwrap().is_file() {
content_file_count += 1;
}
}
}
assert_eq!(
content_file_count, 1,
"Should have only one content file due to deduplication"
);
cache.delete("key1").await.unwrap();
assert!(
content_path1.exists(),
"Content file should still exist after deleting one reference"
);
let retrieved2_again = cache.get("key2").await.unwrap().unwrap();
assert_eq!(retrieved2_again.0.status(), 200);
cache.delete("key2").await.unwrap();
assert!(
!content_path1.exists(),
"Content file should be deleted when no references remain"
);
}
#[tokio::test]
async fn test_reference_count_persistence() {
let temp_dir = TempDir::new().unwrap();
let cache_path = temp_dir.path().to_path_buf();
let identical_content =
Bytes::from("persistent reference test content");
let request_url = Url::parse("http://example.com/test").unwrap();
{
let cache = StreamingManager::new(cache_path.clone());
let response1 = Response::builder()
.status(200)
.body(Full::new(identical_content.clone()))
.unwrap();
let response2 = Response::builder()
.status(404)
.body(Full::new(identical_content.clone()))
.unwrap();
let policy1 = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test1")
.body(())
.unwrap()
.into_parts()
.0,
&response1.clone().map(|_| ()),
);
let policy2 = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test2")
.body(())
.unwrap()
.into_parts()
.0,
&response2.clone().map(|_| ()),
);
cache
.put(
"persistent-key1".to_string(),
response1,
policy1,
request_url.clone(),
)
.await
.unwrap();
cache
.put(
"persistent-key2".to_string(),
response2,
policy2,
request_url.clone(),
)
.await
.unwrap();
let content_digest =
StreamingManager::calculate_digest(&identical_content);
let content_path = cache.content_path(&content_digest);
assert!(content_path.exists(), "Content file should exist");
}
{
let cache =
StreamingManager::new_with_existing_cache(cache_path.clone())
.await
.unwrap();
let retrieved1 = cache.get("persistent-key1").await.unwrap();
let retrieved2 = cache.get("persistent-key2").await.unwrap();
assert!(retrieved1.is_some());
assert!(retrieved2.is_some());
cache.delete("persistent-key1").await.unwrap();
let content_digest =
StreamingManager::calculate_digest(&identical_content);
let content_path = cache.content_path(&content_digest);
assert!(
content_path.exists(),
"Content file should still exist after one deletion"
);
let retrieved2_again = cache.get("persistent-key2").await.unwrap();
assert!(retrieved2_again.is_some());
cache.delete("persistent-key2").await.unwrap();
assert!(
!content_path.exists(),
"Content file should be deleted when all references removed"
);
}
}
#[tokio::test]
async fn test_concurrent_reference_counting() {
use std::sync::Arc;
use tokio::task;
let temp_dir = TempDir::new().unwrap();
let cache =
Arc::new(StreamingManager::new(temp_dir.path().to_path_buf()));
let request_url = Url::parse("http://example.com/test").unwrap();
let shared_content = Bytes::from("concurrent test content");
let tasks_count = 10;
let mut handles = Vec::new();
for i in 0..tasks_count {
let cache = Arc::clone(&cache);
let content = shared_content.clone();
let url = request_url.clone();
let handle = task::spawn(async move {
let response = Response::builder()
.status(200)
.header("x-task-id", i.to_string())
.body(Full::new(content))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri(format!("/concurrent-test-{}", i))
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
cache
.put(format!("concurrent-key-{}", i), response, policy, url)
.await
.unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
for i in 0..tasks_count {
let retrieved =
cache.get(&format!("concurrent-key-{}", i)).await.unwrap();
assert!(retrieved.is_some(), "Entry {} should exist", i);
}
let content_digest =
StreamingManager::calculate_digest(&shared_content);
let content_path = cache.content_path(&content_digest);
assert!(content_path.exists(), "Shared content file should exist");
let mut delete_handles = Vec::new();
for i in 0..tasks_count / 2 {
let cache = Arc::clone(&cache);
let handle = task::spawn(async move {
cache.delete(&format!("concurrent-key-{}", i)).await.unwrap();
});
delete_handles.push(handle);
}
for handle in delete_handles {
handle.await.unwrap();
}
assert!(content_path.exists(), "Content file should still exist");
let mut final_delete_handles = Vec::new();
for i in tasks_count / 2..tasks_count {
let cache = Arc::clone(&cache);
let handle = task::spawn(async move {
cache.delete(&format!("concurrent-key-{}", i)).await.unwrap();
});
final_delete_handles.push(handle);
}
for handle in final_delete_handles {
handle.await.unwrap();
}
assert!(
!content_path.exists(),
"Content file should be deleted when all references removed"
);
}
#[tokio::test]
async fn test_large_content_streaming() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let large_content = vec![b'X'; 1024 * 1024];
let large_bytes = Bytes::from(large_content);
let response = Response::builder()
.status(200)
.header("content-type", "application/octet-stream")
.header("content-length", large_bytes.len().to_string())
.body(Full::new(large_bytes.clone()))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/large-file")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/large-file").unwrap();
let cached_response = cache
.put("large-key".to_string(), response, policy, request_url)
.await
.unwrap();
assert_eq!(cached_response.status(), 200);
let (retrieved_response, _) =
cache.get("large-key").await.unwrap().unwrap();
assert_eq!(retrieved_response.status(), 200);
assert_eq!(
retrieved_response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
let content_digest = StreamingManager::calculate_digest(&large_bytes);
let content_path = cache.content_path(&content_digest);
assert!(content_path.exists(), "Large content file should exist");
let metadata = std::fs::metadata(&content_path).unwrap();
assert_eq!(
metadata.len(),
1024 * 1024,
"Content file should have correct size"
);
}
#[tokio::test]
async fn test_error_handling() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let result = cache.get("non-existent").await.unwrap();
assert!(result.is_none(), "Should return None for non-existent key");
let result = cache.delete("non-existent").await;
assert!(result.is_err(), "Should error when deleting non-existent key");
let metadata_dir = temp_dir.path().join(CACHE_VERSION).join("metadata");
std::fs::create_dir_all(&metadata_dir).unwrap();
let corrupt_metadata_path = metadata_dir.join("corrupt.json");
std::fs::write(&corrupt_metadata_path, "invalid json").unwrap();
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from("test")))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/test").unwrap();
let result = cache
.put("valid-key".to_string(), response, policy, request_url)
.await;
assert!(result.is_ok(), "Should handle corrupted metadata gracefully");
}
#[tokio::test]
async fn test_content_integrity_validation() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let original_content =
Bytes::from("original content for integrity test");
let response = Response::builder()
.status(200)
.body(Full::new(original_content.clone()))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/integrity-test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url =
Url::parse("http://example.com/integrity-test").unwrap();
cache
.put("integrity-key".to_string(), response, policy, request_url)
.await
.unwrap();
let retrieved = cache.get("integrity-key").await.unwrap();
assert!(retrieved.is_some(), "Content should be retrievable");
let content_digest =
StreamingManager::calculate_digest(&original_content);
let content_path = cache.content_path(&content_digest);
assert!(content_path.exists(), "Content file should exist");
std::fs::write(&content_path, "corrupted content").unwrap();
let retrieved_after_corruption =
cache.get("integrity-key").await.unwrap();
assert!(
retrieved_after_corruption.is_some(),
"Should still return metadata even with corrupted content"
);
}
#[tokio::test]
async fn test_http_version_preservation() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let test_versions = vec![
(Version::HTTP_09, 9u8),
(Version::HTTP_10, 10u8),
(Version::HTTP_11, 11u8),
(Version::HTTP_2, 2u8),
(Version::HTTP_3, 3u8),
];
for (i, (version, expected_stored)) in
test_versions.into_iter().enumerate()
{
let response = Response::builder()
.status(200)
.version(version)
.header("content-type", "text/plain")
.body(Full::new(Bytes::from(format!("version test {}", i))))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri(format!("/version-test-{}", i))
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url =
Url::parse(&format!("http://example.com/version-test-{}", i))
.unwrap();
let cache_key = format!("version-key-{}", i);
cache
.put(cache_key.clone(), response, policy, request_url)
.await
.unwrap();
let (retrieved_response, _) =
cache.get(&cache_key).await.unwrap().unwrap();
assert_eq!(
retrieved_response.version(),
version,
"HTTP version should be preserved for version {:?}",
version
);
let metadata_path = cache.metadata_path(&cache_key);
let metadata_content = std::fs::read(&metadata_path).unwrap();
let metadata: CacheMetadata =
serde_json::from_slice(&metadata_content).unwrap();
assert_eq!(
metadata.version, expected_stored,
"Stored version should match expected for {:?}",
version
);
}
}
#[tokio::test]
async fn test_header_edge_cases() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let response = Response::builder()
.status(200)
.header("content-type", "application/json; charset=utf-8")
.header("cache-control", "max-age=3600, public")
.header("custom-header", "custom-value")
.header("empty-header", "")
.header("unicode-header", "test-ñ-value")
.header("multiple-values", "value1")
.header("multiple-values", "value2") .body(Full::new(Bytes::from(r#"{"test": "json"}"#)))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/header-test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/header-test").unwrap();
cache
.put(
"header-key".to_string(),
response.clone(),
policy,
request_url,
)
.await
.unwrap();
let (retrieved_response, _) =
cache.get("header-key").await.unwrap().unwrap();
assert_eq!(
retrieved_response.headers().get("content-type").unwrap(),
"application/json; charset=utf-8"
);
assert_eq!(
retrieved_response.headers().get("cache-control").unwrap(),
"max-age=3600, public"
);
assert_eq!(
retrieved_response.headers().get("custom-header").unwrap(),
"custom-value"
);
assert_eq!(
retrieved_response.headers().get("empty-header").unwrap(),
""
);
}
#[tokio::test]
async fn test_cache_key_edge_cases() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from("test content")))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let request_url = Url::parse("http://example.com/test").unwrap();
let edge_case_keys = vec![
"simple-key",
"key:with:colons",
"key with spaces",
"key/with/slashes",
"key?with=query¶ms=true",
"key#with-fragment",
"very-long-key-that-exceeds-normal-filename-length-limits-and-should-still-work-properly-without-issues-abcdefghijklmnopqrstuvwxyz",
"unicode-key-ñáéÃóú-test",
"", ];
for (i, key) in edge_case_keys.into_iter().enumerate() {
let test_response = Response::builder()
.status(200)
.header("x-key-index", i.to_string())
.body(Full::new(Bytes::from(format!(
"content for key: {}",
key
))))
.unwrap();
let result = cache
.put(
key.to_string(),
test_response,
policy.clone(),
request_url.clone(),
)
.await;
if key.is_empty() || key.len() > 100 {
continue;
}
assert!(result.is_ok(), "Should handle key: '{}'", key);
let retrieved = cache.get(key).await.unwrap();
assert!(
retrieved.is_some(),
"Should retrieve content for key: '{}'",
key
);
let (retrieved_response, _) = retrieved.unwrap();
assert_eq!(retrieved_response.status(), 200);
assert_eq!(
retrieved_response.headers().get("x-key-index").unwrap(),
&i.to_string()
);
}
}
#[tokio::test]
async fn test_cache_size_limits_and_lru_eviction() {
let temp_dir = TempDir::new().unwrap();
let config = StreamingCacheConfig {
max_cache_size: Some(1000), max_entries: Some(3), ..StreamingCacheConfig::default()
};
let cache = StreamingManager::new_with_config(
temp_dir.path().to_path_buf(),
config,
);
let request_url = Url::parse("http://example.com/test").unwrap();
let entries = vec![
("key1", "first content - should be evicted first", 500), ("key2", "second content - larger", 600), ("key3", "third content - should remain", 400), ];
for (key, content, _size) in &entries {
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from(*content)))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri(format!("/{}", key))
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
cache
.put(key.to_string(), response, policy, request_url.clone())
.await
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
}
let _retrieved = cache.get("key2").await.unwrap();
let large_content = "x".repeat(700); let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from(large_content)))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/key4")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
cache
.put("key4".to_string(), response, policy, request_url)
.await
.unwrap();
assert!(
cache.get("key1").await.unwrap().is_none(),
"key1 should be evicted"
);
assert!(
cache.get("key2").await.unwrap().is_some(),
"key2 should remain"
);
assert!(
cache.get("key3").await.unwrap().is_some(),
"key3 should remain"
);
assert!(
cache.get("key4").await.unwrap().is_some(),
"key4 should remain"
);
}
#[tokio::test]
async fn test_background_cleanup() {
let temp_dir = TempDir::new().unwrap();
let config = StreamingCacheConfig {
..StreamingCacheConfig::default()
};
let cache = StreamingManager::new_with_config(
temp_dir.path().to_path_buf(),
config,
);
let request_url = Url::parse("http://example.com/test").unwrap();
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from("cleanup test content")))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/cleanup-test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
cache
.put("cleanup-key".to_string(), response, policy, request_url)
.await
.unwrap();
let content_dir = temp_dir.path().join(CACHE_VERSION).join("content");
let orphaned_file = content_dir.join("orphaned_content_file");
std::fs::write(&orphaned_file, "orphaned content").unwrap();
std::thread::sleep(std::time::Duration::from_secs(2));
let _retrieved = cache.get("cleanup-key").await.unwrap();
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(cache.get("cleanup-key").await.unwrap().is_some());
}
#[tokio::test]
async fn test_metadata_write_failure_rollback() {
let temp_dir = TempDir::new().unwrap();
let cache = StreamingManager::new(temp_dir.path().to_path_buf());
let request_url = Url::parse("http://example.com/test").unwrap();
let content = Bytes::from("rollback test content");
let response = Response::builder()
.status(200)
.body(Full::new(content.clone()))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/rollback-test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let very_long_key = "a".repeat(300);
let result = cache
.put(very_long_key.clone(), response, policy, request_url)
.await;
assert!(result.is_err(), "Put should fail when metadata write fails");
let retrieved = cache.get(&very_long_key).await.unwrap();
assert!(retrieved.is_none(), "Entry should not exist after rollback");
let content_digest = StreamingManager::calculate_digest(&content);
let content_path = cache.content_path(&content_digest);
let content_exists = runtime::metadata(&content_path).await.is_ok();
if content_exists {
println!(
"Content file exists but cache entry was properly rolled back"
);
}
let content_digest = StreamingManager::calculate_digest(&content);
let _content_path = cache.content_path(&content_digest);
let retrieved = cache.get("rollback-key").await.unwrap();
assert!(retrieved.is_none(), "Entry should not exist after rollback");
}
#[tokio::test]
async fn test_atomic_operations_under_stress() {
let temp_dir = TempDir::new().unwrap();
let cache =
Arc::new(StreamingManager::new(temp_dir.path().to_path_buf()));
let request_url = Url::parse("http://example.com/test").unwrap();
let tasks_count = 20;
let mut handles = Vec::new();
for i in 0..tasks_count {
let cache = Arc::clone(&cache);
let url = request_url.clone();
let handle = tokio::task::spawn(async move {
for j in 0..10 {
let key = format!("stress-key-{}-{}", i, j);
let content = format!("stress test content {} {}", i, j);
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from(content)))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri(format!("/stress-{}-{}", i, j))
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let put_result = cache
.put(key.clone(), response, policy, url.clone())
.await;
assert!(
put_result.is_ok(),
"Put should succeed under stress"
);
let get_result = cache.get(&key).await.unwrap();
assert!(
get_result.is_some(),
"Get should succeed after put"
);
let delete_result = cache.delete(&key).await;
assert!(delete_result.is_ok(), "Delete should succeed");
let final_get = cache.get(&key).await.unwrap();
assert!(
final_get.is_none(),
"Entry should be gone after delete"
);
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let content_dir = temp_dir.path().join(CACHE_VERSION).join("content");
let metadata_dir = temp_dir.path().join(CACHE_VERSION).join("metadata");
let content_count = if content_dir.exists() {
std::fs::read_dir(&content_dir).unwrap().count()
} else {
0
};
let metadata_count = if metadata_dir.exists() {
std::fs::read_dir(&metadata_dir).unwrap().count()
} else {
0
};
assert!(
content_count <= 5,
"Should have minimal content files after stress test"
);
assert!(
metadata_count <= 5,
"Should have minimal metadata files after stress test"
);
}
#[tokio::test]
async fn test_config_validation() {
let temp_dir = TempDir::new().unwrap();
let config = StreamingCacheConfig {
max_cache_size: Some(0), max_entries: Some(0), streaming_buffer_size: 1, };
let cache = StreamingManager::new_with_config(
temp_dir.path().to_path_buf(),
config,
);
let request_url = Url::parse("http://example.com/test").unwrap();
let response = Response::builder()
.status(200)
.body(Full::new(Bytes::from("config test")))
.unwrap();
let policy = CachePolicy::new(
&http::request::Request::builder()
.method("GET")
.uri("/config-test")
.body(())
.unwrap()
.into_parts()
.0,
&response.clone().map(|_| ()),
);
let _result = cache
.put("config-key".to_string(), response, policy, request_url)
.await;
let _get_result = cache.get("config-key").await;
}
}