Skip to main content

pulith_fetch/cache/
http_cache.rs

1//! HTTP caching support for conditional requests.
2//!
3//! This module provides types and functions for implementing HTTP caching
4//! based on ETags, Last-Modified timestamps, and cache control directives.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::SystemTime;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub enum CacheError {
13    #[error("Cache entry not found")]
14    NotFound,
15    #[error("Cache entry expired")]
16    Expired,
17    #[error("Invalid cache entry: {0}")]
18    InvalidEntry(String),
19    #[error("Serialization error: {0}")]
20    SerializationError(#[from] serde_json::Error),
21    #[error("IO error: {0}")]
22    IoError(#[from] std::io::Error),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct CacheControl {
27    pub max_age: Option<u64>,
28    pub no_cache: bool,
29    pub no_store: bool,
30    pub must_revalidate: bool,
31    pub private: bool,
32    pub public: bool,
33    pub proxy_revalidate: bool,
34    pub s_maxage: Option<u64>,
35}
36
37impl CacheControl {
38    pub fn parse(header: &str) -> Self {
39        let mut control = Self::default();
40
41        for directive in header.split(',') {
42            let directive = directive.trim();
43
44            match directive {
45                "no-cache" => control.no_cache = true,
46                "no-store" => control.no_store = true,
47                "must-revalidate" => control.must_revalidate = true,
48                "private" => control.private = true,
49                "public" => control.public = true,
50                "proxy-revalidate" => control.proxy_revalidate = true,
51                _ => {
52                    if let Some(max_age) = directive.strip_prefix("max-age=") {
53                        if let Ok(seconds) = max_age.parse::<u64>() {
54                            control.max_age = Some(seconds);
55                        }
56                    } else if let Some(s_maxage) = directive.strip_prefix("s-maxage=")
57                        && let Ok(seconds) = s_maxage.parse::<u64>()
58                    {
59                        control.s_maxage = Some(seconds);
60                    }
61                }
62            }
63        }
64
65        control
66    }
67
68    pub fn is_cacheable(&self) -> bool {
69        !self.no_store && !self.no_cache
70    }
71
72    pub fn is_fresh(&self, stored_time: SystemTime) -> bool {
73        if self.must_revalidate {
74            return false;
75        }
76
77        let max_age = self.s_maxage.or(self.max_age);
78        if let Some(max_age) = max_age {
79            let elapsed = stored_time.elapsed().unwrap_or_default().as_secs();
80            elapsed < max_age
81        } else {
82            true
83        }
84    }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CacheEntry {
89    pub url: String,
90    pub etag: Option<String>,
91    pub last_modified: Option<SystemTime>,
92    pub cache_control: CacheControl,
93    pub stored_at: SystemTime,
94    pub content_length: Option<u64>,
95    pub content_type: Option<String>,
96    pub headers: HashMap<String, String>,
97    pub vary: Option<String>,
98}
99
100impl CacheEntry {
101    pub fn new(url: String) -> Self {
102        Self {
103            url,
104            etag: None,
105            last_modified: None,
106            cache_control: CacheControl::default(),
107            stored_at: SystemTime::now(),
108            content_length: None,
109            content_type: None,
110            headers: HashMap::new(),
111            vary: None,
112        }
113    }
114
115    pub fn is_valid(&self) -> bool {
116        if !self.cache_control.is_fresh(self.stored_at) {
117            return false;
118        }
119
120        let max_age = self.cache_control.max_age.unwrap_or(86400);
121        let elapsed = self.stored_at.elapsed().unwrap_or_default().as_secs();
122        elapsed < max_age
123    }
124
125    pub fn cache_key(&self) -> String {
126        use std::collections::hash_map::DefaultHasher;
127        use std::hash::{Hash, Hasher};
128
129        let mut hasher = DefaultHasher::new();
130        self.url.hash(&mut hasher);
131
132        if let Some(vary) = &self.vary {
133            vary.hash(&mut hasher);
134        }
135
136        format!("cache_{}", hasher.finish())
137    }
138}
139
140#[derive(Debug, Clone)]
141pub struct ConditionalHeaders {
142    pub if_none_match: Option<String>,
143    pub if_modified_since: Option<SystemTime>,
144}
145
146impl ConditionalHeaders {
147    pub fn from_cache_entry(entry: &CacheEntry) -> Self {
148        Self {
149            if_none_match: entry.etag.clone(),
150            if_modified_since: entry.last_modified,
151        }
152    }
153
154    pub fn to_header_map(&self) -> reqwest::header::HeaderMap {
155        let mut headers = reqwest::header::HeaderMap::new();
156
157        if let Some(if_none_match) = &self.if_none_match {
158            headers.insert(
159                reqwest::header::IF_NONE_MATCH,
160                if_none_match.parse().unwrap(),
161            );
162        }
163
164        if let Some(if_modified_since) = self.if_modified_since
165            && let Ok(since_str) = httpdate::fmt_http_date(if_modified_since)
166        {
167            headers.insert(
168                reqwest::header::IF_MODIFIED_SINCE,
169                since_str.parse().unwrap(),
170            );
171        }
172
173        headers
174    }
175}
176
177#[derive(Debug, Clone)]
178pub enum CacheValidation {
179    Fresh,
180    StaleNeedsValidation,
181    Invalid,
182}
183
184#[derive(Debug)]
185pub struct HttpCache {
186    entries: HashMap<String, CacheEntry>,
187    max_entries: usize,
188}
189
190impl HttpCache {
191    pub fn new(max_entries: usize) -> Self {
192        Self {
193            entries: HashMap::new(),
194            max_entries,
195        }
196    }
197
198    pub fn get(&self, url: &str) -> Result<&CacheEntry, CacheError> {
199        self.entries.get(url).ok_or(CacheError::NotFound)
200    }
201
202    pub fn put(&mut self, entry: CacheEntry) -> Result<(), CacheError> {
203        if self.entries.len() >= self.max_entries {
204            self.evict_oldest();
205        }
206
207        self.entries.insert(entry.url.clone(), entry);
208        Ok(())
209    }
210
211    pub fn remove(&mut self, url: &str) -> Result<(), CacheError> {
212        self.entries.remove(url).ok_or(CacheError::NotFound)?;
213        Ok(())
214    }
215
216    pub fn validate(&self, url: &str) -> Result<CacheValidation, CacheError> {
217        let entry = self.get(url)?;
218
219        if !entry.is_valid() {
220            return Ok(CacheValidation::Invalid);
221        }
222
223        if entry.cache_control.is_fresh(entry.stored_at) {
224            Ok(CacheValidation::Fresh)
225        } else {
226            Ok(CacheValidation::StaleNeedsValidation)
227        }
228    }
229
230    pub fn get_conditional_headers(&self, url: &str) -> Result<ConditionalHeaders, CacheError> {
231        let entry = self.get(url)?;
232        Ok(ConditionalHeaders::from_cache_entry(entry))
233    }
234
235    pub fn update_from_response(
236        &mut self,
237        url: &str,
238        response: &reqwest::Response,
239    ) -> Result<(), CacheError> {
240        let mut entry = CacheEntry::new(url.to_string());
241
242        if let Some(etag) = response.headers().get(reqwest::header::ETAG) {
243            entry.etag = Some(etag.to_str().unwrap_or_default().to_string());
244        }
245
246        if let Some(last_modified) = response.headers().get(reqwest::header::LAST_MODIFIED)
247            && let Ok(parsed) =
248                httpdate::parse_http_date(last_modified.to_str().unwrap_or_default())
249        {
250            entry.last_modified = Some(parsed);
251        }
252
253        if let Some(cache_control) = response.headers().get(reqwest::header::CACHE_CONTROL) {
254            entry.cache_control = CacheControl::parse(cache_control.to_str().unwrap_or_default());
255        }
256
257        if let Some(content_length) = response.headers().get(reqwest::header::CONTENT_LENGTH)
258            && let Ok(length) = content_length.to_str().unwrap_or_default().parse::<u64>()
259        {
260            entry.content_length = Some(length);
261        }
262
263        if let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) {
264            entry.content_type = Some(content_type.to_str().unwrap_or_default().to_string());
265        }
266
267        if let Some(vary) = response.headers().get(reqwest::header::VARY) {
268            entry.vary = Some(vary.to_str().unwrap_or_default().to_string());
269        }
270
271        for (name, value) in response.headers() {
272            let name_str = name.as_str();
273            if !matches!(
274                name,
275                &reqwest::header::ETAG
276                    | &reqwest::header::LAST_MODIFIED
277                    | &reqwest::header::CACHE_CONTROL
278                    | &reqwest::header::CONTENT_LENGTH
279                    | &reqwest::header::CONTENT_TYPE
280                    | &reqwest::header::VARY
281            ) {
282                entry.headers.insert(
283                    name_str.to_string(),
284                    value.to_str().unwrap_or_default().to_string(),
285                );
286            }
287        }
288
289        self.put(entry)
290    }
291
292    pub fn clear(&mut self) {
293        self.entries.clear();
294    }
295
296    pub fn stats(&self) -> CacheStats {
297        CacheStats {
298            total_entries: self.entries.len(),
299            max_entries: self.max_entries,
300            fresh_entries: self
301                .entries
302                .values()
303                .filter(|e| e.is_valid() && e.cache_control.is_fresh(e.stored_at))
304                .count(),
305            stale_entries: self
306                .entries
307                .values()
308                .filter(|e| e.is_valid() && !e.cache_control.is_fresh(e.stored_at))
309                .count(),
310            expired_entries: self.entries.values().filter(|e| !e.is_valid()).count(),
311        }
312    }
313
314    fn evict_oldest(&mut self) {
315        if self.entries.is_empty() {
316            return;
317        }
318
319        let oldest_url = self
320            .entries
321            .iter()
322            .min_by_key(|(_, entry)| entry.stored_at)
323            .map(|(url, _)| url.clone());
324
325        if let Some(url) = oldest_url {
326            self.entries.remove(&url);
327        }
328    }
329}
330
331#[derive(Debug, Clone)]
332pub struct CacheStats {
333    pub total_entries: usize,
334    pub max_entries: usize,
335    pub fresh_entries: usize,
336    pub stale_entries: usize,
337    pub expired_entries: usize,
338}
339
340mod httpdate {
341    use std::time::{SystemTime, UNIX_EPOCH};
342
343    pub fn parse_http_date(date_str: &str) -> Result<SystemTime, Box<dyn std::error::Error>> {
344        let timestamp = chrono::DateTime::parse_from_rfc2822(date_str)?;
345        let system_time = UNIX_EPOCH + std::time::Duration::from_secs(timestamp.timestamp() as u64);
346        Ok(system_time)
347    }
348
349    pub fn fmt_http_date(time: SystemTime) -> Result<String, Box<dyn std::error::Error>> {
350        let datetime = chrono::DateTime::<chrono::Utc>::from(time);
351        Ok(datetime.to_rfc2822())
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use std::time::Duration;
359
360    #[test]
361    fn test_cache_control_parsing() {
362        let control = CacheControl::parse("max-age=3600, public");
363        assert_eq!(control.max_age, Some(3600));
364        assert!(control.public);
365        assert!(!control.no_cache);
366
367        let control = CacheControl::parse("no-cache, must-revalidate");
368        assert!(control.no_cache);
369        assert!(control.must_revalidate);
370        assert!(control.max_age.is_none());
371    }
372
373    #[test]
374    fn test_cache_control_freshness() {
375        let control = CacheControl {
376            max_age: Some(3600),
377            ..Default::default()
378        };
379
380        let past_time = SystemTime::now() - Duration::from_secs(1800);
381        assert!(control.is_fresh(past_time));
382
383        let too_old_time = SystemTime::now() - Duration::from_secs(7200);
384        assert!(!control.is_fresh(too_old_time));
385    }
386
387    #[test]
388    fn test_cache_entry_validation() {
389        let mut entry = CacheEntry::new("https://example.com".to_string());
390        entry.cache_control.max_age = Some(3600);
391        entry.stored_at = SystemTime::now() - Duration::from_secs(1800);
392
393        assert!(entry.is_valid());
394
395        entry.stored_at = SystemTime::now() - Duration::from_secs(7200);
396        assert!(!entry.is_valid());
397    }
398
399    #[test]
400    fn test_http_cache_operations() {
401        let mut cache = HttpCache::new(10);
402
403        let mut entry = CacheEntry::new("https://example.com".to_string());
404        entry.etag = Some("\"12345\"".to_string());
405        entry.cache_control.max_age = Some(3600);
406
407        cache.put(entry.clone()).unwrap();
408
409        let retrieved = cache.get("https://example.com").unwrap();
410        assert_eq!(retrieved.etag, entry.etag);
411
412        let validation = cache.validate("https://example.com").unwrap();
413        assert!(matches!(validation, CacheValidation::Fresh));
414
415        let stats = cache.stats();
416        assert_eq!(stats.total_entries, 1);
417        assert_eq!(stats.fresh_entries, 1);
418    }
419
420    #[test]
421    fn test_conditional_headers() {
422        let mut entry = CacheEntry::new("https://example.com".to_string());
423        entry.etag = Some("\"12345\"".to_string());
424        entry.last_modified = Some(SystemTime::now() - Duration::from_secs(3600));
425
426        let headers = ConditionalHeaders::from_cache_entry(&entry);
427        assert_eq!(headers.if_none_match, Some("\"12345\"".to_string()));
428        assert!(headers.if_modified_since.is_some());
429    }
430
431    #[test]
432    fn test_cache_eviction() {
433        let mut cache = HttpCache::new(2);
434
435        let entry1 = CacheEntry::new("https://example1.com".to_string());
436        let entry2 = CacheEntry::new("https://example2.com".to_string());
437        let entry3 = CacheEntry::new("https://example3.com".to_string());
438
439        cache.put(entry1).unwrap();
440        cache.put(entry2).unwrap();
441        cache.put(entry3).unwrap();
442
443        assert!(cache.get("https://example1.com").is_err());
444        assert!(cache.get("https://example2.com").is_ok());
445        assert!(cache.get("https://example3.com").is_ok());
446    }
447}