1use std::convert::Infallible;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use axum::body::Body;
27use axum::http::{Method, StatusCode};
28use http::Request;
29use http_body_util::BodyExt;
30use tower::{Layer, Service};
31
32use super::Cache;
33
34#[derive(Clone, serde::Deserialize, serde::Serialize)]
36struct CachedResponse {
37 status: u16,
38 headers: Vec<CachedHeader>,
39 body: Vec<u8>,
40}
41
42#[derive(Clone, serde::Deserialize, serde::Serialize)]
43struct CachedHeader {
44 name: String,
45 value: Vec<u8>,
46}
47
48fn cached_response_from_parts(
49 parts: &http::response::Parts,
50 body: &bytes::Bytes,
51) -> CachedResponse {
52 let headers = parts
53 .headers
54 .iter()
55 .map(|(name, value)| CachedHeader {
56 name: name.as_str().to_owned(),
57 value: value.as_bytes().to_vec(),
58 })
59 .collect();
60
61 CachedResponse {
62 status: parts.status.as_u16(),
63 headers,
64 body: body.to_vec(),
65 }
66}
67
68fn cached_response_into_response(cached: CachedResponse) -> Option<axum::response::Response> {
69 let status = StatusCode::from_u16(cached.status).ok()?;
70 let mut builder = axum::response::Response::builder().status(status);
71 let headers = builder.headers_mut()?;
72
73 for cached_header in cached.headers {
74 let name = http::HeaderName::from_bytes(cached_header.name.as_bytes()).ok()?;
75 let value = http::HeaderValue::from_bytes(&cached_header.value).ok()?;
76 headers.append(name, value);
77 }
78
79 builder.body(Body::from(cached.body)).ok()
80}
81
82#[derive(Clone)]
92pub struct CacheResponseLayer {
93 store: Arc<dyn Cache>,
94}
95
96impl CacheResponseLayer {
97 pub fn from_cache(store: impl Cache + 'static) -> Self {
99 Self {
100 store: Arc::new(store),
101 }
102 }
103
104 pub fn from_shared(store: Arc<dyn Cache>) -> Self {
106 Self { store }
107 }
108
109 #[must_use]
114 pub fn from_app(state: &crate::state::AppState) -> Option<Self> {
115 state.cache().map(Self::from_shared)
116 }
117}
118
119impl<S> Layer<S> for CacheResponseLayer {
120 type Service = CacheResponseService<S>;
121
122 fn layer(&self, inner: S) -> Self::Service {
123 CacheResponseService {
124 inner,
125 store: self.store.clone(),
126 }
127 }
128}
129
130#[derive(Clone)]
132pub struct CacheResponseService<S> {
133 inner: S,
134 store: Arc<dyn Cache>,
135}
136
137impl<S> Service<Request<Body>> for CacheResponseService<S>
138where
139 S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
140 + Clone
141 + Send
142 + 'static,
143 S::Future: Send,
144{
145 type Response = axum::response::Response;
146 type Error = Infallible;
147 type Future = std::pin::Pin<
148 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
149 >;
150
151 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152 self.inner.poll_ready(cx)
153 }
154
155 fn call(&mut self, req: Request<Body>) -> Self::Future {
156 if req.method() != Method::GET {
158 return Box::pin(self.inner.call(req));
159 }
160
161 let mut buf = [0u8; 512];
166 let cache_key_str = {
167 let mut cursor = &mut buf[..];
168 if std::io::Write::write_fmt(&mut cursor, format_args!("http:{}", req.uri())).is_ok() {
169 let len = 512 - cursor.len();
170 std::str::from_utf8(&buf[..len]).unwrap_or_default()
171 } else {
172 ""
173 }
174 };
175
176 let store = self.store.clone();
177
178 let cache_hit = if cache_key_str.is_empty() {
179 super::get_cached::<CachedResponse>(store.as_ref(), &format!("http:{}", req.uri()))
181 } else {
182 super::get_cached::<CachedResponse>(store.as_ref(), cache_key_str)
183 };
184
185 if let Some(cached) = cache_hit
187 && let Some(resp) = cached_response_into_response(cached)
188 {
189 return Box::pin(async move { Ok(resp) });
190 }
191
192 let mut inner = self.inner.clone();
194 let cache_key = if cache_key_str.is_empty() {
195 format!("http:{}", req.uri())
196 } else {
197 cache_key_str.to_owned()
198 };
199
200 Box::pin(async move {
201 let response = inner.call(req).await?;
202
203 if response.status() != StatusCode::OK {
205 return Ok(response);
206 }
207
208 let (parts, body) = response.into_parts();
209
210 let Ok(collected) = body.collect().await else {
212 let resp = axum::response::Response::builder()
213 .status(StatusCode::INTERNAL_SERVER_ERROR)
214 .body(Body::empty())
215 .expect("infallible response builder");
216 return Ok(resp);
217 };
218 let body_bytes = collected.to_bytes();
219
220 let cached = cached_response_from_parts(&parts, &body_bytes);
222 super::insert_cached(store.as_ref(), &cache_key, cached, None);
223
224 let response = axum::response::Response::from_parts(parts, Body::from(body_bytes));
226 Ok(response)
227 })
228 }
229}
230
231#[cfg(all(test, feature = "cache-moka"))]
232mod tests {
233 use super::*;
234 use crate::cache::RawCacheBytes;
235 use std::collections::HashMap;
236 use std::sync::Mutex;
237 use std::sync::atomic::{AtomicUsize, Ordering};
238 use tower::{ServiceBuilder, ServiceExt};
239
240 #[derive(Default)]
241 struct RawOnlyCache {
242 entries: Mutex<HashMap<String, Vec<u8>>>,
243 }
244
245 impl Cache for RawOnlyCache {
246 fn get_value(&self, key: &str) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
247 self.entries
248 .lock()
249 .expect("raw cache lock poisoned")
250 .get(key)
251 .cloned()
252 .map(|bytes| Arc::new(RawCacheBytes(bytes)) as Arc<dyn std::any::Any + Send + Sync>)
253 }
254
255 fn insert_value(&self, _key: &str, _value: Arc<dyn std::any::Any + Send + Sync>) {}
256
257 fn insert_raw_bytes(&self, key: &str, bytes: Vec<u8>, _ttl: Option<std::time::Duration>) {
258 self.entries
259 .lock()
260 .expect("raw cache lock poisoned")
261 .insert(key.to_owned(), bytes);
262 }
263
264 fn invalidate(&self, key: &str) {
265 self.entries
266 .lock()
267 .expect("raw cache lock poisoned")
268 .remove(key);
269 }
270
271 fn clear(&self) {
272 self.entries
273 .lock()
274 .expect("raw cache lock poisoned")
275 .clear();
276 }
277 }
278
279 fn counting_service(
281 counter: Arc<AtomicUsize>,
282 body: &'static str,
283 ) -> impl Service<
284 Request<Body>,
285 Response = axum::response::Response,
286 Error = Infallible,
287 Future = impl std::future::Future<Output = Result<axum::response::Response, Infallible>> + Send,
288 > + Clone
289 + Send
290 + 'static {
291 let body = body.to_owned();
292 tower::service_fn(move |_req: Request<Body>| {
293 let counter = counter.clone();
294 let body = body.clone();
295 async move {
296 counter.fetch_add(1, Ordering::SeqCst);
297 Ok(axum::response::Response::builder()
298 .status(StatusCode::OK)
299 .body(Body::from(body))
300 .expect("infallible response builder"))
301 }
302 })
303 }
304
305 #[tokio::test]
306 async fn caches_get_responses() {
307 let store = super::super::MokaCache::new(100, None);
308 let counter = Arc::new(AtomicUsize::new(0));
309
310 let mut svc = ServiceBuilder::new()
311 .layer(CacheResponseLayer::from_cache(store))
312 .service(counting_service(counter.clone(), "hello"));
313
314 let req = Request::get("/test")
316 .body(Body::empty())
317 .expect("infallible response builder");
318 let resp = svc
319 .ready()
320 .await
321 .expect("infallible response builder")
322 .call(req)
323 .await
324 .expect("infallible response builder");
325 assert_eq!(resp.status(), StatusCode::OK);
326 let body = http_body_util::BodyExt::collect(resp.into_body())
327 .await
328 .expect("infallible response builder")
329 .to_bytes();
330 assert_eq!(body.as_ref(), b"hello");
331 assert_eq!(counter.load(Ordering::SeqCst), 1);
332
333 let req = Request::get("/test")
335 .body(Body::empty())
336 .expect("infallible response builder");
337 let resp = svc
338 .ready()
339 .await
340 .expect("infallible response builder")
341 .call(req)
342 .await
343 .expect("infallible response builder");
344 assert_eq!(resp.status(), StatusCode::OK);
345 let body = http_body_util::BodyExt::collect(resp.into_body())
346 .await
347 .expect("infallible response builder")
348 .to_bytes();
349 assert_eq!(body.as_ref(), b"hello");
350 assert_eq!(
351 counter.load(Ordering::SeqCst),
352 1,
353 "inner should not be called again"
354 );
355 }
356
357 #[tokio::test]
358 async fn caches_get_responses_with_raw_byte_backends() {
359 let store = Arc::new(RawOnlyCache::default());
360 let counter = Arc::new(AtomicUsize::new(0));
361
362 let inner = {
363 let counter = counter.clone();
364 tower::service_fn(move |_req: Request<Body>| {
365 let counter = counter.clone();
366 async move {
367 counter.fetch_add(1, Ordering::SeqCst);
368 Ok::<_, Infallible>(
369 axum::response::Response::builder()
370 .status(StatusCode::OK)
371 .header("x-cache-test", "persisted")
372 .body(Body::from("redis-like"))
373 .expect("infallible response builder"),
374 )
375 }
376 })
377 };
378
379 let mut svc = ServiceBuilder::new()
380 .layer(CacheResponseLayer::from_shared(store))
381 .service(inner);
382
383 let req = Request::get("/redis-backed")
384 .body(Body::empty())
385 .expect("infallible response builder");
386 let resp = svc
387 .ready()
388 .await
389 .expect("infallible response builder")
390 .call(req)
391 .await
392 .expect("infallible response builder");
393 assert_eq!(resp.status(), StatusCode::OK);
394
395 let req = Request::get("/redis-backed")
396 .body(Body::empty())
397 .expect("infallible response builder");
398 let resp = svc
399 .ready()
400 .await
401 .expect("infallible response builder")
402 .call(req)
403 .await
404 .expect("infallible response builder");
405
406 assert_eq!(resp.status(), StatusCode::OK);
407 assert_eq!(
408 resp.headers()
409 .get("x-cache-test")
410 .and_then(|v| v.to_str().ok()),
411 Some("persisted")
412 );
413 let body = http_body_util::BodyExt::collect(resp.into_body())
414 .await
415 .expect("infallible response builder")
416 .to_bytes();
417 assert_eq!(body.as_ref(), b"redis-like");
418 assert_eq!(
419 counter.load(Ordering::SeqCst),
420 1,
421 "raw-byte backends should cache HTTP responses"
422 );
423 }
424
425 #[tokio::test]
426 async fn does_not_cache_post_requests() {
427 let store = super::super::MokaCache::new(100, None);
428 let counter = Arc::new(AtomicUsize::new(0));
429
430 let mut svc = ServiceBuilder::new()
431 .layer(CacheResponseLayer::from_cache(store))
432 .service(counting_service(counter.clone(), "created"));
433
434 let req = Request::post("/items")
435 .body(Body::empty())
436 .expect("infallible response builder");
437 let _resp = svc
438 .ready()
439 .await
440 .expect("infallible response builder")
441 .call(req)
442 .await
443 .expect("infallible response builder");
444 assert_eq!(counter.load(Ordering::SeqCst), 1);
445
446 let req = Request::post("/items")
447 .body(Body::empty())
448 .expect("infallible response builder");
449 let _resp = svc
450 .ready()
451 .await
452 .expect("infallible response builder")
453 .call(req)
454 .await
455 .expect("infallible response builder");
456 assert_eq!(
457 counter.load(Ordering::SeqCst),
458 2,
459 "POST should not be cached"
460 );
461 }
462
463 #[tokio::test]
464 async fn does_not_cache_non_200_responses() {
465 let store = super::super::MokaCache::new(100, None);
466 let counter = Arc::new(AtomicUsize::new(0));
467
468 let svc_inner = {
469 let counter = counter.clone();
470 tower::service_fn(move |_req: Request<Body>| {
471 let counter = counter.clone();
472 async move {
473 counter.fetch_add(1, Ordering::SeqCst);
474 Ok::<_, Infallible>(
475 axum::response::Response::builder()
476 .status(StatusCode::NOT_FOUND)
477 .body(Body::from("not found"))
478 .expect("infallible response builder"),
479 )
480 }
481 })
482 };
483
484 let mut svc = ServiceBuilder::new()
485 .layer(CacheResponseLayer::from_cache(store))
486 .service(svc_inner);
487
488 let req = Request::get("/missing")
489 .body(Body::empty())
490 .expect("infallible response builder");
491 let resp = svc
492 .ready()
493 .await
494 .expect("infallible response builder")
495 .call(req)
496 .await
497 .expect("infallible response builder");
498 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
499
500 let req = Request::get("/missing")
501 .body(Body::empty())
502 .expect("infallible response builder");
503 let resp = svc
504 .ready()
505 .await
506 .expect("infallible response builder")
507 .call(req)
508 .await
509 .expect("infallible response builder");
510 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
511 assert_eq!(
512 counter.load(Ordering::SeqCst),
513 2,
514 "404 should not be cached"
515 );
516 }
517
518 #[tokio::test]
519 async fn different_uris_cached_separately() {
520 let store = super::super::MokaCache::new(100, None);
521 let counter = Arc::new(AtomicUsize::new(0));
522
523 let mut svc = ServiceBuilder::new()
524 .layer(CacheResponseLayer::from_cache(store))
525 .service(counting_service(counter.clone(), "ok"));
526
527 let req = Request::get("/a")
528 .body(Body::empty())
529 .expect("infallible response builder");
530 let _resp = svc
531 .ready()
532 .await
533 .expect("infallible response builder")
534 .call(req)
535 .await
536 .expect("infallible response builder");
537 let req = Request::get("/b")
538 .body(Body::empty())
539 .expect("infallible response builder");
540 let _resp = svc
541 .ready()
542 .await
543 .expect("infallible response builder")
544 .call(req)
545 .await
546 .expect("infallible response builder");
547 assert_eq!(
548 counter.load(Ordering::SeqCst),
549 2,
550 "different URIs should miss"
551 );
552
553 let req = Request::get("/a")
555 .body(Body::empty())
556 .expect("infallible response builder");
557 let _resp = svc
558 .ready()
559 .await
560 .expect("infallible response builder")
561 .call(req)
562 .await
563 .expect("infallible response builder");
564 assert_eq!(counter.load(Ordering::SeqCst), 2, "/a should be cached");
565 }
566
567 #[test]
568 fn from_shared_accepts_arc() {
569 let store = Arc::new(super::super::MokaCache::new(100, None));
570 let _layer = CacheResponseLayer::from_shared(store);
572 }
573
574 #[tokio::test]
575 async fn caches_get_responses_very_long_uri() {
576 let store = super::super::MokaCache::new(100, None);
577 let counter = Arc::new(AtomicUsize::new(0));
578
579 let mut svc = ServiceBuilder::new()
580 .layer(CacheResponseLayer::from_cache(store))
581 .service(counting_service(counter.clone(), "hello"));
582
583 let long_uri = format!("/test/{}", "a".repeat(1000));
584
585 let req1 = Request::get(&long_uri)
586 .body(Body::empty())
587 .expect("infallible response builder");
588
589 let resp1 = svc
590 .ready()
591 .await
592 .expect("infallible response builder")
593 .call(req1)
594 .await
595 .expect("infallible response builder");
596
597 assert_eq!(resp1.status(), StatusCode::OK);
598 assert_eq!(counter.load(Ordering::SeqCst), 1);
599
600 let req2 = Request::get(&long_uri)
601 .body(Body::empty())
602 .expect("infallible response builder");
603
604 let resp2 = svc
605 .ready()
606 .await
607 .expect("infallible response builder")
608 .call(req2)
609 .await
610 .expect("infallible response builder");
611
612 assert_eq!(resp2.status(), StatusCode::OK);
613 assert_eq!(
614 counter.load(Ordering::SeqCst),
615 1,
616 "Should be cached despite long URI"
617 );
618 }
619}