use crate::prelude::*;
use comprash::CompressedResponse;
use utils::SuperUnsafePointer;
pub(crate) type Transformation = Pin<Arc<dyn Fn(&str) -> Cow<'static, str> + Send + Sync>>;
#[derive(Clone)]
pub(crate) struct Rule {
name: &'static str,
transformation: Transformation,
default: &'static str,
}
impl Debug for Rule {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut s = f.debug_struct(utils::ident_str!(Rule));
utils::fmt_fields!(
s,
(self.name),
(self.transformation, &"[transformation fn]".as_clean()),
(self.default)
);
s.finish()
}
}
impl Rule {
pub(crate) fn name(&self) -> &'static str {
self.name
}
pub(crate) fn default(&self) -> &'static str {
self.default
}
pub(crate) fn transformation(&self) -> &Transformation {
&self.transformation
}
}
#[derive(Debug, Clone, Default)]
#[must_use = "supply your vary settings to Kvarn"]
pub struct Settings {
pub(crate) rules: Vec<Rule>,
}
impl Settings {
pub fn empty() -> Self {
Self { rules: Vec::new() }
}
pub fn add_rule(
mut self,
request_header: &'static str,
transformation: impl Fn(&str) -> Cow<'static, str> + Send + Sync + 'static,
default: &'static str,
) -> Self {
if self.rules.len() > 4 {
warn!("More than 4 headers affect the caching of requests. This will exponentially increase memory usage.");
}
for byte in request_header.as_bytes().iter().copied() {
assert!(
utils::is_valid_header_value_byte(byte),
"A Vary request header contains invalid bytes."
);
}
self.rules.push(Rule {
name: request_header,
transformation: Arc::pin(transformation),
default,
});
self
}
}
pub type Vary = extensions::RuleSet<Settings>;
impl Vary {
pub fn rules_from_request<'a, T>(&'a self, request: &Request<T>) -> Cow<'a, Settings> {
self.get(request.uri().path())
.map_or_else(|| Cow::Owned(Settings::default()), Cow::Borrowed)
}
}
impl Default for Vary {
fn default() -> Self {
Self::empty()
}
}
#[must_use]
fn get_header(headers: &[Header]) -> HeaderValue {
use bytes::BufMut;
let always_add = &b"accept-encoding, range"[..];
let len = headers
.iter()
.fold(0, |acc, header| acc + header.name.len())
+ headers.len() * 2
+ always_add.len();
let mut bytes = BytesMut::with_capacity(len);
bytes.put(always_add);
for header in headers.iter() {
bytes.put(&b", "[..]);
bytes.put(header.name.as_bytes());
}
unsafe { HeaderValue::from_maybe_shared_unchecked(bytes) }
}
pub(crate) fn apply_header(response: &mut Response<Bytes>, headers: &[Header]) {
if !response.body().is_empty() {
let header = get_header(headers);
response.headers_mut().insert("vary", header);
}
}
#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord)]
pub(crate) struct Header {
name: &'static str,
transformed: Cow<'static, str>,
}
#[derive(Debug, PartialEq, Eq, Clone)]
struct ReferenceHeader {
name: &'static str,
transformation: SuperUnsafePointer<Transformation>,
default: &'static str,
}
pub(crate) type HeaderCollection = Vec<Header>;
pub(crate) struct CacheParams {
position: usize,
headers: HeaderCollection,
}
#[derive(Debug)]
pub struct VariedResponse {
reference_headers: Vec<ReferenceHeader>,
responses: Vec<Arc<(CompressedResponse, HeaderCollection)>>,
}
impl VariedResponse {
pub(crate) unsafe fn new<T>(
response: CompressedResponse,
request: &Request<T>,
settings: &Settings,
) -> Self {
let available_headers = settings
.rules
.iter()
.map(|rule| {
ReferenceHeader {
name: rule.name(),
transformation: SuperUnsafePointer::new(rule.transformation()),
default: rule.default(),
}
})
.collect();
let mut me = Self {
reference_headers: available_headers,
responses: Vec::new(),
};
let params = me.get_by_request(request).unwrap_err();
me.push_response(response, params);
me
}
pub(crate) fn push_response(
&mut self,
response: CompressedResponse,
params: CacheParams,
) -> &Arc<(CompressedResponse, HeaderCollection)> {
debug_assert_eq!(self.reference_headers.len(), params.headers.len());
let CacheParams { position, headers } = params;
self.responses
.insert(position, Arc::new((response, headers)));
&self.responses[position]
}
fn get(&self, other: &[Header]) -> Result<usize, usize> {
self.responses.binary_search_by_key(&other, |pair| &pair.1)
}
fn get_headers_for_request<T>(&self, request: &Request<T>) -> HeaderCollection {
let mut headers = Vec::new();
for reference in &self.reference_headers {
let name = reference.name;
if let Some(header) = request
.headers()
.get(name)
.map(HeaderValue::to_str)
.and_then(Result::ok)
{
let transformation = unsafe { reference.transformation.get() };
let header = transformation(header);
headers.push(Header {
name: reference.name,
transformed: header,
});
} else {
headers.push(Header {
name: reference.name,
transformed: Cow::Borrowed(reference.default),
});
}
}
headers
}
pub(crate) fn get_by_request<T>(
&self,
request: &Request<T>,
) -> Result<&Arc<(CompressedResponse, HeaderCollection)>, CacheParams> {
let headers = self.get_headers_for_request(request);
match self.get(&headers) {
Ok(position) => Ok(&self.responses[position]),
Err(sorted_position) => Err(CacheParams {
position: sorted_position,
headers,
}),
}
}
pub(crate) fn first(&self) -> &Arc<(CompressedResponse, HeaderCollection)> {
self.responses.get(0).unwrap()
}
}