reqwest_middleware_cache/managers/
cacache.rs1use std::collections::HashMap;
2use std::convert::{TryFrom, TryInto};
3
4use crate::CacheManager;
5
6use anyhow::{anyhow, Result};
7use http::version::Version;
8use http_cache_semantics::CachePolicy;
9use reqwest::{
10 header::{HeaderName, HeaderValue},
11 Request, Response, ResponseBuilderExt,
12};
13use serde::{Deserialize, Serialize};
14use url::Url;
15
16#[derive(Debug, Clone)]
18pub struct CACacheManager {
19 pub path: String,
21}
22
23impl Default for CACacheManager {
24 fn default() -> Self {
25 CACacheManager {
26 path: "./reqwest-cacache".into(),
27 }
28 }
29}
30
31#[derive(Debug, Copy, Clone, Deserialize, Serialize)]
33enum HttpVersion {
34 #[serde(rename = "HTTP/0.9")]
35 Http09,
36 #[serde(rename = "HTTP/1.0")]
37 Http10,
38 #[serde(rename = "HTTP/1.1")]
39 Http11,
40 #[serde(rename = "HTTP/2.0")]
41 H2,
42 #[serde(rename = "HTTP/3.0")]
43 H3,
44}
45
46impl TryFrom<Version> for HttpVersion {
47 type Error = anyhow::Error;
48
49 fn try_from(value: Version) -> Result<Self> {
50 Ok(match value {
51 Version::HTTP_09 => HttpVersion::Http09,
52 Version::HTTP_10 => HttpVersion::Http10,
53 Version::HTTP_11 => HttpVersion::Http11,
54 Version::HTTP_2 => HttpVersion::H2,
55 Version::HTTP_3 => HttpVersion::H3,
56 _ => return Err(anyhow!("Unknown HTTP version")),
57 })
58 }
59}
60
61impl From<HttpVersion> for Version {
62 fn from(value: HttpVersion) -> Self {
63 match value {
64 HttpVersion::Http09 => Version::HTTP_09,
65 HttpVersion::Http10 => Version::HTTP_10,
66 HttpVersion::Http11 => Version::HTTP_11,
67 HttpVersion::H2 => Version::HTTP_2,
68 HttpVersion::H3 => Version::HTTP_3,
69 }
70 }
71}
72
73#[derive(Debug, Deserialize, Serialize)]
74struct Store {
75 response: StoredResponse,
76 policy: CachePolicy,
77}
78
79#[derive(Debug, Deserialize, Serialize)]
80struct StoredResponse {
81 body: Vec<u8>,
82 headers: HashMap<String, String>,
83 status: u16,
84 url: Url,
85 version: HttpVersion,
86}
87
88async fn to_store(res: Response, policy: CachePolicy) -> Result<Store> {
89 let mut headers = HashMap::new();
90 for header in res.headers() {
91 headers.insert(header.0.as_str().to_owned(), header.1.to_str()?.to_owned());
92 }
93 let status = res.status().as_u16();
94 let url = res.url().clone();
95 let version = res.version().try_into()?;
96 let body: Vec<u8> = res.bytes().await?.to_vec();
97 Ok(Store {
98 response: StoredResponse {
99 body,
100 headers,
101 status,
102 url,
103 version,
104 },
105 policy,
106 })
107}
108
109fn from_store(store: &Store) -> Result<Response> {
110 let mut res = http::Response::builder()
111 .status(store.response.status)
112 .url(store.response.url.clone())
113 .version(store.response.version.into())
114 .body(store.response.body.clone())?;
115 for header in &store.response.headers {
116 res.headers_mut().insert(
117 HeaderName::from_lowercase(header.0.clone().as_str().to_lowercase().as_bytes())?,
118 HeaderValue::from_str(header.1.clone().as_str())?,
119 );
120 }
121 Ok(Response::from(res))
122}
123
124fn req_key(req: &Request) -> String {
125 format!("{}:{}", req.method(), req.url())
126}
127
128#[allow(dead_code)]
129impl CACacheManager {
130 pub async fn clear(&self) -> Result<()> {
132 cacache::clear(&self.path).await?;
133 Ok(())
134 }
135}
136
137#[async_trait::async_trait]
138impl CacheManager for CACacheManager {
139 async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>> {
140 let store: Store = match cacache::read(&self.path, &req_key(req)).await {
141 Ok(d) => bincode::deserialize(&d)?,
142 Err(_e) => {
143 return Ok(None);
144 }
145 };
146 Ok(Some((from_store(&store)?, store.policy)))
147 }
148
149 async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response> {
151 let status = res.status();
152 let url = res.url().clone();
153 let version = res.version();
154 let headers = res.headers().clone();
155 let data = to_store(res, policy).await?;
156 let bytes = bincode::serialize(&data)?;
157 cacache::write(&self.path, &req_key(req), bytes).await?;
158 let mut ret_res = http::Response::builder()
159 .status(status)
160 .url(url)
161 .version(version)
162 .body(data.response.body)?;
163 for header in headers {
164 ret_res
165 .headers_mut()
166 .insert(header.0.unwrap(), header.1.clone());
167 }
168 *ret_res.version_mut() = version;
169 Ok(Response::from(ret_res))
170 }
171
172 async fn delete(&self, req: &Request) -> Result<()> {
173 cacache::remove(&self.path, &req_key(req)).await?;
174 Ok(())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use anyhow::Result;
182 use http::{Method, Response};
183 use reqwest::Request;
184 use std::str::FromStr;
185
186 #[tokio::test]
187 async fn can_cache_response() -> Result<()> {
188 let url = reqwest::Url::from_str("https://example.com")?;
189 let res = Response::new("test");
190 let res = reqwest::Response::from(res);
191 let req = Request::new(Method::GET, url);
192 let policy = CachePolicy::new(&req, &res);
193 let manager = CACacheManager::default();
194 manager.put(&req, res, policy).await?;
195 let data = manager.get(&req).await?;
196 let body = match data {
197 Some(d) => d.0.text().await?,
198 None => String::new(),
199 };
200 assert_eq!(&body, "test");
201 manager.delete(&req).await?;
202 let data = manager.get(&req).await?;
203 assert!(data.is_none());
204 manager.clear().await?;
205 Ok(())
206 }
207}