1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
79#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
80#![cfg_attr(docsrs, feature(doc_cfg))]
81
82use std::borrow::Borrow;
83use std::collections::VecDeque;
84use std::error::Error as StdError;
85use std::fmt::{self, Debug, Formatter};
86use std::hash::Hash;
87
88use bytes::Bytes;
89use salvo_core::handler::Skipper;
90use salvo_core::http::{HeaderMap, ResBody, StatusCode};
91use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
92
93mod skipper;
94pub use skipper::MethodSkipper;
95
96#[macro_use]
97mod cfg;
98
99cfg_feature! {
100 #![feature = "moka-store"]
101
102 pub mod moka_store;
103 pub use moka_store::{MokaStore};
104}
105
106pub trait CacheIssuer: Send + Sync + 'static {
108 type Key: Hash + Eq + Send + Sync + 'static;
110 fn issue(
112 &self,
113 req: &mut Request,
114 depot: &Depot,
115 ) -> impl Future<Output = Option<Self::Key>> + Send;
116}
117impl<F, K> CacheIssuer for F
118where
119 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
120 K: Hash + Eq + Send + Sync + 'static,
121{
122 type Key = K;
123 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
124 (self)(req, depot)
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct RequestIssuer {
131 use_scheme: bool,
132 use_authority: bool,
133 use_path: bool,
134 use_query: bool,
135 use_method: bool,
136}
137impl Default for RequestIssuer {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142impl RequestIssuer {
143 #[must_use]
145 pub fn new() -> Self {
146 Self {
147 use_scheme: true,
148 use_authority: true,
149 use_path: true,
150 use_query: true,
151 use_method: true,
152 }
153 }
154 #[must_use]
156 pub fn use_scheme(mut self, value: bool) -> Self {
157 self.use_scheme = value;
158 self
159 }
160 #[must_use]
162 pub fn use_authority(mut self, value: bool) -> Self {
163 self.use_authority = value;
164 self
165 }
166 #[must_use]
168 pub fn use_path(mut self, value: bool) -> Self {
169 self.use_path = value;
170 self
171 }
172 #[must_use]
174 pub fn use_query(mut self, value: bool) -> Self {
175 self.use_query = value;
176 self
177 }
178 #[must_use]
180 pub fn use_method(mut self, value: bool) -> Self {
181 self.use_method = value;
182 self
183 }
184}
185
186impl CacheIssuer for RequestIssuer {
187 type Key = String;
188 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
189 let mut key = String::new();
190 if self.use_scheme
191 && let Some(scheme) = req.uri().scheme_str()
192 {
193 key.push_str(scheme);
194 key.push_str("://");
195 }
196 if self.use_authority
197 && let Some(authority) = req.uri().authority()
198 {
199 key.push_str(authority.as_str());
200 }
201 if self.use_path {
202 key.push_str(req.uri().path());
203 }
204 if self.use_query
205 && let Some(query) = req.uri().query()
206 {
207 key.push('?');
208 key.push_str(query);
209 }
210 if self.use_method {
211 key.push('|');
212 key.push_str(req.method().as_str());
213 }
214 Some(key)
215 }
216}
217
218pub trait CacheStore: Send + Sync + 'static {
220 type Error: StdError + Sync + Send + 'static;
222 type Key: Hash + Eq + Send + Clone + 'static;
224 fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
226 where
227 Self::Key: Borrow<Q>,
228 Q: Hash + Eq + Sync;
229 fn save_entry(
231 &self,
232 key: Self::Key,
233 data: CachedEntry,
234 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
235}
236
237#[derive(Clone, Debug, PartialEq)]
242#[non_exhaustive]
243pub enum CachedBody {
244 None,
246 Once(Bytes),
248 Chunks(VecDeque<Bytes>),
250}
251impl TryFrom<&ResBody> for CachedBody {
252 type Error = Error;
253 fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
254 match body {
255 ResBody::None => Ok(Self::None),
256 ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
257 ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
258 _ => Err(Error::other("unsupported body type")),
259 }
260 }
261}
262impl From<CachedBody> for ResBody {
263 fn from(body: CachedBody) -> Self {
264 match body {
265 CachedBody::None => Self::None,
266 CachedBody::Once(bytes) => Self::Once(bytes),
267 CachedBody::Chunks(chunks) => Self::Chunks(chunks),
268 }
269 }
270}
271
272#[derive(Clone, Debug)]
274#[non_exhaustive]
275pub struct CachedEntry {
276 pub status: Option<StatusCode>,
278 pub headers: HeaderMap,
280 pub body: CachedBody,
284}
285impl CachedEntry {
286 pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
288 Self {
289 status,
290 headers,
291 body,
292 }
293 }
294
295 pub fn status(&self) -> Option<StatusCode> {
297 self.status
298 }
299
300 pub fn headers(&self) -> &HeaderMap {
302 &self.headers
303 }
304
305 pub fn body(&self) -> &CachedBody {
309 &self.body
310 }
311}
312
313#[non_exhaustive]
332pub struct Cache<S, I> {
333 pub store: S,
335 pub issuer: I,
337 pub skipper: Box<dyn Skipper>,
339}
340impl<S, I> Debug for Cache<S, I>
341where
342 S: Debug,
343 I: Debug,
344{
345 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
346 f.debug_struct("Cache")
347 .field("store", &self.store)
348 .field("issuer", &self.issuer)
349 .finish()
350 }
351}
352
353impl<S, I> Cache<S, I> {
354 #[inline]
356 #[must_use]
357 pub fn new(store: S, issuer: I) -> Self {
358 let skipper = MethodSkipper::new().skip_all().skip_get(false);
359 Self {
360 store,
361 issuer,
362 skipper: Box::new(skipper),
363 }
364 }
365 #[inline]
367 #[must_use]
368 pub fn skipper(mut self, skipper: impl Skipper) -> Self {
369 self.skipper = Box::new(skipper);
370 self
371 }
372}
373
374#[async_trait]
375impl<S, I> Handler for Cache<S, I>
376where
377 S: CacheStore<Key = I::Key>,
378 I: CacheIssuer,
379{
380 async fn handle(
381 &self,
382 req: &mut Request,
383 depot: &mut Depot,
384 res: &mut Response,
385 ctrl: &mut FlowCtrl,
386 ) {
387 if self.skipper.skipped(req, depot) {
388 return;
389 }
390 let Some(key) = self.issuer.issue(req, depot).await else {
391 return;
392 };
393 let Some(cache) = self.store.load_entry(&key).await else {
394 ctrl.call_next(req, depot, res).await;
395 if !res.body.is_stream() && !res.body.is_error() {
396 let headers = res.headers().clone();
397 let body = TryInto::<CachedBody>::try_into(&res.body);
398 match body {
399 Ok(body) => {
400 let cached_data = CachedEntry::new(res.status_code, headers, body);
401 if let Err(e) = self.store.save_entry(key, cached_data).await {
402 tracing::error!(error = ?e, "cache failed");
403 }
404 }
405 Err(e) => tracing::error!(error = ?e, "cache failed"),
406 }
407 }
408 return;
409 };
410 let CachedEntry {
411 status,
412 headers,
413 body,
414 } = cache;
415 if let Some(status) = status {
416 res.status_code(status);
417 }
418 *res.headers_mut() = headers;
419 *res.body_mut() = body.into();
420 ctrl.skip_rest();
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use salvo_core::prelude::*;
427 use salvo_core::test::{ResponseExt, TestClient};
428 use time::OffsetDateTime;
429
430 use super::*;
431
432 #[handler]
433 async fn cached() -> String {
434 format!(
435 "Hello World, my birth time is {}",
436 OffsetDateTime::now_utc()
437 )
438 }
439
440 #[tokio::test]
441 async fn test_cache() {
442 let cache = Cache::new(
443 MokaStore::builder()
444 .time_to_live(std::time::Duration::from_secs(5))
445 .build(),
446 RequestIssuer::default(),
447 );
448 let router = Router::new().hoop(cache).goal(cached);
449 let service = Service::new(router);
450
451 let mut res = TestClient::get("http://127.0.0.1:5801")
452 .send(&service)
453 .await;
454 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
455
456 let content0 = res.take_string().await.unwrap();
457
458 let mut res = TestClient::get("http://127.0.0.1:5801")
459 .send(&service)
460 .await;
461 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
462
463 let content1 = res.take_string().await.unwrap();
464 assert_eq!(content0, content1);
465
466 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
467 let mut res = TestClient::post("http://127.0.0.1:5801")
468 .send(&service)
469 .await;
470 let content2 = res.take_string().await.unwrap();
471
472 assert_ne!(content0, content2);
473 }
474}