axum-cache-fred 0.2.0

Axum middleware for response caching with fred
Documentation
//!
//! The tower-cache-fred middleware is intended to be a route-layer middleware that can cache
//! responses for specific routes such as in a RESTful API.
//!
//! This should not be used to cache routes which require a user-scope or data associated with a
//! cookie/session associated with the request.

/// Different strategies for generating cache keys
pub mod strategy;

use axum::body::Body;
use axum::extract::State;
use axum::middleware::Next;
use bytes::Bytes;
use fred::interfaces::KeysInterface;
use http::Method;
use tracing::log::*;

use crate::strategy::{CacheStrategy, RouteKey};

/// Simple trait which the [axum::extract::State] should implement for using the middleware for
/// cache retrieval
pub trait CacheState {
    fn cache(&self) -> Option<impl KeysInterface>;
}

/// The caching middleware is applied to all the routes necessary which can support caching. This
/// only will cache requests on GET HTTP methods to ensure that data modifications are not cached
/// calls
pub async fn middleware<T>(
    State(state): State<T>,
    request: axum::extract::Request,
    next: Next,
) -> axum::response::Response
where
    T: CacheState,
{
    if request.method() != Method::GET {
        trace!("The request was not a GET, short-circuiting caching");
        return next.run(request).await;
    }

    let strategy = RouteKey::default();
    let (duration, key) = strategy.computed_key(&request);

    match state.cache() {
        None => {
            // When there is no cache, pass straight through!
            info!("No cache configured, request will not be cached");
            next.run(request).await
        }
        Some(cache) => {
            debug!("This request should be cached");
            match cache.get::<Option<Bytes>, _>(key.clone()).await {
                Ok(cached) => {
                    if let Some(response) = cached {
                        debug!("cache hit! {response:?}");
                        return axum::response::Response::new(Body::from(response.clone()));
                    }
                }
                Err(err) => {
                    error!("Failed to query cache! {err:?}");
                }
            }

            trace!("passing middleware along");
            let response = next.run(request).await;
            trace!("response fetched!");

            if response.status().is_success() {
                // cache it!
                //
                // The response be consumed from the [Response] which means this code must retrieve the
                // body and _then_ recreated the [Response] in order to return something sensible to
                // the caller
                let mut modified = axum::response::Response::builder().status(response.status());
                if let Some(headers) = modified.headers_mut() {
                    for (key, value) in response.headers() {
                        headers.insert(key, value.clone());
                    }
                }
                let expiration = duration.as_secs().try_into().unwrap();
                let buffer = axum::body::to_bytes(response.into_body(), usize::MAX)
                    .await
                    .unwrap();
                cache
                    .set::<Bytes, _, _>(
                        key,
                        buffer.clone(),
                        Some(fred::types::Expiration::EX(expiration)),
                        None,
                        false,
                    )
                    .await
                    .expect("Failed to set!");
                return modified.body(Body::from(buffer)).unwrap();
            }
            // All other responses fall through
            response
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use axum::body::Body;
    use axum::http::StatusCode;
    use axum::middleware;
    use axum::{Router, routing::get};
    use fred::interfaces::ClientLike;
    use fred::mocks::SimpleMap;
    use fred::prelude::Client;
    use fred::types::Builder;
    use fred::types::config::Config;
    use pretty_assertions::assert_eq;
    use std::sync::{Arc, Mutex};
    use tower::ServiceExt;

    use axum::extract::{Request, State};

    #[tokio::test]
    async fn test_axum_layer() -> anyhow::Result<()> {
        let config = Config {
            mocks: Some(Arc::new(SimpleMap::new())),
            ..Default::default()
        };
        let valkey = Builder::from_config(config).build()?;
        valkey.init().await?;
        let state = Arc::new(TestState::with_valkey(valkey));

        assert_eq!(state.misses(), 0);

        let response = Router::new()
            .route("/", get(handler))
            .with_state(state.clone())
            .layer(middleware::from_fn_with_state(
                state.clone(),
                middleware::<Arc<TestState>>,
            ))
            .oneshot(Request::builder().uri("/").body(Body::empty())?)
            .await?;
        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(state.misses(), 1);

        let response = Router::new()
            .route("/", get(handler))
            .with_state(state.clone())
            //.layer(CacheLayer::with_valkey(valkey))
            .layer(middleware::from_fn_with_state(
                state.clone(),
                middleware::<Arc<TestState>>,
            ))
            .oneshot(Request::builder().uri("/").body(Body::empty())?)
            .await?;
        assert_eq!(response.status(), StatusCode::OK);
        // The cache shouldn't have missed again
        assert_eq!(state.misses(), 1);

        Ok(())
    }

    #[derive(Clone)]
    struct TestState {
        miss: Arc<Mutex<u64>>,
        valkey: Client,
    }

    impl TestState {
        fn with_valkey(valkey: Client) -> Self {
            Self {
                miss: Arc::new(Mutex::new(0)),
                valkey,
            }
        }
        fn incr(&self) {
            if let Ok(mut h) = self.miss.lock() {
                (*h) += 1;
            }
        }

        fn misses(&self) -> u64 {
            *(self
                .miss
                .lock()
                .expect("Failed to unlcok the `miss` counter"))
        }
    }

    impl CacheState for Arc<TestState> {
        fn cache(&self) -> Option<impl KeysInterface> {
            Some(self.valkey.clone())
        }
    }

    async fn handler(State(s): State<Arc<TestState>>) {
        println!("Invoking the handler!");
        s.incr();
    }
}