use http::{Method, Request, Response, StatusCode, Uri};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tower::{Layer, Service, ServiceExt};
use tower_http_cache::backend::memory::InMemoryBackend;
use tower_http_cache::layer::CacheLayer;
use tower_http_cache::refresh::AutoRefreshConfig;
#[tokio::test]
async fn auto_refresh_disabled_by_default() {
let backend = InMemoryBackend::new(100);
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(10))
.build();
let call_count = Arc::new(AtomicUsize::new(0));
let handler = tower::service_fn({
let count = call_count.clone();
move |_req: Request<()>| {
let count = count.clone();
async move {
count.fetch_add(1, Ordering::Relaxed);
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from(
"test response",
)))
.unwrap(),
)
}
}
});
let mut service = layer.layer(handler);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
let req2 = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req2).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn auto_refresh_tracks_hits() {
let backend = InMemoryBackend::new(100);
let config = AutoRefreshConfig {
enabled: true,
min_hits_per_minute: 5.0,
check_interval: Duration::from_millis(100),
cleanup_interval: Duration::from_secs(60),
..Default::default()
};
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(10))
.refresh_before(Duration::from_secs(5))
.auto_refresh(config)
.build();
let call_count = Arc::new(AtomicUsize::new(0));
let handler = tower::service_fn({
let count = call_count.clone();
move |_req: Request<()>| {
let count = count.clone();
async move {
count.fetch_add(1, Ordering::Relaxed);
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from(
"test response",
)))
.unwrap(),
)
}
}
});
let mut service = layer.layer(handler);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
for _ in 0..10 {
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
}
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn auto_refresh_respects_concurrency_limit() {
let backend = InMemoryBackend::new(100);
let config = AutoRefreshConfig {
enabled: true,
min_hits_per_minute: 1.0,
check_interval: Duration::from_millis(50),
max_concurrent_refreshes: 2,
cleanup_interval: Duration::from_secs(60),
..Default::default()
};
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(1))
.refresh_before(Duration::from_millis(900))
.auto_refresh(config)
.build();
let in_flight = Arc::new(AtomicUsize::new(0));
let max_in_flight = Arc::new(AtomicUsize::new(0));
let in_flight_clone = in_flight.clone();
let max_in_flight_clone = max_in_flight.clone();
let slow_service = tower::service_fn(move |_req: Request<()>| {
let in_flight = in_flight_clone.clone();
let max_in_flight = max_in_flight_clone.clone();
async move {
let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
max_in_flight.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(100)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from(
"slow response",
)))
.unwrap(),
)
}
});
let mut service = layer.layer(slow_service);
for i in 0..5 {
let req = Request::builder()
.method(Method::GET)
.uri(format!("http://example.com/test{}", i))
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
}
for i in 0..5 {
let req = Request::builder()
.method(Method::GET)
.uri(format!("http://example.com/test{}", i))
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(1000)).await;
tokio::time::sleep(Duration::from_millis(500)).await;
let max = max_in_flight.load(Ordering::SeqCst);
println!("Max concurrent refreshes: {}", max);
assert!(max <= 3, "Expected max concurrent <= 3, got {}", max);
}
#[tokio::test]
async fn auto_refresh_handles_service_errors_gracefully() {
let backend = InMemoryBackend::new(100);
let config = AutoRefreshConfig {
enabled: true,
min_hits_per_minute: 1.0,
check_interval: Duration::from_millis(100),
cleanup_interval: Duration::from_secs(60),
..Default::default()
};
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(1))
.refresh_before(Duration::from_millis(900))
.auto_refresh(config)
.build();
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let error_service = tower::service_fn(move |_req: Request<()>| {
let count = call_count_clone.clone();
async move {
let current = count.fetch_add(1, Ordering::Relaxed);
if current == 0 {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from("success")))
.unwrap(),
)
} else {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(http_body_util::Full::from(bytes::Bytes::from("error")))
.unwrap(),
)
}
}
});
let mut service = layer.layer(error_service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req2 = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req2).await.unwrap();
tokio::time::sleep(Duration::from_millis(1200)).await;
let count = call_count.load(Ordering::Relaxed);
println!("Total service calls: {}", count);
let req3 = Request::builder()
.method(Method::GET)
.uri("http://example.com/test2")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req3).await;
}
#[tokio::test]
async fn auto_refresh_metadata_reconstruction() {
use tower_http_cache::refresh::RefreshMetadata;
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test?foo=bar")
.header("authorization", "Bearer token")
.body(())
.unwrap();
let metadata = RefreshMetadata::from_request_with_headers(&req, &["authorization".to_string()]);
assert_eq!(metadata.method, Method::GET);
assert_eq!(metadata.uri.path(), "/test");
assert_eq!(metadata.uri.query(), Some("foo=bar"));
assert_eq!(metadata.headers.len(), 1);
assert_eq!(metadata.headers[0].0, "authorization");
let reconstructed = metadata.try_into_request();
assert!(reconstructed.is_some());
let reconstructed = reconstructed.unwrap();
assert_eq!(reconstructed.method(), Method::GET);
assert_eq!(reconstructed.uri().path(), "/test");
assert!(reconstructed.headers().get("authorization").is_some());
}
#[tokio::test]
async fn auto_refresh_config_validation() {
let valid = AutoRefreshConfig::default();
assert!(valid.validate().is_ok());
let invalid_hits = AutoRefreshConfig {
min_hits_per_minute: -1.0,
..Default::default()
};
assert!(invalid_hits.validate().is_err());
let invalid_concurrent = AutoRefreshConfig {
max_concurrent_refreshes: 0,
..Default::default()
};
assert!(invalid_concurrent.validate().is_err());
let invalid_interval = AutoRefreshConfig {
check_interval: Duration::ZERO,
..Default::default()
};
assert!(invalid_interval.validate().is_err());
}
#[tokio::test]
async fn auto_refresh_cleanup_stale_tracking() {
use tower_http_cache::refresh::AccessTracker;
let config = AutoRefreshConfig {
hit_rate_window: Duration::from_secs(60),
..Default::default()
};
let tracker = AccessTracker::new(config);
tracker.record_hit("key1");
tracker.record_hit("key2");
assert_eq!(tracker.tracked_keys(), 2);
tracker.cleanup_stale(Duration::from_secs(3600));
assert_eq!(tracker.tracked_keys(), 2);
tracker.cleanup_stale(Duration::ZERO);
assert!(tracker.tracked_keys() <= 2);
}
#[tokio::test]
async fn auto_refresh_hit_rate_calculation() {
use tower_http_cache::refresh::AccessTracker;
let config = AutoRefreshConfig {
min_hits_per_minute: 10.0,
hit_rate_window: Duration::from_secs(1), ..Default::default()
};
let tracker = AccessTracker::new(config);
for _ in 0..5 {
tracker.record_hit("hot_key");
tokio::time::sleep(Duration::from_millis(10)).await;
}
let rate = tracker.hits_per_minute("hot_key");
println!("Hit rate: {} hits/min", rate);
assert!(rate > 0.0);
}
#[tokio::test]
async fn auto_refresh_with_different_methods() {
let backend = InMemoryBackend::new(100);
let config = AutoRefreshConfig::enabled(5.0);
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(10))
.auto_refresh(config)
.build();
let call_count = Arc::new(AtomicUsize::new(0));
let handler = tower::service_fn({
let count = call_count.clone();
move |_req: Request<()>| {
let count = count.clone();
async move {
count.fetch_add(1, Ordering::Relaxed);
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from(
"test response",
)))
.unwrap(),
)
}
}
});
let mut service = layer.layer(handler);
let get_req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(get_req).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
let post_req = Request::builder()
.method(Method::POST)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(post_req).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn auto_refresh_manager_lifecycle() {
use tower_http_cache::refresh::{RefreshCallback, RefreshManager, RefreshMetadata};
struct TestCallback {
call_count: Arc<AtomicUsize>,
}
impl RefreshCallback for TestCallback {
fn refresh(
&self,
_key: String,
_metadata: RefreshMetadata,
) -> tower_http_cache::refresh::RefreshFuture {
let count = self.call_count.clone();
Box::pin(async move {
count.fetch_add(1, Ordering::Relaxed);
Ok(())
})
}
}
let config = AutoRefreshConfig {
enabled: true,
check_interval: Duration::from_millis(50),
min_hits_per_minute: 0.1, cleanup_interval: Duration::from_secs(60),
..Default::default()
};
let manager = RefreshManager::new(config);
let call_count = Arc::new(AtomicUsize::new(0));
let callback = Arc::new(TestCallback {
call_count: call_count.clone(),
});
assert!(manager.start(callback).await.is_ok());
let metadata = RefreshMetadata {
method: Method::GET,
uri: Uri::from_static("http://example.com/test"),
headers: Vec::new(),
};
manager.store_metadata("test_key".to_string(), metadata);
for _ in 0..10 {
manager.tracker().record_hit("test_key");
}
tokio::time::sleep(Duration::from_millis(200)).await;
manager.shutdown().await;
println!(
"Callback was called {} times",
call_count.load(Ordering::Relaxed)
);
}
#[tokio::test]
async fn auto_refresh_only_refreshes_in_window() {
let backend = InMemoryBackend::new(100);
let config = AutoRefreshConfig {
enabled: true,
min_hits_per_minute: 1.0,
check_interval: Duration::from_millis(100),
..Default::default()
};
let layer = CacheLayer::builder(backend)
.ttl(Duration::from_secs(10)) .refresh_before(Duration::from_secs(1)) .auto_refresh(config)
.build();
let call_count = Arc::new(AtomicUsize::new(0));
let handler = tower::service_fn({
let count = call_count.clone();
move |_req: Request<()>| {
let count = count.clone();
async move {
count.fetch_add(1, Ordering::Relaxed);
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(http_body_util::Full::from(bytes::Bytes::from(
"test response",
)))
.unwrap(),
)
}
}
});
let mut service = layer.layer(handler);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(call_count.load(Ordering::Relaxed), 1);
for _ in 0..5 {
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/test")
.body(())
.unwrap();
let _ = service.ready().await.unwrap().call(req).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(500)).await;
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}