#![cfg(feature = "streaming")]
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use http_body_util::{BodyExt, Full};
use http_cache::{CACacheManager, StreamingManager};
use http_cache_tower::{HttpCacheLayer, HttpCacheStreamingLayer};
use std::alloc::{GlobalAlloc, Layout, System};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tower::{Layer, Service, ServiceExt};
struct MemoryTracker {
allocations: AtomicUsize,
}
impl MemoryTracker {
const fn new() -> Self {
Self { allocations: AtomicUsize::new(0) }
}
fn current_usage(&self) -> usize {
self.allocations.load(Ordering::Relaxed)
}
fn reset(&self) {
self.allocations.store(0, Ordering::Relaxed);
}
}
unsafe impl GlobalAlloc for MemoryTracker {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let ptr = System.alloc(layout);
if !ptr.is_null() {
self.allocations.fetch_add(layout.size(), Ordering::Relaxed);
}
ptr
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
System.dealloc(ptr, layout);
self.allocations.fetch_sub(layout.size(), Ordering::Relaxed);
}
}
#[global_allocator]
static MEMORY_TRACKER: MemoryTracker = MemoryTracker::new();
#[derive(Clone)]
struct LargeResponseService {
size: usize,
}
impl LargeResponseService {
fn new(size: usize) -> Self {
Self { size }
}
}
impl Service<Request<Full<Bytes>>> for LargeResponseService {
type Response = Response<Full<Bytes>>;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = Pin<
Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let size = self.size;
Box::pin(async move {
let data = vec![b'X'; size];
let response = Response::builder()
.status(StatusCode::OK)
.header("cache-control", "max-age=3600, public")
.header("content-type", "application/octet-stream")
.header("content-length", size.to_string())
.body(Full::new(Bytes::from(data)))
.map_err(|e| {
Box::new(e) as Box<dyn std::error::Error + Send + Sync>
})?;
Ok(response)
})
}
}
async fn measure_cache_hit_memory_usage(
payload_size: usize,
is_streaming: bool,
) -> (usize, usize, usize) {
if is_streaming {
let file_cache_manager = StreamingManager::with_temp_dir(1000)
.await
.expect("Failed to create streaming manager");
let streaming_layer = HttpCacheStreamingLayer::new(file_cache_manager);
let service = LargeResponseService::new(payload_size);
let cached_service = streaming_layer.layer(service);
let request1 = Request::builder()
.uri("https://example.com/cache-hit-test")
.body(Full::new(Bytes::new()))
.unwrap();
let _ = cached_service
.clone()
.oneshot(request1)
.await
.unwrap()
.into_body()
.collect()
.await;
MEMORY_TRACKER.reset();
let initial_memory = MEMORY_TRACKER.current_usage();
let request2 = Request::builder()
.uri("https://example.com/cache-hit-test")
.body(Full::new(Bytes::new()))
.unwrap();
let response = cached_service.oneshot(request2).await.unwrap();
let peak_after_response = MEMORY_TRACKER.current_usage();
let body = response.into_body();
let mut peak_during_streaming = peak_after_response;
let mut body_stream = std::pin::pin!(body);
while let Some(frame_result) = body_stream.frame().await {
let frame = frame_result.unwrap();
if let Some(_chunk) = frame.data_ref() {
let current_memory = MEMORY_TRACKER.current_usage();
peak_during_streaming =
peak_during_streaming.max(current_memory);
}
}
let peak_after_consumption = MEMORY_TRACKER.current_usage();
(
peak_after_response - initial_memory,
peak_during_streaming - initial_memory,
peak_after_consumption - initial_memory,
)
} else {
let temp_dir = tempfile::tempdir().unwrap();
let cache_manager =
CACacheManager::new(temp_dir.path().to_path_buf(), false);
let cache_layer = HttpCacheLayer::new(cache_manager);
let service = LargeResponseService::new(payload_size);
let cached_service = cache_layer.layer(service);
let request1 = Request::builder()
.uri("https://example.com/cache-hit-test")
.body(Full::new(Bytes::new()))
.unwrap();
let _ = cached_service
.clone()
.oneshot(request1)
.await
.unwrap()
.into_body()
.collect()
.await;
MEMORY_TRACKER.reset();
let initial_memory = MEMORY_TRACKER.current_usage();
let request2 = Request::builder()
.uri("https://example.com/cache-hit-test")
.body(Full::new(Bytes::new()))
.unwrap();
let response = cached_service.oneshot(request2).await.unwrap();
let peak_after_response = MEMORY_TRACKER.current_usage();
let body = response.into_body();
let mut peak_during_streaming = peak_after_response;
let mut body_stream = std::pin::pin!(body);
while let Some(frame_result) = body_stream.frame().await {
let frame = frame_result.unwrap();
if let Some(_chunk) = frame.data_ref() {
let current_memory = MEMORY_TRACKER.current_usage();
peak_during_streaming =
peak_during_streaming.max(current_memory);
}
}
let peak_after_consumption = MEMORY_TRACKER.current_usage();
(
peak_after_response - initial_memory,
peak_during_streaming - initial_memory,
peak_after_consumption - initial_memory,
)
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Memory Usage Analysis: Buffered vs Streaming Cache");
println!("==================================================");
println!("This analysis measures memory efficiency differences between");
println!("traditional buffered caching and file-based streaming caching.");
println!("Measurements are taken during cache hits to compare memory usage patterns.\n");
let payload_sizes = vec![
100 * 1024, 1024 * 1024, 5 * 1024 * 1024, 10 * 1024 * 1024, ];
let mut overall_buffered_peak = 0;
let mut overall_streaming_peak = 0;
for size in &payload_sizes {
println!("Testing cache hits with {}KB payload:", size / 1024);
println!("{}", "=".repeat(60));
let (buffered_response, buffered_peak, buffered_final) =
measure_cache_hit_memory_usage(*size, false).await;
println!("Buffered Cache Hit ({}KB payload):", size / 1024);
println!(" Response memory delta: {buffered_response} bytes");
println!(" Peak memory delta: {buffered_peak} bytes");
println!(" Final memory delta: {buffered_final} bytes");
let (streaming_response, streaming_peak, streaming_final) =
measure_cache_hit_memory_usage(*size, true).await;
println!("\nStreaming Cache Hit ({}KB payload):", size / 1024);
println!(" Response memory delta: {streaming_response} bytes");
println!(" Peak memory delta: {streaming_peak} bytes");
println!(" Final memory delta: {streaming_final} bytes");
println!("\nCache hit memory comparison:");
if buffered_response > 0 && streaming_response < buffered_response {
let response_savings = ((buffered_response - streaming_response)
as f64
/ buffered_response as f64)
* 100.0;
println!(
" Response memory savings: {response_savings:.1}% ({buffered_response} vs {streaming_response} bytes)"
);
}
if buffered_peak > 0 && streaming_peak < buffered_peak {
let peak_savings = ((buffered_peak - streaming_peak) as f64
/ buffered_peak as f64)
* 100.0;
println!(
" Peak memory savings: {peak_savings:.1}% ({buffered_peak} vs {streaming_peak} bytes)"
);
} else if streaming_peak > buffered_peak {
let peak_increase = ((streaming_peak - buffered_peak) as f64
/ buffered_peak as f64)
* 100.0;
println!(
" Peak memory increase: {peak_increase:.1}% ({buffered_peak} vs {streaming_peak} bytes)"
);
}
if buffered_final > 0 && streaming_final < buffered_final {
let final_savings = ((buffered_final - streaming_final) as f64
/ buffered_final as f64)
* 100.0;
println!(
" Final memory savings: {final_savings:.1}% ({buffered_final} vs {streaming_final} bytes)"
);
}
println!(
" Absolute memory difference: {} bytes",
(buffered_peak as i64 - streaming_peak as i64).abs()
);
overall_buffered_peak = overall_buffered_peak.max(buffered_peak);
overall_streaming_peak = overall_streaming_peak.max(streaming_peak);
println!("\n");
}
println!("Overall Analysis Summary:");
println!("========================");
println!("Max buffered peak memory: {overall_buffered_peak} bytes");
println!("Max streaming peak memory: {overall_streaming_peak} bytes");
if overall_buffered_peak > 0
&& overall_streaming_peak < overall_buffered_peak
{
let overall_savings = ((overall_buffered_peak - overall_streaming_peak)
as f64
/ overall_buffered_peak as f64)
* 100.0;
println!("Overall memory savings: {overall_savings:.1}%");
}
Ok(())
}