salvo_cache/
lib.rs

1//! Cache middleware for the Salvo web framework.
2//!
3//! Cache middleware for Salvo designed to intercept responses and cache them.
4//! This middleware will cache the response's StatusCode, Headers, and Body.
5//!
6//! You can define your custom [`CacheIssuer`] to determine which responses should be cached,
7//! or you can use the default [`RequestIssuer`].
8//!
9//! The default cache store is [`MokaStore`], which is a wrapper of [`moka`].
10//! You can define your own cache store by implementing [`CacheStore`].
11//!
12//! Example: [cache-simple](https://github.com/salvo-rs/salvo/tree/main/examples/cache-simple)
13//! Read more: <https://salvo.rs>
14#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::collections::VecDeque;
20use std::error::Error as StdError;
21use std::hash::Hash;
22
23use bytes::Bytes;
24use salvo_core::handler::Skipper;
25use salvo_core::http::{HeaderMap, ResBody, StatusCode};
26use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
27
28mod skipper;
29pub use skipper::MethodSkipper;
30
31#[macro_use]
32mod cfg;
33
34cfg_feature! {
35    #![feature = "moka-store"]
36
37    pub mod moka_store;
38    pub use moka_store::{MokaStore};
39}
40
41/// Issuer
42pub trait CacheIssuer: Send + Sync + 'static {
43    /// The key is used to identify the rate limit.
44    type Key: Hash + Eq + Send + Sync + 'static;
45    /// Issue a new key for the request. If it returns `None`, the request will not be cached.
46    fn issue(
47        &self,
48        req: &mut Request,
49        depot: &Depot,
50    ) -> impl Future<Output = Option<Self::Key>> + Send;
51}
52impl<F, K> CacheIssuer for F
53where
54    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
55    K: Hash + Eq + Send + Sync + 'static,
56{
57    type Key = K;
58    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
59        (self)(req, depot)
60    }
61}
62
63/// Identify user by Request Uri.
64#[derive(Clone, Debug)]
65pub struct RequestIssuer {
66    use_scheme: bool,
67    use_authority: bool,
68    use_path: bool,
69    use_query: bool,
70    use_method: bool,
71}
72impl Default for RequestIssuer {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77impl RequestIssuer {
78    /// Create a new `RequestIssuer`.
79    pub fn new() -> Self {
80        Self {
81            use_scheme: true,
82            use_authority: true,
83            use_path: true,
84            use_query: true,
85            use_method: true,
86        }
87    }
88    /// Whether to use the request's URI scheme when generating the key.
89    pub fn use_scheme(mut self, value: bool) -> Self {
90        self.use_scheme = value;
91        self
92    }
93    /// Whether to use the request's URI authority when generating the key.
94    pub fn use_authority(mut self, value: bool) -> Self {
95        self.use_authority = value;
96        self
97    }
98    /// Whether to use the request's URI path when generating the key.
99    pub fn use_path(mut self, value: bool) -> Self {
100        self.use_path = value;
101        self
102    }
103    /// Whether to use the request's URI query when generating the key.
104    pub fn use_query(mut self, value: bool) -> Self {
105        self.use_query = value;
106        self
107    }
108    /// Whether to use the request method when generating the key.
109    pub fn use_method(mut self, value: bool) -> Self {
110        self.use_method = value;
111        self
112    }
113}
114
115impl CacheIssuer for RequestIssuer {
116    type Key = String;
117    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
118        let mut key = String::new();
119        if self.use_scheme {
120            if let Some(scheme) = req.uri().scheme_str() {
121                key.push_str(scheme);
122                key.push_str("://");
123            }
124        }
125        if self.use_authority {
126            if let Some(authority) = req.uri().authority() {
127                key.push_str(authority.as_str());
128            }
129        }
130        if self.use_path {
131            key.push_str(req.uri().path());
132        }
133        if self.use_query {
134            if let Some(query) = req.uri().query() {
135                key.push('?');
136                key.push_str(query);
137            }
138        }
139        if self.use_method {
140            key.push('|');
141            key.push_str(req.method().as_str());
142        }
143        Some(key)
144    }
145}
146
147/// Store cache.
148pub trait CacheStore: Send + Sync + 'static {
149    /// Error type for CacheStore.
150    type Error: StdError + Sync + Send + 'static;
151    /// Key
152    type Key: Hash + Eq + Send + Clone + 'static;
153    /// Get the cache item from the store.
154    fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
155    where
156        Self::Key: Borrow<Q>,
157        Q: Hash + Eq + Sync;
158    /// Save the cache item to the store.
159    fn save_entry(
160        &self,
161        key: Self::Key,
162        data: CachedEntry,
163    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
164}
165
166/// `CachedBody` is used to save the response body to `CacheStore`.
167///
168/// [`ResBody`] has a Stream type, which is not `Send + Sync`, so we need to convert it to `CachedBody`.
169/// If the response's body is [`ResBody::Stream`], it will not be cached.
170#[derive(Clone, Debug)]
171#[non_exhaustive]
172pub enum CachedBody {
173    /// No body.
174    None,
175    /// Single bytes body.
176    Once(Bytes),
177    /// Chunks body.
178    Chunks(VecDeque<Bytes>),
179}
180impl TryFrom<&ResBody> for CachedBody {
181    type Error = Error;
182    fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
183        match body {
184            ResBody::None => Ok(Self::None),
185            ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
186            ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
187            _ => Err(Error::other("unsupported body type")),
188        }
189    }
190}
191impl From<CachedBody> for ResBody {
192    fn from(body: CachedBody) -> Self {
193        match body {
194            CachedBody::None => Self::None,
195            CachedBody::Once(bytes) => Self::Once(bytes),
196            CachedBody::Chunks(chunks) => Self::Chunks(chunks),
197        }
198    }
199}
200
201/// Cached entry which will be stored in the cache store.
202#[derive(Clone, Debug)]
203#[non_exhaustive]
204pub struct CachedEntry {
205    /// Response status.
206    pub status: Option<StatusCode>,
207    /// Response headers.
208    pub headers: HeaderMap,
209    /// Response body.
210    ///
211    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
212    pub body: CachedBody,
213}
214impl CachedEntry {
215    /// Create a new `CachedEntry`.
216    pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
217        Self {
218            status,
219            headers,
220            body,
221        }
222    }
223
224    /// Get the response status.
225    pub fn status(&self) -> Option<StatusCode> {
226        self.status
227    }
228
229    /// Get the response headers.
230    pub fn headers(&self) -> &HeaderMap {
231        &self.headers
232    }
233
234    /// Get the response body.
235    ///
236    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
237    pub fn body(&self) -> &CachedBody {
238        &self.body
239    }
240}
241
242/// Cache middleware.
243///
244/// # Example
245///
246/// ```
247/// use std::time::Duration;
248///
249/// use salvo_core::Router;
250/// use salvo_cache::{Cache, MokaStore, RequestIssuer};
251///
252/// let cache = Cache::new(
253///     MokaStore::builder().time_to_live(Duration::from_secs(60)).build(),
254///     RequestIssuer::default(),
255/// );
256/// let router = Router::new().hoop(cache);
257/// ```
258#[non_exhaustive]
259pub struct Cache<S, I> {
260    /// Cache store.
261    pub store: S,
262    /// Cache issuer.
263    pub issuer: I,
264    /// Skipper.
265    pub skipper: Box<dyn Skipper>,
266}
267
268impl<S, I> Cache<S, I> {
269    /// Create a new `Cache`.
270    #[inline]
271    pub fn new(store: S, issuer: I) -> Self {
272        let skipper = MethodSkipper::new().skip_all().skip_get(false);
273        Cache {
274            store,
275            issuer,
276            skipper: Box::new(skipper),
277        }
278    }
279    /// Sets skipper and returns a new `Cache`.
280    #[inline]
281    pub fn skipper(mut self, skipper: impl Skipper) -> Self {
282        self.skipper = Box::new(skipper);
283        self
284    }
285}
286
287#[async_trait]
288impl<S, I> Handler for Cache<S, I>
289where
290    S: CacheStore<Key = I::Key>,
291    I: CacheIssuer,
292{
293    async fn handle(
294        &self,
295        req: &mut Request,
296        depot: &mut Depot,
297        res: &mut Response,
298        ctrl: &mut FlowCtrl,
299    ) {
300        if self.skipper.skipped(req, depot) {
301            return;
302        }
303        let key = match self.issuer.issue(req, depot).await {
304            Some(key) => key,
305            None => {
306                return;
307            }
308        };
309        let cache = match self.store.load_entry(&key).await {
310            Some(cache) => cache,
311            None => {
312                ctrl.call_next(req, depot, res).await;
313                if !res.body.is_stream() && !res.body.is_error() {
314                    let headers = res.headers().clone();
315                    let body = TryInto::<CachedBody>::try_into(&res.body);
316                    match body {
317                        Ok(body) => {
318                            let cached_data = CachedEntry::new(res.status_code, headers, body);
319                            if let Err(e) = self.store.save_entry(key, cached_data).await {
320                                tracing::error!(error = ?e, "cache failed");
321                            }
322                        }
323                        Err(e) => tracing::error!(error = ?e, "cache failed"),
324                    }
325                }
326                return;
327            }
328        };
329        let CachedEntry {
330            status,
331            headers,
332            body,
333        } = cache;
334        if let Some(status) = status {
335            res.status_code(status);
336        }
337        *res.headers_mut() = headers;
338        *res.body_mut() = body.into();
339        ctrl.skip_rest();
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use salvo_core::prelude::*;
347    use salvo_core::test::{ResponseExt, TestClient};
348    use time::OffsetDateTime;
349
350    #[handler]
351    async fn cached() -> String {
352        format!(
353            "Hello World, my birth time is {}",
354            OffsetDateTime::now_utc()
355        )
356    }
357
358    #[tokio::test]
359    async fn test_cache() {
360        let cache = Cache::new(
361            MokaStore::builder()
362                .time_to_live(std::time::Duration::from_secs(5))
363                .build(),
364            RequestIssuer::default(),
365        );
366        let router = Router::new().hoop(cache).goal(cached);
367        let service = Service::new(router);
368
369        let mut res = TestClient::get("http://127.0.0.1:5801")
370            .send(&service)
371            .await;
372        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
373
374        let content0 = res.take_string().await.unwrap();
375
376        let mut res = TestClient::get("http://127.0.0.1:5801")
377            .send(&service)
378            .await;
379        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
380
381        let content1 = res.take_string().await.unwrap();
382        assert_eq!(content0, content1);
383
384        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
385        let mut res = TestClient::post("http://127.0.0.1:5801")
386            .send(&service)
387            .await;
388        let content2 = res.take_string().await.unwrap();
389
390        assert_ne!(content0, content2);
391    }
392}