#![deny(trivial_casts, trivial_numeric_casts, unused_extern_crates, unused_qualifications)]
#![warn(
missing_debug_implementations,
missing_docs,
unused_import_braces,
dead_code,
clippy::unwrap_used,
clippy::expect_used,
clippy::missing_docs_in_private_items
)]
use std::{sync::Arc, time::SystemTime};
use bytes::Bytes;
use chashmap_async::CHashMap;
pub use http_cache_semantics::CacheOptions;
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
use reqwest::Url;
use reqwest_middleware::Middleware;
#[derive(Debug)]
struct CacheEntry {
policy: CachePolicy,
response: Bytes,
}
impl CacheEntry {
pub fn new(policy: CachePolicy, response: Bytes) -> Self {
Self { policy, response }
}
}
#[derive(Default)]
pub struct CacheMiddleware {
cache: Arc<CHashMap<Url, CacheEntry>>,
options: CacheOptions,
}
impl CacheMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn with_options(options: CacheOptions) -> Self {
Self { cache: Arc::new(CHashMap::new()), options }
}
}
impl std::fmt::Debug for CacheMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CacheMiddleware")
.field("cache", &format!("<{} entries>", self.cache.len()))
.field("options", &self.options)
.finish()
}
}
#[async_trait::async_trait]
impl Middleware for CacheMiddleware {
async fn handle(
&self,
mut req: reqwest::Request,
extensions: &mut task_local_extensions::Extensions,
next: reqwest_middleware::Next<'_>,
) -> reqwest_middleware::Result<reqwest::Response> {
let mut url = req.url().clone();
url.set_fragment(None);
if let Some(mut cache) = self.cache.get_mut(&url).await {
let before = cache.policy.before_request(&req, SystemTime::now());
match before {
BeforeRequest::Fresh(parts) => {
let response = http::Response::from_parts(parts, cache.response.clone());
return Ok(response.into());
}
BeforeRequest::Stale { request: parts, matches } => {
*req.headers_mut() = parts.headers.clone();
let response = next.run(req, extensions).await?;
let after = cache.policy.after_response(&parts, &response, SystemTime::now());
match after {
AfterResponse::NotModified(policy, parts) => {
if matches {
cache.policy = policy;
}
let response =
http::Response::from_parts(parts, cache.response.clone());
return Ok(response.into());
}
AfterResponse::Modified(policy, parts) => {
if matches {
cache.policy = policy;
}
let body = response.bytes().await?;
cache.response = body;
let response =
http::Response::from_parts(parts, cache.response.clone());
return Ok(response.into());
}
}
}
}
}
#[allow(clippy::expect_used)]
let (mut parts, _) = http::Request::builder()
.uri(req.uri())
.method(req.method().clone())
.version(req.version())
.body(())
.expect("Builder used correctly")
.into_parts();
parts.headers = req.headers().clone();
let response = next.run(req, extensions).await?;
let policy = CachePolicy::new_options(&parts, &response, SystemTime::now(), self.options);
if policy.is_storable() {
let response = reqwest_to_http(response).await?;
let cache = CacheEntry::new(policy, response.body().clone());
self.cache
.alter(url, |entry| async move {
match entry {
None => Some(cache),
Some(entry) => {
let time = SystemTime::now();
if entry.policy.age(time) > cache.policy.age(time) {
Some(cache)
} else {
Some(entry)
}
}
}
})
.await;
return Ok(response.into());
}
Ok(response)
}
}
async fn reqwest_to_http(
mut response: reqwest::Response,
) -> reqwest::Result<http::Response<Bytes>> {
let mut http = http::Response::new(Bytes::new());
*http.status_mut() = response.status();
*http.version_mut() = response.version();
std::mem::swap(http.headers_mut(), response.headers_mut());
*http.body_mut() = response.bytes().await?;
Ok(http)
}