#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
unused_import_braces,
unused_extern_crates,
trivial_casts,
trivial_numeric_casts
)]
#![allow(clippy::doc_lazy_continuation)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(all(feature = "url-standard", feature = "url-ada"))]
compile_error!("features `url-standard` and `url-ada` are mutually exclusive");
#[cfg(not(any(feature = "url-standard", feature = "url-ada")))]
compile_error!("either feature `url-standard` or `url-ada` must be enabled");
mod body;
mod error;
mod managers;
#[cfg(feature = "rate-limiting")]
pub mod rate_limiting;
use std::{
collections::HashMap,
convert::TryFrom,
fmt::{self, Debug},
future::Future,
str::FromStr,
sync::Arc,
time::{Duration, SystemTime},
};
use http::{
header::CACHE_CONTROL, request, response, HeaderValue, Response, StatusCode,
};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "url-ada")]
pub use ada_url::Url;
#[cfg(not(feature = "url-ada"))]
pub use url::Url;
#[inline]
pub fn url_parse(s: &str) -> Result<Url> {
#[cfg(feature = "url-ada")]
{
Url::parse(s, None).map_err(|e| -> BoxError { e.to_string().into() })
}
#[cfg(not(feature = "url-ada"))]
{
Url::parse(s).map_err(|e| -> BoxError { Box::new(e) })
}
}
#[inline]
pub fn url_set_path(url: &mut Url, path: &str) {
#[cfg(feature = "url-ada")]
{
let _ = url.set_pathname(Some(path));
}
#[cfg(not(feature = "url-ada"))]
{
url.set_path(path);
}
}
#[inline]
pub fn url_set_query(url: &mut Url, query: Option<&str>) {
#[cfg(feature = "url-ada")]
{
url.set_search(query);
}
#[cfg(not(feature = "url-ada"))]
{
url.set_query(query);
}
}
#[inline]
#[must_use]
pub fn url_hostname(url: &Url) -> Option<&str> {
#[cfg(feature = "url-ada")]
{
let hostname = url.hostname();
if hostname.is_empty() {
None
} else {
Some(hostname)
}
}
#[cfg(not(feature = "url-ada"))]
{
url.host_str()
}
}
#[inline]
#[must_use]
pub fn url_host_str(url: &Url) -> String {
#[cfg(feature = "url-ada")]
{
let hostname = url.hostname();
if hostname.is_empty() {
"unknown".to_string()
} else {
hostname.to_string()
}
}
#[cfg(not(feature = "url-ada"))]
{
url.host()
.map(|h| h.to_string())
.unwrap_or_else(|| "unknown".to_string())
}
}
pub use body::StreamingBody;
pub use error::{
BadHeader, BadRequest, BadVersion, BoxError, ClientStreamingError,
HttpCacheError, HttpCacheResult, Result, StreamingError,
};
#[cfg(any(
feature = "manager-cacache",
feature = "manager-cacache-bincode"
))]
pub use managers::cacache::CACacheManager;
#[cfg(feature = "streaming")]
pub use managers::streaming_cache::StreamingManager;
#[cfg(any(feature = "manager-moka", feature = "manager-moka-bincode"))]
pub use managers::moka::MokaManager;
#[cfg(feature = "manager-foyer")]
pub use managers::foyer::FoyerManager;
#[cfg(feature = "rate-limiting")]
pub use rate_limiting::{
CacheAwareRateLimiter, DirectRateLimiter, DomainRateLimiter,
};
#[cfg(feature = "rate-limiting")]
pub use rate_limiting::Quota;
#[cfg(any(feature = "manager-moka", feature = "manager-moka-bincode"))]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-moka")))]
pub use moka::future::{Cache as MokaCache, CacheBuilder as MokaCacheBuilder};
pub const XCACHE: &str = "x-cache";
pub const XCACHELOOKUP: &str = "x-cache-lookup";
const WARNING: &str = "warning";
#[derive(Debug, Copy, Clone)]
pub enum HitOrMiss {
HIT,
MISS,
}
impl fmt::Display for HitOrMiss {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::HIT => write!(f, "HIT"),
Self::MISS => write!(f, "MISS"),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[non_exhaustive]
pub enum HttpVersion {
#[serde(rename = "HTTP/0.9")]
Http09,
#[serde(rename = "HTTP/1.0")]
Http10,
#[serde(rename = "HTTP/1.1")]
Http11,
#[serde(rename = "HTTP/2.0")]
H2,
#[serde(rename = "HTTP/3.0")]
H3,
}
impl fmt::Display for HttpVersion {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HttpVersion::Http09 => write!(f, "HTTP/0.9"),
HttpVersion::Http10 => write!(f, "HTTP/1.0"),
HttpVersion::Http11 => write!(f, "HTTP/1.1"),
HttpVersion::H2 => write!(f, "HTTP/2.0"),
HttpVersion::H3 => write!(f, "HTTP/3.0"),
}
}
}
fn extract_url_from_request_parts(parts: &request::Parts) -> Result<Url> {
if let Some(_scheme) = parts.uri.scheme() {
return url_parse(&parts.uri.to_string())
.map_err(|_| -> BoxError { BadHeader.into() });
}
let host = parts
.headers
.get("host")
.ok_or(BadHeader)?
.to_str()
.map_err(|_| BadHeader)?;
let scheme = determine_scheme(host, &parts.headers)?;
let mut base_url = url_parse(&format!("{}://{}/", &scheme, host))
.map_err(|_| -> BoxError { BadHeader.into() })?;
if let Some(path_and_query) = parts.uri.path_and_query() {
url_set_path(&mut base_url, path_and_query.path());
if let Some(query) = path_and_query.query() {
url_set_query(&mut base_url, Some(query));
}
}
Ok(base_url)
}
fn determine_scheme(host: &str, headers: &http::HeaderMap) -> Result<String> {
if let Some(forwarded_proto) = headers.get("x-forwarded-proto") {
let proto = forwarded_proto.to_str().map_err(|_| BadHeader)?;
return match proto {
"http" | "https" => Ok(proto.to_string()),
_ => Ok("https".to_string()), };
}
if host.starts_with("localhost") || host.starts_with("127.0.0.1") {
Ok("http".to_string())
} else {
Ok("https".to_string()) }
}
#[derive(Debug, Clone)]
pub enum HttpHeaders {
Modern(HashMap<String, Vec<String>>),
#[cfg(feature = "http-headers-compat")]
Legacy(HashMap<String, String>),
}
impl Serialize for HttpHeaders {
fn serialize<S>(
&self,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
#[cfg(feature = "http-headers-compat")]
{
match self {
HttpHeaders::Modern(modern) => {
let legacy: HashMap<String, String> = modern
.iter()
.map(|(k, v)| (k.clone(), v.join(", ")))
.collect();
legacy.serialize(serializer)
}
HttpHeaders::Legacy(legacy) => legacy.serialize(serializer),
}
}
#[cfg(not(feature = "http-headers-compat"))]
{
match self {
HttpHeaders::Modern(modern) => modern.serialize(serializer),
}
}
}
}
impl<'de> Deserialize<'de> for HttpHeaders {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[cfg(feature = "http-headers-compat")]
{
let legacy = HashMap::<String, String>::deserialize(deserializer)?;
Ok(HttpHeaders::Legacy(legacy))
}
#[cfg(not(feature = "http-headers-compat"))]
{
let modern =
HashMap::<String, Vec<String>>::deserialize(deserializer)?;
Ok(HttpHeaders::Modern(modern))
}
}
}
impl HttpHeaders {
pub fn new() -> Self {
HttpHeaders::Modern(HashMap::new())
}
pub fn insert(&mut self, key: String, value: String) {
let normalized_key = key.to_ascii_lowercase();
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => {
legacy.insert(normalized_key, value);
}
HttpHeaders::Modern(modern) => {
modern.insert(normalized_key, vec![value]);
}
}
}
pub fn append(&mut self, key: String, value: String) {
let normalized_key = key.to_ascii_lowercase();
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => {
legacy.insert(normalized_key, value);
}
HttpHeaders::Modern(modern) => {
modern
.entry(normalized_key)
.or_insert_with(Vec::new)
.push(value);
}
}
}
pub fn get(&self, key: &str) -> Option<&String> {
let normalized_key = key.to_ascii_lowercase();
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => legacy.get(&normalized_key),
HttpHeaders::Modern(modern) => {
modern.get(&normalized_key).and_then(|vals| vals.first())
}
}
}
pub fn remove(&mut self, key: &str) {
let normalized_key = key.to_ascii_lowercase();
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => {
legacy.remove(&normalized_key);
}
HttpHeaders::Modern(modern) => {
modern.remove(&normalized_key);
}
}
}
pub fn contains_key(&self, key: &str) -> bool {
let normalized_key = key.to_ascii_lowercase();
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => legacy.contains_key(&normalized_key),
HttpHeaders::Modern(modern) => modern.contains_key(&normalized_key),
}
}
pub fn iter(&self) -> HttpHeadersIterator<'_> {
match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => {
HttpHeadersIterator { inner: legacy.iter().collect(), index: 0 }
}
HttpHeaders::Modern(modern) => HttpHeadersIterator {
inner: modern
.iter()
.flat_map(|(k, vals)| vals.iter().map(move |v| (k, v)))
.collect(),
index: 0,
},
}
}
}
impl From<&http::HeaderMap> for HttpHeaders {
fn from(headers: &http::HeaderMap) -> Self {
let mut modern_headers = HashMap::new();
for name in headers.keys() {
let values: Vec<String> = headers
.get_all(name)
.iter()
.filter_map(|v| v.to_str().ok())
.map(|s| s.to_string())
.collect();
if !values.is_empty() {
modern_headers.insert(name.to_string(), values);
}
}
HttpHeaders::Modern(modern_headers)
}
}
impl From<HttpHeaders> for HashMap<String, Vec<String>> {
fn from(headers: HttpHeaders) -> Self {
match headers {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => {
legacy.into_iter().map(|(k, v)| (k, vec![v])).collect()
}
HttpHeaders::Modern(modern) => modern,
}
}
}
impl Default for HttpHeaders {
fn default() -> Self {
HttpHeaders::new()
}
}
impl IntoIterator for HttpHeaders {
type Item = (String, String);
type IntoIter = HttpHeadersIntoIterator;
fn into_iter(self) -> Self::IntoIter {
HttpHeadersIntoIterator {
inner: match self {
#[cfg(feature = "http-headers-compat")]
HttpHeaders::Legacy(legacy) => legacy.into_iter().collect(),
HttpHeaders::Modern(modern) => modern
.into_iter()
.flat_map(|(k, vals)| {
vals.into_iter().map(move |v| (k.clone(), v))
})
.collect(),
},
index: 0,
}
}
}
#[derive(Debug)]
pub struct HttpHeadersIntoIterator {
inner: Vec<(String, String)>,
index: usize,
}
impl Iterator for HttpHeadersIntoIterator {
type Item = (String, String);
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.inner.len() {
let item = self.inner[self.index].clone();
self.index += 1;
Some(item)
} else {
None
}
}
}
impl<'a> IntoIterator for &'a HttpHeaders {
type Item = (&'a String, &'a String);
type IntoIter = HttpHeadersIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Debug)]
pub struct HttpHeadersIterator<'a> {
inner: Vec<(&'a String, &'a String)>,
index: usize,
}
impl<'a> Iterator for HttpHeadersIterator<'a> {
type Item = (&'a String, &'a String);
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.inner.len() {
let item = self.inner[self.index];
self.index += 1;
Some(item)
} else {
None
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpResponse {
pub body: Vec<u8>,
pub headers: HttpHeaders,
pub status: u16,
pub url: Url,
pub version: HttpVersion,
#[serde(default)]
pub metadata: Option<Vec<u8>>,
}
impl HttpResponse {
pub fn parts(&self) -> Result<response::Parts> {
let mut converted =
response::Builder::new().status(self.status).body(())?;
{
let headers = converted.headers_mut();
for header in &self.headers {
headers.append(
http::header::HeaderName::from_str(header.0.as_str())?,
HeaderValue::from_str(header.1.as_str())?,
);
}
}
Ok(converted.into_parts().0)
}
#[must_use]
fn warning_code(&self) -> Option<usize> {
self.headers.get(WARNING).and_then(|hdr| {
hdr.as_str().chars().take(3).collect::<String>().parse().ok()
})
}
fn add_warning(&mut self, url: &Url, code: usize, message: &str) {
let host = url_host_str(url);
let escaped_message =
message.replace('"', "'").replace(['\n', '\r'], " ");
self.headers.insert(
WARNING.to_string(),
format!(
"{} {} \"{}\" \"{}\"",
code,
host,
escaped_message,
httpdate::fmt_http_date(SystemTime::now())
),
);
}
fn remove_warning(&mut self) {
self.headers.remove(WARNING);
}
pub fn update_headers(&mut self, parts: &response::Parts) -> Result<()> {
for name in parts.headers.keys() {
self.headers.remove(name.as_str());
}
for (name, value) in parts.headers.iter() {
if let Ok(v) = value.to_str() {
self.headers.append(name.as_str().to_string(), v.to_string());
}
}
Ok(())
}
#[must_use]
fn must_revalidate(&self) -> bool {
self.headers.get(CACHE_CONTROL.as_str()).is_some_and(|val| {
val.as_str().to_lowercase().contains("must-revalidate")
})
}
pub fn cache_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHE.to_string(), hit_or_miss.to_string());
}
pub fn cache_lookup_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHELOOKUP.to_string(), hit_or_miss.to_string());
}
}
pub trait CacheManager: Send + Sync + 'static {
fn get(
&self,
cache_key: &str,
) -> impl Future<Output = Result<Option<(HttpResponse, CachePolicy)>>> + Send;
fn put(
&self,
cache_key: String,
res: HttpResponse,
policy: CachePolicy,
) -> impl Future<Output = Result<HttpResponse>> + Send;
fn delete(
&self,
cache_key: &str,
) -> impl Future<Output = Result<()>> + Send;
}
pub trait StreamingCacheManager: Send + Sync + 'static {
type Body: http_body::Body + Send + 'static;
fn get(
&self,
cache_key: &str,
) -> impl Future<Output = Result<Option<(Response<Self::Body>, CachePolicy)>>>
+ Send
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
fn put<B>(
&self,
cache_key: String,
response: Response<B>,
policy: CachePolicy,
request_url: Url,
metadata: Option<Vec<u8>>,
) -> impl Future<Output = Result<Response<Self::Body>>> + Send
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
fn convert_body<B>(
&self,
response: Response<B>,
) -> impl Future<Output = Result<Response<Self::Body>>> + Send
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
fn delete(
&self,
cache_key: &str,
) -> impl Future<Output = Result<()>> + Send;
fn empty_body(&self) -> Self::Body;
#[cfg(feature = "streaming")]
fn body_to_bytes_stream(
body: Self::Body,
) -> impl futures_util::Stream<
Item = std::result::Result<
bytes::Bytes,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error: Send + Sync + 'static;
}
pub trait Middleware: Send {
fn overridden_cache_mode(&self) -> Option<CacheMode> {
None
}
fn is_method_get_head(&self) -> bool;
fn policy(&self, response: &HttpResponse) -> Result<CachePolicy>;
fn policy_with_options(
&self,
response: &HttpResponse,
options: CacheOptions,
) -> Result<CachePolicy>;
fn update_headers(&mut self, parts: &request::Parts) -> Result<()>;
fn force_no_cache(&mut self) -> Result<()>;
fn parts(&self) -> Result<request::Parts>;
fn url(&self) -> Result<Url>;
fn method(&self) -> Result<String>;
fn remote_fetch(
&mut self,
) -> impl Future<Output = Result<HttpResponse>> + Send;
}
pub trait HttpCacheInterface<B = Vec<u8>>: Send + Sync {
fn analyze_request(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
) -> Result<CacheAnalysis>;
#[allow(async_fn_in_trait)]
async fn lookup_cached_response(
&self,
key: &str,
) -> Result<Option<(HttpResponse, CachePolicy)>>;
#[allow(async_fn_in_trait)]
async fn process_response(
&self,
analysis: CacheAnalysis,
response: Response<B>,
metadata: Option<Vec<u8>>,
) -> Result<Response<B>>;
fn prepare_conditional_request(
&self,
parts: &mut request::Parts,
cached_response: &HttpResponse,
policy: &CachePolicy,
) -> Result<()>;
#[allow(async_fn_in_trait)]
async fn handle_not_modified(
&self,
cached_response: HttpResponse,
fresh_parts: &response::Parts,
) -> Result<HttpResponse>;
}
pub trait HttpCacheStreamInterface: Send + Sync {
type Body: http_body::Body + Send + 'static;
fn analyze_request(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
) -> Result<CacheAnalysis>;
#[allow(async_fn_in_trait)]
async fn lookup_cached_response(
&self,
key: &str,
) -> Result<Option<(Response<Self::Body>, CachePolicy)>>
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
#[allow(async_fn_in_trait)]
async fn process_response<B>(
&self,
analysis: CacheAnalysis,
response: Response<B>,
metadata: Option<Vec<u8>>,
) -> Result<Response<Self::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
fn prepare_conditional_request(
&self,
parts: &mut request::Parts,
cached_response: &Response<Self::Body>,
policy: &CachePolicy,
) -> Result<()>;
#[allow(async_fn_in_trait)]
async fn handle_not_modified(
&self,
cached_response: Response<Self::Body>,
fresh_parts: &response::Parts,
) -> Result<Response<Self::Body>>
where
<Self::Body as http_body::Body>::Data: Send,
<Self::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static;
}
#[derive(Debug, Clone)]
pub struct CacheAnalysis {
pub cache_key: String,
pub should_cache: bool,
pub cache_mode: CacheMode,
pub cache_bust_keys: Vec<String>,
pub request_parts: request::Parts,
pub is_get_head: bool,
}
#[derive(Debug)]
pub enum FetchRequest {
Fresh,
FreshNoCache,
Conditional(Box<request::Parts>),
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum CacheMode {
#[default]
Default,
NoStore,
Reload,
NoCache,
ForceCache,
OnlyIfCached,
IgnoreRules,
}
impl TryFrom<http::Version> for HttpVersion {
type Error = BoxError;
fn try_from(value: http::Version) -> Result<Self> {
Ok(match value {
http::Version::HTTP_09 => Self::Http09,
http::Version::HTTP_10 => Self::Http10,
http::Version::HTTP_11 => Self::Http11,
http::Version::HTTP_2 => Self::H2,
http::Version::HTTP_3 => Self::H3,
_ => return Err(Box::new(BadVersion)),
})
}
}
impl From<HttpVersion> for http::Version {
fn from(value: HttpVersion) -> Self {
match value {
HttpVersion::Http09 => Self::HTTP_09,
HttpVersion::Http10 => Self::HTTP_10,
HttpVersion::Http11 => Self::HTTP_11,
HttpVersion::H2 => Self::HTTP_2,
HttpVersion::H3 => Self::HTTP_3,
}
}
}
#[cfg(feature = "http-types")]
impl TryFrom<http_types::Version> for HttpVersion {
type Error = BoxError;
fn try_from(value: http_types::Version) -> Result<Self> {
Ok(match value {
http_types::Version::Http0_9 => Self::Http09,
http_types::Version::Http1_0 => Self::Http10,
http_types::Version::Http1_1 => Self::Http11,
http_types::Version::Http2_0 => Self::H2,
http_types::Version::Http3_0 => Self::H3,
_ => return Err(Box::new(BadVersion)),
})
}
}
#[cfg(feature = "http-types")]
impl From<HttpVersion> for http_types::Version {
fn from(value: HttpVersion) -> Self {
match value {
HttpVersion::Http09 => Self::Http0_9,
HttpVersion::Http10 => Self::Http1_0,
HttpVersion::Http11 => Self::Http1_1,
HttpVersion::H2 => Self::Http2_0,
HttpVersion::H3 => Self::Http3_0,
}
}
}
pub use http_cache_semantics::CacheOptions;
pub type CacheKey = Arc<dyn Fn(&request::Parts) -> String + Send + Sync>;
pub type CacheModeFn = Arc<dyn Fn(&request::Parts) -> CacheMode + Send + Sync>;
pub type ResponseCacheModeFn = Arc<
dyn Fn(&request::Parts, &HttpResponse) -> Option<CacheMode> + Send + Sync,
>;
pub type CacheBust = Arc<
dyn Fn(&request::Parts, &Option<CacheKey>, &str) -> Vec<String>
+ Send
+ Sync,
>;
pub type HttpCacheMetadata = Vec<u8>;
pub type MetadataProvider = Arc<
dyn Fn(&request::Parts, &response::Parts) -> Option<HttpCacheMetadata>
+ Send
+ Sync,
>;
pub type ModifyResponse = Arc<dyn Fn(&mut HttpResponse) + Send + Sync>;
#[derive(Clone)]
pub struct HttpCacheOptions {
pub cache_options: Option<CacheOptions>,
pub cache_key: Option<CacheKey>,
pub cache_mode_fn: Option<CacheModeFn>,
pub response_cache_mode_fn: Option<ResponseCacheModeFn>,
pub cache_bust: Option<CacheBust>,
pub modify_response: Option<ModifyResponse>,
pub cache_status_headers: bool,
pub max_ttl: Option<Duration>,
#[cfg(feature = "rate-limiting")]
pub rate_limiter: Option<Arc<dyn CacheAwareRateLimiter>>,
pub metadata_provider: Option<MetadataProvider>,
}
impl Default for HttpCacheOptions {
fn default() -> Self {
Self {
cache_options: None,
cache_key: None,
cache_mode_fn: None,
response_cache_mode_fn: None,
cache_bust: None,
modify_response: None,
cache_status_headers: true,
max_ttl: None,
#[cfg(feature = "rate-limiting")]
rate_limiter: None,
metadata_provider: None,
}
}
}
impl Debug for HttpCacheOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[cfg(feature = "rate-limiting")]
{
f.debug_struct("HttpCacheOptions")
.field("cache_options", &self.cache_options)
.field("cache_key", &"Fn(&request::Parts) -> String")
.field("cache_mode_fn", &"Fn(&request::Parts) -> CacheMode")
.field(
"response_cache_mode_fn",
&"Fn(&request::Parts, &HttpResponse) -> Option<CacheMode>",
)
.field("cache_bust", &"Fn(&request::Parts) -> Vec<String>")
.field("modify_response", &"Fn(&mut ModifyResponse)")
.field("cache_status_headers", &self.cache_status_headers)
.field("max_ttl", &self.max_ttl)
.field("rate_limiter", &"Option<CacheAwareRateLimiter>")
.field(
"metadata_provider",
&"Fn(&request::Parts, &response::Parts) -> Option<Vec<u8>>",
)
.finish()
}
#[cfg(not(feature = "rate-limiting"))]
{
f.debug_struct("HttpCacheOptions")
.field("cache_options", &self.cache_options)
.field("cache_key", &"Fn(&request::Parts) -> String")
.field("cache_mode_fn", &"Fn(&request::Parts) -> CacheMode")
.field(
"response_cache_mode_fn",
&"Fn(&request::Parts, &HttpResponse) -> Option<CacheMode>",
)
.field("cache_bust", &"Fn(&request::Parts) -> Vec<String>")
.field("modify_response", &"Fn(&mut ModifyResponse)")
.field("cache_status_headers", &self.cache_status_headers)
.field("max_ttl", &self.max_ttl)
.field(
"metadata_provider",
&"Fn(&request::Parts, &response::Parts) -> Option<Vec<u8>>",
)
.finish()
}
}
}
impl HttpCacheOptions {
fn create_cache_key(
&self,
parts: &request::Parts,
override_method: Option<&str>,
) -> String {
if let Some(cache_key) = &self.cache_key {
cache_key(parts)
} else {
format!(
"{}:{}",
override_method.unwrap_or_else(|| parts.method.as_str()),
parts.uri
)
}
}
pub fn create_cache_key_for_invalidation(
&self,
parts: &request::Parts,
method_override: &str,
) -> String {
self.create_cache_key(parts, Some(method_override))
}
pub fn http_response_to_response<B>(
http_response: &HttpResponse,
body: B,
) -> Result<Response<B>> {
let mut response_builder = Response::builder()
.status(http_response.status)
.version(http_response.version.into());
for (name, value) in &http_response.headers {
if let (Ok(header_name), Ok(header_value)) =
(name.parse::<http::HeaderName>(), value.parse::<HeaderValue>())
{
response_builder =
response_builder.header(header_name, header_value);
}
}
Ok(response_builder.body(body)?)
}
fn parts_to_http_response(
&self,
parts: &response::Parts,
request_parts: &request::Parts,
metadata: Option<Vec<u8>>,
) -> Result<HttpResponse> {
Ok(HttpResponse {
body: vec![], headers: (&parts.headers).into(),
status: parts.status.as_u16(),
url: extract_url_from_request_parts(request_parts)?,
version: parts.version.try_into()?,
metadata,
})
}
fn evaluate_response_cache_mode(
&self,
request_parts: &request::Parts,
http_response: &HttpResponse,
original_mode: CacheMode,
) -> CacheMode {
if let Some(response_cache_mode_fn) = &self.response_cache_mode_fn {
if let Some(override_mode) =
response_cache_mode_fn(request_parts, http_response)
{
return override_mode;
}
}
original_mode
}
pub fn generate_metadata(
&self,
request_parts: &request::Parts,
response_parts: &response::Parts,
) -> Option<HttpCacheMetadata> {
self.metadata_provider
.as_ref()
.and_then(|provider| provider(request_parts, response_parts))
}
pub fn modify_response_before_caching(&self, response: &mut HttpResponse) {
if let Some(modify_response) = &self.modify_response {
modify_response(response);
}
}
fn create_cache_policy(
&self,
request_parts: &request::Parts,
response_parts: &response::Parts,
) -> CachePolicy {
let cache_options = self.cache_options.unwrap_or_default();
if let Some(max_ttl) = self.max_ttl {
let cache_control = response_parts
.headers
.get("cache-control")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let existing_max_age =
cache_control.split(',').find_map(|directive| {
let directive = directive.trim();
if directive.starts_with("max-age=") {
directive.strip_prefix("max-age=")?.parse::<u64>().ok()
} else {
None
}
});
let max_ttl_seconds = max_ttl.as_secs();
let effective_max_age = match existing_max_age {
Some(existing) => std::cmp::min(existing, max_ttl_seconds),
None => max_ttl_seconds,
};
let mut new_directives = Vec::new();
for directive in cache_control.split(',').map(|d| d.trim()) {
if !directive.starts_with("max-age=") && !directive.is_empty() {
new_directives.push(directive.to_string());
}
}
new_directives.push(format!("max-age={}", effective_max_age));
let new_cache_control = new_directives.join(", ");
let mut modified_response_parts = response_parts.clone();
modified_response_parts.headers.insert(
"cache-control",
HeaderValue::from_str(&new_cache_control)
.unwrap_or_else(|_| HeaderValue::from_static("max-age=0")),
);
CachePolicy::new_options(
request_parts,
&modified_response_parts,
SystemTime::now(),
cache_options,
)
} else {
CachePolicy::new_options(
request_parts,
response_parts,
SystemTime::now(),
cache_options,
)
}
}
fn should_cache_response(
&self,
effective_cache_mode: CacheMode,
http_response: &HttpResponse,
is_get_head: bool,
policy: &CachePolicy,
) -> bool {
let is_cacheable_status = matches!(
http_response.status,
200 | 203 | 204 | 206 | 300 | 301 | 404 | 405 | 410 | 414 | 501
);
if is_cacheable_status {
match effective_cache_mode {
CacheMode::ForceCache => is_get_head,
CacheMode::IgnoreRules => true,
CacheMode::NoStore => false,
_ => is_get_head && policy.is_storable(),
}
} else {
false
}
}
fn analyze_request_internal(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
default_mode: CacheMode,
) -> Result<CacheAnalysis> {
let effective_mode = mode_override
.or_else(|| self.cache_mode_fn.as_ref().map(|f| f(parts)))
.unwrap_or(default_mode);
let is_get_head = parts.method == "GET" || parts.method == "HEAD";
let should_cache = effective_mode == CacheMode::IgnoreRules
|| (is_get_head && effective_mode != CacheMode::NoStore);
let cache_key = self.create_cache_key(parts, None);
let cache_bust_keys = if let Some(cache_bust) = &self.cache_bust {
cache_bust(parts, &self.cache_key, &cache_key)
} else {
Vec::new()
};
Ok(CacheAnalysis {
cache_key,
should_cache,
cache_mode: effective_mode,
cache_bust_keys,
request_parts: parts.clone(),
is_get_head,
})
}
}
#[derive(Debug, Clone)]
pub struct HttpCache<T: CacheManager> {
pub mode: CacheMode,
pub manager: T,
pub options: HttpCacheOptions,
}
#[derive(Debug, Clone)]
pub(crate) struct CachedUserMetadata(pub Option<Vec<u8>>);
#[derive(Debug, Clone)]
pub struct HttpStreamingCache<T: StreamingCacheManager> {
pub mode: CacheMode,
pub manager: T,
pub options: HttpCacheOptions,
}
fn response_warning_code<B>(response: &Response<B>) -> Option<usize> {
response
.headers()
.get(WARNING)
.and_then(|hdr| hdr.to_str().ok())
.and_then(|s| s.chars().take(3).collect::<String>().parse().ok())
}
fn response_add_warning<B>(
response: &mut Response<B>,
url: &Url,
code: usize,
message: &str,
) {
let host = url_host_str(url);
let escaped_message = message.replace('"', "'").replace(['\n', '\r'], " ");
let value = format!(
"{} {} \"{}\" \"{}\"",
code,
host,
escaped_message,
httpdate::fmt_http_date(SystemTime::now()),
);
if let Ok(hv) = HeaderValue::from_str(&value) {
response.headers_mut().insert(WARNING, hv);
}
}
fn response_remove_warning<B>(response: &mut Response<B>) {
response.headers_mut().remove(WARNING);
}
fn response_must_revalidate<B>(response: &Response<B>) -> bool {
response
.headers()
.get(CACHE_CONTROL)
.and_then(|v| v.to_str().ok())
.is_some_and(|val| val.to_lowercase().contains("must-revalidate"))
}
fn response_cache_status<B>(
response: &mut Response<B>,
hit_or_miss: HitOrMiss,
) {
if let Ok(hv) = HeaderValue::from_str(&hit_or_miss.to_string()) {
response.headers_mut().insert(XCACHE, hv);
}
}
fn response_cache_lookup_status<B>(
response: &mut Response<B>,
hit_or_miss: HitOrMiss,
) {
if let Ok(hv) = HeaderValue::from_str(&hit_or_miss.to_string()) {
response.headers_mut().insert(XCACHELOOKUP, hv);
}
}
fn apply_modify_response_shim<B>(
options: &HttpCacheOptions,
response: &mut Response<B>,
url: &Url,
) {
let modify = match &options.modify_response {
Some(f) => f,
None => return,
};
let mut shim = HttpResponse {
body: Vec::new(),
headers: HttpHeaders::from(response.headers()),
status: response.status().as_u16(),
url: url.clone(),
version: response.version().try_into().unwrap_or(HttpVersion::Http11),
metadata: None,
};
modify(&mut shim);
response.headers_mut().clear();
for (name, value) in shim.headers.iter() {
if let (Ok(hn), Ok(hv)) = (
http::header::HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
response.headers_mut().append(hn, hv);
}
}
if let Ok(new_status) = StatusCode::from_u16(shim.status) {
*response.status_mut() = new_status;
}
}
impl<T: StreamingCacheManager> HttpStreamingCache<T>
where
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
pub fn can_cache_request(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
) -> Result<bool> {
let analysis = <Self as HttpCacheStreamInterface>::analyze_request(
self,
parts,
mode_override,
)?;
Ok(analysis.should_cache)
}
#[cfg(feature = "rate-limiting")]
async fn apply_rate_limiting(&self, url: &Url) {
if let Some(rate_limiter) = &self.options.rate_limiter {
let rate_limit_key = url_hostname(url).unwrap_or("unknown");
rate_limiter.until_key_ready(rate_limit_key).await;
}
}
#[cfg(not(feature = "rate-limiting"))]
async fn apply_rate_limiting(&self, _url: &Url) {
}
pub async fn run_no_cache(&self, parts: &request::Parts) -> Result<()> {
self.manager
.delete(&self.options.create_cache_key(parts, Some("GET")))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(parts, Some("HEAD")))
.await
.ok();
let cache_key = self.options.create_cache_key(parts, None);
if let Some(cache_bust) = &self.options.cache_bust {
for key_to_cache_bust in
cache_bust(parts, &self.options.cache_key, &cache_key)
{
self.manager.delete(&key_to_cache_bust).await?;
}
}
Ok(())
}
pub async fn run<B, F, Fut>(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
fetch: F,
) -> Result<Response<T::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
F: FnOnce(FetchRequest) -> Fut,
Fut: Future<Output = Result<Response<B>>>,
{
let analysis = <Self as HttpCacheStreamInterface>::analyze_request(
self,
parts,
mode_override,
)?;
if !analysis.should_cache {
let url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&url).await;
let response = fetch(FetchRequest::Fresh).await?;
return self.remote_fetch_and_cache(analysis, response).await;
}
for key in &analysis.cache_bust_keys {
self.manager.delete(key).await?;
}
if let Some((mut cached_response, policy)) =
<Self as HttpCacheStreamInterface>::lookup_cached_response(
self,
&analysis.cache_key,
)
.await?
{
if self.options.cache_status_headers {
response_cache_lookup_status(
&mut cached_response,
HitOrMiss::HIT,
);
}
if let Some(warning_code) = response_warning_code(&cached_response)
{
if (100..200).contains(&warning_code) {
response_remove_warning(&mut cached_response);
}
}
match analysis.cache_mode {
CacheMode::Default => {
self.conditional_fetch(
&analysis,
fetch,
cached_response,
policy,
)
.await
}
CacheMode::NoCache => {
let url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&url).await;
let response = fetch(FetchRequest::FreshNoCache).await?;
let mut res =
self.remote_fetch_and_cache(analysis, response).await?;
if self.options.cache_status_headers {
response_cache_lookup_status(&mut res, HitOrMiss::HIT);
}
Ok(res)
}
CacheMode::ForceCache
| CacheMode::OnlyIfCached
| CacheMode::IgnoreRules => {
let url = extract_url_from_request_parts(parts)?;
response_add_warning(
&mut cached_response,
&url,
112,
"Disconnected operation",
);
if self.options.cache_status_headers {
response_cache_status(
&mut cached_response,
HitOrMiss::HIT,
);
}
Ok(cached_response)
}
CacheMode::Reload => {
let url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&url).await;
let response = fetch(FetchRequest::Fresh).await?;
let mut res =
self.remote_fetch_and_cache(analysis, response).await?;
if self.options.cache_status_headers {
response_cache_lookup_status(&mut res, HitOrMiss::HIT);
}
Ok(res)
}
_ => {
let url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&url).await;
let response = fetch(FetchRequest::Fresh).await?;
self.remote_fetch_and_cache(analysis, response).await
}
}
} else {
match analysis.cache_mode {
CacheMode::OnlyIfCached => {
let mut res = Response::builder()
.status(StatusCode::GATEWAY_TIMEOUT)
.body(self.manager.empty_body())
.map_err(|e| -> BoxError { e.into() })?;
if self.options.cache_status_headers {
response_cache_status(&mut res, HitOrMiss::MISS);
response_cache_lookup_status(&mut res, HitOrMiss::MISS);
}
Ok(res)
}
_ => {
let url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&url).await;
let response = fetch(FetchRequest::Fresh).await?;
self.remote_fetch_and_cache(analysis, response).await
}
}
}
}
async fn remote_fetch_and_cache<B>(
&self,
analysis: CacheAnalysis,
response: Response<B>,
) -> Result<Response<T::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
{
let res = <Self as HttpCacheStreamInterface>::process_response(
self,
analysis.clone(),
response,
None,
)
.await?;
Ok(res)
}
async fn conditional_fetch<B, F, Fut>(
&self,
analysis: &CacheAnalysis,
fetch: F,
mut cached_res: Response<T::Body>,
mut policy: CachePolicy,
) -> Result<Response<T::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
F: FnOnce(FetchRequest) -> Fut,
Fut: Future<Output = Result<Response<B>>>,
{
let parts = &analysis.request_parts;
let before_req = policy.before_request(parts, SystemTime::now());
match before_req {
BeforeRequest::Fresh(fresh_parts) => {
for name in fresh_parts.headers.keys() {
cached_res.headers_mut().remove(name);
}
for (name, value) in fresh_parts.headers.iter() {
cached_res
.headers_mut()
.append(name.clone(), value.clone());
}
if self.options.cache_status_headers {
response_cache_status(&mut cached_res, HitOrMiss::HIT);
response_cache_lookup_status(
&mut cached_res,
HitOrMiss::HIT,
);
}
Ok(cached_res)
}
BeforeRequest::Stale { request: stale_parts, matches } => {
let req_url = extract_url_from_request_parts(parts)?;
self.apply_rate_limiting(&req_url).await;
let fetch_result = if matches {
fetch(FetchRequest::Conditional(Box::new(stale_parts)))
.await
} else {
fetch(FetchRequest::Fresh).await
};
match fetch_result {
Ok(cond_res) => {
let status = cond_res.status();
if status.is_server_error()
&& response_must_revalidate(&cached_res)
{
response_add_warning(
&mut cached_res,
&req_url,
111,
"Revalidation failed",
);
if self.options.cache_status_headers {
response_cache_status(
&mut cached_res,
HitOrMiss::HIT,
);
}
Ok(cached_res)
} else if status == StatusCode::NOT_MODIFIED {
let (cond_parts, _cond_body) =
cond_res.into_parts();
let after_res = policy.after_response(
parts,
&cond_parts,
SystemTime::now(),
);
match after_res {
AfterResponse::Modified(
new_policy,
new_parts,
)
| AfterResponse::NotModified(
new_policy,
new_parts,
) => {
policy = new_policy;
for name in new_parts.headers.keys() {
cached_res.headers_mut().remove(name);
}
for (name, value) in
new_parts.headers.iter()
{
cached_res.headers_mut().append(
name.clone(),
value.clone(),
);
}
}
}
if self.options.cache_status_headers {
response_cache_status(
&mut cached_res,
HitOrMiss::HIT,
);
response_cache_lookup_status(
&mut cached_res,
HitOrMiss::HIT,
);
}
apply_modify_response_shim(
&self.options,
&mut cached_res,
&req_url,
);
let request_url =
extract_url_from_request_parts(parts)?;
let metadata = cached_res
.extensions()
.get::<CachedUserMetadata>()
.and_then(|m| m.0.clone());
let res = self
.manager
.put(
self.options.create_cache_key(parts, None),
cached_res,
policy,
request_url,
metadata,
)
.await?;
Ok(res)
} else if status == StatusCode::OK {
let (cond_parts, cond_body) = cond_res.into_parts();
let new_policy = self
.options
.create_cache_policy(parts, &cond_parts);
let metadata = self
.options
.generate_metadata(parts, &cond_parts);
let cond_res =
Response::from_parts(cond_parts, cond_body);
let request_url =
extract_url_from_request_parts(parts)?;
let mut cond_res = cond_res;
apply_modify_response_shim(
&self.options,
&mut cond_res,
&request_url,
);
let http_response_shim = HttpResponse {
body: vec![],
headers: cond_res.headers().into(),
status: cond_res.status().as_u16(),
url: request_url.clone(),
version: cond_res
.version()
.try_into()
.unwrap_or(HttpVersion::Http11),
metadata: metadata.clone(),
};
let effective_mode =
self.options.evaluate_response_cache_mode(
parts,
&http_response_shim,
analysis.cache_mode,
);
let is_cacheable =
self.options.should_cache_response(
effective_mode,
&http_response_shim,
analysis.is_get_head,
&new_policy,
);
if self.options.cache_status_headers {
response_cache_status(
&mut cond_res,
HitOrMiss::MISS,
);
response_cache_lookup_status(
&mut cond_res,
HitOrMiss::HIT,
);
}
if is_cacheable {
let res = self
.manager
.put(
self.options
.create_cache_key(parts, None),
cond_res,
new_policy,
request_url,
metadata,
)
.await?;
Ok(res)
} else {
let res =
self.manager.convert_body(cond_res).await?;
Ok(res)
}
} else {
let mut res =
self.manager.convert_body(cond_res).await?;
if self.options.cache_status_headers {
response_cache_status(
&mut res,
HitOrMiss::MISS,
);
response_cache_lookup_status(
&mut res,
HitOrMiss::HIT,
);
}
Ok(res)
}
}
Err(e) => {
if response_must_revalidate(&cached_res) {
Err(e)
} else {
response_add_warning(
&mut cached_res,
&req_url,
111,
"Revalidation failed",
);
if self.options.cache_status_headers {
response_cache_status(
&mut cached_res,
HitOrMiss::HIT,
);
}
Ok(cached_res)
}
}
}
}
}
}
}
impl<T: CacheManager> HttpCache<T> {
pub fn can_cache_request(
&self,
middleware: &impl Middleware,
) -> Result<bool> {
let analysis = self.analyze_request(
&middleware.parts()?,
middleware.overridden_cache_mode(),
)?;
Ok(analysis.should_cache)
}
#[cfg(feature = "rate-limiting")]
async fn apply_rate_limiting(&self, url: &Url) {
if let Some(rate_limiter) = &self.options.rate_limiter {
let rate_limit_key = url_hostname(url).unwrap_or("unknown");
rate_limiter.until_key_ready(rate_limit_key).await;
}
}
#[cfg(not(feature = "rate-limiting"))]
async fn apply_rate_limiting(&self, _url: &Url) {
}
pub async fn run_no_cache_from_parts(
&self,
parts: &request::Parts,
) -> Result<()> {
self.manager
.delete(&self.options.create_cache_key(parts, Some("GET")))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(parts, Some("HEAD")))
.await
.ok();
let cache_key = self.options.create_cache_key(parts, None);
if let Some(cache_bust) = &self.options.cache_bust {
for key_to_cache_bust in
cache_bust(parts, &self.options.cache_key, &cache_key)
{
self.manager.delete(&key_to_cache_bust).await?;
}
}
Ok(())
}
pub async fn run_no_cache(
&self,
middleware: &mut impl Middleware,
) -> Result<()> {
let parts = middleware.parts()?;
self.run_no_cache_from_parts(&parts).await
}
pub async fn run(
&self,
mut middleware: impl Middleware,
) -> Result<HttpResponse> {
let analysis = self.analyze_request(
&middleware.parts()?,
middleware.overridden_cache_mode(),
)?;
if !analysis.should_cache {
return self.remote_fetch(&mut middleware).await;
}
for key in &analysis.cache_bust_keys {
self.manager.delete(key).await?;
}
if let Some((mut cached_response, policy)) =
self.lookup_cached_response(&analysis.cache_key).await?
{
if self.options.cache_status_headers {
cached_response.cache_lookup_status(HitOrMiss::HIT);
}
if let Some(warning_code) = cached_response.warning_code() {
if (100..200).contains(&warning_code) {
cached_response.remove_warning();
}
}
match analysis.cache_mode {
CacheMode::Default => {
self.conditional_fetch(middleware, cached_response, policy)
.await
}
CacheMode::NoCache => {
middleware.force_no_cache()?;
let mut res = self.remote_fetch(&mut middleware).await?;
if self.options.cache_status_headers {
res.cache_lookup_status(HitOrMiss::HIT);
}
Ok(res)
}
CacheMode::ForceCache
| CacheMode::OnlyIfCached
| CacheMode::IgnoreRules => {
cached_response.add_warning(
&cached_response.url.clone(),
112,
"Disconnected operation",
);
if self.options.cache_status_headers {
cached_response.cache_status(HitOrMiss::HIT);
}
Ok(cached_response)
}
CacheMode::Reload => {
let mut res = self.remote_fetch(&mut middleware).await?;
if self.options.cache_status_headers {
res.cache_lookup_status(HitOrMiss::HIT);
}
Ok(res)
}
_ => self.remote_fetch(&mut middleware).await,
}
} else {
match analysis.cache_mode {
CacheMode::OnlyIfCached => {
let mut res = HttpResponse {
body: Vec::new(),
headers: HttpHeaders::default(),
status: 504,
url: middleware.url()?,
version: HttpVersion::Http11,
metadata: None,
};
if self.options.cache_status_headers {
res.cache_status(HitOrMiss::MISS);
res.cache_lookup_status(HitOrMiss::MISS);
}
Ok(res)
}
_ => self.remote_fetch(&mut middleware).await,
}
}
}
fn cache_mode(&self, middleware: &impl Middleware) -> Result<CacheMode> {
Ok(if let Some(mode) = middleware.overridden_cache_mode() {
mode
} else if let Some(cache_mode_fn) = &self.options.cache_mode_fn {
cache_mode_fn(&middleware.parts()?)
} else {
self.mode
})
}
async fn remote_fetch(
&self,
middleware: &mut impl Middleware,
) -> Result<HttpResponse> {
let url = middleware.url()?;
self.apply_rate_limiting(&url).await;
let mut res = middleware.remote_fetch().await?;
if self.options.cache_status_headers {
res.cache_status(HitOrMiss::MISS);
res.cache_lookup_status(HitOrMiss::MISS);
}
let policy = match self.options.cache_options {
Some(options) => middleware.policy_with_options(&res, options)?,
None => middleware.policy(&res)?,
};
let is_get_head = middleware.is_method_get_head();
let mut mode = self.cache_mode(middleware)?;
let parts = middleware.parts()?;
if let Some(response_cache_mode_fn) =
&self.options.response_cache_mode_fn
{
if let Some(override_mode) = response_cache_mode_fn(&parts, &res) {
mode = override_mode;
}
}
let is_cacheable = self.options.should_cache_response(
mode,
&res,
is_get_head,
&policy,
);
if is_cacheable {
let response_parts = res.parts()?;
res.metadata =
self.options.generate_metadata(&parts, &response_parts);
self.options.modify_response_before_caching(&mut res);
let res = self
.manager
.put(self.options.create_cache_key(&parts, None), res, policy)
.await?;
if !is_get_head {
let status = StatusCode::from_u16(res.status)?;
if status.is_success() || status.is_redirection() {
self.manager
.delete(
&self.options.create_cache_key(&parts, Some("GET")),
)
.await
.ok();
self.manager
.delete(
&self
.options
.create_cache_key(&parts, Some("HEAD")),
)
.await
.ok();
}
}
Ok(res)
} else if !is_get_head {
let status = StatusCode::from_u16(res.status)?;
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(&parts, Some("GET")))
.await
.ok();
self.manager
.delete(
&self.options.create_cache_key(&parts, Some("HEAD")),
)
.await
.ok();
}
Ok(res)
} else {
Ok(res)
}
}
async fn conditional_fetch(
&self,
mut middleware: impl Middleware,
mut cached_res: HttpResponse,
mut policy: CachePolicy,
) -> Result<HttpResponse> {
let parts = middleware.parts()?;
let before_req = policy.before_request(&parts, SystemTime::now());
match before_req {
BeforeRequest::Fresh(parts) => {
cached_res.update_headers(&parts)?;
if self.options.cache_status_headers {
cached_res.cache_status(HitOrMiss::HIT);
cached_res.cache_lookup_status(HitOrMiss::HIT);
}
return Ok(cached_res);
}
BeforeRequest::Stale { request: parts, matches } => {
if matches {
middleware.update_headers(&parts)?;
}
}
}
let req_url = middleware.url()?;
self.apply_rate_limiting(&req_url).await;
match middleware.remote_fetch().await {
Ok(mut cond_res) => {
let status = StatusCode::from_u16(cond_res.status)?;
if status.is_server_error() && cached_res.must_revalidate() {
cached_res.add_warning(
&req_url,
111,
"Revalidation failed",
);
if self.options.cache_status_headers {
cached_res.cache_status(HitOrMiss::HIT);
}
Ok(cached_res)
} else if cond_res.status == 304 {
let after_res = policy.after_response(
&parts,
&cond_res.parts()?,
SystemTime::now(),
);
match after_res {
AfterResponse::Modified(new_policy, parts)
| AfterResponse::NotModified(new_policy, parts) => {
policy = new_policy;
cached_res.update_headers(&parts)?;
}
}
if self.options.cache_status_headers {
cached_res.cache_status(HitOrMiss::HIT);
cached_res.cache_lookup_status(HitOrMiss::HIT);
}
self.options
.modify_response_before_caching(&mut cached_res);
let res = self
.manager
.put(
self.options.create_cache_key(&parts, None),
cached_res,
policy,
)
.await?;
Ok(res)
} else if cond_res.status == 200 {
let policy = match self.options.cache_options {
Some(options) => middleware
.policy_with_options(&cond_res, options)?,
None => middleware.policy(&cond_res)?,
};
if self.options.cache_status_headers {
cond_res.cache_status(HitOrMiss::MISS);
cond_res.cache_lookup_status(HitOrMiss::HIT);
}
let response_parts = cond_res.parts()?;
cond_res.metadata =
self.options.generate_metadata(&parts, &response_parts);
self.options.modify_response_before_caching(&mut cond_res);
let mode = self.cache_mode(&middleware)?;
let mode = self
.options
.evaluate_response_cache_mode(&parts, &cond_res, mode);
let is_get_head = middleware.is_method_get_head();
let is_cacheable = self.options.should_cache_response(
mode,
&cond_res,
is_get_head,
&policy,
);
if is_cacheable {
let res = self
.manager
.put(
self.options.create_cache_key(&parts, None),
cond_res,
policy,
)
.await?;
Ok(res)
} else {
Ok(cond_res)
}
} else {
if self.options.cache_status_headers {
cond_res.cache_status(HitOrMiss::MISS);
cond_res.cache_lookup_status(HitOrMiss::HIT);
}
Ok(cond_res)
}
}
Err(e) => {
if cached_res.must_revalidate() {
Err(e)
} else {
cached_res.add_warning(
&req_url,
111,
"Revalidation failed",
);
if self.options.cache_status_headers {
cached_res.cache_status(HitOrMiss::HIT);
}
Ok(cached_res)
}
}
}
}
}
impl<T: StreamingCacheManager> HttpCacheStreamInterface
for HttpStreamingCache<T>
where
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
type Body = T::Body;
fn analyze_request(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
) -> Result<CacheAnalysis> {
self.options.analyze_request_internal(parts, mode_override, self.mode)
}
async fn lookup_cached_response(
&self,
key: &str,
) -> Result<Option<(Response<Self::Body>, CachePolicy)>> {
self.manager.get(key).await
}
async fn process_response<B>(
&self,
analysis: CacheAnalysis,
response: Response<B>,
metadata: Option<Vec<u8>>,
) -> Result<Response<Self::Body>>
where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<StreamingError>,
<T::Body as http_body::Body>::Data: Send,
<T::Body as http_body::Body>::Error:
Into<StreamingError> + Send + Sync + 'static,
{
if !analysis.should_cache {
if !analysis.is_get_head {
let status = response.status();
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let mut converted_response =
self.manager.convert_body(response).await?;
if self.options.cache_status_headers {
converted_response.headers_mut().insert(
XCACHE,
"MISS".parse().map_err(StreamingError::new)?,
);
converted_response.headers_mut().insert(
XCACHELOOKUP,
"MISS".parse().map_err(StreamingError::new)?,
);
}
return Ok(converted_response);
}
for key in &analysis.cache_bust_keys {
self.manager.delete(key).await?;
}
let (parts, body) = response.into_parts();
let effective_metadata = metadata.or_else(|| {
self.options.generate_metadata(&analysis.request_parts, &parts)
});
let http_response = self.options.parts_to_http_response(
&parts,
&analysis.request_parts,
effective_metadata.clone(),
)?;
let effective_cache_mode = self.options.evaluate_response_cache_mode(
&analysis.request_parts,
&http_response,
analysis.cache_mode,
);
let response = Response::from_parts(parts, body);
if effective_cache_mode == CacheMode::NoStore {
if !analysis.is_get_head {
let status_code = StatusCode::from_u16(http_response.status)?;
if status_code.is_success() || status_code.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let mut converted_response =
self.manager.convert_body(response).await?;
if self.options.cache_status_headers {
converted_response.headers_mut().insert(
XCACHE,
"MISS".parse().map_err(StreamingError::new)?,
);
converted_response.headers_mut().insert(
XCACHELOOKUP,
"MISS".parse().map_err(StreamingError::new)?,
);
}
return Ok(converted_response);
}
let (parts, body) = response.into_parts();
let policy =
self.options.create_cache_policy(&analysis.request_parts, &parts);
let response = Response::from_parts(parts, body);
let should_cache_response = self.options.should_cache_response(
effective_cache_mode,
&http_response,
analysis.is_get_head,
&policy,
);
if should_cache_response {
let request_url =
extract_url_from_request_parts(&analysis.request_parts)?;
let mut response = response;
apply_modify_response_shim(
&self.options,
&mut response,
&request_url,
);
let mut cached_response = self
.manager
.put(
analysis.cache_key,
response,
policy,
request_url,
effective_metadata,
)
.await?;
if !analysis.is_get_head {
let status_code = StatusCode::from_u16(http_response.status)?;
if status_code.is_success() || status_code.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
if self.options.cache_status_headers {
cached_response.headers_mut().insert(
XCACHE,
"MISS".parse().map_err(StreamingError::new)?,
);
cached_response.headers_mut().insert(
XCACHELOOKUP,
"MISS".parse().map_err(StreamingError::new)?,
);
}
Ok(cached_response)
} else {
if !analysis.is_get_head {
let status_code = StatusCode::from_u16(http_response.status)?;
if status_code.is_success() || status_code.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let mut converted_response =
self.manager.convert_body(response).await?;
if self.options.cache_status_headers {
converted_response.headers_mut().insert(
XCACHE,
"MISS".parse().map_err(StreamingError::new)?,
);
converted_response.headers_mut().insert(
XCACHELOOKUP,
"MISS".parse().map_err(StreamingError::new)?,
);
}
Ok(converted_response)
}
}
fn prepare_conditional_request(
&self,
parts: &mut request::Parts,
_cached_response: &Response<Self::Body>,
policy: &CachePolicy,
) -> Result<()> {
let before_req = policy.before_request(parts, SystemTime::now());
if let BeforeRequest::Stale { request, .. } = before_req {
parts.headers.extend(request.headers);
}
Ok(())
}
async fn handle_not_modified(
&self,
cached_response: Response<Self::Body>,
fresh_parts: &response::Parts,
) -> Result<Response<Self::Body>> {
let (mut parts, body) = cached_response.into_parts();
for name in fresh_parts.headers.keys() {
parts.headers.remove(name);
}
for (name, value) in fresh_parts.headers.iter() {
parts.headers.append(name.clone(), value.clone());
}
let mut response = Response::from_parts(parts, body);
if self.options.cache_status_headers {
response_cache_status(&mut response, HitOrMiss::HIT);
response_cache_lookup_status(&mut response, HitOrMiss::HIT);
}
Ok(response)
}
}
impl<T: CacheManager> HttpCacheInterface for HttpCache<T> {
fn analyze_request(
&self,
parts: &request::Parts,
mode_override: Option<CacheMode>,
) -> Result<CacheAnalysis> {
self.options.analyze_request_internal(parts, mode_override, self.mode)
}
async fn lookup_cached_response(
&self,
key: &str,
) -> Result<Option<(HttpResponse, CachePolicy)>> {
self.manager.get(key).await
}
async fn process_response(
&self,
analysis: CacheAnalysis,
response: Response<Vec<u8>>,
metadata: Option<Vec<u8>>,
) -> Result<Response<Vec<u8>>> {
if !analysis.should_cache {
if !analysis.is_get_head {
let status = response.status();
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
return Ok(response);
}
for key in &analysis.cache_bust_keys {
self.manager.delete(key).await?;
}
let (parts, body) = response.into_parts();
let effective_metadata = metadata.or_else(|| {
self.options.generate_metadata(&analysis.request_parts, &parts)
});
let mut http_response = self.options.parts_to_http_response(
&parts,
&analysis.request_parts,
effective_metadata,
)?;
http_response.body = body.clone();
let effective_cache_mode = self.options.evaluate_response_cache_mode(
&analysis.request_parts,
&http_response,
analysis.cache_mode,
);
if effective_cache_mode == CacheMode::NoStore {
if !analysis.is_get_head {
let status = StatusCode::from_u16(http_response.status)?;
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let response = Response::from_parts(parts, body);
return Ok(response);
}
let policy = self.options.create_cache_policy(
&analysis.request_parts,
&http_response.parts()?,
);
let should_cache_response = self.options.should_cache_response(
effective_cache_mode,
&http_response,
analysis.is_get_head,
&policy,
);
if should_cache_response {
self.options.modify_response_before_caching(&mut http_response);
let cached_response = self
.manager
.put(analysis.cache_key, http_response, policy)
.await?;
if !analysis.is_get_head {
let status = StatusCode::from_u16(cached_response.status)?;
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let response_parts = cached_response.parts()?;
let mut response = Response::builder()
.status(response_parts.status)
.version(response_parts.version)
.body(cached_response.body)?;
*response.headers_mut() = response_parts.headers;
Ok(response)
} else {
if !analysis.is_get_head {
let status = StatusCode::from_u16(http_response.status)?;
if status.is_success() || status.is_redirection() {
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("GET"),
))
.await
.ok();
self.manager
.delete(&self.options.create_cache_key(
&analysis.request_parts,
Some("HEAD"),
))
.await
.ok();
}
}
let response = Response::from_parts(parts, body);
Ok(response)
}
}
fn prepare_conditional_request(
&self,
parts: &mut request::Parts,
_cached_response: &HttpResponse,
policy: &CachePolicy,
) -> Result<()> {
let before_req = policy.before_request(parts, SystemTime::now());
if let BeforeRequest::Stale { request, .. } = before_req {
parts.headers.extend(request.headers);
}
Ok(())
}
async fn handle_not_modified(
&self,
mut cached_response: HttpResponse,
fresh_parts: &response::Parts,
) -> Result<HttpResponse> {
cached_response.update_headers(fresh_parts)?;
if self.options.cache_status_headers {
cached_response.cache_status(HitOrMiss::HIT);
cached_response.cache_lookup_status(HitOrMiss::HIT);
}
Ok(cached_response)
}
}
#[cfg(test)]
mod test;