use async_trait::async_trait;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use hyper::header::{
ETAG, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, LAST_MODIFIED,
};
use hyper::{Method, StatusCode};
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use sha2::{Digest, Sha256};
use std::sync::Arc;
pub struct ConditionalGetMiddleware {
generate_etag: bool,
}
impl ConditionalGetMiddleware {
pub fn new() -> Self {
Self {
generate_etag: true,
}
}
pub fn without_etag() -> Self {
Self {
generate_etag: false,
}
}
fn generate_etag_from_body(&self, body: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(body);
let result = hasher.finalize();
format!("\"{}\"", hex::encode(&result[..16]))
}
fn parse_if_none_match(&self, value: &str) -> Vec<String> {
value.split(',').map(|s| s.trim().to_string()).collect()
}
fn etag_matches(&self, etag: &str, if_none_match: &[String]) -> bool {
if_none_match
.iter()
.any(|inm| inm == "*" || inm == etag || inm.trim_matches('"') == etag.trim_matches('"'))
}
fn parse_http_date(&self, value: &str) -> Option<DateTime<Utc>> {
httpdate::parse_http_date(value).ok().map(DateTime::from)
}
}
impl Default for ConditionalGetMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for ConditionalGetMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let if_none_match = request.headers.get(IF_NONE_MATCH).cloned();
let if_modified_since = request.headers.get(IF_MODIFIED_SINCE).cloned();
let if_match = request.headers.get(IF_MATCH).cloned();
let if_unmodified_since = request.headers.get(IF_UNMODIFIED_SINCE).cloned();
let method = request.method.clone();
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if method != Method::GET && method != Method::HEAD {
return Ok(response);
}
if !response.status.is_success() {
return Ok(response);
}
let etag = if self.generate_etag && !response.headers.contains_key(ETAG) {
let generated = self.generate_etag_from_body(&response.body);
if let Ok(etag_value) = generated.parse() {
response.headers.insert(ETAG, etag_value);
Some(generated)
} else {
None
}
} else {
response
.headers
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
};
let last_modified = response
.headers
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.and_then(|s| self.parse_http_date(s));
if let Some(if_none_match) = if_none_match
&& let (Ok(inm_str), Some(etag_value)) = (if_none_match.to_str(), etag.as_ref())
{
let inm_list = self.parse_if_none_match(inm_str);
if self.etag_matches(etag_value, &inm_list) {
let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
if let Some(etag_header) = response.headers.get(ETAG) {
not_modified.headers.insert(ETAG, etag_header.clone());
}
if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
not_modified
.headers
.insert(LAST_MODIFIED, lm_header.clone());
}
return Ok(not_modified);
}
}
if let Some(if_modified_since) = if_modified_since
&& let (Ok(ims_str), Some(lm)) = (if_modified_since.to_str(), last_modified)
&& let Some(ims) = self.parse_http_date(ims_str)
{
if lm <= ims {
let mut not_modified = Response::new(StatusCode::NOT_MODIFIED);
if let Some(etag_header) = response.headers.get(ETAG) {
not_modified.headers.insert(ETAG, etag_header.clone());
}
if let Some(lm_header) = response.headers.get(LAST_MODIFIED) {
not_modified
.headers
.insert(LAST_MODIFIED, lm_header.clone());
}
return Ok(not_modified);
}
}
if let Some(if_match) = if_match
&& let (Ok(im_str), Some(etag_value)) = (if_match.to_str(), etag.as_ref())
{
let im_list = self.parse_if_none_match(im_str);
if !self.etag_matches(etag_value, &im_list) && !im_list.contains(&"*".to_string()) {
return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
.with_body(Bytes::from(&b"Precondition Failed"[..])));
}
}
if let Some(if_unmodified_since) = if_unmodified_since
&& let (Ok(ius_str), Some(lm)) = (if_unmodified_since.to_str(), last_modified)
&& let Some(ius) = self.parse_http_date(ius_str)
{
if lm > ius {
return Ok(Response::new(StatusCode::PRECONDITION_FAILED)
.with_body(Bytes::from(&b"Precondition Failed"[..])));
}
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::{HeaderMap, Version};
struct TestHandler {
body: &'static str,
with_etag: Option<String>,
with_last_modified: Option<DateTime<Utc>>,
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let mut response = Response::new(StatusCode::OK).with_body(self.body.as_bytes());
if let Some(ref etag) = self.with_etag {
response.headers.insert(ETAG, etag.parse().unwrap());
}
if let Some(lm) = self.with_last_modified {
let lm_str = httpdate::fmt_http_date(lm.into());
response
.headers
.insert(LAST_MODIFIED, lm_str.parse().unwrap());
}
Ok(response)
}
}
#[tokio::test]
async fn test_generates_etag() {
let middleware = ConditionalGetMiddleware::new();
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: None,
with_last_modified: None,
});
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(ETAG));
}
#[tokio::test]
async fn test_if_none_match_returns_304() {
let middleware = ConditionalGetMiddleware::new();
let etag = "\"abc123\"";
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: Some(etag.to_string()),
with_last_modified: None,
});
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, etag.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_MODIFIED);
assert_eq!(response.body.len(), 0);
}
#[tokio::test]
async fn test_if_modified_since_returns_304() {
let middleware = ConditionalGetMiddleware::new();
let last_modified = Utc::now() - chrono::Duration::days(1);
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: None,
with_last_modified: Some(last_modified),
});
let mut headers = HeaderMap::new();
let ims_str = httpdate::fmt_http_date((last_modified + chrono::Duration::hours(1)).into());
headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_MODIFIED);
}
#[tokio::test]
async fn test_if_match_fails_returns_412() {
let middleware = ConditionalGetMiddleware::new();
let etag = "\"abc123\"";
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: Some(etag.to_string()),
with_last_modified: None,
});
let mut headers = HeaderMap::new();
headers.insert(IF_MATCH, "\"xyz789\"".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::PRECONDITION_FAILED);
}
#[tokio::test]
async fn test_middleware_wont_overwrite_etag() {
let middleware = ConditionalGetMiddleware::new();
let custom_etag = "\"custom-etag\"";
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: Some(custom_etag.to_string()),
with_last_modified: None,
});
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(
response.headers.get(ETAG).unwrap().to_str().unwrap(),
custom_etag
);
}
#[tokio::test]
async fn test_if_none_match_and_different_etag() {
let middleware = ConditionalGetMiddleware::new();
let etag = "\"abc123\"";
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: Some(etag.to_string()),
with_last_modified: None,
});
let mut headers = HeaderMap::new();
headers.insert(IF_NONE_MATCH, "\"different-etag\"".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[tokio::test]
async fn test_if_modified_since_and_last_modified_in_the_future() {
let middleware = ConditionalGetMiddleware::new();
let last_modified = Utc::now();
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: None,
with_last_modified: Some(last_modified),
});
let mut headers = HeaderMap::new();
let ims_str = httpdate::fmt_http_date((last_modified - chrono::Duration::hours(1)).into());
headers.insert(IF_MODIFIED_SINCE, ims_str.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[tokio::test]
async fn test_no_etag_on_post_request() {
let middleware = ConditionalGetMiddleware::new();
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: None,
with_last_modified: None,
});
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(ETAG));
}
#[tokio::test]
async fn test_without_etag_generation() {
let middleware = ConditionalGetMiddleware::without_etag();
let handler = Arc::new(TestHandler {
body: "test response",
with_etag: None,
with_last_modified: None,
});
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(ETAG));
}
}