reqwest_middleware_cache/managers/
cacache.rs

1use std::collections::HashMap;
2use std::convert::{TryFrom, TryInto};
3
4use crate::CacheManager;
5
6use anyhow::{anyhow, Result};
7use http::version::Version;
8use http_cache_semantics::CachePolicy;
9use reqwest::{
10    header::{HeaderName, HeaderValue},
11    Request, Response, ResponseBuilderExt,
12};
13use serde::{Deserialize, Serialize};
14use url::Url;
15
16/// Implements [`CacheManager`] with [`cacache`](https://github.com/zkat/cacache-rs) as the backend.
17#[derive(Debug, Clone)]
18pub struct CACacheManager {
19    /// Directory where the cache will be stored.
20    pub path: String,
21}
22
23impl Default for CACacheManager {
24    fn default() -> Self {
25        CACacheManager {
26            path: "./reqwest-cacache".into(),
27        }
28    }
29}
30
31// HTTP version enum in the http crate does not support serde, hence the modified copy.
32#[derive(Debug, Copy, Clone, Deserialize, Serialize)]
33enum HttpVersion {
34    #[serde(rename = "HTTP/0.9")]
35    Http09,
36    #[serde(rename = "HTTP/1.0")]
37    Http10,
38    #[serde(rename = "HTTP/1.1")]
39    Http11,
40    #[serde(rename = "HTTP/2.0")]
41    H2,
42    #[serde(rename = "HTTP/3.0")]
43    H3,
44}
45
46impl TryFrom<Version> for HttpVersion {
47    type Error = anyhow::Error;
48
49    fn try_from(value: Version) -> Result<Self> {
50        Ok(match value {
51            Version::HTTP_09 => HttpVersion::Http09,
52            Version::HTTP_10 => HttpVersion::Http10,
53            Version::HTTP_11 => HttpVersion::Http11,
54            Version::HTTP_2 => HttpVersion::H2,
55            Version::HTTP_3 => HttpVersion::H3,
56            _ => return Err(anyhow!("Unknown HTTP version")),
57        })
58    }
59}
60
61impl From<HttpVersion> for Version {
62    fn from(value: HttpVersion) -> Self {
63        match value {
64            HttpVersion::Http09 => Version::HTTP_09,
65            HttpVersion::Http10 => Version::HTTP_10,
66            HttpVersion::Http11 => Version::HTTP_11,
67            HttpVersion::H2 => Version::HTTP_2,
68            HttpVersion::H3 => Version::HTTP_3,
69        }
70    }
71}
72
73#[derive(Debug, Deserialize, Serialize)]
74struct Store {
75    response: StoredResponse,
76    policy: CachePolicy,
77}
78
79#[derive(Debug, Deserialize, Serialize)]
80struct StoredResponse {
81    body: Vec<u8>,
82    headers: HashMap<String, String>,
83    status: u16,
84    url: Url,
85    version: HttpVersion,
86}
87
88async fn to_store(res: Response, policy: CachePolicy) -> Result<Store> {
89    let mut headers = HashMap::new();
90    for header in res.headers() {
91        headers.insert(header.0.as_str().to_owned(), header.1.to_str()?.to_owned());
92    }
93    let status = res.status().as_u16();
94    let url = res.url().clone();
95    let version = res.version().try_into()?;
96    let body: Vec<u8> = res.bytes().await?.to_vec();
97    Ok(Store {
98        response: StoredResponse {
99            body,
100            headers,
101            status,
102            url,
103            version,
104        },
105        policy,
106    })
107}
108
109fn from_store(store: &Store) -> Result<Response> {
110    let mut res = http::Response::builder()
111        .status(store.response.status)
112        .url(store.response.url.clone())
113        .version(store.response.version.into())
114        .body(store.response.body.clone())?;
115    for header in &store.response.headers {
116        res.headers_mut().insert(
117            HeaderName::from_lowercase(header.0.clone().as_str().to_lowercase().as_bytes())?,
118            HeaderValue::from_str(header.1.clone().as_str())?,
119        );
120    }
121    Ok(Response::from(res))
122}
123
124fn req_key(req: &Request) -> String {
125    format!("{}:{}", req.method(), req.url())
126}
127
128#[allow(dead_code)]
129impl CACacheManager {
130    /// Clears out the entire cache.
131    pub async fn clear(&self) -> Result<()> {
132        cacache::clear(&self.path).await?;
133        Ok(())
134    }
135}
136
137#[async_trait::async_trait]
138impl CacheManager for CACacheManager {
139    async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>> {
140        let store: Store = match cacache::read(&self.path, &req_key(req)).await {
141            Ok(d) => bincode::deserialize(&d)?,
142            Err(_e) => {
143                return Ok(None);
144            }
145        };
146        Ok(Some((from_store(&store)?, store.policy)))
147    }
148
149    // TODO - This needs some reviewing.
150    async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response> {
151        let status = res.status();
152        let url = res.url().clone();
153        let version = res.version();
154        let headers = res.headers().clone();
155        let data = to_store(res, policy).await?;
156        let bytes = bincode::serialize(&data)?;
157        cacache::write(&self.path, &req_key(req), bytes).await?;
158        let mut ret_res = http::Response::builder()
159            .status(status)
160            .url(url)
161            .version(version)
162            .body(data.response.body)?;
163        for header in headers {
164            ret_res
165                .headers_mut()
166                .insert(header.0.unwrap(), header.1.clone());
167        }
168        *ret_res.version_mut() = version;
169        Ok(Response::from(ret_res))
170    }
171
172    async fn delete(&self, req: &Request) -> Result<()> {
173        cacache::remove(&self.path, &req_key(req)).await?;
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use anyhow::Result;
182    use http::{Method, Response};
183    use reqwest::Request;
184    use std::str::FromStr;
185
186    #[tokio::test]
187    async fn can_cache_response() -> Result<()> {
188        let url = reqwest::Url::from_str("https://example.com")?;
189        let res = Response::new("test");
190        let res = reqwest::Response::from(res);
191        let req = Request::new(Method::GET, url);
192        let policy = CachePolicy::new(&req, &res);
193        let manager = CACacheManager::default();
194        manager.put(&req, res, policy).await?;
195        let data = manager.get(&req).await?;
196        let body = match data {
197            Some(d) => d.0.text().await?,
198            None => String::new(),
199        };
200        assert_eq!(&body, "test");
201        manager.delete(&req).await?;
202        let data = manager.get(&req).await?;
203        assert!(data.is_none());
204        manager.clear().await?;
205        Ok(())
206    }
207}