Skip to main content

axum_cache_fred/
lib.rs

1//!
2//! The tower-cache-fred middleware is intended to be a route-layer middleware that can cache
3//! responses for specific routes such as in a RESTful API.
4//!
5//! This should not be used to cache routes which require a user-scope or data associated with a
6//! cookie/session associated with the request.
7
8/// Different strategies for generating cache keys
9pub mod strategy;
10
11use axum::body::Body;
12use axum::extract::State;
13use axum::middleware::Next;
14use bytes::Bytes;
15use fred::interfaces::KeysInterface;
16use http::Method;
17use tracing::log::*;
18
19use crate::strategy::{CacheStrategy, RouteKey};
20
21/// Simple trait which the [axum::extract::State] should implement for using the middleware for
22/// cache retrieval
23pub trait CacheState {
24    fn cache(&self) -> Option<impl KeysInterface>;
25}
26
27/// The caching middleware is applied to all the routes necessary which can support caching. This
28/// only will cache requests on GET HTTP methods to ensure that data modifications are not cached
29/// calls
30pub async fn middleware<T>(
31    State(state): State<T>,
32    request: axum::extract::Request,
33    next: Next,
34) -> axum::response::Response
35where
36    T: CacheState,
37{
38    if request.method() != Method::GET {
39        trace!("The request was not a GET, short-circuiting caching");
40        return next.run(request).await;
41    }
42
43    let strategy = RouteKey::default();
44    let (duration, key) = strategy.computed_key(&request);
45
46    match state.cache() {
47        None => {
48            // When there is no cache, pass straight through!
49            info!("No cache configured, request will not be cached");
50            next.run(request).await
51        }
52        Some(cache) => {
53            debug!("This request should be cached");
54            match cache.get::<Option<Bytes>, _>(key.clone()).await {
55                Ok(cached) => {
56                    if let Some(response) = cached {
57                        debug!("cache hit! {response:?}");
58                        return axum::response::Response::new(Body::from(response.clone()));
59                    }
60                }
61                Err(err) => {
62                    error!("Failed to query cache! {err:?}");
63                }
64            }
65
66            trace!("passing middleware along");
67            let response = next.run(request).await;
68            trace!("response fetched!");
69
70            if response.status().is_success() {
71                // cache it!
72                //
73                // The response be consumed from the [Response] which means this code must retrieve the
74                // body and _then_ recreated the [Response] in order to return something sensible to
75                // the caller
76                let mut modified = axum::response::Response::builder().status(response.status());
77                if let Some(headers) = modified.headers_mut() {
78                    for (key, value) in response.headers() {
79                        headers.insert(key, value.clone());
80                    }
81                }
82                let expiration = duration.as_secs().try_into().unwrap();
83                let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
84                    .await
85                    .unwrap();
86                cache
87                    .set::<Bytes, _, _>(
88                        key,
89                        buffer.clone(),
90                        Some(fred::types::Expiration::EX(expiration)),
91                        None,
92                        false,
93                    )
94                    .await
95                    .expect("Failed to set!");
96                return modified.body(Body::from(buffer)).unwrap();
97            }
98            // All other responses fall through
99            response
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    use axum::body::Body;
109    use axum::http::StatusCode;
110    use axum::middleware;
111    use axum::{Router, routing::get};
112    use fred::interfaces::ClientLike;
113    use fred::mocks::SimpleMap;
114    use fred::prelude::Client;
115    use fred::types::Builder;
116    use fred::types::config::Config;
117    use pretty_assertions::assert_eq;
118    use std::sync::{Arc, Mutex};
119    use tower::ServiceExt;
120
121    use axum::extract::{Request, State};
122
123    #[tokio::test]
124    async fn test_axum_layer() -> anyhow::Result<()> {
125        let config = Config {
126            mocks: Some(Arc::new(SimpleMap::new())),
127            ..Default::default()
128        };
129        let valkey = Builder::from_config(config).build()?;
130        valkey.init().await?;
131        let state = Arc::new(TestState::with_valkey(valkey));
132
133        assert_eq!(state.misses(), 0);
134
135        let response = Router::new()
136            .route("/", get(handler))
137            .with_state(state.clone())
138            .layer(middleware::from_fn_with_state(
139                state.clone(),
140                middleware::<Arc<TestState>>,
141            ))
142            .oneshot(Request::builder().uri("/").body(Body::empty())?)
143            .await?;
144        assert_eq!(response.status(), StatusCode::OK);
145        assert_eq!(state.misses(), 1);
146
147        let response = Router::new()
148            .route("/", get(handler))
149            .with_state(state.clone())
150            //.layer(CacheLayer::with_valkey(valkey))
151            .layer(middleware::from_fn_with_state(
152                state.clone(),
153                middleware::<Arc<TestState>>,
154            ))
155            .oneshot(Request::builder().uri("/").body(Body::empty())?)
156            .await?;
157        assert_eq!(response.status(), StatusCode::OK);
158        // The cache shouldn't have missed again
159        assert_eq!(state.misses(), 1);
160
161        Ok(())
162    }
163
164    #[derive(Clone)]
165    struct TestState {
166        miss: Arc<Mutex<u64>>,
167        valkey: Client,
168    }
169
170    impl TestState {
171        fn with_valkey(valkey: Client) -> Self {
172            Self {
173                miss: Arc::new(Mutex::new(0)),
174                valkey,
175            }
176        }
177        fn incr(&self) {
178            if let Ok(mut h) = self.miss.lock() {
179                (*h) += 1;
180            }
181        }
182
183        fn misses(&self) -> u64 {
184            *(self
185                .miss
186                .lock()
187                .expect("Failed to unlcok the `miss` counter"))
188        }
189    }
190
191    impl CacheState for Arc<TestState> {
192        fn cache(&self) -> Option<impl KeysInterface> {
193            Some(self.valkey.clone())
194        }
195    }
196
197    async fn handler(State(s): State<Arc<TestState>>) {
198        println!("Invoking the handler!");
199        s.incr();
200    }
201}