use http::{Method, Request, Response, StatusCode};
use http_body_util::Full;
use std::time::Duration;
use tower::{Service, ServiceBuilder, ServiceExt};
use tower_http_cache::prelude::*;
use tower_http_cache::{
backend::multi_tier::{MultiTierBackend, PromotionStrategy},
logging::{CacheEvent, CacheEventType, MLLoggingConfig},
request_id::RequestId,
tags::TagPolicy,
};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(feature = "tracing")]
{
tracing_subscriber::fmt()
.with_env_filter("tower_http_cache=debug")
.init();
}
println!("=== tower-http-cache v0.3.0 Features Demo ===\n");
println!("1. Cache Tags & Invalidation Groups");
demo_cache_tags().await?;
println!("\n2. Multi-Tier Caching (L1 + L2)");
demo_multi_tier().await?;
println!("\n3. ML-Ready Structured Logging");
demo_ml_logging().await?;
println!("\n4. Request ID Propagation");
demo_request_id().await?;
println!("\n5. All Features Combined");
demo_combined_features().await?;
Ok(())
}
async fn demo_cache_tags() -> Result<(), Box<dyn std::error::Error>> {
let backend = InMemoryBackend::new(1000);
let tag_policy = TagPolicy::new()
.with_enabled(true)
.with_max_tags_per_entry(5);
let layer = CacheLayer::builder(backend.clone())
.ttl(Duration::from_secs(300))
.policy(
CachePolicy::default()
.with_tag_policy(tag_policy)
.with_tag_extractor(|method, uri| {
if method == &Method::GET {
let path = uri.path();
if path.starts_with("/users/") {
let parts: Vec<&str> = path.split('/').collect();
if parts.len() >= 3 {
return vec![format!("user:{}", parts[2]), "users".to_string()];
}
}
}
vec![]
}),
)
.build();
let service = ServiceBuilder::new()
.layer(layer)
.service(tower::service_fn(|_req: Request<()>| async {
Ok::<_, std::convert::Infallible>(Response::new(Full::from("User data")))
}));
let req1 = Request::builder().uri("/users/123").body(()).unwrap();
let req2 = Request::builder().uri("/users/456").body(()).unwrap();
let _resp1 = service
.clone()
.oneshot(req1)
.await
.map_err(|e| format!("{}", e))?;
let _resp2 = service
.clone()
.oneshot(req2)
.await
.map_err(|e| format!("{}", e))?;
println!(" ✓ Cached 2 user entries with tags");
println!(" ✓ Tags: user:123, user:456, users");
let count = backend.invalidate_by_tag("users").await?;
println!(" ✓ Invalidated {} entries by tag 'users'", count);
Ok(())
}
async fn demo_multi_tier() -> Result<(), Box<dyn std::error::Error>> {
let l1 = InMemoryBackend::new(100); let l2 = InMemoryBackend::new(10_000);
let backend = MultiTierBackend::builder()
.l1(l1.clone())
.l2(l2.clone())
.promotion_strategy(PromotionStrategy::HitCount { threshold: 3 })
.write_through(true)
.build();
let layer = CacheLayer::builder(backend.clone())
.ttl(Duration::from_secs(300))
.build();
let mut service = ServiceBuilder::new()
.layer(layer)
.service(tower::service_fn(|_req: Request<()>| async {
Ok::<_, std::convert::Infallible>(Response::new(Full::from("Data")))
}));
let req = Request::builder().uri("/api/hot-data").body(()).unwrap();
let _resp = service
.ready()
.await
.map_err(|e| format!("{}", e))?
.call(req.clone())
.await
.map_err(|e| format!("{}", e))?;
println!(" ✓ First request: stored in L1 and L2");
l1.invalidate("/api/hot-data").await?;
println!(" ✓ Simulated L1 eviction");
for i in 1..=3 {
let _resp = service
.ready()
.await
.map_err(|e| format!("{}", e))?
.call(req.clone())
.await
.map_err(|e| format!("{}", e))?;
println!(" ✓ Request {}: L2 hit", i);
}
tokio::time::sleep(Duration::from_millis(50)).await;
if l1.get("/api/hot-data").await?.is_some() {
println!(" ✓ Hot data promoted back to L1 after threshold");
}
Ok(())
}
async fn demo_ml_logging() -> Result<(), Box<dyn std::error::Error>> {
let ml_config = MLLoggingConfig::new()
.with_enabled(true)
.with_sample_rate(1.0) .with_hash_keys(true);
let backend = InMemoryBackend::new(1000);
let layer = CacheLayer::builder(backend)
.policy(CachePolicy::default().with_ml_logging(ml_config))
.build();
let request_id = RequestId::new();
println!(" ✓ Request ID: {}", request_id);
let event = CacheEvent::new(
CacheEventType::Hit,
request_id.clone(),
"/api/users/123".to_string(),
)
.with_hit(true)
.with_latency(Duration::from_micros(150))
.with_size(1024)
.with_ttl(Duration::from_secs(300))
.with_tags(vec!["user:123".to_string(), "users".to_string()])
.with_tier("l1");
event.log(&MLLoggingConfig::new().with_enabled(true));
println!(" ✓ Logged cache hit event with full metadata");
println!(" ✓ Event includes: latency, size, TTL, tags, tier");
Ok(())
}
async fn demo_request_id() -> Result<(), Box<dyn std::error::Error>> {
let request_id = RequestId::new();
let mut req = Request::builder().uri("/api/data").body(()).unwrap();
req.headers_mut()
.insert("x-request-id", request_id.as_str().parse().unwrap());
println!(" ✓ Request with X-Request-ID: {}", request_id);
let extracted_id = RequestId::from_request(&req);
assert_eq!(extracted_id.as_str(), request_id.as_str());
println!(" ✓ Request ID successfully extracted from headers");
let req_without_id = Request::builder().uri("/api/data").body(()).unwrap();
let auto_id = RequestId::from_request(&req_without_id);
println!(" ✓ Auto-generated request ID: {}", auto_id);
Ok(())
}
async fn demo_combined_features() -> Result<(), Box<dyn std::error::Error>> {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(10_000);
let backend = MultiTierBackend::builder()
.l1(l1)
.l2(l2)
.promotion_threshold(2)
.build();
let layer = CacheLayer::builder(backend.clone())
.policy(
CachePolicy::default()
.with_ttl(Duration::from_secs(300))
.with_tag_policy(
TagPolicy::new()
.with_enabled(true)
.with_max_tags_per_entry(10),
)
.with_tag_extractor(|_method, uri| vec![format!("path:{}", uri.path())])
.with_ml_logging(
MLLoggingConfig::new()
.with_enabled(true)
.with_sample_rate(1.0),
),
)
.build();
let mut service = ServiceBuilder::new()
.layer(layer)
.service(tower::service_fn(|_req: Request<()>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Full::from("Combined features response"))
.unwrap(),
)
}));
let request_id = RequestId::new();
let mut req = Request::builder().uri("/api/combined").body(()).unwrap();
req.headers_mut()
.insert("x-request-id", request_id.as_str().parse().unwrap());
let _resp = service
.ready()
.await
.map_err(|e| format!("{}", e))?
.call(req)
.await
.map_err(|e| format!("{}", e))?;
println!(" ✓ Request processed with all features:");
println!(" - Request ID: {}", request_id);
println!(" - Multi-tier caching (L1 + L2)");
println!(" - Cache tags: path:/api/combined");
println!(" - ML logging enabled");
let tags = backend.list_tags().await?;
println!(" ✓ Tags in cache: {:?}", tags);
let stats = backend.stats();
println!(" ✓ Tier stats available for monitoring");
Ok(())
}