use http::header::{self, HeaderMap};
use http::{Method, StatusCode, Uri, Version};
use md5;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub mod axum;
pub use self::axum::{CacheMiddlewareConfig, CacheMiddlewareState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpCacheResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
pub cached_at: chrono::DateTime<chrono::Utc>,
pub ttl: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct HttpCacheKeyGenerator {
include_query: bool,
exclude_headers: Vec<String>,
include_method: bool,
include_version: bool,
}
impl HttpCacheKeyGenerator {
pub fn new() -> Self {
Self::default()
}
pub fn with_include_query(mut self, include: bool) -> Self {
self.include_query = include;
self
}
pub fn with_exclude_headers(mut self, headers: Vec<String>) -> Self {
self.exclude_headers = headers;
self
}
pub fn with_include_method(mut self, include: bool) -> Self {
self.include_method = include;
self
}
pub fn with_include_version(mut self, include: bool) -> Self {
self.include_version = include;
self
}
pub fn generate_key(&self, request: &HttpRequest) -> String {
let mut key_parts = Vec::new();
if self.include_method {
key_parts.push(request.method.to_string());
}
key_parts.push(request.uri.path().to_string());
if self.include_query && request.uri.query().is_some() {
key_parts.push(request.uri.query().unwrap_or("").to_string());
}
if self.include_version {
key_parts.push(format!("{:?}", request.version));
}
for (name, value) in &request.headers {
if !self.exclude_headers.iter().any(|h| name.as_str() == h)
&& (name == header::ACCEPT_ENCODING
|| name == header::VARY
|| name == header::AUTHORIZATION)
{
key_parts.push(format!("{}:{}", name, value.to_str().unwrap_or("")));
}
}
let key_string = key_parts.join(":");
let hash = md5::compute(&key_string);
format!("{:x}", hash)
}
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: Method,
pub uri: Uri,
pub version: Version,
pub headers: HeaderMap,
pub body: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct HttpCachePolicy {
pub cache_status_codes: Vec<StatusCode>,
pub default_ttl: u64,
pub use_header_ttl: bool,
pub ignore_patterns: Vec<String>,
pub key_prefix: String,
}
impl HttpCachePolicy {
pub fn new() -> Self {
Self::default()
}
pub fn with_cache_status_codes(mut self, codes: Vec<StatusCode>) -> Self {
self.cache_status_codes = codes;
self
}
pub fn with_default_ttl(mut self, ttl: u64) -> Self {
self.default_ttl = ttl;
self
}
pub fn with_use_header_ttl(mut self, use_header: bool) -> Self {
self.use_header_ttl = use_header;
self
}
pub fn with_ignore_patterns(mut self, patterns: Vec<String>) -> Self {
self.ignore_patterns = patterns;
self
}
pub fn should_cache_response(&self, status: StatusCode) -> bool {
self.cache_status_codes.contains(&status)
}
pub fn extract_ttl_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
if !self.use_header_ttl {
return None;
}
if let Some(cache_control) = headers.get(header::CACHE_CONTROL) {
if let Ok(value) = cache_control.to_str() {
for directive in value.split(',') {
let directive = directive.trim();
if directive.starts_with("max-age=") {
if let Some(age) = directive.strip_prefix("max-age=") {
return age.parse().ok();
}
}
}
}
}
None
}
}
#[async_trait::async_trait]
pub trait HttpCacheAdapter {
async fn get_response(
&self,
key: &str,
) -> Result<Option<HttpCacheResponse>, crate::error::CacheError>;
async fn set_response(
&self,
key: &str,
response: &HttpCacheResponse,
) -> Result<(), crate::error::CacheError>;
async fn delete_response(&self, key: &str) -> Result<bool, crate::error::CacheError>;
async fn invalidate_by_pattern(&self, pattern: &str) -> Result<u64, crate::error::CacheError>;
async fn invalidate_by_path_pattern(
&self,
path_pattern: &str,
) -> Result<u64, crate::error::CacheError> {
self.invalidate_by_pattern(path_pattern).await
}
async fn get_responses(
&self,
keys: &[&str],
) -> Result<HashMap<String, HttpCacheResponse>, crate::error::CacheError>;
}
#[derive(Debug, Clone, Default)]
pub struct PathPatternMatcher;
impl PathPatternMatcher {
pub fn new() -> Self {
Self
}
pub fn matches(&self, path: &str, pattern: &str) -> bool {
if pattern.contains("**") {
self.matches_double_star(path, pattern)
} else if pattern.contains('*') {
self.matches_single_star(path, pattern)
} else {
path == pattern
}
}
fn matches_single_star(&self, path: &str, pattern: &str) -> bool {
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
if path_parts.len() != pattern_parts.len() {
return false;
}
for (path_part, pattern_part) in path_parts.iter().zip(pattern_parts.iter()) {
if !self.matches_segment(pattern_part, path_part) {
return false;
}
}
true
}
fn matches_double_star(&self, path: &str, pattern: &str) -> bool {
let regex_pattern = pattern
.replace("**", "§§§") .replace("*", "[^/]*") .replace("§§§", ".*");
if let Ok(re) = regex::Regex::new(&format!("^{}$", regex_pattern)) {
re.is_match(path)
} else {
false
}
}
fn matches_segment(&self, pattern: &str, segment: &str) -> bool {
let regex_pattern: String = pattern
.chars()
.map(|c| {
match c {
'*' => ".*".to_string(), c => regex::escape(&c.to_string()),
}
})
.collect();
if let Ok(re) = regex::Regex::new(&format!("^{}$", regex_pattern)) {
re.is_match(segment)
} else {
pattern == segment
}
}
}
#[derive(Debug, Clone)]
pub enum ConditionalRequestResult {
FullResponse(HttpCacheResponse),
NotModified,
PreconditionFailed,
}
#[derive(Debug, Clone, Default)]
pub struct ConditionalRequestHandler;
impl ConditionalRequestHandler {
pub fn new() -> Self {
Self
}
pub fn check_conditional(
&self,
cached_response: &HttpCacheResponse,
if_none_match: Option<&str>,
if_modified_since: Option<&str>,
) -> ConditionalRequestResult {
if let Some(request_etag) = if_none_match {
if let Some(cached_etag) = &cached_response.etag {
if request_etag == cached_etag.trim_matches('"') || request_etag == cached_etag {
return ConditionalRequestResult::NotModified;
}
}
}
if let Some(imf) = if_modified_since {
if let Ok(modified_since) = chrono::DateTime::parse_from_rfc2822(imf) {
let cached_time = cached_response.cached_at;
if modified_since >= cached_time {
return ConditionalRequestResult::NotModified;
}
} else if let Ok(modified_since) = chrono::DateTime::parse_from_rfc2822(&format!(
"{}, 01 Jan 1970 00:00:00 GMT",
imf.trim()
)) {
if modified_since >= cached_response.cached_at {
return ConditionalRequestResult::NotModified;
}
}
}
ConditionalRequestResult::FullResponse(cached_response.clone())
}
pub fn create_not_modified_response(
&self,
cached_response: &HttpCacheResponse,
) -> HttpCacheResponse {
let mut headers = HashMap::new();
if let Some(etag) = &cached_response.etag {
headers.insert("ETag".to_string(), etag.clone());
}
if let Some(lm) = &cached_response.last_modified {
headers.insert("Last-Modified".to_string(), lm.clone());
}
let now = chrono::Utc::now();
headers.insert(
"Date".to_string(),
now.format("%a, %d %b %Y %H:%M:%S GMT").to_string(),
);
HttpCacheResponse {
status: StatusCode::NOT_MODIFIED.as_u16(),
headers,
body: Vec::new(),
cached_at: now,
ttl: cached_response.ttl,
etag: cached_response.etag.clone(),
last_modified: cached_response.last_modified.clone(),
}
}
pub fn extract_conditionals(&self, headers: &HeaderMap) -> (Option<String>, Option<String>) {
let if_none_match = headers
.get(http::header::IF_NONE_MATCH)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let if_modified_since = headers
.get(http::header::IF_MODIFIED_SINCE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
(if_none_match, if_modified_since)
}
pub fn generate_strong_etag(&self, body: &[u8]) -> String {
let digest = md5::compute(body);
format!("\"{:x}\"", digest)
}
pub fn generate_weak_etag(&self, body: &[u8]) -> String {
let digest = md5::compute(body);
format!("W/\"{:x}\"", digest)
}
}
#[derive(Clone)]
pub struct CacheTagManager {
tag_mapping: Arc<dashmap::DashMap<String, dashmap::DashSet<String>>>,
adapter: Arc<dyn HttpCacheAdapter + Send + Sync>,
}
impl CacheTagManager {
pub fn new(adapter: Arc<dyn HttpCacheAdapter + Send + Sync>) -> Self {
Self {
tag_mapping: Arc::new(dashmap::DashMap::new()),
adapter,
}
}
pub async fn add_tags(
&self,
cache_key: &str,
tags: &[&str],
) -> Result<(), crate::error::CacheError> {
for tag in tags {
let tag_set = self.tag_mapping.entry(tag.to_string()).or_default();
tag_set.insert(cache_key.to_string());
}
Ok(())
}
pub async fn invalidate_by_tag(&self, tag: &str) -> Result<u64, crate::error::CacheError> {
if let Some((_, keys)) = self.tag_mapping.remove(tag) {
let count = keys.len() as u64;
for key in keys {
let _ = self.adapter.delete_response(&key).await;
}
return Ok(count);
}
Ok(0)
}
pub async fn invalidate_by_pattern(
&self,
pattern: &str,
) -> Result<u64, crate::error::CacheError> {
self.adapter.invalidate_by_pattern(pattern).await
}
pub fn clear(&self) {
self.tag_mapping.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_cache_key_generator() {
let generator = HttpCacheKeyGenerator::new()
.with_include_method(true)
.with_include_query(true);
let request = HttpRequest {
method: Method::GET,
uri: "/api/users?id=123".parse().unwrap(),
version: Version::HTTP_11,
headers: HeaderMap::new(),
body: vec![],
};
let key = generator.generate_key(&request);
assert!(!key.is_empty());
}
#[test]
fn test_http_cache_policy_should_cache() {
let policy = HttpCachePolicy::new()
.with_cache_status_codes(vec![StatusCode::OK, StatusCode::NOT_FOUND]);
assert!(policy.should_cache_response(StatusCode::OK));
assert!(policy.should_cache_response(StatusCode::NOT_FOUND));
assert!(!policy.should_cache_response(StatusCode::INTERNAL_SERVER_ERROR));
}
#[test]
fn test_http_cache_policy_extract_ttl() {
let mut headers = HeaderMap::new();
headers.insert(header::CACHE_CONTROL, "max-age=3600".parse().unwrap());
let policy = HttpCachePolicy::new().with_use_header_ttl(true);
let ttl = policy.extract_ttl_from_headers(&headers);
assert_eq!(ttl, Some(3600));
}
#[test]
fn test_http_cache_response() {
let response = HttpCacheResponse {
status: 200,
headers: vec![("Content-Type".to_string(), "application/json".to_string())]
.into_iter()
.collect(),
body: vec![1, 2, 3],
cached_at: chrono::Utc::now(),
ttl: Some(3600),
etag: Some("abc123".to_string()),
last_modified: None,
};
assert_eq!(response.status, 200);
assert_eq!(response.ttl, Some(3600));
}
#[test]
fn test_conditional_request_handler_etag_match() {
let handler = ConditionalRequestHandler::new();
let cached = HttpCacheResponse {
status: 200,
headers: HashMap::new(),
body: vec![1, 2, 3],
cached_at: chrono::Utc::now(),
ttl: Some(3600),
etag: Some("\"abc123\"".to_string()),
last_modified: None,
};
let result = handler.check_conditional(&cached, Some("\"abc123\""), None);
match result {
ConditionalRequestResult::NotModified => {}
_ => panic!("Expected NotModified"),
}
}
#[test]
fn test_conditional_request_handler_etag_mismatch() {
let handler = ConditionalRequestHandler::new();
let cached = HttpCacheResponse {
status: 200,
headers: HashMap::new(),
body: vec![1, 2, 3],
cached_at: chrono::Utc::now(),
ttl: Some(3600),
etag: Some("\"abc123\"".to_string()),
last_modified: None,
};
let result = handler.check_conditional(&cached, Some("\"different\""), None);
match result {
ConditionalRequestResult::FullResponse(_) => {}
_ => panic!("Expected FullResponse"),
}
}
#[test]
fn test_conditional_request_handler_if_modified_since() {
let handler = ConditionalRequestHandler::new();
let old_time = chrono::DateTime::from_timestamp(1000000, 0).expect("Invalid timestamp");
let cached = HttpCacheResponse {
status: 200,
headers: HashMap::new(),
body: vec![1, 2, 3],
cached_at: old_time,
ttl: Some(3600),
etag: None,
last_modified: Some("Mon, 01 Jan 2001 00:00:00 GMT".to_string()),
};
let recent_time = "Tue, 01 Jan 2002 00:00:00 GMT";
let result = handler.check_conditional(&cached, None, Some(recent_time));
match result {
ConditionalRequestResult::NotModified => {}
_ => panic!("Expected NotModified"),
}
}
#[test]
fn test_generate_etag() {
let handler = ConditionalRequestHandler::new();
let body = b"hello world";
let strong = handler.generate_strong_etag(body);
assert!(strong.starts_with('"'));
assert!(strong.ends_with('"'));
let weak = handler.generate_weak_etag(body);
assert!(weak.starts_with("W/\""));
assert!(weak.ends_with('"'));
}
#[test]
fn test_create_not_modified_response() {
let handler = ConditionalRequestHandler::new();
let cached = HttpCacheResponse {
status: 200,
headers: vec![("Content-Type".to_string(), "application/json".to_string())]
.into_iter()
.collect(),
body: vec![1, 2, 3],
cached_at: chrono::Utc::now(),
ttl: Some(3600),
etag: Some("\"abc123\"".to_string()),
last_modified: Some("Mon, 01 Jan 2024 00:00:00 GMT".to_string()),
};
let not_modified = handler.create_not_modified_response(&cached);
assert_eq!(not_modified.status, 304);
assert!(not_modified.body.is_empty());
assert_eq!(not_modified.etag, cached.etag);
assert_eq!(not_modified.last_modified, cached.last_modified);
}
#[test]
fn test_path_pattern_matcher_exact() {
let matcher = PathPatternMatcher::new();
assert!(matcher.matches("/api/users", "/api/users"));
assert!(!matcher.matches("/api/users", "/api/products"));
}
#[test]
fn test_path_pattern_matcher_single_star() {
let matcher = PathPatternMatcher::new();
assert!(matcher.matches("/api/users/123", "/api/users/*"));
assert!(matcher.matches("/api/users/abc", "/api/users/*"));
assert!(!matcher.matches("/api/users/123/profile", "/api/users/*"));
}
#[test]
fn test_path_pattern_matcher_double_star() {
let matcher = PathPatternMatcher::new();
assert!(matcher.matches("/api/users/123", "/api/**"));
assert!(matcher.matches("/api/users/123/profile", "/api/**"));
assert!(matcher.matches("/api/a/b/c/d", "/api/**"));
assert!(!matcher.matches("/other/users", "/api/**"));
}
#[test]
fn test_path_pattern_matcher_mixed() {
let matcher = PathPatternMatcher::new();
assert!(matcher.matches("/api/users/123", "/api/users/*"));
assert!(matcher.matches("/api/products/456", "/api/products/*"));
assert!(matcher.matches("/api/users/123/profile", "/api/users/*/profile"));
assert!(!matcher.matches("/api/users/123/extra", "/api/users/*/profile"));
}
}