use bytes::Bytes;
use http::{
header::CACHE_CONTROL, request, HeaderValue, Method, Request, Response,
};
use http_body::Body;
use http_body_util::BodyExt;
#[cfg(feature = "manager-cacache")]
pub use http_cache::CACacheManager;
#[cfg(feature = "rate-limiting")]
pub use http_cache::rate_limiting::{
CacheAwareRateLimiter, DirectRateLimiter, DomainRateLimiter, Quota,
};
#[cfg(feature = "streaming")]
use http_cache::StreamingError;
use http_cache::{
url_parse, BoxError, CacheManager, CacheMode, CacheOptions, HitOrMiss,
HttpCache, HttpCacheOptions, HttpResponse, Middleware, Url, XCACHE,
XCACHELOOKUP,
};
#[cfg(feature = "streaming")]
use http_cache::{HttpStreamingCache, StreamingCacheManager};
use http_cache_semantics::CachePolicy;
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::SystemTime,
};
use tower::{Layer, Service, ServiceExt};
pub use http_cache::HttpCacheError;
#[cfg(feature = "streaming")]
pub type TowerStreamingError = http_cache::ClientStreamingError;
trait HttpCacheErrorExt<T> {
fn cache_err(self) -> Result<T, HttpCacheError>;
}
impl<T, E> HttpCacheErrorExt<T> for Result<T, E>
where
E: ToString,
{
fn cache_err(self) -> Result<T, HttpCacheError> {
self.map_err(|e| HttpCacheError::cache(e.to_string()))
}
}
fn add_cache_status_headers<B>(
mut response: Response<HttpCacheBody<B>>,
hit_or_miss: &str,
cache_lookup: &str,
) -> Response<HttpCacheBody<B>> {
let headers = response.headers_mut();
if let Ok(hv) = HeaderValue::from_str(hit_or_miss) {
headers.insert(XCACHE, hv);
}
if let Ok(hv) = HeaderValue::from_str(cache_lookup) {
headers.insert(XCACHELOOKUP, hv);
}
response
}
struct TowerMiddleware<S, ReqBody> {
parts: request::Parts,
body: Option<ReqBody>,
service: Option<S>,
}
impl<S, ReqBody, ResBody> Middleware for TowerMiddleware<S, ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send + 'static,
ReqBody: Body + Send + 'static,
ReqBody::Data: Send,
ReqBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
ResBody: Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
fn is_method_get_head(&self) -> bool {
self.parts.method == Method::GET || self.parts.method == Method::HEAD
}
fn policy(
&self,
response: &HttpResponse,
) -> http_cache::Result<CachePolicy> {
Ok(CachePolicy::new(&self.parts, &response.parts()?))
}
fn policy_with_options(
&self,
response: &HttpResponse,
options: CacheOptions,
) -> http_cache::Result<CachePolicy> {
Ok(CachePolicy::new_options(
&self.parts,
&response.parts()?,
SystemTime::now(),
options,
))
}
fn update_headers(
&mut self,
parts: &request::Parts,
) -> http_cache::Result<()> {
for (name, value) in parts.headers.iter() {
self.parts.headers.insert(name.clone(), value.clone());
}
Ok(())
}
fn force_no_cache(&mut self) -> http_cache::Result<()> {
self.parts
.headers
.insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
Ok(())
}
fn parts(&self) -> http_cache::Result<request::Parts> {
Ok(self.parts.clone())
}
fn url(&self) -> http_cache::Result<Url> {
url_parse(self.parts.uri.to_string().as_str())
}
fn method(&self) -> http_cache::Result<String> {
Ok(self.parts.method.as_ref().to_string())
}
async fn remote_fetch(&mut self) -> http_cache::Result<HttpResponse> {
let body = self
.body
.take()
.ok_or_else(|| BoxError::from("request body already consumed"))?;
let service = self
.service
.take()
.ok_or_else(|| BoxError::from("inner service already consumed"))?;
let request = Request::from_parts(self.parts.clone(), body);
let response = service.oneshot(request).await.map_err(|e| {
let boxed: Box<dyn std::error::Error + Send + Sync> = e.into();
boxed
})?;
let (res_parts, res_body) = response.into_parts();
let collected = BodyExt::collect(res_body).await.map_err(|e| {
let boxed: Box<dyn std::error::Error + Send + Sync> = e.into();
boxed
})?;
let body_bytes = collected.to_bytes().to_vec();
let url = url_parse(self.parts.uri.to_string().as_str())?;
let headers = (&res_parts.headers).into();
let status = res_parts.status.as_u16();
let version = res_parts.version.try_into()?;
Ok(HttpResponse {
body: body_bytes,
headers,
status,
url,
version,
metadata: None,
})
}
}
fn http_response_to_tower_response<B>(
http_response: HttpResponse,
) -> Result<Response<HttpCacheBody<B>>, HttpCacheError> {
let mut response = HttpCacheOptions::http_response_to_response(
&http_response,
HttpCacheBody::Buffered(http_response.body.clone()),
)
.map_err(HttpCacheError::other)?;
if let Some(metadata) = http_response.metadata {
response
.extensions_mut()
.insert(http_cache::HttpCacheMetadata::from(metadata));
}
Ok(response)
}
#[cfg(feature = "streaming")]
fn add_cache_status_headers_streaming<B>(
mut response: Response<B>,
hit_or_miss: &str,
cache_lookup: &str,
) -> Response<B> {
let headers = response.headers_mut();
if let Ok(hv) = HeaderValue::from_str(hit_or_miss) {
headers.insert(XCACHE, hv);
}
if let Ok(hv) = HeaderValue::from_str(cache_lookup) {
headers.insert(XCACHELOOKUP, hv);
}
response
}
#[derive(Clone)]
pub struct HttpCacheLayer<CM>
where
CM: CacheManager,
{
cache: Arc<HttpCache<CM>>,
}
impl<CM> HttpCacheLayer<CM>
where
CM: CacheManager,
{
pub fn new(cache_manager: CM) -> Self {
Self {
cache: Arc::new(HttpCache {
mode: CacheMode::Default,
manager: cache_manager,
options: HttpCacheOptions::default(),
}),
}
}
pub fn with_options(cache_manager: CM, options: HttpCacheOptions) -> Self {
Self {
cache: Arc::new(HttpCache {
mode: CacheMode::Default,
manager: cache_manager,
options,
}),
}
}
pub fn with_cache(cache: HttpCache<CM>) -> Self {
Self { cache: Arc::new(cache) }
}
}
#[cfg(feature = "streaming")]
#[derive(Clone)]
pub struct HttpCacheStreamingLayer<CM>
where
CM: StreamingCacheManager,
{
cache: Arc<HttpStreamingCache<CM>>,
}
#[cfg(feature = "streaming")]
impl<CM> HttpCacheStreamingLayer<CM>
where
CM: StreamingCacheManager,
{
pub fn new(cache_manager: CM) -> Self {
Self {
cache: Arc::new(HttpStreamingCache {
mode: CacheMode::Default,
manager: cache_manager,
options: HttpCacheOptions::default(),
}),
}
}
pub fn with_options(cache_manager: CM, options: HttpCacheOptions) -> Self {
Self {
cache: Arc::new(HttpStreamingCache {
mode: CacheMode::Default,
manager: cache_manager,
options,
}),
}
}
pub fn with_cache(cache: HttpStreamingCache<CM>) -> Self {
Self { cache: Arc::new(cache) }
}
}
impl<S, CM> Layer<S> for HttpCacheLayer<CM>
where
CM: CacheManager,
{
type Service = HttpCacheService<S, CM>;
fn layer(&self, inner: S) -> Self::Service {
HttpCacheService { inner, cache: self.cache.clone() }
}
}
#[cfg(feature = "streaming")]
impl<S, CM> Layer<S> for HttpCacheStreamingLayer<CM>
where
CM: StreamingCacheManager,
{
type Service = HttpCacheStreamingService<S, CM>;
fn layer(&self, inner: S) -> Self::Service {
HttpCacheStreamingService { inner, cache: self.cache.clone() }
}
}
pub struct HttpCacheService<S, CM>
where
CM: CacheManager,
{
inner: S,
cache: Arc<HttpCache<CM>>,
}
impl<S, CM> Clone for HttpCacheService<S, CM>
where
S: Clone,
CM: CacheManager,
{
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), cache: self.cache.clone() }
}
}
#[cfg(feature = "streaming")]
pub struct HttpCacheStreamingService<S, CM>
where
CM: StreamingCacheManager,
{
inner: S,
cache: Arc<HttpStreamingCache<CM>>,
}
#[cfg(feature = "streaming")]
impl<S, CM> Clone for HttpCacheStreamingService<S, CM>
where
S: Clone,
CM: StreamingCacheManager,
{
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), cache: self.cache.clone() }
}
}
impl<S, CM, ReqBody, ResBody> Service<Request<ReqBody>>
for HttpCacheService<S, CM>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send + 'static,
ReqBody: Body + Send + 'static,
ReqBody::Data: Send,
ReqBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
ResBody: Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
CM: CacheManager,
{
type Response = Response<HttpCacheBody<ResBody>>;
type Error = HttpCacheError;
type Future = Pin<
Box<
dyn std::future::Future<
Output = Result<Self::Response, Self::Error>,
> + Send,
>,
>;
fn poll_ready(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| HttpCacheError::http(e.into()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let cache = self.cache.clone();
let (parts, body) = req.into_parts();
let inner_service = self.inner.clone();
Box::pin(async move {
let middleware = TowerMiddleware {
parts: parts.clone(),
body: Some(body),
service: Some(inner_service),
};
let can_cache = cache.can_cache_request(&middleware).cache_err()?;
if can_cache {
let res = cache.run(middleware).await.cache_err()?;
http_response_to_tower_response(res)
} else {
let parts_for_invalidation = middleware.parts().cache_err()?;
let body = middleware.body.ok_or_else(|| {
HttpCacheError::cache(
"request body already consumed".to_string(),
)
})?;
let service = middleware.service.ok_or_else(|| {
HttpCacheError::cache(
"inner service already consumed".to_string(),
)
})?;
let req = Request::from_parts(parts, body);
let response = service.oneshot(req).await.map_err(|e| {
let boxed: Box<dyn std::error::Error + Send + Sync> =
e.into();
HttpCacheError::http(boxed)
})?;
if !parts_for_invalidation.method.is_safe()
&& (response.status().is_success()
|| response.status().is_redirection())
{
cache
.run_no_cache_from_parts(&parts_for_invalidation)
.await
.cache_err()?;
}
let mut response = response.map(HttpCacheBody::Original);
if cache.options.cache_status_headers {
response = add_cache_status_headers(
response,
HitOrMiss::MISS.to_string().as_ref(),
HitOrMiss::MISS.to_string().as_ref(),
);
}
Ok(response)
}
})
}
}
impl<S, CM> hyper::service::Service<Request<hyper::body::Incoming>>
for HttpCacheService<S, CM>
where
S: Service<
Request<hyper::body::Incoming>,
Response = Response<http_body_util::Full<Bytes>>,
> + Clone
+ Send
+ 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send + 'static,
CM: CacheManager,
{
type Response = Response<HttpCacheBody<http_body_util::Full<Bytes>>>;
type Error = HttpCacheError;
type Future = Pin<
Box<
dyn std::future::Future<
Output = Result<Self::Response, Self::Error>,
> + Send,
>,
>;
fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
let mut service_clone = self.clone();
Box::pin(
async move { tower::Service::call(&mut service_clone, req).await },
)
}
}
#[cfg(feature = "streaming")]
impl<S, CM, ReqBody, ResBody> Service<Request<ReqBody>>
for HttpCacheStreamingService<S, CM>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send + 'static,
ReqBody: Body + Send + 'static,
ReqBody::Data: Send,
ReqBody::Error: Into<StreamingError>,
ResBody: Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<StreamingError>,
CM: StreamingCacheManager,
<CM::Body as http_body::Body>::Data: Send,
<CM::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
type Response = Response<CM::Body>;
type Error = HttpCacheError;
type Future = Pin<
Box<
dyn std::future::Future<
Output = Result<Self::Response, Self::Error>,
> + Send,
>,
>;
fn poll_ready(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| HttpCacheError::http(e.into()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let cache = self.cache.clone();
let (parts, body) = req.into_parts();
let inner_service = self.inner.clone();
Box::pin(async move {
let can_cache =
cache.can_cache_request(&parts, None).cache_err()?;
if !can_cache {
let req = Request::from_parts(parts.clone(), body);
let response =
inner_service.oneshot(req).await.map_err(|e| {
let boxed: Box<dyn std::error::Error + Send + Sync> =
e.into();
HttpCacheError::http(boxed)
})?;
if !parts.method.is_safe()
&& (response.status().is_success()
|| response.status().is_redirection())
{
cache.run_no_cache(&parts).await.cache_err()?;
}
let mut converted =
cache.manager.convert_body(response).await.cache_err()?;
if cache.options.cache_status_headers {
converted = add_cache_status_headers_streaming(
converted, "MISS", "MISS",
);
}
return Ok(converted);
}
let result = cache
.run(&parts, None, |fetch_req| {
let parts_ref = parts.clone();
async move {
let request_parts = match fetch_req {
http_cache::FetchRequest::Fresh => parts_ref,
http_cache::FetchRequest::FreshNoCache => {
let mut p = parts_ref;
p.headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
p
}
http_cache::FetchRequest::Conditional(
cond_parts,
) => *cond_parts,
};
let req = Request::from_parts(request_parts, body);
inner_service.oneshot(req).await.map_err(|e| {
let boxed: Box<
dyn std::error::Error + Send + Sync,
> = e.into();
boxed
})
}
})
.await
.cache_err()?;
Ok(result)
})
}
}
pub enum HttpCacheBody<B> {
Buffered(Vec<u8>),
Original(B),
}
impl<B> Body for HttpCacheBody<B>
where
B: Body + Unpin,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B::Data: Into<bytes::Bytes>,
{
type Data = bytes::Bytes;
type Error = Box<dyn std::error::Error + Send + Sync>;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match &mut *self {
HttpCacheBody::Buffered(bytes) => {
if bytes.is_empty() {
Poll::Ready(None)
} else {
let data = std::mem::take(bytes);
Poll::Ready(Some(Ok(http_body::Frame::data(
bytes::Bytes::from(data),
))))
}
}
HttpCacheBody::Original(body) => {
Pin::new(body).poll_frame(cx).map(|opt| {
opt.map(|res| {
res.map(|frame| frame.map_data(Into::into))
.map_err(Into::into)
})
})
}
}
}
fn is_end_stream(&self) -> bool {
match self {
HttpCacheBody::Buffered(bytes) => bytes.is_empty(),
HttpCacheBody::Original(body) => body.is_end_stream(),
}
}
fn size_hint(&self) -> http_body::SizeHint {
match self {
HttpCacheBody::Buffered(bytes) => {
let len = bytes.len() as u64;
http_body::SizeHint::with_exact(len)
}
HttpCacheBody::Original(body) => body.size_hint(),
}
}
}
#[cfg(test)]
mod test;