use actix_web::{
Error, HttpResponse,
body::{BoxBody, MessageBody, to_bytes},
dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
http::{Method, StatusCode, header},
web::Bytes,
};
use futures_util::future::{LocalBoxFuture, Ready, ok};
use std::rc::Rc;
use base64::Engine;
use xxhash_rust::xxh3::xxh3_128;
#[derive(Clone, Copy)]
pub struct ETag {
strength: Strength,
}
#[derive(Clone, Copy)]
enum Strength {
Strong,
Weak,
}
impl ETag {
pub const fn new() -> Self {
Self::strong()
}
pub const fn strong() -> Self {
Self {
strength: Strength::Strong,
}
}
pub const fn weak() -> Self {
Self {
strength: Strength::Weak,
}
}
}
impl Default for ETag {
fn default() -> Self {
Self::strong()
}
}
impl<S, B> Transform<S, ServiceRequest> for ETag
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
B: MessageBody + 'static,
B::Error: Into<Error>,
{
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type InitError = ();
type Transform = ETagMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(ETagMiddleware {
service: Rc::new(service),
strength: self.strength,
})
}
}
pub struct ETagMiddleware<S> {
service: Rc<S>,
strength: Strength,
}
impl<S, B> Service<ServiceRequest> for ETagMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
B: MessageBody + 'static,
B::Error: Into<Error>,
{
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let srv = Rc::clone(&self.service);
let strength = self.strength;
Box::pin(async move {
let res = srv.call(req).await?;
let (req, res) = res.into_parts();
let (mut head, body) = res.into_parts();
let body_bytes = to_bytes(body).await.map_err(Into::into)?;
let etag_value = extract_or_compute_etag(&mut head, &body_bytes, strength);
if let Some(precondition) = evaluate_conditionals(&req, &etag_value) {
return Ok(ServiceResponse::new(req, precondition));
}
let response = head.set_body(body_bytes).map_body(|_, body| body.boxed());
Ok(ServiceResponse::new(req, response))
})
}
}
fn extract_or_compute_etag(
head: &mut HttpResponse<()>,
body: &Bytes,
strength: Strength,
) -> String {
if let Some(value) = head
.headers()
.get(header::ETAG)
.and_then(|value| value.to_str().ok())
{
return value.trim().to_string();
}
let value = build_entity_tag(body, strength);
if let Ok(header_value) = header::HeaderValue::from_str(&value) {
head.headers_mut().insert(header::ETAG, header_value);
}
value
}
fn evaluate_conditionals(req: &actix_web::HttpRequest, etag: &str) -> Option<HttpResponse> {
if let Some(if_match) = req
.headers()
.get(header::IF_MATCH)
.and_then(|h| h.to_str().ok())
{
if !match_if_match(etag, if_match) {
return Some(
HttpResponse::build(StatusCode::PRECONDITION_FAILED)
.insert_header((header::ETAG, etag.to_string()))
.finish(),
);
}
}
if let Some(if_none_match) = req
.headers()
.get(header::IF_NONE_MATCH)
.and_then(|h| h.to_str().ok())
{
if match_if_none_match(etag, if_none_match) {
let status = match *req.method() {
Method::GET | Method::HEAD => StatusCode::NOT_MODIFIED,
_ => StatusCode::PRECONDITION_FAILED,
};
return Some(
HttpResponse::build(status)
.insert_header((header::ETAG, etag.to_string()))
.finish(),
);
}
}
None
}
fn match_if_match(etag: &str, header_value: &str) -> bool {
header_value
.split(',')
.map(|value| value.trim())
.any(|value| value == "*" || strong_compare(value, etag))
}
fn match_if_none_match(etag: &str, header_value: &str) -> bool {
let etag_core = strip_weak_prefix(etag);
header_value
.split(',')
.map(|value| value.trim())
.any(|value| {
if value == "*" {
return true;
}
strip_weak_prefix(value) == etag_core
})
}
fn build_entity_tag(body: &Bytes, strength: Strength) -> String {
let response_hash = xxh3_128(body);
let base64 = base64::prelude::BASE64_URL_SAFE.encode(response_hash.to_le_bytes());
match strength {
Strength::Strong => format!("\"{:x}-{}\"", base64.len(), base64),
Strength::Weak => format!("W/\"{:x}-{}\"", base64.len(), base64),
}
}
fn strong_compare(left: &str, right: &str) -> bool {
!is_weak(left) && !is_weak(right) && left == right
}
fn strip_weak_prefix(value: &str) -> &str {
value.strip_prefix("W/").unwrap_or(value)
}
fn is_weak(value: &str) -> bool {
value.starts_with("W/")
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::{
App, HttpResponse,
dev::ServiceResponse,
http::header,
test::{TestRequest, call_service, init_service},
web,
};
fn expected_etag(payload: &[u8], strength: Strength) -> String {
let bytes = Bytes::copy_from_slice(payload);
build_entity_tag(&bytes, strength)
}
#[actix_web::test]
async fn sets_etag_header_when_missing() {
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
let value = response.headers().get(header::ETAG).unwrap();
assert_eq!(
value.to_str().unwrap(),
expected_etag(b"hello", Strength::Strong)
);
}
#[actix_web::test]
async fn returns_not_modified_for_matching_if_none_match() {
let etag = expected_etag(b"hello", Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_NONE_MATCH, etag.clone()))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
etag
);
}
#[actix_web::test]
async fn returns_precondition_failed_for_non_matching_if_match() {
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_MATCH, "\"deadbeef\""))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::PRECONDITION_FAILED);
}
#[actix_web::test]
async fn allows_if_match_when_strong_tag_matches() {
let body = b"hello";
let expected = expected_etag(body, Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_MATCH, expected.clone()))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
expected
);
}
#[actix_web::test]
async fn sets_weak_etag_header_when_configured() {
let app = init_service(App::new().wrap(ETag::weak()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
expected_etag(b"hello", Strength::Weak)
);
}
#[actix_web::test]
async fn weak_etag_triggers_not_modified_with_strong_if_none_match() {
let etag = expected_etag(b"hello", Strength::Weak);
let app = init_service(App::new().wrap(ETag::weak()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_NONE_MATCH, etag.trim_start_matches("W/")))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
etag
);
}
#[actix_web::test]
async fn weak_etag_fails_if_match_even_when_value_matches() {
let etag = expected_etag(b"hello", Strength::Weak);
let app = init_service(App::new().wrap(ETag::weak()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_MATCH, etag.clone()))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::PRECONDITION_FAILED);
}
#[actix_web::test]
async fn new_method_creates_strong_etag_middleware() {
let app = init_service(App::new().wrap(ETag::new()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
let value = response.headers().get(header::ETAG).unwrap();
assert_eq!(
value.to_str().unwrap(),
expected_etag(b"hello", Strength::Strong)
);
assert!(!value.to_str().unwrap().starts_with("W/"));
}
#[actix_web::test]
async fn default_trait_creates_strong_etag_middleware() {
let app = init_service(App::new().wrap(ETag::default()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
let value = response.headers().get(header::ETAG).unwrap();
assert_eq!(
value.to_str().unwrap(),
expected_etag(b"hello", Strength::Strong)
);
assert!(!value.to_str().unwrap().starts_with("W/"));
}
#[actix_web::test]
async fn if_none_match_with_post_returns_precondition_failed() {
let etag = expected_etag(b"hello", Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::post().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::post()
.uri("/")
.insert_header((header::IF_NONE_MATCH, etag.clone()))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::PRECONDITION_FAILED);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
etag
);
}
#[actix_web::test]
async fn if_none_match_with_put_returns_precondition_failed() {
let etag = expected_etag(b"hello", Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::put().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::put()
.uri("/")
.insert_header((header::IF_NONE_MATCH, etag.clone()))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::PRECONDITION_FAILED);
}
#[actix_web::test]
async fn if_none_match_with_wildcard_matches_any_etag() {
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_NONE_MATCH, "*"))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
assert!(response.headers().contains_key(header::ETAG));
}
#[actix_web::test]
async fn if_match_with_wildcard_matches_any_etag() {
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_MATCH, "*"))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(header::ETAG));
}
#[actix_web::test]
async fn preserves_handler_set_etag() {
let custom_etag = "\"custom-etag-12345\"";
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async {
HttpResponse::Ok()
.insert_header((header::ETAG, "\"custom-etag-12345\""))
.body("hello")
}),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::ETAG)
.unwrap()
.to_str()
.unwrap(),
custom_etag
);
}
#[actix_web::test]
async fn if_none_match_matches_handler_set_etag() {
let custom_etag = "\"custom-etag-12345\"";
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async {
HttpResponse::Ok()
.insert_header((header::ETAG, "\"custom-etag-12345\""))
.body("hello")
}),
))
.await;
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_NONE_MATCH, custom_etag))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
}
#[actix_web::test]
async fn multiple_etags_in_if_none_match() {
let etag = expected_etag(b"hello", Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let multiple_etags = format!("\"other-etag\", {}, \"another-etag\"", etag);
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_NONE_MATCH, multiple_etags))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
}
#[actix_web::test]
async fn multiple_etags_in_if_match_with_match() {
let etag = expected_etag(b"hello", Strength::Strong);
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let multiple_etags = format!("\"other-etag\", {}, \"another-etag\"", etag);
let request = TestRequest::get()
.uri("/")
.insert_header((header::IF_MATCH, multiple_etags))
.to_request();
let response: ServiceResponse = call_service(&app, request).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[actix_web::test]
async fn response_without_precondition_headers_passes_through() {
let app = init_service(App::new().wrap(ETag::strong()).route(
"/",
web::get().to(|| async { HttpResponse::Ok().body("hello") }),
))
.await;
let response: ServiceResponse =
call_service(&app, TestRequest::get().uri("/").to_request()).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(header::ETAG));
}
}