#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
unused_import_braces,
unused_extern_crates,
trivial_casts,
trivial_numeric_casts
)]
#![allow(clippy::doc_lazy_continuation)]
#![cfg_attr(docsrs, feature(doc_cfg))]
pub use http_cache::{BadRequest, HttpCacheError};
#[cfg(feature = "streaming")]
pub type ReqwestStreamingError = http_cache::ClientStreamingError;
#[cfg(feature = "streaming")]
use http_cache::StreamingCacheManager;
use std::{convert::TryInto, str::FromStr, time::SystemTime};
pub use http::request::Parts;
use http::{
header::{HeaderName, CACHE_CONTROL},
Extensions, HeaderValue, Method,
};
use http_cache::{
url_parse, BoxError, HitOrMiss, Middleware, Result, Url, XCACHE,
XCACHELOOKUP,
};
use http_cache_semantics::CachePolicy;
use reqwest::{Request, Response, ResponseBuilderExt};
use reqwest_middleware::{Error, Next};
fn to_middleware_error<E: std::error::Error + Send + Sync + 'static>(
error: E,
) -> Error {
Error::Middleware(anyhow::Error::new(error))
}
pub use http_cache::{
CacheManager, CacheMode, CacheOptions, HttpCache, HttpCacheMetadata,
HttpCacheOptions, HttpResponse, MetadataProvider, ResponseCacheModeFn,
};
#[cfg(feature = "streaming")]
pub use http_cache::{
HttpCacheStreamInterface, HttpStreamingCache, StreamingBody,
StreamingManager,
};
#[cfg(feature = "manager-cacache")]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-cacache")))]
pub use http_cache::CACacheManager;
#[cfg(feature = "manager-moka")]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-moka")))]
pub use http_cache::{MokaCache, MokaCacheBuilder, MokaManager};
#[cfg(feature = "rate-limiting")]
#[cfg_attr(docsrs, doc(cfg(feature = "rate-limiting")))]
pub use http_cache::rate_limiting::{
CacheAwareRateLimiter, DirectRateLimiter, DomainRateLimiter, Quota,
};
#[derive(Debug)]
pub struct Cache<T: CacheManager>(pub HttpCache<T>);
#[cfg(feature = "streaming")]
#[derive(Debug, Clone)]
pub struct StreamingCache<T: StreamingCacheManager> {
cache: HttpStreamingCache<T>,
}
#[cfg(feature = "streaming")]
impl<T: StreamingCacheManager> StreamingCache<T> {
pub fn new(manager: T, mode: CacheMode) -> Self {
Self {
cache: HttpStreamingCache {
mode,
manager,
options: HttpCacheOptions::default(),
},
}
}
pub fn with_options(
manager: T,
mode: CacheMode,
options: HttpCacheOptions,
) -> Self {
Self { cache: HttpStreamingCache { mode, manager, options } }
}
}
pub(crate) struct ReqwestMiddleware<'a> {
pub req: Request,
pub next: Next<'a>,
pub extensions: &'a mut Extensions,
}
fn clone_req(request: &Request) -> std::result::Result<Request, Error> {
match request.try_clone() {
Some(r) => Ok(r),
None => Err(to_middleware_error(BadRequest)),
}
}
#[async_trait::async_trait]
impl Middleware for ReqwestMiddleware<'_> {
fn overridden_cache_mode(&self) -> Option<CacheMode> {
self.extensions.get().cloned()
}
fn is_method_get_head(&self) -> bool {
self.req.method() == Method::GET || self.req.method() == Method::HEAD
}
fn policy(&self, response: &HttpResponse) -> Result<CachePolicy> {
Ok(CachePolicy::new(&self.parts()?, &response.parts()?))
}
fn policy_with_options(
&self,
response: &HttpResponse,
options: CacheOptions,
) -> Result<CachePolicy> {
Ok(CachePolicy::new_options(
&self.parts()?,
&response.parts()?,
SystemTime::now(),
options,
))
}
fn update_headers(&mut self, parts: &Parts) -> Result<()> {
for header in parts.headers.iter() {
self.req.headers_mut().append(header.0.clone(), header.1.clone());
}
Ok(())
}
fn force_no_cache(&mut self) -> Result<()> {
self.req
.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
Ok(())
}
fn parts(&self) -> Result<Parts> {
let mut builder = http::Request::builder()
.method(self.req.method().as_str())
.uri(self.req.url().as_str())
.version(self.req.version());
for (name, value) in self.req.headers() {
builder = builder.header(name, value);
}
if let Some(no_error) = builder.extensions_mut() {
*no_error = self.extensions.clone();
}
let http_req = builder.body(()).map_err(Box::new)?;
Ok(http_req.into_parts().0)
}
fn url(&self) -> Result<Url> {
url_parse(self.req.url().as_str())
}
fn method(&self) -> Result<String> {
Ok(self.req.method().as_ref().to_string())
}
async fn remote_fetch(&mut self) -> Result<HttpResponse> {
let copied_req = clone_req(&self.req)?;
let res = self
.next
.clone()
.run(copied_req, self.extensions)
.await
.map_err(BoxError::from)?;
let headers = res.headers().into();
let url = url_parse(res.url().as_str())?;
let status = res.status().into();
let version = res.version();
let body: Vec<u8> = res.bytes().await.map_err(BoxError::from)?.to_vec();
Ok(HttpResponse {
body,
headers,
status,
url,
version: version.try_into()?,
metadata: None,
})
}
}
fn convert_response(response: HttpResponse) -> Result<Response> {
let metadata = response.metadata.clone();
let reqwest_url =
::url::Url::parse(response.url.as_str()).map_err(BoxError::from)?;
let mut ret_res = http::Response::builder()
.status(response.status)
.url(reqwest_url)
.version(response.version.into())
.body(response.body)?;
for header in response.headers {
ret_res.headers_mut().append(
HeaderName::from_str(&header.0)?,
HeaderValue::from_str(&header.1)?,
);
}
if let Some(metadata) = metadata {
ret_res.extensions_mut().insert(HttpCacheMetadata::from(metadata));
}
Ok(Response::from(ret_res))
}
#[cfg(feature = "streaming")]
async fn convert_reqwest_response_to_http_full_body(
response: Response,
) -> Result<http::Response<http_body_util::Full<bytes::Bytes>>> {
let status = response.status();
let version = response.version();
let headers = response.headers().clone();
let body_bytes = response.bytes().await.map_err(BoxError::from)?;
let mut http_response =
http::Response::builder().status(status).version(version);
for (name, value) in headers.iter() {
http_response = http_response.header(name, value);
}
http_response
.body(http_body_util::Full::new(body_bytes))
.map_err(BoxError::from)
}
#[cfg(feature = "streaming")]
fn convert_reqwest_response_to_http_parts(
response: Response,
) -> Result<(http::response::Parts, ())> {
let status = response.status();
let version = response.version();
let headers = response.headers();
let mut http_response =
http::Response::builder().status(status).version(version);
for (name, value) in headers.iter() {
http_response = http_response.header(name, value);
}
let response = http_response.body(()).map_err(BoxError::from)?;
Ok(response.into_parts())
}
#[cfg(feature = "streaming")]
fn add_cache_status_headers_to_response<T>(
mut response: http::Response<T>,
hit_or_miss: &str,
cache_lookup: &str,
) -> http::Response<T> {
use http::HeaderValue;
use http_cache::{XCACHE, XCACHELOOKUP};
let headers = response.headers_mut();
if let Ok(value1) = HeaderValue::from_str(hit_or_miss) {
headers.insert(XCACHE, value1);
}
if let Ok(value2) = HeaderValue::from_str(cache_lookup) {
headers.insert(XCACHELOOKUP, value2);
}
response
}
#[cfg(feature = "streaming")]
async fn convert_streaming_body_to_reqwest<T>(
response: http::Response<T::Body>,
) -> Result<Response>
where
T: StreamingCacheManager,
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error: Send + Sync + 'static,
{
let (parts, body) = response.into_parts();
let bytes_stream = T::body_to_bytes_stream(body);
let reqwest_body = reqwest::Body::wrap_stream(bytes_stream);
let mut http_response =
http::Response::builder().status(parts.status).version(parts.version);
for (name, value) in parts.headers.iter() {
http_response = http_response.header(name, value);
}
let response = http_response.body(reqwest_body)?;
Ok(Response::from(response))
}
fn bad_header(e: reqwest::header::InvalidHeaderValue) -> Error {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
}
fn from_box_error(e: BoxError) -> Error {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
}
#[async_trait::async_trait]
impl<T: CacheManager> reqwest_middleware::Middleware for Cache<T> {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, Error> {
let mut middleware = ReqwestMiddleware { req, next, extensions };
let can_cache =
self.0.can_cache_request(&middleware).map_err(from_box_error)?;
if can_cache {
let res = self.0.run(middleware).await.map_err(from_box_error)?;
let converted = convert_response(res).map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
Ok(converted)
} else {
self.0
.run_no_cache(&mut middleware)
.await
.map_err(from_box_error)?;
let mut res = middleware
.next
.run(middleware.req, middleware.extensions)
.await?;
let miss =
HeaderValue::from_str(HitOrMiss::MISS.to_string().as_ref())
.map_err(bad_header)?;
res.headers_mut().insert(XCACHE, miss.clone());
res.headers_mut().insert(XCACHELOOKUP, miss);
Ok(res)
}
}
}
#[cfg(feature = "streaming")]
#[async_trait::async_trait]
impl<T: StreamingCacheManager> reqwest_middleware::Middleware
for StreamingCache<T>
where
T::Body: Send + 'static,
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error:
Into<http_cache::StreamingError> + Send + Sync + 'static,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, Error> {
use http_cache::HttpCacheStreamInterface;
let copied_req = match clone_req(&req) {
Ok(req) => req,
Err(_) => {
let response = next.run(req, extensions).await?;
return Ok(response);
}
};
let http_req = match http::Request::try_from(copied_req) {
Ok(r) => r,
Err(e) => {
return Err(to_middleware_error(HttpCacheError::Cache(
e.to_string(),
)))
}
};
let (parts, _) = http_req.into_parts();
let mode_override = extensions.get::<CacheMode>().cloned();
let analysis = match self.cache.analyze_request(&parts, mode_override) {
Ok(a) => a,
Err(e) => {
return Err(to_middleware_error(HttpCacheError::Cache(
e.to_string(),
)))
}
};
if !analysis.should_cache {
let response = next.run(req, extensions).await?;
return Ok(response);
}
if let Some((cached_response, policy)) = self
.cache
.lookup_cached_response(&analysis.cache_key)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?
{
use http_cache_semantics::BeforeRequest;
let before_req = policy.before_request(&parts, SystemTime::now());
match before_req {
BeforeRequest::Fresh(_fresh_parts) => {
let mut cached_response = cached_response;
if self.cache.options.cache_status_headers {
cached_response = add_cache_status_headers_to_response(
cached_response,
"HIT",
"HIT",
);
}
return convert_streaming_body_to_reqwest::<T>(
cached_response,
)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
});
}
BeforeRequest::Stale { request: conditional_parts, .. } => {
#[cfg(feature = "rate-limiting")]
if let Some(rate_limiter) = &self.cache.options.rate_limiter
{
let url = req.url().clone();
let rate_limit_key =
url.host_str().unwrap_or("unknown");
rate_limiter.until_key_ready(rate_limit_key).await;
}
let mut conditional_req = req;
for (name, value) in conditional_parts.headers.iter() {
conditional_req
.headers_mut()
.insert(name.clone(), value.clone());
}
let conditional_response =
next.run(conditional_req, extensions).await?;
if conditional_response.status() == 304 {
let (fresh_parts, _) =
convert_reqwest_response_to_http_parts(
conditional_response,
)
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
})?;
let updated_response = self
.cache
.handle_not_modified(cached_response, &fresh_parts)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
})?;
let mut final_response = updated_response;
if self.cache.options.cache_status_headers {
final_response =
add_cache_status_headers_to_response(
final_response,
"HIT",
"HIT",
);
}
return convert_streaming_body_to_reqwest::<T>(
final_response,
)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
});
} else {
let http_response =
convert_reqwest_response_to_http_full_body(
conditional_response,
)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
})?;
let cached_response = self
.cache
.process_response(analysis, http_response, None)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
})?;
let mut final_response = cached_response;
if self.cache.options.cache_status_headers {
final_response =
add_cache_status_headers_to_response(
final_response,
"MISS",
"MISS",
);
}
return convert_streaming_body_to_reqwest::<T>(
final_response,
)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(
e.to_string(),
))
});
}
}
}
}
if analysis.cache_mode == CacheMode::OnlyIfCached {
let http_response = http::Response::builder()
.status(504)
.body(self.cache.manager.empty_body())
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
let mut final_response = http_response;
if self.cache.options.cache_status_headers {
final_response = add_cache_status_headers_to_response(
final_response,
"MISS",
"MISS",
);
}
return convert_streaming_body_to_reqwest::<T>(final_response)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
});
}
#[cfg(feature = "rate-limiting")]
if let Some(rate_limiter) = &self.cache.options.rate_limiter {
let url = req.url().clone();
let rate_limit_key = url.host_str().unwrap_or("unknown");
rate_limiter.until_key_ready(rate_limit_key).await;
}
let response = next.run(req, extensions).await?;
let http_response =
convert_reqwest_response_to_http_full_body(response)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
let cached_response = self
.cache
.process_response(analysis, http_response, None)
.await
.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
let mut final_response = cached_response;
if self.cache.options.cache_status_headers {
final_response = add_cache_status_headers_to_response(
final_response,
"MISS",
"MISS",
);
}
convert_streaming_body_to_reqwest::<T>(final_response).await.map_err(
|e| to_middleware_error(HttpCacheError::Cache(e.to_string())),
)
}
}
#[cfg(test)]
mod test;