Skip to main content

spider_middleware/
http_cache.rs

1//! HTTP Cache Middleware for caching web responses.
2//!
3//! This module provides the `HttpCacheMiddleware`, which intercepts HTTP requests and
4//! responses to implement a caching mechanism. It stores successful HTTP responses (e.g., 200 OK)
5//! to a local directory, and for subsequent identical requests, it serves the cached response
6//! instead of making a new network request. This can significantly reduce network traffic,
7//! improve crawling speed, and enable offline processing or replay of crawls.
8//!
9//! The cache uses request fingerprints to identify unique requests and associates them
10//! with their corresponding cached responses. Responses are serialized and deserialized
11//! using `bincode`.
12
13use async_trait::async_trait;
14use reqwest::StatusCode;
15use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
16use tokio::fs;
17use std::path::PathBuf;
18use tracing::{debug, info, trace, warn};
19
20use spider_util::error::SpiderError;
21use crate::middleware::{Middleware, MiddlewareAction};
22use spider_util::request::Request;
23use spider_util::response::Response;
24use bytes::Bytes;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use url::Url;
27
28fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
29where
30    S: Serializer,
31{
32    let mut map = std::collections::HashMap::<String, String>::new();
33    for (name, value) in headers.iter() {
34        map.insert(
35            name.to_string(),
36            value.to_str().unwrap_or_default().to_string(),
37        );
38    }
39    map.serialize(serializer)
40}
41
42fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
43where
44    D: Deserializer<'de>,
45{
46    let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
47    let mut headers = HeaderMap::new();
48    for (name, value) in map {
49        if let (Ok(header_name), Ok(header_value)) =
50            (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
51        {
52            headers.insert(header_name, header_value);
53        } else {
54            warn!("Failed to parse header: {} = {}", name, value);
55        }
56    }
57    Ok(headers)
58}
59
60fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
61where
62    S: Serializer,
63{
64    status.as_u16().serialize(serializer)
65}
66
67fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
68where
69    D: Deserializer<'de>,
70{
71    let status_u16 = u16::deserialize(deserializer)?;
72    StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
73}
74
75fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
76where
77    S: Serializer,
78{
79    url.to_string().serialize(serializer)
80}
81
82fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
83where
84    D: Deserializer<'de>,
85{
86    let s = String::deserialize(deserializer)?;
87    Url::parse(&s).map_err(serde::de::Error::custom)
88}
89
90/// Represents a cached response, including enough information to reconstruct a `Response` object.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct CachedResponse {
93    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
94    url: Url,
95    #[serde(
96        serialize_with = "serialize_statuscode",
97        deserialize_with = "deserialize_statuscode"
98    )]
99    status: StatusCode,
100    #[serde(
101        serialize_with = "serialize_headermap",
102        deserialize_with = "deserialize_headermap"
103    )]
104    headers: HeaderMap,
105    body: Vec<u8>,
106    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
107    request_url: Url,
108}
109
110impl From<Response> for CachedResponse {
111    fn from(response: Response) -> Self {
112        CachedResponse {
113            url: response.url,
114            status: response.status,
115            headers: response.headers,
116            body: response.body.to_vec(),
117            request_url: response.request_url,
118        }
119    }
120}
121
122impl From<CachedResponse> for Response {
123    fn from(cached_response: CachedResponse) -> Self {
124        Response {
125            url: cached_response.url,
126            status: cached_response.status,
127            headers: cached_response.headers,
128            body: Bytes::from(cached_response.body),
129            request_url: cached_response.request_url,
130            meta: Default::default(),
131            cached: true,
132        }
133    }
134}
135
136/// Builder for `HttpCacheMiddleware`.
137#[derive(Default)]
138pub struct HttpCacheMiddlewareBuilder {
139    cache_dir: Option<PathBuf>,
140}
141
142impl HttpCacheMiddlewareBuilder {
143    /// Sets the directory where cache files will be stored.
144    pub fn cache_dir(mut self, path: PathBuf) -> Self {
145        self.cache_dir = Some(path);
146        self
147    }
148
149    /// Builds the `HttpCacheMiddleware`.
150    /// This can fail if the cache directory cannot be created or determined.
151    pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
152        let cache_dir = if let Some(path) = self.cache_dir {
153            path
154        } else {
155            dirs::cache_dir()
156                .ok_or_else(|| {
157                    SpiderError::ConfigurationError(
158                        "Could not determine cache directory".to_string(),
159                    )
160                })?
161                .join("spider-lib")
162                .join("http_cache")
163        };
164
165        std::fs::create_dir_all(&cache_dir)?;
166
167        let middleware = HttpCacheMiddleware { cache_dir };
168        info!(
169            "Initializing HttpCacheMiddleware with config: {:?}",
170            middleware
171        );
172
173        Ok(middleware)
174    }
175}
176
177#[derive(Debug)]
178pub struct HttpCacheMiddleware {
179    cache_dir: PathBuf,
180}
181
182impl HttpCacheMiddleware {
183    /// Creates a new `HttpCacheMiddlewareBuilder` to start building an `HttpCacheMiddleware`.
184    pub fn builder() -> HttpCacheMiddlewareBuilder {
185        HttpCacheMiddlewareBuilder::default()
186    }
187
188    fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
189        self.cache_dir.join(format!("{}.bin", fingerprint))
190    }
191}
192
193#[async_trait]
194impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
195    fn name(&self) -> &str {
196        "HttpCacheMiddleware"
197    }
198
199    async fn process_request(
200        &mut self,
201        _client: &C,
202        request: Request,
203    ) -> Result<MiddlewareAction<Request>, SpiderError> {
204        let fingerprint = request.fingerprint();
205        let cache_file_path = self.get_cache_file_path(&fingerprint);
206
207        trace!(
208            "Checking cache for request: {} (fingerprint: {})",
209            request.url, fingerprint
210        );
211        if fs::metadata(&cache_file_path).await.is_ok() {
212            debug!("Cache hit for request: {}", request.url);
213            match fs::read(&cache_file_path).await {
214                Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
215                    Ok(cached_resp) => {
216                        trace!(
217                            "Successfully deserialized cached response for {}",
218                            request.url
219                        );
220                        let mut response: Response = cached_resp.into();
221                        response.meta = request.meta;
222                        debug!("Returning cached response for {}", response.url);
223                        return Ok(MiddlewareAction::ReturnResponse(response));
224                    }
225                    Err(e) => {
226                        warn!(
227                            "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
228                            cache_file_path.display(),
229                            e
230                        );
231                        fs::remove_file(&cache_file_path).await.ok();
232                    }
233                },
234                Err(e) => {
235                    warn!(
236                        "Failed to read cache file {}: {}. Deleting invalid cache file.",
237                        cache_file_path.display(),
238                        e
239                    );
240                    fs::remove_file(&cache_file_path).await.ok();
241                }
242            }
243        } else {
244            trace!(
245                "Cache miss for request: {} (no cache file found)",
246                request.url
247            );
248        }
249
250        trace!("Continuing request to downloader: {}", request.url);
251        Ok(MiddlewareAction::Continue(request))
252    }
253
254    async fn process_response(
255        &mut self,
256        response: Response,
257    ) -> Result<MiddlewareAction<Response>, SpiderError> {
258        trace!(
259            "Processing response for caching: {} with status: {}",
260            response.url, response.status
261        );
262
263        // Only cache successful responses (e.g., 200 OK)
264        if response.status.is_success() {
265            let original_request_fingerprint = response.request_from_response().fingerprint();
266            let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
267
268            trace!(
269                "Serializing response for caching to: {}",
270                cache_file_path.display()
271            );
272            let cached_response: CachedResponse = response.clone().into();
273            match bincode::serialize(&cached_response) {
274                Ok(serialized_bytes) => {
275                    let bytes_count = serialized_bytes.len();
276                    trace!(
277                        "Writing {} bytes to cache file: {}",
278                        bytes_count,
279                        cache_file_path.display()
280                    );
281                    fs::write(&cache_file_path, serialized_bytes)
282                        .await
283                        .map_err(|e| SpiderError::IoError(e.to_string()))?;
284                    debug!(
285                        "Cached response for {} ({} bytes)",
286                        response.url, bytes_count
287                    );
288                }
289                Err(e) => {
290                    warn!(
291                        "Failed to serialize response for caching {}: {}",
292                        response.url, e
293                    );
294                }
295            }
296        } else {
297            trace!(
298                "Response status {} is not successful, skipping cache for: {}",
299                response.status, response.url
300            );
301        }
302
303        trace!("Continuing response: {}", response.url);
304        Ok(MiddlewareAction::Continue(response))
305    }
306}