use crate::middleware::v2::{Middleware, Next, NextFuture};
use crate::request::{ElifMethod, ElifRequest};
use crate::response::{ElifHeaderValue, ElifResponse};
use axum::http::{HeaderMap, HeaderName, HeaderValue};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, PartialEq)]
pub enum ETagType {
Strong(String),
Weak(String),
}
impl ETagType {
pub fn from_header_value(value: &str) -> Option<Self> {
let value = value.trim();
if value.starts_with("W/") {
if value.len() > 3 && value.starts_with("W/\"") && value.ends_with('"') {
let etag_value = &value[3..value.len() - 1];
Some(Self::Weak(etag_value.to_string()))
} else {
None
}
} else if value.starts_with('"') && value.ends_with('"') {
let etag_value = &value[1..value.len() - 1];
Some(Self::Strong(etag_value.to_string()))
} else {
None
}
}
pub fn to_header_value(&self) -> String {
match self {
Self::Strong(value) => format!("\"{}\"", value),
Self::Weak(value) => format!("W/\"{}\"", value),
}
}
pub fn value(&self) -> &str {
match self {
Self::Strong(value) | Self::Weak(value) => value,
}
}
pub fn matches_for_if_none_match(&self, other: &Self) -> bool {
self.value() == other.value()
}
pub fn matches_for_if_match(&self, other: &Self) -> bool {
match (self, other) {
(Self::Strong(a), Self::Strong(b)) => a == b,
_ => false, }
}
}
#[derive(Debug, Clone)]
pub enum ETagStrategy {
BodyHash,
WeakBodyHash,
Custom(fn(&[u8], &HeaderMap) -> Option<ETagType>),
}
impl Default for ETagStrategy {
fn default() -> Self {
Self::BodyHash
}
}
#[derive(Debug, Clone)]
pub struct ETagConfig {
pub strategy: ETagStrategy,
pub min_size: usize,
pub max_size: usize,
pub content_types: Vec<String>,
}
impl Default for ETagConfig {
fn default() -> Self {
Self {
strategy: ETagStrategy::default(),
min_size: 0,
max_size: 10 * 1024 * 1024, content_types: vec![
"text/html".to_string(),
"text/css".to_string(),
"text/javascript".to_string(),
"text/plain".to_string(),
"application/json".to_string(),
"application/javascript".to_string(),
"application/xml".to_string(),
"text/xml".to_string(),
"image/svg+xml".to_string(),
],
}
}
}
#[derive(Debug)]
pub struct ETagMiddleware {
config: ETagConfig,
}
impl ETagMiddleware {
pub fn new() -> Self {
Self {
config: ETagConfig::default(),
}
}
pub fn with_config(config: ETagConfig) -> Self {
Self { config }
}
pub fn strategy(mut self, strategy: ETagStrategy) -> Self {
self.config.strategy = strategy;
self
}
pub fn min_size(mut self, min_size: usize) -> Self {
self.config.min_size = min_size;
self
}
pub fn max_size(mut self, max_size: usize) -> Self {
self.config.max_size = max_size;
self
}
pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
self.config.content_types.push(content_type.into());
self
}
pub fn weak(mut self) -> Self {
self.config.strategy = ETagStrategy::WeakBodyHash;
self
}
fn should_generate_etag(&self, headers: &HeaderMap, body_size: usize) -> bool {
if body_size < self.config.min_size || body_size > self.config.max_size {
return false;
}
if headers.contains_key("etag") {
return false;
}
if let Some(content_type) = headers.get("content-type") {
if let Ok(content_type_str) = content_type.to_str() {
let content_type_lower = content_type_str.to_lowercase();
return self
.config
.content_types
.iter()
.any(|ct| content_type_lower.starts_with(&ct.to_lowercase()));
}
}
true
}
fn generate_etag(&self, body: &[u8], headers: &HeaderMap) -> Option<ETagType> {
match &self.config.strategy {
ETagStrategy::BodyHash => {
let mut hasher = DefaultHasher::new();
body.hash(&mut hasher);
for (name, value) in headers.iter() {
name.as_str().hash(&mut hasher);
if let Ok(value_str) = value.to_str() {
value_str.hash(&mut hasher);
}
}
let hash = hasher.finish();
Some(ETagType::Strong(format!("{:x}", hash)))
}
ETagStrategy::WeakBodyHash => {
let mut hasher = DefaultHasher::new();
body.hash(&mut hasher);
let hash = hasher.finish();
Some(ETagType::Weak(format!("{:x}", hash)))
}
ETagStrategy::Custom(func) => func(body, headers),
}
}
fn parse_if_none_match(&self, header_value: &str) -> Vec<ETagType> {
let mut etags = Vec::new();
if header_value.trim() == "*" {
return etags; }
for etag_str in header_value.split(',') {
if let Some(etag) = ETagType::from_header_value(etag_str) {
etags.push(etag);
}
}
etags
}
fn parse_if_match(&self, header_value: &str) -> Vec<ETagType> {
let mut etags = Vec::new();
if header_value.trim() == "*" {
return etags; }
for etag_str in header_value.split(',') {
if let Some(etag) = ETagType::from_header_value(etag_str) {
etags.push(etag);
}
}
etags
}
fn check_if_none_match(&self, request_etags: &[ETagType], response_etag: &ETagType) -> bool {
if request_etags.is_empty() {
return true; }
!request_etags
.iter()
.any(|req_etag| response_etag.matches_for_if_none_match(req_etag))
}
fn check_if_match(&self, request_etags: &[ETagType], response_etag: &ETagType) -> bool {
if request_etags.is_empty() {
return true; }
request_etags
.iter()
.any(|req_etag| response_etag.matches_for_if_match(req_etag))
}
async fn process_response_with_headers(
&self,
response: ElifResponse,
if_none_match: Option<ElifHeaderValue>,
if_match: Option<ElifHeaderValue>,
request_method: ElifMethod,
) -> ElifResponse {
let axum_if_none_match = if_none_match.as_ref().map(|v| v.to_axum());
let axum_if_match = if_match.as_ref().map(|v| v.to_axum());
let axum_method = request_method.to_axum();
let axum_response = response.into_axum_response();
let (parts, body) = axum_response.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => {
let response =
axum::response::Response::from_parts(parts, axum::body::Body::empty());
return ElifResponse::from_axum_response(response).await;
}
};
if !self.should_generate_etag(&parts.headers, body_bytes.len()) {
let response =
axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes));
return ElifResponse::from_axum_response(response).await;
}
let etag = match self.generate_etag(&body_bytes, &parts.headers) {
Some(etag) => etag,
None => {
let response =
axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes));
return ElifResponse::from_axum_response(response).await;
}
};
if let Some(if_none_match) = axum_if_none_match {
if let Ok(if_none_match_str) = if_none_match.to_str() {
let request_etags = self.parse_if_none_match(if_none_match_str);
if if_none_match_str.trim() == "*" {
return if axum_method == axum::http::Method::GET
|| axum_method == axum::http::Method::HEAD
{
ElifResponse::from_axum_response(
axum::response::Response::builder()
.status(axum::http::StatusCode::NOT_MODIFIED)
.header("etag", etag.to_header_value())
.body(axum::body::Body::empty())
.unwrap(),
)
.await
} else {
ElifResponse::from_axum_response(
axum::response::Response::builder()
.status(axum::http::StatusCode::PRECONDITION_FAILED)
.header("etag", etag.to_header_value())
.body(axum::body::Body::from(
serde_json::to_vec(&serde_json::json!({
"error": {
"code": "precondition_failed",
"message": "If-None-Match: * failed - resource exists"
}
}))
.unwrap_or_default(),
))
.unwrap(),
)
.await
};
}
if !self.check_if_none_match(&request_etags, &etag) {
return if axum_method == axum::http::Method::GET
|| axum_method == axum::http::Method::HEAD
{
ElifResponse::from_axum_response(
axum::response::Response::builder()
.status(axum::http::StatusCode::NOT_MODIFIED)
.header("etag", etag.to_header_value())
.body(axum::body::Body::empty())
.unwrap(),
)
.await
} else {
ElifResponse::from_axum_response(
axum::response::Response::builder()
.status(axum::http::StatusCode::PRECONDITION_FAILED)
.header("etag", etag.to_header_value())
.body(axum::body::Body::from(
serde_json::to_vec(&serde_json::json!({
"error": {
"code": "precondition_failed",
"message": "If-None-Match precondition failed - resource unchanged"
}
})).unwrap_or_default()
))
.unwrap()
).await
};
}
}
}
if let Some(if_match) = axum_if_match {
if let Ok(if_match_str) = if_match.to_str() {
let request_etags = self.parse_if_match(if_match_str);
if if_match_str.trim() == "*" {
} else if !self.check_if_match(&request_etags, &etag) {
return ElifResponse::from_axum_response(
axum::response::Response::builder()
.status(axum::http::StatusCode::PRECONDITION_FAILED)
.header("etag", etag.to_header_value())
.body(axum::body::Body::from(
serde_json::to_vec(&serde_json::json!({
"error": {
"code": "precondition_failed",
"message": "Request ETag does not match current resource ETag"
}
})).unwrap_or_default()
))
.unwrap()
).await;
}
}
}
let mut new_parts = parts;
new_parts.headers.insert(
HeaderName::from_static("etag"),
HeaderValue::from_str(&etag.to_header_value()).unwrap(),
);
if !new_parts.headers.contains_key("cache-control") {
new_parts.headers.insert(
HeaderName::from_static("cache-control"),
HeaderValue::from_static("private, max-age=0"),
);
}
let response =
axum::response::Response::from_parts(new_parts, axum::body::Body::from(body_bytes));
ElifResponse::from_axum_response(response).await
}
}
impl Default for ETagMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for ETagMiddleware {
fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
let config = self.config.clone();
Box::pin(async move {
let if_none_match = request.header("if-none-match").cloned();
let if_match = request.header("if-match").cloned();
let method = request.method.clone();
let response = next.run(request).await;
let middleware = ETagMiddleware { config };
middleware
.process_response_with_headers(response, if_none_match, if_match, method)
.await
})
}
fn name(&self) -> &'static str {
"ETagMiddleware"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::request::ElifRequest;
use crate::response::ElifResponse;
#[test]
fn test_etag_parsing() {
let etag = ETagType::from_header_value("\"abc123\"").unwrap();
assert_eq!(etag, ETagType::Strong("abc123".to_string()));
assert_eq!(etag.to_header_value(), "\"abc123\"");
let etag = ETagType::from_header_value("W/\"abc123\"").unwrap();
assert_eq!(etag, ETagType::Weak("abc123".to_string()));
assert_eq!(etag.to_header_value(), "W/\"abc123\"");
assert!(ETagType::from_header_value("invalid").is_none());
assert!(ETagType::from_header_value("\"unclosed").is_none());
}
#[test]
fn test_etag_matching() {
let strong1 = ETagType::Strong("abc123".to_string());
let strong2 = ETagType::Strong("abc123".to_string());
let strong3 = ETagType::Strong("def456".to_string());
let weak1 = ETagType::Weak("abc123".to_string());
assert!(strong1.matches_for_if_none_match(&strong2));
assert!(strong1.matches_for_if_none_match(&weak1));
assert!(!strong1.matches_for_if_none_match(&strong3));
assert!(strong1.matches_for_if_match(&strong2));
assert!(!strong1.matches_for_if_match(&weak1));
assert!(!strong1.matches_for_if_match(&strong3));
}
#[test]
fn test_etag_config() {
let config = ETagConfig::default();
assert_eq!(config.min_size, 0);
assert_eq!(config.max_size, 10 * 1024 * 1024);
assert!(config
.content_types
.contains(&"application/json".to_string()));
}
#[test]
fn test_should_generate_etag() {
let middleware = ETagMiddleware::new();
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
assert!(middleware.should_generate_etag(&headers, 1024));
headers.insert("etag", "\"existing\"".parse().unwrap());
assert!(!middleware.should_generate_etag(&headers, 1024));
let mut headers = HeaderMap::new();
headers.insert("content-type", "image/jpeg".parse().unwrap());
assert!(!middleware.should_generate_etag(&headers, 1024));
}
#[tokio::test]
async fn test_etag_generation() {
let middleware = ETagMiddleware::new();
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
crate::response::headers::ElifHeaderMap::new(),
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Hello, World!"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::OK
);
let axum_response = response.into_axum_response();
let (parts, _) = axum_response.into_parts();
assert!(parts.headers.contains_key("etag"));
}
#[tokio::test]
async fn test_if_none_match_304() {
let middleware = ETagMiddleware::new();
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
crate::response::headers::ElifHeaderMap::new(),
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Hello, World!"
}))
})
});
let response = middleware.handle(request, next).await;
let axum_response = response.into_axum_response();
let (parts, _) = axum_response.into_parts();
let etag_header = parts.headers.get("etag").unwrap();
let etag_value = etag_header.to_str().unwrap();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let header_name =
crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
let header_value = crate::response::headers::ElifHeaderValue::from_str(etag_value).unwrap();
headers.insert(header_name, header_value);
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Hello, World!"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::NOT_MODIFIED
);
}
#[tokio::test]
async fn test_if_match_412() {
let middleware = ETagMiddleware::new();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let header_name = crate::response::headers::ElifHeaderName::from_str("if-match").unwrap();
let header_value =
crate::response::headers::ElifHeaderValue::from_str("\"non-matching-etag\"").unwrap();
headers.insert(header_name, header_value);
let request = ElifRequest::new(
crate::request::ElifMethod::PUT,
"/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Updated!"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::PRECONDITION_FAILED
);
}
#[tokio::test]
async fn test_if_none_match_star_put_request() {
let middleware = ETagMiddleware::new();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let header_name =
crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
let header_value = crate::response::headers::ElifHeaderValue::from_str("*").unwrap();
headers.insert(header_name, header_value);
let request = ElifRequest::new(
crate::request::ElifMethod::PUT, "/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Created!"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::PRECONDITION_FAILED
);
}
#[tokio::test]
async fn test_if_none_match_star_get_request() {
let middleware = ETagMiddleware::new();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let header_name =
crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
let header_value = crate::response::headers::ElifHeaderValue::from_str("*").unwrap();
headers.insert(header_name, header_value);
let request = ElifRequest::new(
crate::request::ElifMethod::GET, "/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Data"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::NOT_MODIFIED
);
}
#[tokio::test]
async fn test_if_none_match_etag_put_request() {
let middleware = ETagMiddleware::new();
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
crate::response::headers::ElifHeaderMap::new(),
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Data"
}))
})
});
let response = middleware.handle(request, next).await;
let axum_response = response.into_axum_response();
let (parts, _) = axum_response.into_parts();
let etag_header = parts.headers.get("etag").unwrap();
let etag_value = etag_header.to_str().unwrap();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let header_name =
crate::response::headers::ElifHeaderName::from_str("if-none-match").unwrap();
let header_value = crate::response::headers::ElifHeaderValue::from_str(etag_value).unwrap();
headers.insert(header_name, header_value);
let request = ElifRequest::new(
crate::request::ElifMethod::PUT,
"/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Data"
}))
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::PRECONDITION_FAILED
);
}
#[tokio::test]
async fn test_weak_etag_strategy() {
let middleware = ETagMiddleware::new().weak();
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
crate::response::headers::ElifHeaderMap::new(),
);
let next = Next::new(|_req| {
Box::pin(async move {
ElifResponse::ok().json_value(serde_json::json!({
"message": "Hello, World!"
}))
})
});
let response = middleware.handle(request, next).await;
let axum_response = response.into_axum_response();
let (parts, _) = axum_response.into_parts();
let etag_header = parts.headers.get("etag").unwrap();
let etag_value = etag_header.to_str().unwrap();
assert!(etag_value.starts_with("W/"));
}
#[test]
fn test_etag_middleware_builder() {
let middleware = ETagMiddleware::new()
.min_size(1024)
.max_size(5 * 1024 * 1024)
.content_type("application/xml")
.weak();
assert_eq!(middleware.config.min_size, 1024);
assert_eq!(middleware.config.max_size, 5 * 1024 * 1024);
assert!(middleware
.config
.content_types
.contains(&"application/xml".to_string()));
assert!(matches!(
middleware.config.strategy,
ETagStrategy::WeakBodyHash
));
}
}