1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//! TODO: crate documentation

#![deny(trivial_casts, trivial_numeric_casts, unused_extern_crates, unused_qualifications)]
#![warn(
	missing_debug_implementations,
	missing_docs,
	unused_import_braces,
	dead_code,
	clippy::unwrap_used,
	clippy::expect_used,
	clippy::missing_docs_in_private_items
)]

use std::{sync::Arc, time::SystemTime};

use bytes::Bytes;
use chashmap::CHashMap;
pub use http_cache_semantics::CacheOptions;
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
use reqwest::Url;
use reqwest_middleware::Middleware;

/// Data about an entry in the cache
#[derive(Debug)]
struct CacheEntry {
	/// The cache policy used to check the freshness of the cache.
	policy: CachePolicy,
	/// The body of the cached response
	response: Bytes,
}

impl CacheEntry {
	/// Constructs a new `CacheEntry`.
	pub fn new(policy: CachePolicy, response: Bytes) -> Self {
		Self { policy, response }
	}
}

/// Middleware that caches responses based on [HTTP cache headers].
///
/// [HTTP cache heders]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching
#[derive(Debug, Default)]
pub struct CacheMiddleware {
	/// The cached responses.
	cache: Arc<CHashMap<Url, CacheEntry>>,
	/// Configuration of caching behavior
	options: CacheOptions,
}

impl CacheMiddleware {
	/// Constructs a new `CacheMiddleware`
	pub fn new() -> Self {
		Self::default()
	}

	/// Constructs a new `CacheMiddleware` with the given caching options.
	pub fn with_options(options: CacheOptions) -> Self {
		Self { cache: Arc::new(CHashMap::new()), options }
	}
}

#[async_trait::async_trait]
impl Middleware for CacheMiddleware {
	async fn handle(
		&self,
		mut req: reqwest::Request,
		extensions: &mut task_local_extensions::Extensions,
		next: reqwest_middleware::Next<'_>,
	) -> reqwest_middleware::Result<reqwest::Response> {
		// Strip the fragment part (the stuff after #) of the URL since is exclusively
		// client-side and has no bearing on caching
		let mut url = req.url().clone();
		url.set_fragment(None);

		if let Some(mut cache) = self.cache.get_mut(&url) {
			// Check freshness of the cached response
			let before = cache.policy.before_request(&req, SystemTime::now());
			match before {
				BeforeRequest::Fresh(parts) => {
					// Cache is fresh, no need to hit the server
					let response = http::Response::from_parts(parts, cache.response.clone());
					return Ok(response.into());
				}
				BeforeRequest::Stale { request: parts, matches } => {
					// Cache is stale, validate it.
					*req.headers_mut() = parts.headers.clone();
					let response = next.run(req, extensions).await?;
					let after = cache.policy.after_response(&parts, &response, SystemTime::now());
					match after {
						AfterResponse::NotModified(policy, parts) => {
							// Cached body is still valid.
							if matches {
								cache.policy = policy;
							}
							let response =
								http::Response::from_parts(parts, cache.response.clone());
							return Ok(response.into());
						}
						AfterResponse::Modified(policy, parts) => {
							// Cached body is not valid, update it.
							if matches {
								cache.policy = policy;
							}
							let body = response.bytes().await?;
							cache.response = body;
							let response =
								http::Response::from_parts(parts, cache.response.clone());
							return Ok(response.into());
						}
					}
				}
			}
		}
		// Make a `Parts` so that we have something to give the `CachePolicy`
		// constructor
		#[allow(clippy::expect_used)]
		let (mut parts, _) = http::Request::builder()
			.uri(req.uri())
			.method(req.method().clone())
			.version(req.version())
			.body(())
			.expect("Builder used correctly")
			.into_parts();
		// TODO: Cloning the full header map can get expensive, find a way to avoid
		// doing this.
		parts.headers = req.headers().clone();
		let response = next.run(req, extensions).await?;
		let policy = CachePolicy::new_options(&parts, &response, SystemTime::now(), self.options);
		if policy.is_storable() {
			let response = reqwest_to_http(response).await?;
			let cache = CacheEntry::new(policy, response.body().clone());
			self.cache.alter(url, move |entry| match entry {
				None => Some(cache),
				Some(entry) => {
					// If a cache entry got added while fetching body, pick the newest.
					let time = SystemTime::now();
					if entry.policy.age(time) > cache.policy.age(time) {
						Some(cache)
					} else {
						Some(entry)
					}
				}
			});
			return Ok(response.into());
		}
		Ok(response)
	}
}

/// Convert a [`reqwest::Response`] to an [`http::Response`]
async fn reqwest_to_http(
	mut response: reqwest::Response,
) -> reqwest::Result<http::Response<Bytes>> {
	let mut http = http::Response::new(Bytes::new());
	*http.status_mut() = response.status();
	*http.version_mut() = response.version();
	std::mem::swap(http.headers_mut(), response.headers_mut());
	*http.body_mut() = response.bytes().await?;
	Ok(http)
}