pub mod mem;
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{ready, Poll},
};
use bytes::Bytes;
use http::{header, request::Request, HeaderMap, HeaderValue, Response, StatusCode, Uri};
use http_body::{Body, Frame, SizeHint};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use pin_project::pin_project;
use tower::{Layer, Service};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum CacheKey {
ETag(String),
LastModified(String),
}
#[derive(Debug, Clone, Default)]
pub struct CachedResponse {
pub body: Vec<u8>,
pub headers: HeaderMap,
}
pub trait CacheStorage: Send + Sync {
fn try_hit(&self, uri: &Uri) -> Option<CacheKey>;
fn load(&self, uri: &Uri) -> Option<CachedResponse>;
fn writer(&self, uri: &Uri, key: CacheKey, headers: HeaderMap) -> Box<dyn CacheWriter>;
}
pub trait CacheWriter: Send + Sync {
fn write_body(&mut self, data: &[u8]);
}
#[derive(Clone)]
pub struct HttpCacheLayer {
storage: Option<Arc<dyn CacheStorage>>,
}
impl HttpCacheLayer {
pub fn new(storage: Option<Arc<dyn CacheStorage>>) -> Self {
HttpCacheLayer { storage }
}
}
impl<S> Layer<S> for HttpCacheLayer {
type Service = HttpCache<S>;
fn layer(&self, inner: S) -> Self::Service {
HttpCache {
inner,
storage: self.storage.clone(),
}
}
}
pub struct HttpCache<S> {
inner: S,
storage: Option<Arc<dyn CacheStorage>>,
}
type ResBody = BoxBody<Bytes, crate::Error>;
impl<S, ReqBody> Service<Request<ReqBody>> for HttpCache<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Error = S::Error;
type Response = S::Response;
type Future = HttpCacheFuture<S::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let uri = req.uri().clone();
if let Some(ref storage) = self.storage {
if let Some(key) = storage.try_hit(&uri) {
match key {
CacheKey::ETag(etag) => {
req.headers_mut()
.append(header::IF_NONE_MATCH, HeaderValue::from_str(&etag).unwrap());
}
CacheKey::LastModified(last_modified) => {
req.headers_mut().append(
header::IF_MODIFIED_SINCE,
HeaderValue::from_str(&last_modified).unwrap(),
);
}
}
}
}
HttpCacheFuture {
inner: self.inner.call(req),
storage: self.storage.clone(),
uri,
}
}
}
#[pin_project]
pub struct HttpCacheFuture<F> {
#[pin]
inner: F,
storage: Option<Arc<dyn CacheStorage>>,
uri: Uri,
}
impl<F, E> Future for HttpCacheFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResBody>, E>;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
let mut response = ready!(this.inner.poll(cx))?;
if let Some(ref storage) = this.storage {
if response.status() == StatusCode::NOT_MODIFIED {
let cached = storage.load(this.uri).expect("no body for cache hit");
for (name, value) in cached.headers.iter() {
if [header::CONTENT_TYPE, header::CONTENT_LENGTH, header::LINK].contains(name) {
response.headers_mut().append(name, value.clone());
}
}
*response.body_mut() = BoxBody::new(Box::new(
Full::new(Bytes::from(cached.body)).map_err(|infallible| match infallible {}),
));
*response.status_mut() = StatusCode::OK;
} else {
let cache_key = CacheKey::extract_from_headers(response.headers());
if let Some(key) = cache_key {
let writer = storage.writer(this.uri, key, response.headers().clone());
let (parts, mut body) = response.into_parts();
body = BoxBody::new(Box::new(WriteToCacheBody::new(body, writer)));
response = Response::from_parts(parts, body);
}
}
}
Poll::Ready(Ok(response))
}
}
#[pin_project]
struct WriteToCacheBody<B> {
#[pin]
inner: B,
writer: Box<dyn CacheWriter>,
}
impl<B> WriteToCacheBody<B> {
fn new(inner: B, writer: Box<dyn CacheWriter>) -> Self {
Self { inner, writer }
}
}
impl<B> Body for WriteToCacheBody<B>
where
B: Body<Data = Bytes, Error = crate::Error>,
{
type Data = Bytes;
type Error = crate::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
match this.inner.poll_frame(cx) {
Poll::Ready(frame) => {
if let Some(Ok(ref data)) = frame {
if let Some(data) = data.data_ref() {
this.writer.write_body(data);
}
}
Poll::Ready(frame)
}
Poll::Pending => Poll::Pending,
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> SizeHint {
self.inner.size_hint()
}
}
impl CacheKey {
fn extract_from_headers(headers: &HeaderMap) -> Option<Self> {
headers
.get(header::ETAG)
.and_then(|etag| Some(CacheKey::ETag(etag.to_str().ok()?.to_owned())))
.or_else(|| {
headers
.get(header::LAST_MODIFIED)
.and_then(|last_modified| {
Some(CacheKey::LastModified(
last_modified.to_str().ok()?.to_owned(),
))
})
})
}
}