pub mod strategy;
use axum::body::Body;
use axum::extract::State;
use axum::middleware::Next;
use bytes::Bytes;
use fred::interfaces::KeysInterface;
use http::Method;
use tracing::log::*;
use crate::strategy::{CacheStrategy, RouteKey};
pub trait CacheState {
fn cache(&self) -> Option<impl KeysInterface>;
}
pub async fn middleware<T>(
State(state): State<T>,
request: axum::extract::Request,
next: Next,
) -> axum::response::Response
where
T: CacheState,
{
if request.method() != Method::GET {
trace!("The request was not a GET, short-circuiting caching");
return next.run(request).await;
}
let strategy = RouteKey::default();
let (duration, key) = strategy.computed_key(&request);
match state.cache() {
None => {
info!("No cache configured, request will not be cached");
next.run(request).await
}
Some(cache) => {
debug!("This request should be cached");
match cache.get::<Option<Bytes>, _>(key.clone()).await {
Ok(cached) => {
if let Some(response) = cached {
debug!("cache hit! {response:?}");
return axum::response::Response::new(Body::from(response.clone()));
}
}
Err(err) => {
error!("Failed to query cache! {err:?}");
}
}
trace!("passing middleware along");
let response = next.run(request).await;
trace!("response fetched!");
if response.status().is_success() {
let mut modified = axum::response::Response::builder().status(response.status());
if let Some(headers) = modified.headers_mut() {
for (key, value) in response.headers() {
headers.insert(key, value.clone());
}
}
let expiration = duration.as_secs().try_into().unwrap();
let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
cache
.set::<Bytes, _, _>(
key,
buffer.clone(),
Some(fred::types::Expiration::EX(expiration)),
None,
false,
)
.await
.expect("Failed to set!");
return modified.body(Body::from(buffer)).unwrap();
}
response
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::StatusCode;
use axum::middleware;
use axum::{Router, routing::get};
use fred::interfaces::ClientLike;
use fred::mocks::SimpleMap;
use fred::prelude::Client;
use fred::types::Builder;
use fred::types::config::Config;
use pretty_assertions::assert_eq;
use std::sync::{Arc, Mutex};
use tower::ServiceExt;
use axum::extract::{Request, State};
#[tokio::test]
async fn test_axum_layer() -> anyhow::Result<()> {
let config = Config {
mocks: Some(Arc::new(SimpleMap::new())),
..Default::default()
};
let valkey = Builder::from_config(config).build()?;
valkey.init().await?;
let state = Arc::new(TestState::with_valkey(valkey));
assert_eq!(state.misses(), 0);
let response = Router::new()
.route("/", get(handler))
.with_state(state.clone())
.layer(middleware::from_fn_with_state(
state.clone(),
middleware::<Arc<TestState>>,
))
.oneshot(Request::builder().uri("/").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(state.misses(), 1);
let response = Router::new()
.route("/", get(handler))
.with_state(state.clone())
.layer(middleware::from_fn_with_state(
state.clone(),
middleware::<Arc<TestState>>,
))
.oneshot(Request::builder().uri("/").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(state.misses(), 1);
Ok(())
}
#[derive(Clone)]
struct TestState {
miss: Arc<Mutex<u64>>,
valkey: Client,
}
impl TestState {
fn with_valkey(valkey: Client) -> Self {
Self {
miss: Arc::new(Mutex::new(0)),
valkey,
}
}
fn incr(&self) {
if let Ok(mut h) = self.miss.lock() {
(*h) += 1;
}
}
fn misses(&self) -> u64 {
*(self
.miss
.lock()
.expect("Failed to unlcok the `miss` counter"))
}
}
impl CacheState for Arc<TestState> {
fn cache(&self) -> Option<impl KeysInterface> {
Some(self.valkey.clone())
}
}
async fn handler(State(s): State<Arc<TestState>>) {
println!("Invoking the handler!");
s.incr();
}
}