use std::collections::hash_map::DefaultHasher;
use std::future::Future;
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::response::IntoResponse;
use futures::stream::StreamExt as _;
use http::header::{
CACHE_CONTROL, CONTENT_LOCATION, DATE, ETAG, EXPIRES, IF_MODIFIED_SINCE, IF_NONE_MATCH,
LAST_MODIFIED, SET_COOKIE, VARY,
};
use http::{HeaderMap, HeaderValue, Response, StatusCode};
use http_body_util::BodyExt;
use sha2::Digest as _;
use tower::{Layer, Service};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ETag {
tag: String,
weak: bool,
}
impl ETag {
#[must_use]
pub fn strong(tag: impl Into<String>) -> Self {
Self {
tag: tag.into(),
weak: false,
}
}
#[must_use]
pub fn weak(tag: impl Into<String>) -> Self {
Self {
tag: tag.into(),
weak: true,
}
}
#[must_use]
pub fn tag(&self) -> &str {
&self.tag
}
#[must_use]
pub const fn is_weak(&self) -> bool {
self.weak
}
#[must_use]
pub fn header_value(&self) -> HeaderValue {
let formatted = if self.weak {
format!("W/\"{}\"", self.tag)
} else {
format!("\"{}\"", self.tag)
};
HeaderValue::from_str(&formatted).unwrap_or_else(|_| HeaderValue::from_static(""))
}
fn matches_if_none_match(&self, if_none_match: &str) -> bool {
let if_none_match = if_none_match.trim();
if if_none_match == "*" {
return true;
}
for candidate in if_none_match.split(',') {
let candidate = candidate.trim();
let tag = candidate
.strip_prefix("W/")
.unwrap_or(candidate)
.trim_matches('"');
if tag == self.tag {
return true;
}
}
false
}
}
pub trait IntoETag {
fn into_etag(self) -> ETag;
}
impl IntoETag for ETag {
fn into_etag(self) -> ETag {
self
}
}
impl IntoETag for String {
fn into_etag(self) -> ETag {
ETag::strong(sha256_hex(self.as_bytes()))
}
}
impl IntoETag for &str {
fn into_etag(self) -> ETag {
ETag::strong(sha256_hex(self.as_bytes()))
}
}
impl IntoETag for i64 {
fn into_etag(self) -> ETag {
ETag::strong(sha256_hex(&self.to_be_bytes()))
}
}
impl IntoETag for i32 {
fn into_etag(self) -> ETag {
i64::from(self).into_etag()
}
}
impl IntoETag for (chrono::NaiveDateTime, i64) {
fn into_etag(self) -> ETag {
let mut hasher = sha2::Sha256::new();
hasher.update(self.0.and_utc().timestamp().to_be_bytes());
hasher.update(self.0.and_utc().timestamp_subsec_nanos().to_be_bytes());
hasher.update(self.1.to_be_bytes());
ETag::strong(hex_lower(hasher.finalize()))
}
}
impl IntoETag for (chrono::NaiveDateTime, i32) {
fn into_etag(self) -> ETag {
(self.0, i64::from(self.1)).into_etag()
}
}
#[must_use]
pub fn hash_etag<T: Hash>(value: &T) -> ETag {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
ETag::weak(format!("{:016x}", hasher.finish()))
}
fn sha256_hex(bytes: &[u8]) -> String {
hex_lower(sha2::Sha256::digest(bytes))
}
fn hex_lower(bytes: impl AsRef<[u8]>) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let bytes = bytes.as_ref();
let mut out = String::with_capacity(bytes.len() * 2);
out.extend(bytes.iter().flat_map(|b| {
[
HEX[(b >> 4) as usize] as char,
HEX[(b & 0xf) as usize] as char,
]
}));
out
}
#[must_use = "call `.or(response)` to resolve the conditional-GET result"]
pub struct FreshWhen {
etag: ETag,
last_modified: Option<chrono::DateTime<chrono::Utc>>,
is_fresh: bool,
if_none_match_present: bool,
raw_if_modified_since: Option<String>,
}
impl FreshWhen {
#[must_use]
pub const fn is_fresh(&self) -> bool {
self.is_fresh
}
pub fn last_modified(mut self, dt: impl Into<Option<chrono::DateTime<chrono::Utc>>>) -> Self {
self.last_modified = dt.into();
if !self.if_none_match_present
&& let Some(ref ims_str) = self.raw_if_modified_since
&& let Some(lm) = self.last_modified
{
let lm_secs = chrono::DateTime::from_timestamp(lm.timestamp(), 0).unwrap_or(lm);
self.is_fresh = parse_http_date(ims_str)
.is_some_and(|parsed| std::time::SystemTime::from(lm_secs) <= parsed);
}
self
}
pub fn or(self, response: impl IntoResponse) -> impl IntoResponse {
let r = response.into_response();
if self.is_fresh {
let mut not_modified = not_modified_response(&self.etag, self.last_modified);
for name in [CACHE_CONTROL, VARY, CONTENT_LOCATION, DATE, EXPIRES] {
for v in r.headers().get_all(&name) {
not_modified.headers_mut().append(name.clone(), v.clone());
}
}
if self.last_modified.is_none() {
for v in r.headers().get_all(LAST_MODIFIED) {
not_modified.headers_mut().append(LAST_MODIFIED, v.clone());
}
}
not_modified
} else {
let mut r = r;
r.headers_mut().insert(ETAG, self.etag.header_value());
if let Some(lm) = self.last_modified
&& let Ok(v) = HeaderValue::from_str(&http_date(lm))
{
r.headers_mut().insert(LAST_MODIFIED, v);
}
r
}
}
pub fn or_else<R: IntoResponse, F: FnOnce() -> R>(self, f: F) -> impl IntoResponse {
if self.is_fresh {
not_modified_response(&self.etag, self.last_modified)
} else {
let mut r = f().into_response();
r.headers_mut().insert(ETAG, self.etag.header_value());
if let Some(lm) = self.last_modified
&& let Ok(v) = HeaderValue::from_str(&http_date(lm))
{
r.headers_mut().insert(LAST_MODIFIED, v);
}
r
}
}
}
pub fn fresh_when<E: IntoETag>(request_headers: &HeaderMap, etag: E) -> FreshWhen {
let etag = etag.into_etag();
let if_none_match_present = request_headers.contains_key(IF_NONE_MATCH);
let is_fresh = if_none_match_present && check_if_none_match(request_headers, &etag);
let raw_if_modified_since = if if_none_match_present {
None
} else {
request_headers
.get(IF_MODIFIED_SINCE)
.and_then(|v| v.to_str().ok())
.map(ToOwned::to_owned)
};
FreshWhen {
etag,
last_modified: None,
is_fresh,
if_none_match_present,
raw_if_modified_since,
}
}
fn check_if_none_match(headers: &HeaderMap, etag: &ETag) -> bool {
headers
.get_all(IF_NONE_MATCH)
.iter()
.any(|v| v.to_str().is_ok_and(|s| etag.matches_if_none_match(s)))
}
fn not_modified_response(
etag: &ETag,
last_modified: Option<chrono::DateTime<chrono::Utc>>,
) -> Response<Body> {
let mut builder = Response::builder().status(StatusCode::NOT_MODIFIED);
let headers = builder.headers_mut().expect("builder not consumed");
headers.insert(ETAG, etag.header_value());
if let Some(lm) = last_modified
&& let Ok(v) = HeaderValue::from_str(&http_date(lm))
{
headers.insert(LAST_MODIFIED, v);
}
builder
.body(Body::empty())
.expect("304 body is always valid")
}
fn http_date(dt: chrono::DateTime<chrono::Utc>) -> String {
dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
}
fn parse_http_date(s: &str) -> Option<std::time::SystemTime> {
chrono::DateTime::parse_from_rfc2822(s)
.map(|dt| std::time::SystemTime::from(dt.with_timezone(&chrono::Utc)))
.or_else(|_| {
chrono::NaiveDateTime::parse_from_str(
s.trim_end_matches(" GMT"),
"%a, %d %b %Y %H:%M:%S",
)
.map(|ndt| {
std::time::SystemTime::from(
chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(ndt, chrono::Utc),
)
})
})
.ok()
}
#[derive(Clone, Debug, Default)]
pub struct EtagLayer;
impl EtagLayer {
pub const MAX_BODY_BYTES: usize = 4 * 1024 * 1024;
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl<S> Layer<S> for EtagLayer {
type Service = EtagService<S>;
fn layer(&self, inner: S) -> Self::Service {
EtagService { inner }
}
}
#[derive(Clone)]
pub struct EtagService<S> {
inner: S,
}
impl<S, ReqBody> Service<http::Request<ReqBody>> for EtagService<S>
where
S: Service<http::Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ReqBody: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let if_none_match: Option<String> = {
let vals: Vec<&str> = req
.headers()
.get_all(IF_NONE_MATCH)
.iter()
.filter_map(|v| v.to_str().ok())
.collect();
if vals.is_empty() {
None
} else {
Some(vals.join(", "))
}
};
let is_get = req.method() == http::Method::GET;
let fut = self.inner.call(req);
Box::pin(async move {
let response = fut.await?;
if !is_get || response.status() != StatusCode::OK {
return Ok(response);
}
Ok(apply_etag(response, if_none_match.as_deref()).await)
})
}
}
fn copy_304_headers(src: &http::HeaderMap, dst: &mut Response<Body>) {
for name in [
CACHE_CONTROL,
VARY,
CONTENT_LOCATION,
DATE,
EXPIRES,
LAST_MODIFIED,
] {
for v in src.get_all(&name) {
dst.headers_mut().append(name.clone(), v.clone());
}
}
}
async fn apply_etag(response: Response<Body>, if_none_match: Option<&str>) -> Response<Body> {
if let Some(existing_etag) = response.headers().get(ETAG).cloned() {
if let Some(inm) = if_none_match {
let existing_tag = existing_etag.to_str().unwrap_or("");
let tag = existing_tag
.strip_prefix("W/")
.unwrap_or(existing_tag)
.trim_matches('"');
let candidate_etag = ETag::strong(tag.to_owned());
if candidate_etag.matches_if_none_match(inm) {
let (parts, _body) = response.into_parts();
let mut not_modified = not_modified_response(&candidate_etag, None);
copy_304_headers(&parts.headers, &mut not_modified);
not_modified.headers_mut().remove(SET_COOKIE);
not_modified.headers_mut().insert(ETAG, existing_etag);
return not_modified;
}
}
return response;
}
let (mut parts, mut body) = response.into_parts();
if parts
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.is_some_and(|len| len > EtagLayer::MAX_BODY_BYTES)
{
return Response::from_parts(parts, body);
}
let mut buf = bytes::BytesMut::new();
let mut overflow_frame: Option<http_body::Frame<bytes::Bytes>> = None;
let mut stream_errored = false;
loop {
match BodyExt::frame(&mut body).await {
None => break,
Some(Err(_)) => {
stream_errored = true;
break;
}
Some(Ok(frame)) => match frame.into_data() {
Ok(data) => {
if buf.len() + data.len() > EtagLayer::MAX_BODY_BYTES {
overflow_frame = Some(http_body::Frame::data(data));
break;
}
buf.extend_from_slice(&data);
}
Err(non_data) => {
overflow_frame = Some(non_data);
break;
}
},
}
}
if stream_errored {
let frozen = buf.freeze();
let err = axum::Error::new(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"upstream body error during ETag buffering",
));
let frames: Vec<Result<bytes::Bytes, axum::Error>> = if frozen.is_empty() {
vec![Err(err)]
} else {
vec![Ok(frozen), Err(err)]
};
return Response::from_parts(parts, Body::from_stream(futures::stream::iter(frames)));
}
if let Some(overflow) = overflow_frame {
return Response::from_parts(parts, rebuild_oversized_body(buf.freeze(), overflow, body));
}
let bytes = buf.freeze();
let etag = {
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
ETag::weak(format!("{:016x}", hasher.finish()))
};
if if_none_match.is_some_and(|inm| etag.matches_if_none_match(inm)) {
let mut not_modified = not_modified_response(&etag, None);
copy_304_headers(&parts.headers, &mut not_modified);
not_modified.headers_mut().remove(SET_COOKIE);
return not_modified;
}
parts.headers.insert(ETAG, etag.header_value());
Response::from_parts(parts, Body::from(bytes))
}
fn rebuild_oversized_body(
prefix: bytes::Bytes,
overflow: http_body::Frame<bytes::Bytes>,
remaining: Body,
) -> Body {
use http_body_util::StreamBody;
let preamble = futures::stream::iter([
Ok::<http_body::Frame<bytes::Bytes>, axum::Error>(http_body::Frame::data(prefix)),
Ok(overflow),
]);
let tail = futures::stream::unfold(remaining, |mut b| async move {
BodyExt::frame(&mut b).await.map(|result| (result, b))
});
Body::new(StreamBody::new(preamble.chain(tail)))
}
#[must_use]
pub fn build_not_modified(
original_headers: &HeaderMap,
etag: &ETag,
last_modified: Option<chrono::DateTime<chrono::Utc>>,
) -> Response<Body> {
let mut response = not_modified_response(etag, last_modified);
for name in [CACHE_CONTROL, VARY, CONTENT_LOCATION, DATE, EXPIRES] {
for v in original_headers.get_all(&name) {
response.headers_mut().append(name.clone(), v.clone());
}
}
if last_modified.is_none() {
for v in original_headers.get_all(LAST_MODIFIED) {
response.headers_mut().append(LAST_MODIFIED, v.clone());
}
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use http::{HeaderMap, HeaderValue, Method, Request, StatusCode};
use tower::ServiceExt;
#[test]
fn strong_etag_header_value_has_quotes() {
let etag = ETag::strong("abc123");
assert_eq!(etag.header_value().to_str().unwrap(), r#""abc123""#);
}
#[test]
fn weak_etag_header_value_has_w_prefix() {
let etag = ETag::weak("abc123");
assert_eq!(etag.header_value().to_str().unwrap(), r#"W/"abc123""#);
}
#[test]
fn etag_is_not_weak_by_default_strong_constructor() {
let etag = ETag::strong("x");
assert!(!etag.is_weak());
}
#[test]
fn weak_etag_is_weak() {
let etag = ETag::weak("x");
assert!(etag.is_weak());
}
#[test]
fn str_into_etag_produces_deterministic_strong_etag() {
let e1: ETag = "hello".into_etag();
let e2: ETag = "hello".into_etag();
assert_eq!(e1, e2);
assert!(!e1.is_weak());
}
#[test]
fn different_strings_produce_different_etags() {
let e1: ETag = "hello".into_etag();
let e2: ETag = "world".into_etag();
assert_ne!(e1, e2);
}
#[test]
fn string_into_etag_same_as_str() {
let e1: ETag = "hello".into_etag();
let e2: ETag = String::from("hello").into_etag();
assert_eq!(e1, e2);
}
#[test]
fn i64_into_etag_is_deterministic() {
let e1: ETag = 42_i64.into_etag();
let e2: ETag = 42_i64.into_etag();
assert_eq!(e1, e2);
assert!(!e1.is_weak());
}
#[test]
fn different_i64_values_produce_different_etags() {
let e1: ETag = 1_i64.into_etag();
let e2: ETag = 2_i64.into_etag();
assert_ne!(e1, e2);
}
#[test]
fn i32_into_etag_matches_equivalent_i64() {
let e1: ETag = 7_i32.into_etag();
let e2: ETag = 7_i64.into_etag();
assert_eq!(e1, e2);
}
#[test]
fn tuple_into_etag_is_deterministic() {
let dt = chrono::DateTime::from_timestamp(1_000_000, 0)
.unwrap()
.naive_utc();
let e1: ETag = (dt, 3_i64).into_etag();
let e2: ETag = (dt, 3_i64).into_etag();
assert_eq!(e1, e2);
assert!(!e1.is_weak());
}
#[test]
fn tuple_etag_differs_when_lock_version_differs() {
let dt = chrono::DateTime::from_timestamp(1_000_000, 0)
.unwrap()
.naive_utc();
let e1: ETag = (dt, 1_i64).into_etag();
let e2: ETag = (dt, 2_i64).into_etag();
assert_ne!(e1, e2);
}
#[test]
fn tuple_etag_differs_when_timestamp_differs() {
let dt1 = chrono::DateTime::from_timestamp(1_000_000, 0)
.unwrap()
.naive_utc();
let dt2 = chrono::DateTime::from_timestamp(1_000_001, 0)
.unwrap()
.naive_utc();
let e1: ETag = (dt1, 1_i64).into_etag();
let e2: ETag = (dt2, 1_i64).into_etag();
assert_ne!(e1, e2);
}
#[test]
fn hash_etag_is_weak() {
let etag = hash_etag(&vec![1u8, 2, 3]);
assert!(etag.is_weak());
}
#[test]
fn hash_etag_is_deterministic_for_same_input() {
let etag1 = hash_etag(&"stable_value");
let etag2 = hash_etag(&"stable_value");
assert_eq!(etag1, etag2);
}
#[test]
fn etag_matches_exact_quoted_value() {
let etag = ETag::strong("abc");
assert!(etag.matches_if_none_match(r#""abc""#));
}
#[test]
fn etag_matches_weak_variant_by_tag() {
let etag = ETag::strong("abc");
assert!(etag.matches_if_none_match(r#"W/"abc""#));
}
#[test]
fn etag_matches_star_wildcard() {
let etag = ETag::strong("anything");
assert!(etag.matches_if_none_match("*"));
}
#[test]
fn etag_does_not_match_different_value() {
let etag = ETag::strong("abc");
assert!(!etag.matches_if_none_match(r#""xyz""#));
}
#[test]
fn etag_matches_one_of_many_in_list() {
let etag = ETag::strong("abc");
assert!(etag.matches_if_none_match(r#""xyz", "abc", "foo""#));
}
#[test]
fn fresh_when_returns_stale_with_no_headers() {
let headers = HeaderMap::new();
let result = fresh_when(&headers, 1_i64);
assert!(!result.is_fresh());
}
#[test]
fn fresh_when_returns_fresh_on_matching_if_none_match() {
let etag: ETag = 42_i64.into_etag();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.header_value());
let result = fresh_when(&headers, 42_i64);
assert!(result.is_fresh());
}
#[test]
fn fresh_when_returns_stale_on_different_etag() {
let etag: ETag = 1_i64.into_etag();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.header_value());
let result = fresh_when(&headers, 2_i64);
assert!(!result.is_fresh());
}
#[test]
fn fresh_when_or_returns_304_when_fresh() {
let etag: ETag = 7_i64.into_etag();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.header_value());
let response = fresh_when(&headers, 7_i64)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
}
#[test]
fn fresh_when_or_returns_200_and_sets_etag_when_stale() {
let headers = HeaderMap::new(); let response = fresh_when(&headers, 1_i64)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::OK);
let etag_header = response.headers().get(ETAG);
assert!(
etag_header.is_some(),
"ETag header must be set on stale response"
);
}
#[test]
fn fresh_when_304_has_empty_body() {
use http_body_util::BodyExt;
let etag: ETag = 5_i64.into_etag();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.header_value());
let response = fresh_when(&headers, 5_i64)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
let rt = tokio::runtime::Runtime::new().unwrap();
let bytes = rt.block_on(async { response.into_body().collect().await.unwrap().to_bytes() });
assert!(bytes.is_empty(), "304 body must be empty, got {bytes:?}");
}
#[test]
fn fresh_when_or_includes_etag_in_304_headers() {
let etag: ETag = 3_i64.into_etag();
let etag_val = etag.header_value();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag_val.clone());
let response = fresh_when(&headers, 3_i64)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert_eq!(response.headers().get(ETAG), Some(&etag_val));
}
#[test]
fn fresh_when_wildcard_if_none_match_returns_fresh() {
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, HeaderValue::from_static("*"));
let result = fresh_when(&headers, "anything");
assert!(result.is_fresh());
}
#[test]
fn fresh_when_last_modified_sets_header_on_stale_response() {
use chrono::TimeZone;
let headers = HeaderMap::new();
let last_modified = chrono::Utc.timestamp_opt(1_700_000_000, 0).unwrap();
let response = fresh_when(&headers, 1_i64)
.last_modified(last_modified)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(LAST_MODIFIED));
}
#[test]
fn fresh_when_last_modified_sets_header_on_304() {
use chrono::TimeZone;
let etag: ETag = 9_i64.into_etag();
let last_modified = chrono::Utc.timestamp_opt(1_700_000_000, 0).unwrap();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.header_value());
let response = fresh_when(&headers, 9_i64)
.last_modified(last_modified)
.or(StatusCode::OK)
.into_response();
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert!(response.headers().contains_key(LAST_MODIFIED));
}
#[test]
fn fresh_when_if_modified_since_fresh_when_last_modified_not_newer() {
use chrono::TimeZone;
let last_modified = chrono::Utc.timestamp_opt(1_700_000_000, 0).unwrap();
let ims_time = chrono::Utc.timestamp_opt(1_700_000_001, 0).unwrap();
let ims_str = http_date(ims_time);
let mut headers = HeaderMap::new();
headers.insert(IF_MODIFIED_SINCE, HeaderValue::from_str(&ims_str).unwrap());
let result = fresh_when(&headers, 1_i64).last_modified(last_modified);
assert!(result.is_fresh(), "IMS >= last_modified → fresh (304)");
}
#[test]
fn fresh_when_if_modified_since_stale_when_resource_newer_than_ims() {
use chrono::TimeZone;
let last_modified = chrono::Utc.timestamp_opt(1_700_000_002, 0).unwrap();
let ims_time = chrono::Utc.timestamp_opt(1_700_000_001, 0).unwrap();
let ims_str = http_date(ims_time);
let mut headers = HeaderMap::new();
headers.insert(IF_MODIFIED_SINCE, HeaderValue::from_str(&ims_str).unwrap());
let result = fresh_when(&headers, 1_i64).last_modified(last_modified);
assert!(!result.is_fresh(), "last_modified > IMS → stale (200)");
}
#[test]
fn fresh_when_ignores_ims_when_inm_present_rfc7232_s3_3() {
use chrono::TimeZone;
let last_modified = chrono::Utc.timestamp_opt(1_700_000_000, 0).unwrap();
let ims_time = chrono::Utc.timestamp_opt(1_700_000_001, 0).unwrap();
let ims_str = http_date(ims_time);
let wrong_etag: ETag = 99_i64.into_etag();
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, wrong_etag.header_value());
headers.insert(IF_MODIFIED_SINCE, HeaderValue::from_str(&ims_str).unwrap());
let result = fresh_when(&headers, 1_i64).last_modified(last_modified);
assert!(
!result.is_fresh(),
"IMS must be ignored when INM is present per RFC 7232 §3.3"
);
}
#[tokio::test]
async fn etag_layer_adds_etag_to_get_200() {
use tower::ServiceExt;
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from("hello world"))
.unwrap(),
)
}));
let req = Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap();
let response = svc.oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(
response.headers().contains_key(ETAG),
"EtagLayer must inject ETag header"
);
}
#[tokio::test]
async fn etag_layer_returns_304_on_matching_if_none_match() {
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from("hello world"))
.unwrap(),
)
}));
let first_req = Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap();
let first_response = svc.clone().oneshot(first_req).await.unwrap();
let etag = first_response.headers().get(ETAG).unwrap().clone();
let second_req = Request::builder()
.method(Method::GET)
.uri("/")
.header(IF_NONE_MATCH, etag)
.body(Body::empty())
.unwrap();
let second_response = svc.oneshot(second_req).await.unwrap();
assert_eq!(second_response.status(), StatusCode::NOT_MODIFIED);
}
#[tokio::test]
async fn etag_layer_does_not_add_etag_to_post() {
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from("ok"))
.unwrap(),
)
}));
let req = Request::builder()
.method(Method::POST)
.uri("/")
.body(Body::empty())
.unwrap();
let response = svc.oneshot(req).await.unwrap();
assert!(!response.headers().contains_key(ETAG));
}
#[tokio::test]
async fn etag_layer_does_not_override_existing_etag() {
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(ETAG, r#""handler-set""#)
.body(Body::from("body"))
.unwrap(),
)
}));
let req = Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap();
let response = svc.oneshot(req).await.unwrap();
assert_eq!(
response.headers().get(ETAG).unwrap().to_str().unwrap(),
r#""handler-set""#
);
}
#[tokio::test]
async fn etag_layer_preserves_cache_control_on_304() {
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(CACHE_CONTROL, "max-age=60")
.body(Body::from("stable content"))
.unwrap(),
)
}));
let first_req = Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap();
let first = svc.clone().oneshot(first_req).await.unwrap();
let etag = first.headers().get(ETAG).unwrap().clone();
let second_req = Request::builder()
.method(Method::GET)
.uri("/")
.header(IF_NONE_MATCH, etag)
.body(Body::empty())
.unwrap();
let second = svc.oneshot(second_req).await.unwrap();
assert_eq!(second.status(), StatusCode::NOT_MODIFIED);
assert_eq!(
second
.headers()
.get(CACHE_CONTROL)
.unwrap()
.to_str()
.unwrap(),
"max-age=60"
);
}
#[tokio::test]
async fn etag_layer_strips_set_cookie_from_304() {
let svc = EtagLayer::new().layer(tower::service_fn(|_req: Request<Body>| async {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(SET_COOKIE, "session=abc; HttpOnly")
.body(Body::from("content"))
.unwrap(),
)
}));
let first = svc
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let etag = first.headers().get(ETAG).unwrap().clone();
let second = svc
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/")
.header(IF_NONE_MATCH, etag)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(second.status(), StatusCode::NOT_MODIFIED);
assert!(
!second.headers().contains_key(SET_COOKIE),
"Set-Cookie must be stripped from 304"
);
}
#[test]
fn etag_derivation_is_deterministic_no_rng_or_clock() {
let e1: ETag = (42_i64).into_etag();
let e2: ETag = (42_i64).into_etag();
let e3: ETag = (42_i64).into_etag();
assert_eq!(e1, e2);
assert_eq!(e2, e3);
}
#[test]
fn build_not_modified_preserves_cache_control_and_vary() {
let mut orig = HeaderMap::new();
orig.insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
orig.insert(VARY, HeaderValue::from_static("Accept"));
orig.insert(SET_COOKIE, HeaderValue::from_static("tok=x"));
let etag = ETag::strong("tag");
let response = build_not_modified(&orig, &etag, None);
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert_eq!(
response
.headers()
.get(CACHE_CONTROL)
.unwrap()
.to_str()
.unwrap(),
"no-cache"
);
assert_eq!(
response.headers().get(VARY).unwrap().to_str().unwrap(),
"Accept"
);
assert!(!response.headers().contains_key(SET_COOKIE));
}
#[tokio::test]
async fn integration_first_get_200_second_get_304() {
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
let lock_version = Arc::new(AtomicI64::new(1));
let lv = Arc::clone(&lock_version);
let svc = EtagLayer::new().layer(tower::service_fn(move |_req: Request<Body>| {
let v = lv.load(Ordering::SeqCst);
async move {
let etag: ETag = v.into_etag();
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(ETAG, etag.header_value())
.body(Body::from(format!("version={v}")))
.unwrap(),
)
}
}));
let first = svc
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/resource")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(first.status(), StatusCode::OK);
let etag = first.headers().get(ETAG).cloned().unwrap();
let second = svc
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/resource")
.header(IF_NONE_MATCH, etag.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(second.status(), StatusCode::NOT_MODIFIED);
let body_bytes = second.into_body().collect().await.unwrap().to_bytes();
assert!(body_bytes.is_empty(), "304 body must be empty");
lock_version.store(2, Ordering::SeqCst);
let third = svc
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/resource")
.header(IF_NONE_MATCH, etag)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(third.status(), StatusCode::OK);
let new_etag = third.headers().get(ETAG).unwrap();
let old_etag: ETag = 1_i64.into_etag();
assert_ne!(new_etag, &old_etag.header_value());
}
}