axum_extra/extract/
cached.rs

1use axum::extract::FromRequestParts;
2use http::request::Parts;
3
4/// Cache results of other extractors.
5///
6/// `Cached` wraps another extractor and caches its result in [request extensions].
7///
8/// This is useful if you have a tree of extractors that share common sub-extractors that
9/// you only want to run once, perhaps because they're expensive.
10///
11/// The cache purely type based so you can only cache one value of each type. The cache is also
12/// local to the current request and not reused across requests.
13///
14/// # Example
15///
16/// ```rust
17/// use axum_extra::extract::Cached;
18/// use axum::{
19///     extract::FromRequestParts,
20///     response::{IntoResponse, Response},
21///     http::{StatusCode, request::Parts},
22/// };
23///
24/// #[derive(Clone)]
25/// struct Session { /* ... */ }
26///
27/// impl<S> FromRequestParts<S> for Session
28/// where
29///     S: Send + Sync,
30/// {
31///     type Rejection = (StatusCode, String);
32///
33///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
34///         // load session...
35///         # unimplemented!()
36///     }
37/// }
38///
39/// struct CurrentUser { /* ... */ }
40///
41/// impl<S> FromRequestParts<S> for CurrentUser
42/// where
43///     S: Send + Sync,
44/// {
45///     type Rejection = Response;
46///
47///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48///         // loading a `CurrentUser` requires first loading the `Session`
49///         //
50///         // by using `Cached<Session>` we avoid extracting the session more than
51///         // once, in case other extractors for the same request also loads the session
52///         let session: Session = Cached::<Session>::from_request_parts(parts, state)
53///             .await
54///             .map_err(|err| err.into_response())?
55///             .0;
56///
57///         // load user from session...
58///         # unimplemented!()
59///     }
60/// }
61///
62/// // handler that extracts the current user and the session
63/// //
64/// // the session will only be loaded once, even though `CurrentUser`
65/// // also loads it
66/// async fn handler(
67///     current_user: CurrentUser,
68///     // we have to use `Cached<Session>` here otherwise the
69///     // cached session would not be used
70///     Cached(session): Cached<Session>,
71/// ) {
72///     // ...
73/// }
74/// ```
75///
76/// [request extensions]: http::Extensions
77#[derive(Debug, Clone, Default)]
78pub struct Cached<T>(pub T);
79
80#[derive(Clone)]
81struct CachedEntry<T>(T);
82
83impl<S, T> FromRequestParts<S> for Cached<T>
84where
85    S: Send + Sync,
86    T: FromRequestParts<S> + Clone + Send + Sync + 'static,
87{
88    type Rejection = T::Rejection;
89
90    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
91        if let Some(value) = parts.extensions.get::<CachedEntry<T>>() {
92            Ok(Self(value.0.clone()))
93        } else {
94            let value = T::from_request_parts(parts, state).await?;
95            parts.extensions.insert(CachedEntry(value.clone()));
96            Ok(Self(value))
97        }
98    }
99}
100
101axum_core::__impl_deref!(Cached);
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use axum::{http::Request, routing::get, Router};
107    use std::{
108        convert::Infallible,
109        sync::atomic::{AtomicU32, Ordering},
110        time::Instant,
111    };
112
113    #[tokio::test]
114    async fn works() {
115        static COUNTER: AtomicU32 = AtomicU32::new(0);
116
117        #[derive(Clone, Debug, PartialEq, Eq)]
118        struct Extractor(Instant);
119
120        impl<S> FromRequestParts<S> for Extractor
121        where
122            S: Send + Sync,
123        {
124            type Rejection = Infallible;
125
126            async fn from_request_parts(
127                _parts: &mut Parts,
128                _state: &S,
129            ) -> Result<Self, Self::Rejection> {
130                COUNTER.fetch_add(1, Ordering::SeqCst);
131                Ok(Self(Instant::now()))
132            }
133        }
134
135        let (mut parts, _) = Request::new(()).into_parts();
136
137        let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
138            .await
139            .unwrap()
140            .0;
141        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
142
143        let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
144            .await
145            .unwrap()
146            .0;
147        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
148
149        assert_eq!(first, second);
150    }
151
152    // Not a #[test], we just want to know this compiles
153    async fn _last_handler_argument() {
154        async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
155        let _r: Router = Router::new().route("/", get(handler));
156    }
157}