1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::SystemTime;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub enum CacheError {
13 #[error("Cache entry not found")]
14 NotFound,
15 #[error("Cache entry expired")]
16 Expired,
17 #[error("Invalid cache entry: {0}")]
18 InvalidEntry(String),
19 #[error("Serialization error: {0}")]
20 SerializationError(#[from] serde_json::Error),
21 #[error("IO error: {0}")]
22 IoError(#[from] std::io::Error),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct CacheControl {
27 pub max_age: Option<u64>,
28 pub no_cache: bool,
29 pub no_store: bool,
30 pub must_revalidate: bool,
31 pub private: bool,
32 pub public: bool,
33 pub proxy_revalidate: bool,
34 pub s_maxage: Option<u64>,
35}
36
37impl CacheControl {
38 pub fn parse(header: &str) -> Self {
39 let mut control = Self::default();
40
41 for directive in header.split(',') {
42 let directive = directive.trim();
43
44 match directive {
45 "no-cache" => control.no_cache = true,
46 "no-store" => control.no_store = true,
47 "must-revalidate" => control.must_revalidate = true,
48 "private" => control.private = true,
49 "public" => control.public = true,
50 "proxy-revalidate" => control.proxy_revalidate = true,
51 _ => {
52 if let Some(max_age) = directive.strip_prefix("max-age=") {
53 if let Ok(seconds) = max_age.parse::<u64>() {
54 control.max_age = Some(seconds);
55 }
56 } else if let Some(s_maxage) = directive.strip_prefix("s-maxage=")
57 && let Ok(seconds) = s_maxage.parse::<u64>()
58 {
59 control.s_maxage = Some(seconds);
60 }
61 }
62 }
63 }
64
65 control
66 }
67
68 pub fn is_cacheable(&self) -> bool {
69 !self.no_store && !self.no_cache
70 }
71
72 pub fn is_fresh(&self, stored_time: SystemTime) -> bool {
73 if self.must_revalidate {
74 return false;
75 }
76
77 let max_age = self.s_maxage.or(self.max_age);
78 if let Some(max_age) = max_age {
79 let elapsed = stored_time.elapsed().unwrap_or_default().as_secs();
80 elapsed < max_age
81 } else {
82 true
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CacheEntry {
89 pub url: String,
90 pub etag: Option<String>,
91 pub last_modified: Option<SystemTime>,
92 pub cache_control: CacheControl,
93 pub stored_at: SystemTime,
94 pub content_length: Option<u64>,
95 pub content_type: Option<String>,
96 pub headers: HashMap<String, String>,
97 pub vary: Option<String>,
98}
99
100impl CacheEntry {
101 pub fn new(url: String) -> Self {
102 Self {
103 url,
104 etag: None,
105 last_modified: None,
106 cache_control: CacheControl::default(),
107 stored_at: SystemTime::now(),
108 content_length: None,
109 content_type: None,
110 headers: HashMap::new(),
111 vary: None,
112 }
113 }
114
115 pub fn is_valid(&self) -> bool {
116 if !self.cache_control.is_fresh(self.stored_at) {
117 return false;
118 }
119
120 let max_age = self.cache_control.max_age.unwrap_or(86400);
121 let elapsed = self.stored_at.elapsed().unwrap_or_default().as_secs();
122 elapsed < max_age
123 }
124
125 pub fn cache_key(&self) -> String {
126 use std::collections::hash_map::DefaultHasher;
127 use std::hash::{Hash, Hasher};
128
129 let mut hasher = DefaultHasher::new();
130 self.url.hash(&mut hasher);
131
132 if let Some(vary) = &self.vary {
133 vary.hash(&mut hasher);
134 }
135
136 format!("cache_{}", hasher.finish())
137 }
138}
139
140#[derive(Debug, Clone)]
141pub struct ConditionalHeaders {
142 pub if_none_match: Option<String>,
143 pub if_modified_since: Option<SystemTime>,
144}
145
146impl ConditionalHeaders {
147 pub fn from_cache_entry(entry: &CacheEntry) -> Self {
148 Self {
149 if_none_match: entry.etag.clone(),
150 if_modified_since: entry.last_modified,
151 }
152 }
153
154 pub fn to_header_map(&self) -> reqwest::header::HeaderMap {
155 let mut headers = reqwest::header::HeaderMap::new();
156
157 if let Some(if_none_match) = &self.if_none_match {
158 headers.insert(
159 reqwest::header::IF_NONE_MATCH,
160 if_none_match.parse().unwrap(),
161 );
162 }
163
164 if let Some(if_modified_since) = self.if_modified_since
165 && let Ok(since_str) = httpdate::fmt_http_date(if_modified_since)
166 {
167 headers.insert(
168 reqwest::header::IF_MODIFIED_SINCE,
169 since_str.parse().unwrap(),
170 );
171 }
172
173 headers
174 }
175}
176
177#[derive(Debug, Clone)]
178pub enum CacheValidation {
179 Fresh,
180 StaleNeedsValidation,
181 Invalid,
182}
183
184#[derive(Debug)]
185pub struct HttpCache {
186 entries: HashMap<String, CacheEntry>,
187 max_entries: usize,
188}
189
190impl HttpCache {
191 pub fn new(max_entries: usize) -> Self {
192 Self {
193 entries: HashMap::new(),
194 max_entries,
195 }
196 }
197
198 pub fn get(&self, url: &str) -> Result<&CacheEntry, CacheError> {
199 self.entries.get(url).ok_or(CacheError::NotFound)
200 }
201
202 pub fn put(&mut self, entry: CacheEntry) -> Result<(), CacheError> {
203 if self.entries.len() >= self.max_entries {
204 self.evict_oldest();
205 }
206
207 self.entries.insert(entry.url.clone(), entry);
208 Ok(())
209 }
210
211 pub fn remove(&mut self, url: &str) -> Result<(), CacheError> {
212 self.entries.remove(url).ok_or(CacheError::NotFound)?;
213 Ok(())
214 }
215
216 pub fn validate(&self, url: &str) -> Result<CacheValidation, CacheError> {
217 let entry = self.get(url)?;
218
219 if !entry.is_valid() {
220 return Ok(CacheValidation::Invalid);
221 }
222
223 if entry.cache_control.is_fresh(entry.stored_at) {
224 Ok(CacheValidation::Fresh)
225 } else {
226 Ok(CacheValidation::StaleNeedsValidation)
227 }
228 }
229
230 pub fn get_conditional_headers(&self, url: &str) -> Result<ConditionalHeaders, CacheError> {
231 let entry = self.get(url)?;
232 Ok(ConditionalHeaders::from_cache_entry(entry))
233 }
234
235 pub fn update_from_response(
236 &mut self,
237 url: &str,
238 response: &reqwest::Response,
239 ) -> Result<(), CacheError> {
240 let mut entry = CacheEntry::new(url.to_string());
241
242 if let Some(etag) = response.headers().get(reqwest::header::ETAG) {
243 entry.etag = Some(etag.to_str().unwrap_or_default().to_string());
244 }
245
246 if let Some(last_modified) = response.headers().get(reqwest::header::LAST_MODIFIED)
247 && let Ok(parsed) =
248 httpdate::parse_http_date(last_modified.to_str().unwrap_or_default())
249 {
250 entry.last_modified = Some(parsed);
251 }
252
253 if let Some(cache_control) = response.headers().get(reqwest::header::CACHE_CONTROL) {
254 entry.cache_control = CacheControl::parse(cache_control.to_str().unwrap_or_default());
255 }
256
257 if let Some(content_length) = response.headers().get(reqwest::header::CONTENT_LENGTH)
258 && let Ok(length) = content_length.to_str().unwrap_or_default().parse::<u64>()
259 {
260 entry.content_length = Some(length);
261 }
262
263 if let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) {
264 entry.content_type = Some(content_type.to_str().unwrap_or_default().to_string());
265 }
266
267 if let Some(vary) = response.headers().get(reqwest::header::VARY) {
268 entry.vary = Some(vary.to_str().unwrap_or_default().to_string());
269 }
270
271 for (name, value) in response.headers() {
272 let name_str = name.as_str();
273 if !matches!(
274 name,
275 &reqwest::header::ETAG
276 | &reqwest::header::LAST_MODIFIED
277 | &reqwest::header::CACHE_CONTROL
278 | &reqwest::header::CONTENT_LENGTH
279 | &reqwest::header::CONTENT_TYPE
280 | &reqwest::header::VARY
281 ) {
282 entry.headers.insert(
283 name_str.to_string(),
284 value.to_str().unwrap_or_default().to_string(),
285 );
286 }
287 }
288
289 self.put(entry)
290 }
291
292 pub fn clear(&mut self) {
293 self.entries.clear();
294 }
295
296 pub fn stats(&self) -> CacheStats {
297 CacheStats {
298 total_entries: self.entries.len(),
299 max_entries: self.max_entries,
300 fresh_entries: self
301 .entries
302 .values()
303 .filter(|e| e.is_valid() && e.cache_control.is_fresh(e.stored_at))
304 .count(),
305 stale_entries: self
306 .entries
307 .values()
308 .filter(|e| e.is_valid() && !e.cache_control.is_fresh(e.stored_at))
309 .count(),
310 expired_entries: self.entries.values().filter(|e| !e.is_valid()).count(),
311 }
312 }
313
314 fn evict_oldest(&mut self) {
315 if self.entries.is_empty() {
316 return;
317 }
318
319 let oldest_url = self
320 .entries
321 .iter()
322 .min_by_key(|(_, entry)| entry.stored_at)
323 .map(|(url, _)| url.clone());
324
325 if let Some(url) = oldest_url {
326 self.entries.remove(&url);
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
332pub struct CacheStats {
333 pub total_entries: usize,
334 pub max_entries: usize,
335 pub fresh_entries: usize,
336 pub stale_entries: usize,
337 pub expired_entries: usize,
338}
339
340mod httpdate {
341 use std::time::{SystemTime, UNIX_EPOCH};
342
343 pub fn parse_http_date(date_str: &str) -> Result<SystemTime, Box<dyn std::error::Error>> {
344 let timestamp = chrono::DateTime::parse_from_rfc2822(date_str)?;
345 let system_time = UNIX_EPOCH + std::time::Duration::from_secs(timestamp.timestamp() as u64);
346 Ok(system_time)
347 }
348
349 pub fn fmt_http_date(time: SystemTime) -> Result<String, Box<dyn std::error::Error>> {
350 let datetime = chrono::DateTime::<chrono::Utc>::from(time);
351 Ok(datetime.to_rfc2822())
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use std::time::Duration;
359
360 #[test]
361 fn test_cache_control_parsing() {
362 let control = CacheControl::parse("max-age=3600, public");
363 assert_eq!(control.max_age, Some(3600));
364 assert!(control.public);
365 assert!(!control.no_cache);
366
367 let control = CacheControl::parse("no-cache, must-revalidate");
368 assert!(control.no_cache);
369 assert!(control.must_revalidate);
370 assert!(control.max_age.is_none());
371 }
372
373 #[test]
374 fn test_cache_control_freshness() {
375 let control = CacheControl {
376 max_age: Some(3600),
377 ..Default::default()
378 };
379
380 let past_time = SystemTime::now() - Duration::from_secs(1800);
381 assert!(control.is_fresh(past_time));
382
383 let too_old_time = SystemTime::now() - Duration::from_secs(7200);
384 assert!(!control.is_fresh(too_old_time));
385 }
386
387 #[test]
388 fn test_cache_entry_validation() {
389 let mut entry = CacheEntry::new("https://example.com".to_string());
390 entry.cache_control.max_age = Some(3600);
391 entry.stored_at = SystemTime::now() - Duration::from_secs(1800);
392
393 assert!(entry.is_valid());
394
395 entry.stored_at = SystemTime::now() - Duration::from_secs(7200);
396 assert!(!entry.is_valid());
397 }
398
399 #[test]
400 fn test_http_cache_operations() {
401 let mut cache = HttpCache::new(10);
402
403 let mut entry = CacheEntry::new("https://example.com".to_string());
404 entry.etag = Some("\"12345\"".to_string());
405 entry.cache_control.max_age = Some(3600);
406
407 cache.put(entry.clone()).unwrap();
408
409 let retrieved = cache.get("https://example.com").unwrap();
410 assert_eq!(retrieved.etag, entry.etag);
411
412 let validation = cache.validate("https://example.com").unwrap();
413 assert!(matches!(validation, CacheValidation::Fresh));
414
415 let stats = cache.stats();
416 assert_eq!(stats.total_entries, 1);
417 assert_eq!(stats.fresh_entries, 1);
418 }
419
420 #[test]
421 fn test_conditional_headers() {
422 let mut entry = CacheEntry::new("https://example.com".to_string());
423 entry.etag = Some("\"12345\"".to_string());
424 entry.last_modified = Some(SystemTime::now() - Duration::from_secs(3600));
425
426 let headers = ConditionalHeaders::from_cache_entry(&entry);
427 assert_eq!(headers.if_none_match, Some("\"12345\"".to_string()));
428 assert!(headers.if_modified_since.is_some());
429 }
430
431 #[test]
432 fn test_cache_eviction() {
433 let mut cache = HttpCache::new(2);
434
435 let entry1 = CacheEntry::new("https://example1.com".to_string());
436 let entry2 = CacheEntry::new("https://example2.com".to_string());
437 let entry3 = CacheEntry::new("https://example3.com".to_string());
438
439 cache.put(entry1).unwrap();
440 cache.put(entry2).unwrap();
441 cache.put(entry3).unwrap();
442
443 assert!(cache.get("https://example1.com").is_err());
444 assert!(cache.get("https://example2.com").is_ok());
445 assert!(cache.get("https://example3.com").is_ok());
446 }
447}