use axum::extract::FromRequestParts;
use http::request::Parts;
#[derive(Debug, Clone, Default)]
pub struct Cached<T>(pub T);
#[derive(Clone)]
struct CachedEntry<T>(T);
impl<S, T> FromRequestParts<S> for Cached<T>
where
S: Send + Sync,
T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if let Some(value) = parts.extensions.get::<CachedEntry<T>>() {
Ok(Self(value.0.clone()))
} else {
let value = T::from_request_parts(parts, state).await?;
parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
}
axum_core::__impl_deref!(Cached);
#[cfg(test)]
mod tests {
use super::*;
use axum::{http::Request, routing::get, Router};
use std::{
convert::Infallible,
sync::atomic::{AtomicU32, Ordering},
time::Instant,
};
#[tokio::test]
async fn works() {
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Clone, Debug, PartialEq, Eq)]
struct Extractor(Instant);
impl<S> FromRequestParts<S> for Extractor
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}
}
let (mut parts, _) = Request::new(()).into_parts();
let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(first, second);
}
async fn _last_handler_argument() {
async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
let _r: Router = Router::new().route("/", get(handler));
}
}