Skip to main content

http_cache_stream_reqwest/
lib.rs

1//! An implementation of a [`reqwest`][reqwest] middleware that uses
2//! [`http-cache-stream`][http-cache-stream].
3//!
4//! ```no_run
5//! use http_cache_stream_reqwest::Cache;
6//! use http_cache_stream_reqwest::storage::DefaultCacheStorage;
7//! use reqwest::Client;
8//! use reqwest_middleware::ClientBuilder;
9//! use reqwest_middleware::Result;
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<()> {
13//!     let client = ClientBuilder::new(Client::new())
14//!         .with(Cache::new(DefaultCacheStorage::new("./cache")))
15//!         .build();
16//!     client.get("https://example.com").send().await?;
17//!     Ok(())
18//! }
19//! ```
20//!
21//! [reqwest]: https://github.com/seanmonstar/reqwest
22//! [http-cache-stream]: https://github.com/stjude-rust-labs/http-cache-stream
23
24#![warn(missing_docs)]
25#![warn(rust_2018_idioms)]
26#![warn(rust_2021_compatibility)]
27#![warn(clippy::missing_docs_in_private_items)]
28#![warn(rustdoc::broken_intra_doc_links)]
29
30use anyhow::Context as _;
31use anyhow::Result;
32use futures::FutureExt;
33use futures::future::BoxFuture;
34use http_body_util::BodyDataStream;
35pub use http_cache_stream::X_CACHE;
36pub use http_cache_stream::X_CACHE_DIGEST;
37pub use http_cache_stream::X_CACHE_LOOKUP;
38use http_cache_stream::http::Extensions;
39use http_cache_stream::http::Uri;
40pub use http_cache_stream::semantics;
41pub use http_cache_stream::semantics::CacheOptions;
42pub use http_cache_stream::storage;
43pub use http_cache_stream::storage::CacheStorage;
44use reqwest::Body;
45use reqwest::Request;
46use reqwest::Response;
47use reqwest::ResponseBuilderExt;
48use reqwest::header::HeaderMap;
49use reqwest_middleware::Next;
50
51/// Represents a request flowing through the cache middleware.
52struct MiddlewareRequest<'a, 'b> {
53    /// The request URI.
54    uri: Uri,
55    /// The request sent to the middleware.
56    request: Request,
57    /// The next middleware to run.
58    next: Next<'a>,
59    /// The request extensions.
60    extensions: &'b mut Extensions,
61}
62
63impl http_cache_stream::Request<Body> for MiddlewareRequest<'_, '_> {
64    fn version(&self) -> http_cache_stream::http::Version {
65        self.request.version()
66    }
67
68    fn method(&self) -> &http_cache_stream::http::Method {
69        self.request.method()
70    }
71
72    fn uri(&self) -> &http_cache_stream::http::Uri {
73        &self.uri
74    }
75
76    fn headers(&self) -> &http_cache_stream::http::HeaderMap {
77        self.request.headers()
78    }
79
80    async fn send(
81        mut self,
82        headers: Option<http_cache_stream::http::HeaderMap>,
83    ) -> anyhow::Result<http_cache_stream::http::Response<Body>> {
84        // Override the specified headers
85        if let Some(headers) = headers {
86            self.request.headers_mut().extend(headers);
87        }
88
89        // Send the response to the next middleware
90        let mut response = self.next.run(self.request, self.extensions).await?;
91
92        // Build a response
93        let mut builder =
94            http_cache_stream::http::Response::builder()
95                .version(response.version())
96                .status(response.status())
97                .url(response.url().as_str().parse().with_context(|| {
98                    format!("invalid response URL `{url}`", url = response.url())
99                })?);
100
101        let headers = std::mem::take(response.headers_mut());
102        builder
103            .headers_mut()
104            .expect("should have headers")
105            .extend(headers);
106        builder
107            .body(response.into())
108            .context("failed to create response")
109    }
110}
111
112/// Implements a caching middleware for [`reqwest`].
113pub struct Cache<S>(http_cache_stream::Cache<S>);
114
115impl<S: CacheStorage> Cache<S> {
116    /// Constructs a new caching middleware with the given storage.
117    pub fn new(storage: S) -> Self {
118        Self(http_cache_stream::Cache::new(storage))
119    }
120
121    /// Construct a new caching middleware with the given storage and options.
122    pub fn new_with_options(storage: S, options: CacheOptions) -> Self {
123        Self(http_cache_stream::Cache::new_with_options(storage, options))
124    }
125
126    /// Sets the revalidation hook to use.
127    ///
128    /// The hook is provided the original request and a mutable header map
129    /// containing headers explicitly set for the revalidation request.
130    ///
131    /// For example, a hook may alter the revalidation headers to update an
132    /// `Authorization` header based on the headers used for revalidation.
133    ///
134    /// If the hook returns an error, the error is propagated out as the result
135    /// of the original request.
136    pub fn with_revalidation_hook(
137        mut self,
138        hook: impl Fn(&dyn semantics::RequestLike, &mut HeaderMap) -> Result<()> + Send + Sync + 'static,
139    ) -> Self {
140        self.0 = self.0.with_revalidation_hook(hook);
141        self
142    }
143
144    /// Gets the underlying storage of the cache.
145    pub fn storage(&self) -> &S {
146        self.0.storage()
147    }
148}
149
150impl<S: CacheStorage> reqwest_middleware::Middleware for Cache<S> {
151    fn handle<'a, 'b, 'c, 'd>(
152        &'a self,
153        req: Request,
154        extensions: &'b mut Extensions,
155        next: Next<'c>,
156    ) -> BoxFuture<'d, reqwest_middleware::Result<Response>>
157    where
158        'a: 'd,
159        'b: 'd,
160        'c: 'd,
161        Self: 'd,
162    {
163        async {
164            let request = MiddlewareRequest {
165                uri: req.url().as_str().parse().map_err(|e| {
166                    anyhow::anyhow!("URL `{url}` is not valid: {e}", url = req.url())
167                })?,
168                request: req,
169                next,
170                extensions,
171            };
172
173            let response = self
174                .0
175                .send(request)
176                .await
177                .map(|r| r.map(|b| Body::wrap_stream(BodyDataStream::new(b))).into())?;
178            Ok(response)
179        }
180        .boxed()
181    }
182}
183
184#[cfg(test)]
185mod test {
186    use std::sync::Arc;
187    use std::sync::Mutex;
188
189    use http_cache_stream::http;
190    use http_cache_stream::storage::DefaultCacheStorage;
191    use reqwest::Response;
192    use reqwest::StatusCode;
193    use reqwest::header;
194    use reqwest_middleware::ClientWithMiddleware;
195    use reqwest_middleware::Middleware;
196    use tempfile::tempdir;
197
198    use super::*;
199
200    struct MockMiddlewareState {
201        responses: Vec<Option<Response>>,
202        current: usize,
203    }
204
205    struct MockMiddleware(Mutex<MockMiddlewareState>);
206
207    impl MockMiddleware {
208        fn new<R>(responses: impl IntoIterator<Item = R>) -> Self
209        where
210            R: Into<Response>,
211        {
212            Self(Mutex::new(MockMiddlewareState {
213                responses: responses.into_iter().map(|r| Some(r.into())).collect(),
214                current: 0,
215            }))
216        }
217    }
218
219    impl Middleware for MockMiddleware {
220        fn handle<'a, 'b, 'c, 'd>(
221            &'a self,
222            _: Request,
223            _: &'b mut Extensions,
224            _: Next<'c>,
225        ) -> BoxFuture<'d, reqwest_middleware::Result<Response>>
226        where
227            'a: 'd,
228            'b: 'd,
229            'c: 'd,
230            Self: 'd,
231        {
232            async {
233                let mut state = self.0.lock().unwrap();
234
235                let current = state.current;
236                state.current += 1;
237
238                Ok(state
239                    .responses
240                    .get_mut(current)
241                    .expect("unexpected client request: not enough responses defined")
242                    .take()
243                    .unwrap())
244            }
245            .boxed()
246        }
247    }
248
249    #[tokio::test]
250    async fn no_store() {
251        const BODY: &str = "hello world!";
252        // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/)
253        const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
254
255        let dir = tempdir().unwrap();
256        let cache = Arc::new(Cache::new(DefaultCacheStorage::new(dir.path())));
257        let mock = Arc::new(MockMiddleware::new([
258            http::Response::builder()
259                .header(header::CACHE_CONTROL, "no-store")
260                .body(BODY)
261                .unwrap(),
262            http::Response::builder()
263                .header(header::CACHE_CONTROL, "no-store")
264                .body(BODY)
265                .unwrap(),
266        ]));
267        let client = ClientWithMiddleware::new(
268            Default::default(),
269            vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
270        );
271
272        // Response should not be served from the cache or stored
273        let response = client.get("http://test.local/").send().await.unwrap();
274        assert_eq!(
275            response.headers().get(header::CACHE_CONTROL).unwrap(),
276            "no-store"
277        );
278        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
279        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
280        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
281        assert_eq!(response.text().await.unwrap(), BODY);
282
283        // Ensure the body wasn't stored in the cache
284        assert!(!cache.storage().body_path(DIGEST).is_file());
285
286        // Response should *still* not be served from the cache or stored
287        let response = client.get("http://test.local/").send().await.unwrap();
288        assert_eq!(
289            response.headers().get(header::CACHE_CONTROL).unwrap(),
290            "no-store"
291        );
292        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
293        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
294        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
295        assert_eq!(response.text().await.unwrap(), BODY);
296
297        // Ensure the body wasn't stored in the cache
298        assert!(!cache.storage().body_path(DIGEST).is_file());
299    }
300
301    #[tokio::test]
302    async fn max_age() {
303        const BODY: &str = "hello world!";
304        // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/)
305        const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
306
307        let dir = tempdir().unwrap();
308        let cache = Arc::new(
309            Cache::new(DefaultCacheStorage::new(dir.path()))
310                .with_revalidation_hook(|_, _| panic!("a revalidation should not take place")),
311        );
312        let mock = Arc::new(MockMiddleware::new([http::Response::builder()
313            .header(header::CACHE_CONTROL, "max-age=1000")
314            .body(BODY)
315            .unwrap()]));
316        let client = ClientWithMiddleware::new(
317            Default::default(),
318            vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
319        );
320
321        // First response should not be served from the cache
322        let response = client.get("http://test.local/").send().await.unwrap();
323        assert_eq!(
324            response.headers().get(header::CACHE_CONTROL).unwrap(),
325            "max-age=1000"
326        );
327        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
328        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
329        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
330        assert_eq!(response.text().await.unwrap(), BODY);
331
332        // Ensure the body was stored in the cache
333        assert!(cache.storage().body_path(DIGEST).is_file());
334
335        // Second response should be served from the cache without revalidation
336        // If a revalidation is made, the mock middleware will panic since there was
337        // only one response defined
338        let response = client.get("http://test.local/").send().await.unwrap();
339        assert_eq!(
340            response.headers().get(header::CACHE_CONTROL).unwrap(),
341            "max-age=1000"
342        );
343        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
344        assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
345        assert_eq!(
346            response
347                .headers()
348                .get(X_CACHE_DIGEST)
349                .map(|v| v.to_str().unwrap())
350                .unwrap(),
351            DIGEST
352        );
353        assert_eq!(response.text().await.unwrap(), BODY);
354    }
355
356    #[tokio::test]
357    async fn cache_hit_unmodified() {
358        const BODY: &str = "hello world!";
359        // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/)
360        const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
361
362        #[derive(Default)]
363        struct State {
364            revalidated: bool,
365        }
366
367        let dir = tempdir().unwrap();
368        let state = Arc::new(Mutex::new(State::default()));
369        let state_clone = state.clone();
370        let cache = Arc::new(
371            Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| {
372                state_clone.lock().unwrap().revalidated = true;
373                Ok(())
374            }),
375        );
376        let mock = Arc::new(MockMiddleware::new([
377            http::Response::builder().body(BODY).unwrap(),
378            http::Response::builder()
379                .status(StatusCode::NOT_MODIFIED)
380                .body("")
381                .unwrap(),
382        ]));
383        let client = ClientWithMiddleware::new(
384            Default::default(),
385            vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
386        );
387
388        // First response should be a miss
389        let response = client.get("http://test.local/").send().await.unwrap();
390        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
391        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
392        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
393        assert_eq!(response.text().await.unwrap(), BODY);
394
395        // Ensure the body was stored in the cache
396        assert!(cache.storage().body_path(DIGEST).is_file());
397
398        // Assert no revalidation took place
399        assert!(!state.lock().unwrap().revalidated);
400
401        // Second response should be served from the cache
402        let response = client.get("http://test.local/").send().await.unwrap();
403        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
404        assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
405        assert_eq!(
406            response
407                .headers()
408                .get(X_CACHE_DIGEST)
409                .map(|v| v.to_str().unwrap())
410                .unwrap(),
411            DIGEST
412        );
413        assert_eq!(response.text().await.unwrap(), BODY);
414
415        // Assert a revalidation took place
416        assert!(state.lock().unwrap().revalidated);
417    }
418
419    #[tokio::test]
420    async fn cache_hit_modified() {
421        const BODY: &str = "hello world!";
422        const MODIFIED_BODY: &str = "hello world!!!";
423        // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/)
424        const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d";
425        // Blake3 digest of the modified body (from https://emn178.github.io/online-tools/blake3/)
426        const MODIFIED_DIGEST: &str =
427            "22b8d362b2e8064356915b1451f630d1d920b427d3b2f9b3432fbf4c03d94184";
428
429        #[derive(Default)]
430        struct State {
431            revalidated: bool,
432        }
433
434        let dir = tempdir().unwrap();
435        let state = Arc::new(Mutex::new(State::default()));
436        let state_clone = state.clone();
437        let cache = Arc::new(
438            Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| {
439                state_clone.lock().unwrap().revalidated = true;
440                Ok(())
441            }),
442        );
443        let mock = Arc::new(MockMiddleware::new([
444            http::Response::builder().body(BODY).unwrap(),
445            http::Response::builder().body(MODIFIED_BODY).unwrap(),
446            http::Response::builder()
447                .status(StatusCode::NOT_MODIFIED)
448                .body("")
449                .unwrap(),
450        ]));
451        let client = ClientWithMiddleware::new(
452            Default::default(),
453            vec![cache.clone() as Arc<dyn Middleware>, mock.clone()],
454        );
455
456        // First response should be a miss
457        let response = client.get("http://test.local/").send().await.unwrap();
458        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS");
459        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
460        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
461        assert_eq!(response.text().await.unwrap(), BODY);
462
463        // Ensure the body was stored in the cache
464        assert!(cache.storage().body_path(DIGEST).is_file());
465
466        // Assert no revalidation took place
467        assert!(!state.lock().unwrap().revalidated);
468
469        // Second response should not be served from the cache (was modified)
470        let response = client.get("http://test.local/").send().await.unwrap();
471        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
472        assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS");
473        assert!(response.headers().get(X_CACHE_DIGEST).is_none());
474        assert_eq!(response.text().await.unwrap(), MODIFIED_BODY);
475
476        // Ensure the body was stored in the cache
477        assert!(cache.storage().body_path(MODIFIED_DIGEST).is_file());
478
479        // Assert a revalidation took place and reset the flag back to false
480        assert!(std::mem::take(&mut state.lock().unwrap().revalidated));
481
482        // Second response should be served from the cache (not modified)
483        let response = client.get("http://test.local/").send().await.unwrap();
484        assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT");
485        assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT");
486        assert_eq!(
487            response
488                .headers()
489                .get(X_CACHE_DIGEST)
490                .map(|v| v.to_str().unwrap())
491                .unwrap(),
492            MODIFIED_DIGEST
493        );
494        assert_eq!(response.text().await.unwrap(), MODIFIED_BODY);
495
496        // Assert a revalidation took place
497        assert!(state.lock().unwrap().revalidated);
498    }
499}