#[cfg(test)]
mod tests;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::application::Application;
use crate::header::Header;
use crate::middleware::Middleware;
use crate::request::{METHOD, Request};
use crate::response::Response;
use crate::server::ConnectionInfo;
struct CachedEntry {
response: Response,
inserted_at: Instant,
}
struct CacheStore {
entries: HashMap<String, CachedEntry>,
order: VecDeque<String>,
}
impl CacheStore {
fn new() -> Self {
CacheStore { entries: HashMap::new(), order: VecDeque::new() }
}
fn get(&self, key: &str, ttl: Duration) -> Option<&CachedEntry> {
self.entries.get(key).filter(|e| e.inserted_at.elapsed() < ttl)
}
fn insert(&mut self, key: String, entry: CachedEntry, capacity: usize) {
if self.entries.contains_key(&key) {
self.entries.insert(key, entry);
return;
}
if self.entries.len() >= capacity {
if let Some(oldest) = self.order.pop_front() {
self.entries.remove(&oldest);
}
}
self.order.push_back(key.clone());
self.entries.insert(key, entry);
}
fn purge_expired(&mut self, ttl: Duration) {
let expired: Vec<String> = self.entries.iter()
.filter(|(_, e)| e.inserted_at.elapsed() >= ttl)
.map(|(k, _)| k.clone())
.collect();
for k in &expired {
self.entries.remove(k);
self.order.retain(|o| o != k);
}
}
}
#[derive(Clone)]
pub struct CacheLayer {
store: Arc<Mutex<CacheStore>>,
hits: Arc<AtomicU64>,
misses: Arc<AtomicU64>,
capacity: usize,
ttl: Duration,
vary_headers: Vec<String>,
}
impl CacheLayer {
pub fn memory(capacity: usize) -> Self {
CacheLayer {
store: Arc::new(Mutex::new(CacheStore::new())),
hits: Arc::new(AtomicU64::new(0)),
misses: Arc::new(AtomicU64::new(0)),
capacity,
ttl: Duration::from_secs(60),
vary_headers: vec![],
}
}
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
pub fn size(&self) -> usize {
self.store.lock().unwrap().entries.len()
}
pub fn clear(&self) {
let mut guard = self.store.lock().unwrap();
guard.entries.clear();
guard.order.clear();
}
pub fn ttl(mut self, secs: u64) -> Self {
self.ttl = Duration::from_secs(secs);
self
}
pub fn vary_by_header(mut self, name: &str) -> Self {
self.vary_headers.push(name.to_ascii_lowercase());
self
}
fn store(&self) -> &Mutex<CacheStore> {
&self.store
}
fn cache_key(&self, request: &Request) -> String {
let mut key = request.request_uri.clone();
for vh in &self.vary_headers {
let val = request.headers.iter()
.find(|h| h.name.eq_ignore_ascii_case(vh))
.map(|h| h.value.as_str())
.unwrap_or("");
key.push('\x00');
key.push_str(val);
}
key
}
fn request_bypasses_cache(request: &Request) -> bool {
request.headers.iter().any(|h| {
h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL)
&& h.value.to_ascii_lowercase().contains("no-cache")
})
}
fn response_is_cacheable(response: &Response) -> bool {
if response.status_code < 200 || response.status_code >= 300 {
return false;
}
!response.headers.iter().any(|h| {
if !h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL) {
return false;
}
let v = h.value.to_ascii_lowercase();
v.contains("no-store") || v.contains("private")
})
}
fn age_secs(entry: &CachedEntry) -> u64 {
entry.inserted_at.elapsed().as_secs()
}
fn cached_response(entry: &CachedEntry) -> Response {
let mut response = entry.response.clone();
let age = Self::age_secs(entry);
if let Some(h) = response.headers.iter_mut().find(|h| h.name.eq_ignore_ascii_case("Age")) {
h.value = age.to_string();
} else {
response.headers.push(Header { name: "Age".to_string(), value: age.to_string() });
}
response
}
}
impl Middleware for CacheLayer {
fn handle(
&self,
request: &Request,
connection: &ConnectionInfo,
next: &dyn Application,
) -> Result<Response, String> {
if request.method != METHOD.get {
return next.execute(request, connection);
}
let key = self.cache_key(request);
let bypass = Self::request_bypasses_cache(request);
if !bypass {
let guard = self.store().lock().unwrap();
if let Some(entry) = guard.get(&key, self.ttl) {
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(Self::cached_response(entry));
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
let response = next.execute(request, connection)?;
if Self::response_is_cacheable(&response) {
let mut guard = self.store().lock().unwrap();
guard.purge_expired(self.ttl);
guard.insert(
key,
CachedEntry { response: response.clone(), inserted_at: Instant::now() },
self.capacity,
);
}
Ok(response)
}
}