Skip to main content

axum_cache_fred/
lib.rs

1//!
2//! The tower-cache-fred middleware is intended to be a route-layer middleware that can cache
3//! responses for specific routes such as in a RESTful API.
4//!
5//! This should not be used to cache routes which require a user-scope or data associated with a
6//! cookie/session associated with the request.
7
8/// Different strategies for generating cache keys
9pub mod strategy;
10
11use axum::body::Body;
12use axum::extract::State;
13use axum::middleware::Next;
14use bytes::Bytes;
15use fred::interfaces::KeysInterface;
16use http::Method;
17use tracing::log::*;
18
19use crate::strategy::{CacheStrategy, RouteKey};
20
21/// Simple trait which the [axum::extract::State] should implement for using the middleware for
22/// cache retrieval
23pub trait CacheState<T>
24where
25    T: KeysInterface,
26{
27    fn cache(&self) -> Option<T>;
28}
29
30/// The caching middleware is applied to all the routes necessary which can support caching. This
31/// only will cache requests on GET HTTP methods to ensure that data modifications are not cached
32/// calls
33pub async fn middleware<T, P>(
34    State(state): State<T>,
35    request: axum::extract::Request,
36    next: Next,
37) -> axum::response::Response
38where
39    T: CacheState<P>,
40    P: KeysInterface,
41{
42    if request.method() != Method::GET {
43        trace!("The request was not a GET, short-circuiting caching");
44        return next.run(request).await;
45    }
46
47    let strategy = RouteKey::default();
48    let (duration, key) = strategy.computed_key(&request);
49
50    match state.cache() {
51        None => {
52            // When there is no cache, pass straight through!
53            info!("No cache configured, request will not be cached");
54            next.run(request).await
55        }
56        Some(cache) => {
57            debug!("This request should be cached");
58            match cache.get::<Option<Bytes>, _>(key.clone()).await {
59                Ok(cached) => {
60                    if let Some(response) = cached {
61                        debug!("cache hit! {response:?}");
62                        return axum::response::Response::new(Body::from(response.clone()));
63                    }
64                }
65                Err(err) => {
66                    error!("Failed to query cache! {err:?}");
67                }
68            }
69
70            trace!("passing middleware along");
71            let response = next.run(request).await;
72            trace!("response fetched!");
73
74            if response.status().is_success() {
75                // cache it!
76                //
77                // The response be consumed from the [Response] which means this code must retrieve the
78                // body and _then_ recreated the [Response] in order to return something sensible to
79                // the caller
80                let mut modified = axum::response::Response::builder().status(response.status());
81                if let Some(headers) = modified.headers_mut() {
82                    for (key, value) in response.headers() {
83                        headers.insert(key, value.clone());
84                    }
85                }
86                let expiration = duration.as_secs().try_into().unwrap();
87                let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
88                    .await
89                    .unwrap();
90                cache
91                    .set::<Bytes, _, _>(
92                        key,
93                        buffer.clone(),
94                        Some(fred::types::Expiration::EX(expiration)),
95                        None,
96                        false,
97                    )
98                    .await
99                    .expect("Failed to set!");
100                return modified.body(Body::from(buffer)).unwrap();
101            }
102            // All other responses fall through
103            response
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    use axum::body::Body;
113    use axum::http::StatusCode;
114    use axum::middleware;
115    use axum::{Router, routing::get};
116    use fred::interfaces::ClientLike;
117    use fred::mocks::SimpleMap;
118    use fred::prelude::Client;
119    use fred::types::Builder;
120    use fred::types::config::Config;
121    use pretty_assertions::assert_eq;
122    use std::sync::{Arc, Mutex};
123    use tower::ServiceExt;
124
125    use axum::extract::{Request, State};
126
127    #[tokio::test]
128    async fn test_axum_layer() -> anyhow::Result<()> {
129        let config = Config {
130            mocks: Some(Arc::new(SimpleMap::new())),
131            ..Default::default()
132        };
133        let valkey = Builder::from_config(config).build()?;
134        valkey.init().await?;
135        let state = Arc::new(TestState::with_valkey(valkey));
136
137        assert_eq!(state.misses(), 0);
138
139        let response = Router::new()
140            .route("/", get(handler))
141            .with_state(state.clone())
142            .layer(middleware::from_fn_with_state(
143                state.clone(),
144                middleware::<Arc<TestState<Client>>, Client>,
145            ))
146            .oneshot(Request::builder().uri("/").body(Body::empty())?)
147            .await?;
148        assert_eq!(response.status(), StatusCode::OK);
149        assert_eq!(state.misses(), 1);
150
151        let response = Router::new()
152            .route("/", get(handler))
153            .with_state(state.clone())
154            //.layer(CacheLayer::with_valkey(valkey))
155            .layer(middleware::from_fn_with_state(
156                state.clone(),
157                middleware::<Arc<TestState<Client>>, Client>,
158            ))
159            .oneshot(Request::builder().uri("/").body(Body::empty())?)
160            .await?;
161        assert_eq!(response.status(), StatusCode::OK);
162        // The cache shouldn't have missed again
163        assert_eq!(state.misses(), 1);
164
165        Ok(())
166    }
167
168    #[derive(Clone)]
169    struct TestState<P>
170    where
171        P: KeysInterface,
172    {
173        miss: Arc<Mutex<u64>>,
174        valkey: P,
175    }
176
177    impl<P> TestState<P>
178    where
179        P: KeysInterface,
180    {
181        fn with_valkey(valkey: P) -> Self {
182            Self {
183                miss: Arc::new(Mutex::new(0)),
184                valkey,
185            }
186        }
187        fn incr(&self) {
188            if let Ok(mut h) = self.miss.lock() {
189                (*h) += 1;
190            }
191        }
192
193        fn misses(&self) -> u64 {
194            *(self
195                .miss
196                .lock()
197                .expect("Failed to unlcok the `miss` counter"))
198        }
199    }
200
201    impl<P> CacheState<P> for Arc<TestState<P>>
202    where
203        P: KeysInterface,
204    {
205        fn cache(&self) -> Option<P> {
206            Some(self.valkey.clone())
207        }
208    }
209
210    async fn handler<P>(State(s): State<Arc<TestState<P>>>)
211    where
212        P: KeysInterface,
213    {
214        println!("Invoking the handler!");
215        s.incr();
216    }
217}