1use async_trait::async_trait;
14use reqwest::StatusCode;
15use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
16use tokio::fs;
17use std::path::PathBuf;
18use tracing::{debug, info, trace, warn};
19
20use spider_util::error::SpiderError;
21use crate::middleware::{Middleware, MiddlewareAction};
22use spider_util::request::Request;
23use spider_util::response::Response;
24use bytes::Bytes;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use url::Url;
27
28fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
29where
30 S: Serializer,
31{
32 let mut map = std::collections::HashMap::<String, String>::new();
33 for (name, value) in headers.iter() {
34 map.insert(
35 name.to_string(),
36 value.to_str().unwrap_or_default().to_string(),
37 );
38 }
39 map.serialize(serializer)
40}
41
42fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
43where
44 D: Deserializer<'de>,
45{
46 let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
47 let mut headers = HeaderMap::new();
48 for (name, value) in map {
49 if let (Ok(header_name), Ok(header_value)) =
50 (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
51 {
52 headers.insert(header_name, header_value);
53 } else {
54 warn!("Failed to parse header: {} = {}", name, value);
55 }
56 }
57 Ok(headers)
58}
59
60fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
61where
62 S: Serializer,
63{
64 status.as_u16().serialize(serializer)
65}
66
67fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
68where
69 D: Deserializer<'de>,
70{
71 let status_u16 = u16::deserialize(deserializer)?;
72 StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
73}
74
75fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
76where
77 S: Serializer,
78{
79 url.to_string().serialize(serializer)
80}
81
82fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
83where
84 D: Deserializer<'de>,
85{
86 let s = String::deserialize(deserializer)?;
87 Url::parse(&s).map_err(serde::de::Error::custom)
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92struct CachedResponse {
93 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
94 url: Url,
95 #[serde(
96 serialize_with = "serialize_statuscode",
97 deserialize_with = "deserialize_statuscode"
98 )]
99 status: StatusCode,
100 #[serde(
101 serialize_with = "serialize_headermap",
102 deserialize_with = "deserialize_headermap"
103 )]
104 headers: HeaderMap,
105 body: Vec<u8>,
106 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
107 request_url: Url,
108}
109
110impl From<Response> for CachedResponse {
111 fn from(response: Response) -> Self {
112 CachedResponse {
113 url: response.url,
114 status: response.status,
115 headers: response.headers,
116 body: response.body.to_vec(),
117 request_url: response.request_url,
118 }
119 }
120}
121
122impl From<CachedResponse> for Response {
123 fn from(cached_response: CachedResponse) -> Self {
124 Response {
125 url: cached_response.url,
126 status: cached_response.status,
127 headers: cached_response.headers,
128 body: Bytes::from(cached_response.body),
129 request_url: cached_response.request_url,
130 meta: Default::default(),
131 cached: true,
132 }
133 }
134}
135
136#[derive(Default)]
138pub struct HttpCacheMiddlewareBuilder {
139 cache_dir: Option<PathBuf>,
140}
141
142impl HttpCacheMiddlewareBuilder {
143 pub fn cache_dir(mut self, path: PathBuf) -> Self {
145 self.cache_dir = Some(path);
146 self
147 }
148
149 pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
152 let cache_dir = if let Some(path) = self.cache_dir {
153 path
154 } else {
155 dirs::cache_dir()
156 .ok_or_else(|| {
157 SpiderError::ConfigurationError(
158 "Could not determine cache directory".to_string(),
159 )
160 })?
161 .join("spider-lib")
162 .join("http_cache")
163 };
164
165 std::fs::create_dir_all(&cache_dir)?;
166
167 let middleware = HttpCacheMiddleware { cache_dir };
168 info!(
169 "Initializing HttpCacheMiddleware with config: {:?}",
170 middleware
171 );
172
173 Ok(middleware)
174 }
175}
176
177#[derive(Debug)]
178pub struct HttpCacheMiddleware {
179 cache_dir: PathBuf,
180}
181
182impl HttpCacheMiddleware {
183 pub fn builder() -> HttpCacheMiddlewareBuilder {
185 HttpCacheMiddlewareBuilder::default()
186 }
187
188 fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
189 self.cache_dir.join(format!("{}.bin", fingerprint))
190 }
191}
192
193#[async_trait]
194impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
195 fn name(&self) -> &str {
196 "HttpCacheMiddleware"
197 }
198
199 async fn process_request(
200 &mut self,
201 _client: &C,
202 request: Request,
203 ) -> Result<MiddlewareAction<Request>, SpiderError> {
204 let fingerprint = request.fingerprint();
205 let cache_file_path = self.get_cache_file_path(&fingerprint);
206
207 trace!(
208 "Checking cache for request: {} (fingerprint: {})",
209 request.url, fingerprint
210 );
211 if fs::metadata(&cache_file_path).await.is_ok() {
212 debug!("Cache hit for request: {}", request.url);
213 match fs::read(&cache_file_path).await {
214 Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
215 Ok(cached_resp) => {
216 trace!(
217 "Successfully deserialized cached response for {}",
218 request.url
219 );
220 let mut response: Response = cached_resp.into();
221 response.meta = request.meta;
222 debug!("Returning cached response for {}", response.url);
223 return Ok(MiddlewareAction::ReturnResponse(response));
224 }
225 Err(e) => {
226 warn!(
227 "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
228 cache_file_path.display(),
229 e
230 );
231 fs::remove_file(&cache_file_path).await.ok();
232 }
233 },
234 Err(e) => {
235 warn!(
236 "Failed to read cache file {}: {}. Deleting invalid cache file.",
237 cache_file_path.display(),
238 e
239 );
240 fs::remove_file(&cache_file_path).await.ok();
241 }
242 }
243 } else {
244 trace!(
245 "Cache miss for request: {} (no cache file found)",
246 request.url
247 );
248 }
249
250 trace!("Continuing request to downloader: {}", request.url);
251 Ok(MiddlewareAction::Continue(request))
252 }
253
254 async fn process_response(
255 &mut self,
256 response: Response,
257 ) -> Result<MiddlewareAction<Response>, SpiderError> {
258 trace!(
259 "Processing response for caching: {} with status: {}",
260 response.url, response.status
261 );
262
263 if response.status.is_success() {
265 let original_request_fingerprint = response.request_from_response().fingerprint();
266 let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
267
268 trace!(
269 "Serializing response for caching to: {}",
270 cache_file_path.display()
271 );
272 let cached_response: CachedResponse = response.clone().into();
273 match bincode::serialize(&cached_response) {
274 Ok(serialized_bytes) => {
275 let bytes_count = serialized_bytes.len();
276 trace!(
277 "Writing {} bytes to cache file: {}",
278 bytes_count,
279 cache_file_path.display()
280 );
281 fs::write(&cache_file_path, serialized_bytes)
282 .await
283 .map_err(|e| SpiderError::IoError(e.to_string()))?;
284 debug!(
285 "Cached response for {} ({} bytes)",
286 response.url, bytes_count
287 );
288 }
289 Err(e) => {
290 warn!(
291 "Failed to serialize response for caching {}: {}",
292 response.url, e
293 );
294 }
295 }
296 } else {
297 trace!(
298 "Response status {} is not successful, skipping cache for: {}",
299 response.status, response.url
300 );
301 }
302
303 trace!("Continuing response: {}", response.url);
304 Ok(MiddlewareAction::Continue(response))
305 }
306}