http_cache_stream_reqwest/
lib.rs1#![warn(missing_docs)]
25#![warn(rust_2018_idioms)]
26#![warn(rust_2021_compatibility)]
27#![warn(clippy::missing_docs_in_private_items)]
28#![warn(rustdoc::broken_intra_doc_links)]
29
30use anyhow::Context as _;
31use anyhow::Result;
32use futures::FutureExt;
33use futures::future::BoxFuture;
34use http_body_util::BodyDataStream;
35pub use http_cache_stream::X_CACHE;
36pub use http_cache_stream::X_CACHE_DIGEST;
37pub use http_cache_stream::X_CACHE_LOOKUP;
38use http_cache_stream::http::Extensions;
39use http_cache_stream::http::Uri;
40pub use http_cache_stream::semantics;
41pub use http_cache_stream::semantics::CacheOptions;
42pub use http_cache_stream::storage;
43pub use http_cache_stream::storage::CacheStorage;
44use reqwest::Body;
45use reqwest::Request;
46use reqwest::Response;
47use reqwest::ResponseBuilderExt;
48use reqwest::header::HeaderMap;
49use reqwest_middleware::Next;
50
51struct MiddlewareRequest<'a, 'b> {
53 uri: Uri,
55 request: Request,
57 next: Next<'a>,
59 extensions: &'b mut Extensions,
61}
62
63impl http_cache_stream::Request<Body> for MiddlewareRequest<'_, '_> {
64 fn version(&self) -> http_cache_stream::http::Version {
65 self.request.version()
66 }
67
68 fn method(&self) -> &http_cache_stream::http::Method {
69 self.request.method()
70 }
71
72 fn uri(&self) -> &http_cache_stream::http::Uri {
73 &self.uri
74 }
75
76 fn headers(&self) -> &http_cache_stream::http::HeaderMap {
77 self.request.headers()
78 }
79
80 async fn send(
81 mut self,
82 headers: Option<http_cache_stream::http::HeaderMap>,
83 ) -> anyhow::Result<http_cache_stream::http::Response<Body>> {
84 if let Some(headers) = headers {
86 self.request.headers_mut().extend(headers);
87 }
88
89 let mut response = self.next.run(self.request, self.extensions).await?;
91
92 let mut builder =
94 http_cache_stream::http::Response::builder()
95 .version(response.version())
96 .status(response.status())
97 .url(response.url().as_str().parse().with_context(|| {
98 format!("invalid response URL `{url}`", url = response.url())
99 })?);
100
101 let headers = std::mem::take(response.headers_mut());
102 builder
103 .headers_mut()
104 .expect("should have headers")
105 .extend(headers);
106 builder
107 .body(response.into())
108 .context("failed to create response")
109 }
110}
111
112pub struct Cache<S>(http_cache_stream::Cache<S>);
114
115impl<S: CacheStorage> Cache<S> {
116 pub fn new(storage: S) -> Self {
118 Self(http_cache_stream::Cache::new(storage))
119 }
120
121 pub fn new_with_options(storage: S, options: CacheOptions) -> Self {
123 Self(http_cache_stream::Cache::new_with_options(storage, options))
124 }
125
126 pub fn with_revalidation_hook(
137 mut self,
138 hook: impl Fn(&dyn semantics::RequestLike, &mut HeaderMap) -> Result<()> + Send + Sync + 'static,
139 ) -> Self {
140 self.0 = self.0.with_revalidation_hook(hook);
141 self
142 }
143
144 pub fn storage(&self) -> &S {
146 self.0.storage()
147 }
148}
149
150impl<S: CacheStorage> reqwest_middleware::Middleware for Cache<S> {
151 fn handle<'a, 'b, 'c, 'd>(
152 &'a self,
153 req: Request,
154 extensions: &'b mut Extensions,
155 next: Next<'c>,
156 ) -> BoxFuture<'d, reqwest_middleware::Result<Response>>
157 where
158 'a: 'd,
159 'b: 'd,
160 'c: 'd,
161 Self: 'd,
162 {
163 async {
164 let request = MiddlewareRequest {
165 uri: req.url().as_str().parse().map_err(|e| {
166 anyhow::anyhow!("URL `{url}` is not valid: {e}", url = req.url())
167 })?,
168 request: req,
169 next,
170 extensions,
171 };
172
173 let response = self
174 .0
175 .send(request)
176 .await
177 .map(|r| r.map(|b| Body::wrap_stream(BodyDataStream::new(b))).into())?;
178 Ok(response)
179 }
180 .boxed()
181 }
182}
183
184#[cfg(test)]
185mod test {
186 use std::sync::Arc;
187 use std::sync::Mutex;
188
189 use http_cache_stream::http;
190 use http_cache_stream::storage::DefaultCacheStorage;
191 use reqwest::Response;
192 use reqwest::StatusCode;
193 use reqwest::header;
194 use reqwest_middleware::ClientWithMiddleware;
195 use reqwest_middleware::Middleware;
196 use tempfile::tempdir;
197
198 use super::*;
199
200 struct MockMiddlewareState {
201 responses: Vec<Option<Response>>,
202 current: usize,
203 }
204
205 struct MockMiddleware(Mutex<MockMiddlewareState>);
206
207 impl MockMiddleware {
208 fn new<R>(responses: impl IntoIterator<Item = R>) -> Self
209 where
210 R: Into<Response>,
211 {
212 Self(Mutex::new(MockMiddlewareState {
213 responses: responses.into_iter().map(|r| Some(r.into())).collect(),
214 current: 0,
215 }))
216 }
217 }
218
219 impl Middleware for MockMiddleware {
220 fn handle<'a, 'b, 'c, 'd>(
221 &'a self,
222 _: Request,
223 _: &'b mut Extensions,
224 _: Next<'c>,
225 ) -> BoxFuture<'d, reqwest_middleware::Result<Response>>
226 where
227 'a: 'd,
228 'b: 'd,
229 'c: 'd,
230 Self: 'd,
231 {
232 async {
233 let mut state = self.0.lock().unwrap();
234
235 let current = state.current;
236 state.current += 1;
237
238 Ok(state
239 .responses
240 .get_mut(current)
241 .expect("unexpected client request: not enough responses defined")
242 .take()
243 .unwrap())
244 }
245 .boxed()
246 }
247 }
248
249 #[tokio::test]
250 async fn no_store() {
251 const BODY: &str = "hello world!";
252 const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
254
255 let dir = tempdir().unwrap();
256 let cache = Arc::new(Cache::new(DefaultCacheStorage::new(dir.path())));
257 let mock = Arc::new(MockMiddleware::new([
258 http::Response::builder()
259 .header(header::CACHE_CONTROL, "no-store")
260 .body(BODY)
261 .unwrap(),
262 http::Response::builder()
263 .header(header::CACHE_CONTROL, "no-store")
264 .body(BODY)
265 .unwrap(),
266 ]));
267 let client = ClientWithMiddleware::new(
268 Default::default(),
269 vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
270 );
271
272 let response = client.get("http://test.local/").send().await.unwrap();
274 assert_eq!(
275 response.headers().get(header::CACHE_CONTROL).unwrap(),
276 "no-store"
277 );
278 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
279 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
280 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
281 assert_eq!(response.text().await.unwrap(), BODY);
282
283 assert!(!cache.storage().body_path(DIGEST).is_file());
285
286 let response = client.get("http://test.local/").send().await.unwrap();
288 assert_eq!(
289 response.headers().get(header::CACHE_CONTROL).unwrap(),
290 "no-store"
291 );
292 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
293 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
294 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
295 assert_eq!(response.text().await.unwrap(), BODY);
296
297 assert!(!cache.storage().body_path(DIGEST).is_file());
299 }
300
301 #[tokio::test]
302 async fn max_age() {
303 const BODY: &str = "hello world!";
304 const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
306
307 let dir = tempdir().unwrap();
308 let cache = Arc::new(
309 Cache::new(DefaultCacheStorage::new(dir.path()))
310 .with_revalidation_hook(|_, _| panic!("a revalidation should not take place")),
311 );
312 let mock = Arc::new(MockMiddleware::new([http::Response::builder()
313 .header(header::CACHE_CONTROL, "max-age=1000")
314 .body(BODY)
315 .unwrap()]));
316 let client = ClientWithMiddleware::new(
317 Default::default(),
318 vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
319 );
320
321 let response = client.get("http://test.local/").send().await.unwrap();
323 assert_eq!(
324 response.headers().get(header::CACHE_CONTROL).unwrap(),
325 "max-age=1000"
326 );
327 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
328 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
329 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
330 assert_eq!(response.text().await.unwrap(), BODY);
331
332 assert!(cache.storage().body_path(DIGEST).is_file());
334
335 let response = client.get("http://test.local/").send().await.unwrap();
339 assert_eq!(
340 response.headers().get(header::CACHE_CONTROL).unwrap(),
341 "max-age=1000"
342 );
343 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
344 assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
345 assert_eq!(
346 response
347 .headers()
348 .get(X_CACHE_DIGEST)
349 .map(|v| v.to_str().unwrap())
350 .unwrap(),
351 DIGEST
352 );
353 assert_eq!(response.text().await.unwrap(), BODY);
354 }
355
356 #[tokio::test]
357 async fn cache_hit_unmodified() {
358 const BODY: &str = "hello world!";
359 const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
361
362 #[derive(Default)]
363 struct State {
364 revalidated: bool,
365 }
366
367 let dir = tempdir().unwrap();
368 let state = Arc::new(Mutex::new(State::default()));
369 let state_clone = state.clone();
370 let cache = Arc::new(
371 Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| {
372 state_clone.lock().unwrap().revalidated = true;
373 Ok(())
374 }),
375 );
376 let mock = Arc::new(MockMiddleware::new([
377 http::Response::builder().body(BODY).unwrap(),
378 http::Response::builder()
379 .status(StatusCode::NOT_MODIFIED)
380 .body("")
381 .unwrap(),
382 ]));
383 let client = ClientWithMiddleware::new(
384 Default::default(),
385 vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
386 );
387
388 let response = client.get("http://test.local/").send().await.unwrap();
390 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
391 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
392 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
393 assert_eq!(response.text().await.unwrap(), BODY);
394
395 assert!(cache.storage().body_path(DIGEST).is_file());
397
398 assert!(!state.lock().unwrap().revalidated);
400
401 let response = client.get("http://test.local/").send().await.unwrap();
403 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
404 assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
405 assert_eq!(
406 response
407 .headers()
408 .get(X_CACHE_DIGEST)
409 .map(|v| v.to_str().unwrap())
410 .unwrap(),
411 DIGEST
412 );
413 assert_eq!(response.text().await.unwrap(), BODY);
414
415 assert!(state.lock().unwrap().revalidated);
417 }
418
419 #[tokio::test]
420 async fn cache_hit_modified() {
421 const BODY: &str = "hello world!";
422 const MODIFIED_BODY: &str = "hello world!!!";
423 const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
425 const MODIFIED_DIGEST: &str =
427 "22b8d362b2e8064356915b1451f630d1d920b427d3b2f9b3432fbf4c03d94184";
428
429 #[derive(Default)]
430 struct State {
431 revalidated: bool,
432 }
433
434 let dir = tempdir().unwrap();
435 let state = Arc::new(Mutex::new(State::default()));
436 let state_clone = state.clone();
437 let cache = Arc::new(
438 Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| {
439 state_clone.lock().unwrap().revalidated = true;
440 Ok(())
441 }),
442 );
443 let mock = Arc::new(MockMiddleware::new([
444 http::Response::builder().body(BODY).unwrap(),
445 http::Response::builder().body(MODIFIED_BODY).unwrap(),
446 http::Response::builder()
447 .status(StatusCode::NOT_MODIFIED)
448 .body("")
449 .unwrap(),
450 ]));
451 let client = ClientWithMiddleware::new(
452 Default::default(),
453 vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
454 );
455
456 let response = client.get("http://test.local/").send().await.unwrap();
458 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
459 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
460 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
461 assert_eq!(response.text().await.unwrap(), BODY);
462
463 assert!(cache.storage().body_path(DIGEST).is_file());
465
466 assert!(!state.lock().unwrap().revalidated);
468
469 let response = client.get("http://test.local/").send().await.unwrap();
471 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
472 assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
473 assert!(response.headers().get(X_CACHE_DIGEST).is_none());
474 assert_eq!(response.text().await.unwrap(), MODIFIED_BODY);
475
476 assert!(cache.storage().body_path(MODIFIED_DIGEST).is_file());
478
479 assert!(std::mem::take(&mut state.lock().unwrap().revalidated));
481
482 let response = client.get("http://test.local/").send().await.unwrap();
484 assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
485 assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
486 assert_eq!(
487 response
488 .headers()
489 .get(X_CACHE_DIGEST)
490 .map(|v| v.to_str().unwrap())
491 .unwrap(),
492 MODIFIED_DIGEST
493 );
494 assert_eq!(response.text().await.unwrap(), MODIFIED_BODY);
495
496 assert!(state.lock().unwrap().revalidated);
498 }
499}