#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
unused_import_braces,
unused_extern_crates,
trivial_casts,
trivial_numeric_casts
)]
#![allow(clippy::doc_lazy_continuation)]
#![cfg_attr(docsrs, feature(doc_cfg))]
pub use http_cache::{BadRequest, HttpCacheError};
#[cfg(feature = "streaming")]
pub type ReqwestStreamingError = http_cache::ClientStreamingError;
#[cfg(feature = "streaming")]
use http_cache::StreamingCacheManager;
use std::{str::FromStr, time::SystemTime};
pub use http::request::Parts;
use http::{
header::{HeaderName, CACHE_CONTROL},
Extensions, HeaderValue, Method,
};
use http_cache::{
url_parse, BoxError, HitOrMiss, Middleware, Result, Url, XCACHE,
XCACHELOOKUP,
};
use http_cache_semantics::CachePolicy;
use reqwest::{Request, Response, ResponseBuilderExt};
use reqwest_middleware::{Error, Next};
fn to_middleware_error<E: std::error::Error + Send + Sync + 'static>(
error: E,
) -> Error {
Error::Middleware(anyhow::Error::new(error))
}
pub use http_cache::{
CacheManager, CacheMode, CacheOptions, HttpCache, HttpCacheMetadata,
HttpCacheOptions, HttpResponse, MetadataProvider, ResponseCacheModeFn,
};
#[cfg(feature = "streaming")]
pub use http_cache::{
HttpCacheStreamInterface, HttpStreamingCache, StreamingBody,
StreamingManager,
};
#[cfg(feature = "manager-cacache")]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-cacache")))]
pub use http_cache::CACacheManager;
#[cfg(feature = "manager-moka")]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-moka")))]
pub use http_cache::{MokaCache, MokaCacheBuilder, MokaManager};
#[cfg(feature = "rate-limiting")]
#[cfg_attr(docsrs, doc(cfg(feature = "rate-limiting")))]
pub use http_cache::rate_limiting::{
CacheAwareRateLimiter, DirectRateLimiter, DomainRateLimiter, Quota,
};
#[derive(Debug)]
pub struct Cache<T: CacheManager>(pub HttpCache<T>);
#[cfg(feature = "streaming")]
#[derive(Debug, Clone)]
pub struct StreamingCache<T: StreamingCacheManager> {
cache: HttpStreamingCache<T>,
}
#[cfg(feature = "streaming")]
impl<T: StreamingCacheManager> StreamingCache<T> {
pub fn new(manager: T, mode: CacheMode) -> Self {
Self {
cache: HttpStreamingCache {
mode,
manager,
options: HttpCacheOptions::default(),
},
}
}
pub fn with_options(
manager: T,
mode: CacheMode,
options: HttpCacheOptions,
) -> Self {
Self { cache: HttpStreamingCache { mode, manager, options } }
}
}
pub(crate) struct ReqwestMiddleware<'a> {
pub req: Request,
pub next: Next<'a>,
pub extensions: &'a mut Extensions,
}
fn clone_req(request: &Request) -> std::result::Result<Request, Error> {
match request.try_clone() {
Some(r) => Ok(r),
None => Err(to_middleware_error(BadRequest)),
}
}
impl Middleware for ReqwestMiddleware<'_> {
fn overridden_cache_mode(&self) -> Option<CacheMode> {
self.extensions.get().cloned()
}
fn is_method_get_head(&self) -> bool {
self.req.method() == Method::GET || self.req.method() == Method::HEAD
}
fn policy(&self, response: &HttpResponse) -> Result<CachePolicy> {
Ok(CachePolicy::new(&self.parts()?, &response.parts()?))
}
fn policy_with_options(
&self,
response: &HttpResponse,
options: CacheOptions,
) -> Result<CachePolicy> {
Ok(CachePolicy::new_options(
&self.parts()?,
&response.parts()?,
SystemTime::now(),
options,
))
}
fn update_headers(&mut self, parts: &Parts) -> Result<()> {
for header in parts.headers.iter() {
self.req.headers_mut().insert(header.0.clone(), header.1.clone());
}
Ok(())
}
fn force_no_cache(&mut self) -> Result<()> {
self.req
.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
Ok(())
}
fn parts(&self) -> Result<Parts> {
let mut builder = http::Request::builder()
.method(self.req.method().as_str())
.uri(self.req.url().as_str())
.version(self.req.version());
for (name, value) in self.req.headers() {
builder = builder.header(name, value);
}
if let Some(no_error) = builder.extensions_mut() {
*no_error = self.extensions.clone();
}
let http_req = builder.body(()).map_err(Box::new)?;
Ok(http_req.into_parts().0)
}
fn url(&self) -> Result<Url> {
url_parse(self.req.url().as_str())
}
fn method(&self) -> Result<String> {
Ok(self.req.method().as_ref().to_string())
}
async fn remote_fetch(&mut self) -> Result<HttpResponse> {
let copied_req = clone_req(&self.req)?;
let res = self
.next
.clone()
.run(copied_req, self.extensions)
.await
.map_err(BoxError::from)?;
let headers = res.headers().into();
let url = url_parse(res.url().as_str())?;
let status = res.status().into();
let version = res.version();
let body: Vec<u8> = res.bytes().await.map_err(BoxError::from)?.to_vec();
Ok(HttpResponse {
body,
headers,
status,
url,
version: version.try_into()?,
metadata: None,
})
}
}
fn convert_response(response: HttpResponse) -> Result<Response> {
let metadata = response.metadata.clone();
let reqwest_url =
::url::Url::parse(response.url.as_str()).map_err(BoxError::from)?;
let mut ret_res = http::Response::builder()
.status(response.status)
.url(reqwest_url)
.version(response.version.into())
.body(response.body)?;
for header in response.headers {
ret_res.headers_mut().append(
HeaderName::from_str(&header.0)?,
HeaderValue::from_str(&header.1)?,
);
}
if let Some(metadata) = metadata {
ret_res.extensions_mut().insert(HttpCacheMetadata::from(metadata));
}
Ok(Response::from(ret_res))
}
#[cfg(feature = "streaming")]
async fn convert_reqwest_response_to_http_full_body(
response: Response,
) -> Result<http::Response<http_body_util::Full<bytes::Bytes>>> {
let status = response.status();
let version = response.version();
let headers = response.headers().clone();
let body_bytes = response.bytes().await.map_err(BoxError::from)?;
let mut http_response =
http::Response::builder().status(status).version(version);
for (name, value) in headers.iter() {
http_response = http_response.header(name, value);
}
http_response
.body(http_body_util::Full::new(body_bytes))
.map_err(BoxError::from)
}
#[cfg(feature = "streaming")]
async fn convert_streaming_body_to_reqwest<T>(
response: http::Response<T::Body>,
) -> Result<Response>
where
T: StreamingCacheManager,
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error: Send + Sync + 'static,
{
let (parts, body) = response.into_parts();
let bytes_stream = T::body_to_bytes_stream(body);
let reqwest_body = reqwest::Body::wrap_stream(bytes_stream);
let mut http_response =
http::Response::builder().status(parts.status).version(parts.version);
for (name, value) in parts.headers.iter() {
http_response = http_response.header(name, value);
}
let mut response = http_response.body(reqwest_body)?;
*response.extensions_mut() = parts.extensions;
Ok(Response::from(response))
}
fn bad_header(e: reqwest::header::InvalidHeaderValue) -> Error {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
}
fn from_box_error(e: BoxError) -> Error {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
}
#[async_trait::async_trait]
impl<T: CacheManager> reqwest_middleware::Middleware for Cache<T> {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, Error> {
let middleware = ReqwestMiddleware { req, next, extensions };
let can_cache =
self.0.can_cache_request(&middleware).map_err(from_box_error)?;
if can_cache {
let res = self.0.run(middleware).await.map_err(from_box_error)?;
let converted = convert_response(res).map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
Ok(converted)
} else {
let parts = middleware.parts().map_err(from_box_error)?;
let mut res = middleware
.next
.run(middleware.req, middleware.extensions)
.await?;
if !parts.method.is_safe()
&& (res.status().is_success() || res.status().is_redirection())
{
self.0
.run_no_cache_from_parts(&parts)
.await
.map_err(from_box_error)?;
}
if self.0.options.cache_status_headers {
let miss =
HeaderValue::from_str(HitOrMiss::MISS.to_string().as_ref())
.map_err(bad_header)?;
res.headers_mut().insert(XCACHE, miss.clone());
res.headers_mut().insert(XCACHELOOKUP, miss);
}
Ok(res)
}
}
}
#[cfg(feature = "streaming")]
#[async_trait::async_trait]
impl<T: StreamingCacheManager> reqwest_middleware::Middleware
for StreamingCache<T>
where
T::Body: Send + 'static,
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error:
Into<http_cache::StreamingError> + Send + Sync + 'static,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> std::result::Result<Response, Error> {
use http_cache::FetchRequest;
let copied_req = match clone_req(&req) {
Ok(r) => r,
Err(_) => return next.run(req, extensions).await,
};
let http_req = http::Request::try_from(copied_req).map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})?;
let (parts, _) = http_req.into_parts();
let mode_override = extensions.get::<CacheMode>().cloned();
let can_cache = self
.cache
.can_cache_request(&parts, mode_override)
.map_err(from_box_error)?;
if can_cache {
let result = self
.cache
.run(&parts, mode_override, |fetch_req| {
let mut req = req;
let next = next.clone();
match fetch_req {
FetchRequest::Fresh => {}
FetchRequest::FreshNoCache => {
req.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
}
FetchRequest::Conditional(cond_parts) => {
for (name, value) in cond_parts.headers.iter() {
req.headers_mut()
.insert(name.clone(), value.clone());
}
}
}
async move {
let resp = next.run(req, extensions).await.map_err(
|e| -> BoxError { e.to_string().into() },
)?;
convert_reqwest_response_to_http_full_body(resp).await
}
})
.await
.map_err(from_box_error)?;
convert_streaming_body_to_reqwest::<T>(result).await.map_err(|e| {
to_middleware_error(HttpCacheError::Cache(e.to_string()))
})
} else {
let mut res = next.run(req, extensions).await?;
if !parts.method.is_safe()
&& (res.status().is_success() || res.status().is_redirection())
{
self.cache
.run_no_cache(&parts)
.await
.map_err(from_box_error)?;
}
if self.cache.options.cache_status_headers {
let miss =
HeaderValue::from_str(HitOrMiss::MISS.to_string().as_ref())
.map_err(bad_header)?;
res.headers_mut().insert(XCACHE, miss.clone());
res.headers_mut().insert(XCACHELOOKUP, miss);
}
Ok(res)
}
}
}
#[cfg(test)]
mod test;