reqwest_cache/
lib.rs

1//! **DEPRECATION NOTICE**: This crate is unmaintained. Use `http-cache-reqwest` instead.
2
3use std::{sync::Arc, time::SystemTime};
4
5use bytes::Bytes;
6use chashmap_async::CHashMap;
7pub use http_cache_semantics::CacheOptions;
8use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
9use reqwest::Url;
10use reqwest_middleware::Middleware;
11
12/// Data about an entry in the cache
13#[derive(Debug)]
14struct CacheEntry {
15	/// The cache policy used to check the freshness of the cache.
16	policy: CachePolicy,
17	/// The body of the cached response
18	response: Bytes,
19}
20
21impl CacheEntry {
22	/// Constructs a new `CacheEntry`.
23	pub fn new(policy: CachePolicy, response: Bytes) -> Self {
24		Self { policy, response }
25	}
26}
27
28/// Middleware that caches responses based on [HTTP cache headers].
29///
30/// [HTTP cache heders]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching
31#[derive(Default)]
32pub struct CacheMiddleware {
33	/// The cached responses.
34	cache: Arc<CHashMap<Url, CacheEntry>>,
35	/// Configuration of caching behavior
36	options: CacheOptions,
37}
38
39impl CacheMiddleware {
40	/// Constructs a new `CacheMiddleware`
41	#[must_use]
42	pub fn new() -> Self {
43		Self::default()
44	}
45
46	/// Constructs a new `CacheMiddleware` with the given caching options.
47	#[must_use]
48	pub fn with_options(options: CacheOptions) -> Self {
49		Self { cache: Arc::new(CHashMap::new()), options }
50	}
51}
52
53impl std::fmt::Debug for CacheMiddleware {
54	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55		f.debug_struct("CacheMiddleware")
56			.field("cache", &format!("<{} entries>", self.cache.len()))
57			.field("options", &self.options)
58			.finish()
59	}
60}
61
62#[async_trait::async_trait]
63impl Middleware for CacheMiddleware {
64	async fn handle(
65		&self,
66		mut req: reqwest::Request,
67		extensions: &mut task_local_extensions::Extensions,
68		next: reqwest_middleware::Next<'_>,
69	) -> reqwest_middleware::Result<reqwest::Response> {
70		// Strip the fragment part (the stuff after #) of the URL since is exclusively
71		// client-side and has no bearing on caching
72		let mut url = req.url().clone();
73		url.set_fragment(None);
74
75		if let Some(mut cache) = self.cache.get_mut(&url).await {
76			// Check freshness of the cached response
77			let before = cache.policy.before_request(&req, SystemTime::now());
78			match before {
79				BeforeRequest::Fresh(parts) => {
80					// Cache is fresh, no need to hit the server
81					let response = http::Response::from_parts(parts, cache.response.clone());
82					return Ok(response.into());
83				}
84				BeforeRequest::Stale { request: parts, matches } => {
85					// Cache is stale, validate it.
86					*req.headers_mut() = parts.headers.clone();
87					let response = next.run(req, extensions).await?;
88					let after = cache.policy.after_response(&parts, &response, SystemTime::now());
89					match after {
90						AfterResponse::NotModified(policy, parts) => {
91							// Cached body is still valid.
92							if matches {
93								cache.policy = policy;
94							}
95							let response =
96								http::Response::from_parts(parts, cache.response.clone());
97							return Ok(response.into());
98						}
99						AfterResponse::Modified(policy, parts) => {
100							// Cached body is not valid, update it.
101							if matches {
102								cache.policy = policy;
103							}
104							let body = response.bytes().await?;
105							cache.response = body;
106							let response =
107								http::Response::from_parts(parts, cache.response.clone());
108							return Ok(response.into());
109						}
110					}
111				}
112			}
113		}
114		// Make a `Parts` so that we have something to give the `CachePolicy`
115		// constructor
116		#[allow(clippy::expect_used)]
117		let (mut parts, _) = http::Request::builder()
118			.uri(req.uri())
119			.method(req.method().clone())
120			.version(req.version())
121			.body(())
122			.expect("Builder used correctly")
123			.into_parts();
124		// TODO: Cloning the full header map can get expensive, find a way to avoid
125		// doing this.
126		parts.headers = req.headers().clone();
127		let response = next.run(req, extensions).await?;
128		let policy = CachePolicy::new_options(&parts, &response, SystemTime::now(), self.options);
129		if policy.is_storable() {
130			let response = reqwest_to_http(response).await?;
131			let cache = CacheEntry::new(policy, response.body().clone());
132			self.cache
133				.alter(url, |entry| async move {
134					match entry {
135						None => Some(cache),
136						Some(entry) => {
137							// If a cache entry got added while fetching body, pick the newest.
138							let time = SystemTime::now();
139							if entry.policy.age(time) > cache.policy.age(time) {
140								Some(cache)
141							} else {
142								Some(entry)
143							}
144						}
145					}
146				})
147				.await;
148			return Ok(response.into());
149		}
150		Ok(response)
151	}
152}
153
154/// Convert a [`reqwest::Response`] to an [`http::Response`]
155async fn reqwest_to_http(
156	mut response: reqwest::Response,
157) -> reqwest::Result<http::Response<Bytes>> {
158	let mut http = http::Response::new(Bytes::new());
159	*http.status_mut() = response.status();
160	*http.version_mut() = response.version();
161	std::mem::swap(http.headers_mut(), response.headers_mut());
162	*http.body_mut() = response.bytes().await?;
163	Ok(http)
164}