use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
accept_header: String,
}
impl CacheKey {
pub fn new(accept_header: impl Into<String>) -> Self {
Self {
accept_header: accept_header.into(),
}
}
pub fn from_headers(headers: &[(&str, &str)]) -> Self {
let combined = headers
.iter()
.map(|(k, v)| format!("{}:{}", k, v))
.collect::<Vec<_>>()
.join(";");
Self {
accept_header: combined,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
value: T,
expires_at: Instant,
}
impl<T> CacheEntry<T> {
fn new(value: T, ttl: Duration) -> Self {
Self {
value,
expires_at: Instant::now() + ttl,
}
}
fn is_expired(&self) -> bool {
Instant::now() > self.expires_at
}
}
#[derive(Debug)]
pub struct NegotiationCache<T>
where
T: Clone,
{
cache: HashMap<CacheKey, CacheEntry<T>>,
ttl: Duration,
max_entries: usize,
}
impl<T> NegotiationCache<T>
where
T: Clone,
{
pub fn new() -> Self {
Self {
cache: HashMap::new(),
ttl: Duration::from_secs(300), max_entries: 1000,
}
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
cache: HashMap::new(),
ttl,
max_entries: 1000,
}
}
pub fn with_config(ttl: Duration, max_entries: usize) -> Self {
Self {
cache: HashMap::new(),
ttl,
max_entries,
}
}
pub fn get(&mut self, key: &CacheKey) -> Option<T> {
if let Some(entry) = self.cache.get(key) {
if entry.is_expired() {
self.cache.remove(key);
return None;
}
return Some(entry.value.clone());
}
None
}
pub fn set(&mut self, key: CacheKey, value: T) {
if self.cache.len() >= self.max_entries {
self.evict_oldest();
}
let entry = CacheEntry::new(value, self.ttl);
self.cache.insert(key, entry);
}
pub fn get_or_compute<F>(&mut self, key: &CacheKey, compute: F) -> T
where
F: FnOnce() -> T,
{
if let Some(cached) = self.get(key) {
return cached;
}
let value = compute();
self.set(key.clone(), value.clone());
value
}
pub fn clear_expired(&mut self) {
self.cache.retain(|_, entry| !entry.is_expired());
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
fn evict_oldest(&mut self) {
if let Some(key) = self.cache.keys().next().cloned() {
self.cache.remove(&key);
}
}
}
impl<T> Default for NegotiationCache<T>
where
T: Clone,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::negotiation::MediaType;
#[test]
fn test_cache_key_new() {
let key = CacheKey::new("application/json");
assert_eq!(key.accept_header, "application/json");
}
#[test]
fn test_cache_key_from_headers() {
let key =
CacheKey::from_headers(&[("Accept", "application/json"), ("Accept-Language", "en-US")]);
assert!(key.accept_header.contains("Accept:application/json"));
assert!(key.accept_header.contains("Accept-Language:en-US"));
}
#[test]
fn test_cache_get_set() {
let mut cache: NegotiationCache<MediaType> = NegotiationCache::new();
let key = CacheKey::new("application/json");
let media_type = MediaType::new("application", "json");
cache.set(key.clone(), media_type);
let result = cache.get(&key);
assert!(result.is_some());
assert_eq!(result.unwrap().subtype, "json");
}
#[test]
fn test_cache_get_or_compute() {
let mut cache: NegotiationCache<MediaType> = NegotiationCache::new();
let key = CacheKey::new("application/json");
let result = cache.get_or_compute(&key, || MediaType::new("application", "json"));
assert_eq!(result.subtype, "json");
let mut called = false;
let result2 = cache.get_or_compute(&key, || {
called = true;
MediaType::new("application", "xml")
});
assert!(!called);
assert_eq!(result2.subtype, "json");
}
#[test]
fn test_cache_expiration() {
let mut cache: NegotiationCache<MediaType> =
NegotiationCache::with_ttl(Duration::from_millis(10));
let key = CacheKey::new("application/json");
let media_type = MediaType::new("application", "json");
cache.set(key.clone(), media_type);
assert!(cache.get(&key).is_some());
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cache_clear() {
let mut cache: NegotiationCache<MediaType> = NegotiationCache::new();
let key = CacheKey::new("application/json");
cache.set(key, MediaType::new("application", "json"));
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_max_entries() {
let mut cache: NegotiationCache<MediaType> = NegotiationCache::with_config(
Duration::from_secs(300),
2, );
cache.set(CacheKey::new("key1"), MediaType::new("application", "json"));
cache.set(CacheKey::new("key2"), MediaType::new("text", "html"));
assert_eq!(cache.len(), 2);
cache.set(CacheKey::new("key3"), MediaType::new("application", "xml"));
assert_eq!(cache.len(), 2);
}
}