reqwest-cache 0.1.2

reqwest-middleware based in-memory HTTP caching middleware
Documentation
//! TODO: crate documentation

#![deny(trivial_casts, trivial_numeric_casts, unused_extern_crates, unused_qualifications)]
#![warn(
	missing_debug_implementations,
	missing_docs,
	unused_import_braces,
	dead_code,
	clippy::unwrap_used,
	clippy::expect_used,
	clippy::missing_docs_in_private_items
)]

use std::{sync::Arc, time::SystemTime};

use bytes::Bytes;
use chashmap_async::CHashMap;
pub use http_cache_semantics::CacheOptions;
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy, RequestLike};
use reqwest::Url;
use reqwest_middleware::Middleware;

/// Data about an entry in the cache
#[derive(Debug)]
struct CacheEntry {
	/// The cache policy used to check the freshness of the cache.
	policy: CachePolicy,
	/// The body of the cached response
	response: Bytes,
}

impl CacheEntry {
	/// Constructs a new `CacheEntry`.
	pub fn new(policy: CachePolicy, response: Bytes) -> Self {
		Self { policy, response }
	}
}

/// Middleware that caches responses based on [HTTP cache headers].
///
/// [HTTP cache heders]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching
#[derive(Default)]
pub struct CacheMiddleware {
	/// The cached responses.
	cache: Arc<CHashMap<Url, CacheEntry>>,
	/// Configuration of caching behavior
	options: CacheOptions,
}

impl CacheMiddleware {
	/// Constructs a new `CacheMiddleware`
	pub fn new() -> Self {
		Self::default()
	}

	/// Constructs a new `CacheMiddleware` with the given caching options.
	pub fn with_options(options: CacheOptions) -> Self {
		Self { cache: Arc::new(CHashMap::new()), options }
	}
}

impl std::fmt::Debug for CacheMiddleware {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		f.debug_struct("CacheMiddleware")
			.field("cache", &format!("<{} entries>", self.cache.len()))
			.field("options", &self.options)
			.finish()
	}
}

#[async_trait::async_trait]
impl Middleware for CacheMiddleware {
	async fn handle(
		&self,
		mut req: reqwest::Request,
		extensions: &mut task_local_extensions::Extensions,
		next: reqwest_middleware::Next<'_>,
	) -> reqwest_middleware::Result<reqwest::Response> {
		// Strip the fragment part (the stuff after #) of the URL since is exclusively
		// client-side and has no bearing on caching
		let mut url = req.url().clone();
		url.set_fragment(None);

		if let Some(mut cache) = self.cache.get_mut(&url).await {
			// Check freshness of the cached response
			let before = cache.policy.before_request(&req, SystemTime::now());
			match before {
				BeforeRequest::Fresh(parts) => {
					// Cache is fresh, no need to hit the server
					let response = http::Response::from_parts(parts, cache.response.clone());
					return Ok(response.into());
				}
				BeforeRequest::Stale { request: parts, matches } => {
					// Cache is stale, validate it.
					*req.headers_mut() = parts.headers.clone();
					let response = next.run(req, extensions).await?;
					let after = cache.policy.after_response(&parts, &response, SystemTime::now());
					match after {
						AfterResponse::NotModified(policy, parts) => {
							// Cached body is still valid.
							if matches {
								cache.policy = policy;
							}
							let response =
								http::Response::from_parts(parts, cache.response.clone());
							return Ok(response.into());
						}
						AfterResponse::Modified(policy, parts) => {
							// Cached body is not valid, update it.
							if matches {
								cache.policy = policy;
							}
							let body = response.bytes().await?;
							cache.response = body;
							let response =
								http::Response::from_parts(parts, cache.response.clone());
							return Ok(response.into());
						}
					}
				}
			}
		}
		// Make a `Parts` so that we have something to give the `CachePolicy`
		// constructor
		#[allow(clippy::expect_used)]
		let (mut parts, _) = http::Request::builder()
			.uri(req.uri())
			.method(req.method().clone())
			.version(req.version())
			.body(())
			.expect("Builder used correctly")
			.into_parts();
		// TODO: Cloning the full header map can get expensive, find a way to avoid
		// doing this.
		parts.headers = req.headers().clone();
		let response = next.run(req, extensions).await?;
		let policy = CachePolicy::new_options(&parts, &response, SystemTime::now(), self.options);
		if policy.is_storable() {
			let response = reqwest_to_http(response).await?;
			let cache = CacheEntry::new(policy, response.body().clone());
			self.cache
				.alter(url, |entry| async move {
					match entry {
						None => Some(cache),
						Some(entry) => {
							// If a cache entry got added while fetching body, pick the newest.
							let time = SystemTime::now();
							if entry.policy.age(time) > cache.policy.age(time) {
								Some(cache)
							} else {
								Some(entry)
							}
						}
					}
				})
				.await;
			return Ok(response.into());
		}
		Ok(response)
	}
}

/// Convert a [`reqwest::Response`] to an [`http::Response`]
async fn reqwest_to_http(
	mut response: reqwest::Response,
) -> reqwest::Result<http::Response<Bytes>> {
	let mut http = http::Response::new(Bytes::new());
	*http.status_mut() = response.status();
	*http.version_mut() = response.version();
	std::mem::swap(http.headers_mut(), response.headers_mut());
	*http.body_mut() = response.bytes().await?;
	Ok(http)
}