1pub mod strategy;
10
11use axum::body::Body;
12use axum::extract::State;
13use axum::middleware::Next;
14use bytes::Bytes;
15use fred::interfaces::KeysInterface;
16use http::Method;
17use tracing::log::*;
18
19use crate::strategy::{CacheStrategy, RouteKey};
20
21pub trait CacheState {
24 fn cache(&self) -> Option<impl KeysInterface>;
25}
26
27pub async fn middleware<T>(
31 State(state): State<T>,
32 request: axum::extract::Request,
33 next: Next,
34) -> axum::response::Response
35where
36 T: CacheState,
37{
38 if request.method() != Method::GET {
39 trace!("The request was not a GET, short-circuiting caching");
40 return next.run(request).await;
41 }
42
43 let strategy = RouteKey::default();
44 let (duration, key) = strategy.computed_key(&request);
45
46 match state.cache() {
47 None => {
48 info!("No cache configured, request will not be cached");
50 next.run(request).await
51 }
52 Some(cache) => {
53 debug!("This request should be cached");
54 match cache.get::<Option<Bytes>, _>(key.clone()).await {
55 Ok(cached) => {
56 if let Some(response) = cached {
57 debug!("cache hit! {response:?}");
58 return axum::response::Response::new(Body::from(response.clone()));
59 }
60 }
61 Err(err) => {
62 error!("Failed to query cache! {err:?}");
63 }
64 }
65
66 trace!("passing middleware along");
67 let response = next.run(request).await;
68 trace!("response fetched!");
69
70 if response.status().is_success() {
71 let mut modified = axum::response::Response::builder().status(response.status());
77 if let Some(headers) = modified.headers_mut() {
78 for (key, value) in response.headers() {
79 headers.insert(key, value.clone());
80 }
81 }
82 let expiration = duration.as_secs().try_into().unwrap();
83 let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
84 .await
85 .unwrap();
86 cache
87 .set::<Bytes, _, _>(
88 key,
89 buffer.clone(),
90 Some(fred::types::Expiration::EX(expiration)),
91 None,
92 false,
93 )
94 .await
95 .expect("Failed to set!");
96 return modified.body(Body::from(buffer)).unwrap();
97 }
98 response
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 use axum::body::Body;
109 use axum::http::StatusCode;
110 use axum::middleware;
111 use axum::{Router, routing::get};
112 use fred::interfaces::ClientLike;
113 use fred::mocks::SimpleMap;
114 use fred::prelude::Client;
115 use fred::types::Builder;
116 use fred::types::config::Config;
117 use pretty_assertions::assert_eq;
118 use std::sync::{Arc, Mutex};
119 use tower::ServiceExt;
120
121 use axum::extract::{Request, State};
122
123 #[tokio::test]
124 async fn test_axum_layer() -> anyhow::Result<()> {
125 let config = Config {
126 mocks: Some(Arc::new(SimpleMap::new())),
127 ..Default::default()
128 };
129 let valkey = Builder::from_config(config).build()?;
130 valkey.init().await?;
131 let state = Arc::new(TestState::with_valkey(valkey));
132
133 assert_eq!(state.misses(), 0);
134
135 let response = Router::new()
136 .route("/", get(handler))
137 .with_state(state.clone())
138 .layer(middleware::from_fn_with_state(
139 state.clone(),
140 middleware::<Arc<TestState>>,
141 ))
142 .oneshot(Request::builder().uri("/").body(Body::empty())?)
143 .await?;
144 assert_eq!(response.status(), StatusCode::OK);
145 assert_eq!(state.misses(), 1);
146
147 let response = Router::new()
148 .route("/", get(handler))
149 .with_state(state.clone())
150 .layer(middleware::from_fn_with_state(
152 state.clone(),
153 middleware::<Arc<TestState>>,
154 ))
155 .oneshot(Request::builder().uri("/").body(Body::empty())?)
156 .await?;
157 assert_eq!(response.status(), StatusCode::OK);
158 assert_eq!(state.misses(), 1);
160
161 Ok(())
162 }
163
164 #[derive(Clone)]
165 struct TestState {
166 miss: Arc<Mutex<u64>>,
167 valkey: Client,
168 }
169
170 impl TestState {
171 fn with_valkey(valkey: Client) -> Self {
172 Self {
173 miss: Arc::new(Mutex::new(0)),
174 valkey,
175 }
176 }
177 fn incr(&self) {
178 if let Ok(mut h) = self.miss.lock() {
179 (*h) += 1;
180 }
181 }
182
183 fn misses(&self) -> u64 {
184 *(self
185 .miss
186 .lock()
187 .expect("Failed to unlcok the `miss` counter"))
188 }
189 }
190
191 impl CacheState for Arc<TestState> {
192 fn cache(&self) -> Option<impl KeysInterface> {
193 Some(self.valkey.clone())
194 }
195 }
196
197 async fn handler(State(s): State<Arc<TestState>>) {
198 println!("Invoking the handler!");
199 s.incr();
200 }
201}