1pub mod strategy;
10
11use axum::body::Body;
12use axum::extract::State;
13use axum::middleware::Next;
14use bytes::Bytes;
15use fred::interfaces::KeysInterface;
16use http::Method;
17use tracing::log::*;
18
19use crate::strategy::{CacheStrategy, RouteKey};
20
21pub trait CacheState<T>
24where
25 T: KeysInterface,
26{
27 fn cache(&self) -> Option<T>;
28}
29
30pub async fn middleware<T, P>(
34 State(state): State<T>,
35 request: axum::extract::Request,
36 next: Next,
37) -> axum::response::Response
38where
39 T: CacheState<P>,
40 P: KeysInterface,
41{
42 if request.method() != Method::GET {
43 trace!("The request was not a GET, short-circuiting caching");
44 return next.run(request).await;
45 }
46
47 let strategy = RouteKey::default();
48 let (duration, key) = strategy.computed_key(&request);
49
50 match state.cache() {
51 None => {
52 info!("No cache configured, request will not be cached");
54 next.run(request).await
55 }
56 Some(cache) => {
57 debug!("This request should be cached");
58 match cache.get::<Option<Bytes>, _>(key.clone()).await {
59 Ok(cached) => {
60 if let Some(response) = cached {
61 debug!("cache hit! {response:?}");
62 return axum::response::Response::new(Body::from(response.clone()));
63 }
64 }
65 Err(err) => {
66 error!("Failed to query cache! {err:?}");
67 }
68 }
69
70 trace!("passing middleware along");
71 let response = next.run(request).await;
72 trace!("response fetched!");
73
74 if response.status().is_success() {
75 let mut modified = axum::response::Response::builder().status(response.status());
81 if let Some(headers) = modified.headers_mut() {
82 for (key, value) in response.headers() {
83 headers.insert(key, value.clone());
84 }
85 }
86 let expiration = duration.as_secs().try_into().unwrap();
87 let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
88 .await
89 .unwrap();
90 cache
91 .set::<Bytes, _, _>(
92 key,
93 buffer.clone(),
94 Some(fred::types::Expiration::EX(expiration)),
95 None,
96 false,
97 )
98 .await
99 .expect("Failed to set!");
100 return modified.body(Body::from(buffer)).unwrap();
101 }
102 response
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 use axum::body::Body;
113 use axum::http::StatusCode;
114 use axum::middleware;
115 use axum::{Router, routing::get};
116 use fred::interfaces::ClientLike;
117 use fred::mocks::SimpleMap;
118 use fred::prelude::Client;
119 use fred::types::Builder;
120 use fred::types::config::Config;
121 use pretty_assertions::assert_eq;
122 use std::sync::{Arc, Mutex};
123 use tower::ServiceExt;
124
125 use axum::extract::{Request, State};
126
127 #[tokio::test]
128 async fn test_axum_layer() -> anyhow::Result<()> {
129 let config = Config {
130 mocks: Some(Arc::new(SimpleMap::new())),
131 ..Default::default()
132 };
133 let valkey = Builder::from_config(config).build()?;
134 valkey.init().await?;
135 let state = Arc::new(TestState::with_valkey(valkey));
136
137 assert_eq!(state.misses(), 0);
138
139 let response = Router::new()
140 .route("/", get(handler))
141 .with_state(state.clone())
142 .layer(middleware::from_fn_with_state(
143 state.clone(),
144 middleware::<Arc<TestState<Client>>, Client>,
145 ))
146 .oneshot(Request::builder().uri("/").body(Body::empty())?)
147 .await?;
148 assert_eq!(response.status(), StatusCode::OK);
149 assert_eq!(state.misses(), 1);
150
151 let response = Router::new()
152 .route("/", get(handler))
153 .with_state(state.clone())
154 .layer(middleware::from_fn_with_state(
156 state.clone(),
157 middleware::<Arc<TestState<Client>>, Client>,
158 ))
159 .oneshot(Request::builder().uri("/").body(Body::empty())?)
160 .await?;
161 assert_eq!(response.status(), StatusCode::OK);
162 assert_eq!(state.misses(), 1);
164
165 Ok(())
166 }
167
168 #[derive(Clone)]
169 struct TestState<P>
170 where
171 P: KeysInterface,
172 {
173 miss: Arc<Mutex<u64>>,
174 valkey: P,
175 }
176
177 impl<P> TestState<P>
178 where
179 P: KeysInterface,
180 {
181 fn with_valkey(valkey: P) -> Self {
182 Self {
183 miss: Arc::new(Mutex::new(0)),
184 valkey,
185 }
186 }
187 fn incr(&self) {
188 if let Ok(mut h) = self.miss.lock() {
189 (*h) += 1;
190 }
191 }
192
193 fn misses(&self) -> u64 {
194 *(self
195 .miss
196 .lock()
197 .expect("Failed to unlcok the `miss` counter"))
198 }
199 }
200
201 impl<P> CacheState<P> for Arc<TestState<P>>
202 where
203 P: KeysInterface,
204 {
205 fn cache(&self) -> Option<P> {
206 Some(self.valkey.clone())
207 }
208 }
209
210 async fn handler<P>(State(s): State<Arc<TestState<P>>>)
211 where
212 P: KeysInterface,
213 {
214 println!("Invoking the handler!");
215 s.incr();
216 }
217}