http_cache_stream_reqwest/
lib.rs

1//! An implementation of a [`reqwest`][reqwest] middleware that uses
2//! [`http-cache-stream`][http-cache-stream].
3//!
4//! ```no_run
5//! use http_cache_stream_reqwest::Cache;
6//! use http_cache_stream_reqwest::storage::DefaultCacheStorage;
7//! use reqwest::Client;
8//! use reqwest_middleware::ClientBuilder;
9//! use reqwest_middleware::Result;
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<()> {
13//!     let client = ClientBuilder::new(Client::new())
14//!         .with(Cache::new(DefaultCacheStorage::new("./cache")))
15//!         .build();
16//!     client.get("https://example.com").send().await?;
17//!     Ok(())
18//! }
19//! ```
20//!
21//! [reqwest]: https://github.com/seanmonstar/reqwest
22//! [http-cache-stream]: https://github.com/stjude-rust-labs/http-cache-stream
23
24#![warn(missing_docs)]
25#![warn(rust_2018_idioms)]
26#![warn(rust_2021_compatibility)]
27#![warn(clippy::missing_docs_in_private_items)]
28#![warn(rustdoc::broken_intra_doc_links)]
29
30use std::io;
31use std::pin::Pin;
32use std::task::Context;
33use std::task::Poll;
34
35use anyhow::Context as _;
36use anyhow::Result;
37use bytes::Bytes;
38use futures::FutureExt;
39use futures::future::BoxFuture;
40pub use http_cache_stream::X_CACHE;
41pub use http_cache_stream::X_CACHE_DIGEST;
42pub use http_cache_stream::X_CACHE_LOOKUP;
43use http_cache_stream::http::Extensions;
44use http_cache_stream::http::Uri;
45use http_cache_stream::http_body::Frame;
46pub use http_cache_stream::semantics::CacheOptions;
47pub use http_cache_stream::storage;
48pub use http_cache_stream::storage::CacheStorage;
49use reqwest::Body;
50use reqwest::Request;
51use reqwest::Response;
52use reqwest::ResponseBuilderExt;
53use reqwest_middleware::Next;
54
55pin_project_lite::pin_project! {
56    /// Adapter for [`Body`] to implement `HttpBody`.
57    struct MiddlewareBody {
58        #[pin]
59        body: Body
60    }
61}
62
63impl http_cache_stream::http_body::Body for MiddlewareBody {
64    type Data = Bytes;
65    type Error = io::Error;
66
67    fn poll_frame(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70    ) -> Poll<Option<Result<Frame<Bytes>, Self::Error>>> {
71        // The two body implementations differ on error type, so map it here
72        self.project().body.poll_frame(cx).map_err(io::Error::other)
73    }
74}
75
76impl http_cache_stream::HttpBody for MiddlewareBody {}
77
78/// Represents a request flowing through the cache middleware.
79struct MiddlewareRequest<'a, 'b> {
80    /// The request URI.
81    uri: Uri,
82    /// The request sent to the middleware.
83    request: Request,
84    /// The next middleware to run.
85    next: Next<'a>,
86    /// The request extensions.
87    extensions: &'b mut Extensions,
88}
89
90impl http_cache_stream::Request<MiddlewareBody> for MiddlewareRequest<'_, '_> {
91    fn version(&self) -> http_cache_stream::http::Version {
92        self.request.version()
93    }
94
95    fn method(&self) -> &http_cache_stream::http::Method {
96        self.request.method()
97    }
98
99    fn uri(&self) -> &http_cache_stream::http::Uri {
100        &self.uri
101    }
102
103    fn headers(&self) -> &http_cache_stream::http::HeaderMap {
104        self.request.headers()
105    }
106
107    async fn send(
108        mut self,
109        headers: Option<http_cache_stream::http::HeaderMap>,
110    ) -> anyhow::Result<http_cache_stream::http::Response<MiddlewareBody>> {
111        // Override the specified headers
112        if let Some(headers) = headers {
113            self.request.headers_mut().extend(headers);
114        }
115
116        // Send the response to the next middleware
117        let mut response = self.next.run(self.request, self.extensions).await?;
118
119        // Build a response
120        let mut builder =
121            http_cache_stream::http::Response::builder()
122                .version(response.version())
123                .status(response.status())
124                .url(response.url().as_str().parse().with_context(|| {
125                    format!("invalid response URL `{url}`", url = response.url())
126                })?);
127
128        let headers = std::mem::take(response.headers_mut());
129        builder
130            .headers_mut()
131            .expect("should have headers")
132            .extend(headers);
133        builder
134            .body(MiddlewareBody {
135                body: Body::wrap_stream(response.bytes_stream()),
136            })
137            .context("failed to create response")
138    }
139}
140
141/// Implements a caching middleware for [`reqwest`].
142pub struct Cache<S>(http_cache_stream::Cache<S>);
143
144impl<S: CacheStorage> Cache<S> {
145    /// Constructs a new caching middleware with the given storage.
146    pub fn new(storage: S) -> Self {
147        Self(http_cache_stream::Cache::new(storage))
148    }
149
150    /// Construct a new caching middleware with the given storage and options.
151    pub fn new_with_options(storage: S, options: CacheOptions) -> Self {
152        Self(http_cache_stream::Cache::new_with_options(storage, options))
153    }
154
155    /// Gets the underlying storage of the cache.
156    pub fn storage(&self) -> &S {
157        self.0.storage()
158    }
159}
160
161impl<S: CacheStorage> reqwest_middleware::Middleware for Cache<S> {
162    fn handle<'a, 'b, 'c, 'd>(
163        &'a self,
164        req: Request,
165        extensions: &'b mut Extensions,
166        next: Next<'c>,
167    ) -> BoxFuture<'d, reqwest_middleware::Result<Response>>
168    where
169        'a: 'd,
170        'b: 'd,
171        'c: 'd,
172        Self: 'd,
173    {
174        async {
175            let request = MiddlewareRequest {
176                uri: req.url().as_str().parse().map_err(|e| {
177                    anyhow::anyhow!("URL `{url}` is not valid: {e}", url = req.url())
178                })?,
179                request: req,
180                next,
181                extensions,
182            };
183
184            let response = self
185                .0
186                .send(request)
187                .await
188                .map(|r| r.map(Body::wrap_stream).into())?;
189            Ok(response)
190        }
191        .boxed()
192    }
193}