axum_extra/extract/cached.rs
1use axum::extract::FromRequestParts;
2use http::request::Parts;
3
4/// Cache results of other extractors.
5///
6/// `Cached` wraps another extractor and caches its result in [request extensions].
7///
8/// This is useful if you have a tree of extractors that share common sub-extractors that
9/// you only want to run once, perhaps because they're expensive.
10///
11/// The cache purely type based so you can only cache one value of each type. The cache is also
12/// local to the current request and not reused across requests.
13///
14/// # Example
15///
16/// ```rust
17/// use axum_extra::extract::Cached;
18/// use axum::{
19/// extract::FromRequestParts,
20/// response::{IntoResponse, Response},
21/// http::{StatusCode, request::Parts},
22/// };
23///
24/// #[derive(Clone)]
25/// struct Session { /* ... */ }
26///
27/// impl<S> FromRequestParts<S> for Session
28/// where
29/// S: Send + Sync,
30/// {
31/// type Rejection = (StatusCode, String);
32///
33/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
34/// // load session...
35/// # unimplemented!()
36/// }
37/// }
38///
39/// struct CurrentUser { /* ... */ }
40///
41/// impl<S> FromRequestParts<S> for CurrentUser
42/// where
43/// S: Send + Sync,
44/// {
45/// type Rejection = Response;
46///
47/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48/// // loading a `CurrentUser` requires first loading the `Session`
49/// //
50/// // by using `Cached<Session>` we avoid extracting the session more than
51/// // once, in case other extractors for the same request also loads the session
52/// let session: Session = Cached::<Session>::from_request_parts(parts, state)
53/// .await
54/// .map_err(|err| err.into_response())?
55/// .0;
56///
57/// // load user from session...
58/// # unimplemented!()
59/// }
60/// }
61///
62/// // handler that extracts the current user and the session
63/// //
64/// // the session will only be loaded once, even though `CurrentUser`
65/// // also loads it
66/// async fn handler(
67/// current_user: CurrentUser,
68/// // we have to use `Cached<Session>` here otherwise the
69/// // cached session would not be used
70/// Cached(session): Cached<Session>,
71/// ) {
72/// // ...
73/// }
74/// ```
75///
76/// [request extensions]: http::Extensions
77#[derive(Debug, Clone, Default)]
78pub struct Cached<T>(pub T);
79
80#[derive(Clone)]
81struct CachedEntry<T>(T);
82
83impl<S, T> FromRequestParts<S> for Cached<T>
84where
85 S: Send + Sync,
86 T: FromRequestParts<S> + Clone + Send + Sync + 'static,
87{
88 type Rejection = T::Rejection;
89
90 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
91 if let Some(value) = parts.extensions.get::<CachedEntry<T>>() {
92 Ok(Self(value.0.clone()))
93 } else {
94 let value = T::from_request_parts(parts, state).await?;
95 parts.extensions.insert(CachedEntry(value.clone()));
96 Ok(Self(value))
97 }
98 }
99}
100
101axum_core::__impl_deref!(Cached);
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use axum::{http::Request, routing::get, Router};
107 use std::{
108 convert::Infallible,
109 sync::atomic::{AtomicU32, Ordering},
110 time::Instant,
111 };
112
113 #[tokio::test]
114 async fn works() {
115 static COUNTER: AtomicU32 = AtomicU32::new(0);
116
117 #[derive(Clone, Debug, PartialEq, Eq)]
118 struct Extractor(Instant);
119
120 impl<S> FromRequestParts<S> for Extractor
121 where
122 S: Send + Sync,
123 {
124 type Rejection = Infallible;
125
126 async fn from_request_parts(
127 _parts: &mut Parts,
128 _state: &S,
129 ) -> Result<Self, Self::Rejection> {
130 COUNTER.fetch_add(1, Ordering::SeqCst);
131 Ok(Self(Instant::now()))
132 }
133 }
134
135 let (mut parts, _) = Request::new(()).into_parts();
136
137 let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
138 .await
139 .unwrap()
140 .0;
141 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
142
143 let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
144 .await
145 .unwrap()
146 .0;
147 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
148
149 assert_eq!(first, second);
150 }
151
152 // Not a #[test], we just want to know this compiles
153 async fn _last_handler_argument() {
154 async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
155 let _r: Router = Router::new().route("/", get(handler));
156 }
157}