nestrs-cache-manager 0.1.0

Rust port of @nestjs/cache-manager backed by moka.
Documentation
//! Port map for upstream `lib/interceptors/cache.interceptor.ts`.

use std::collections::BTreeMap;
use std::sync::Arc;

use serde::Serialize;
use serde_json::Value;

use crate::cache_constants::{CACHE_KEY_METADATA, CACHE_TTL_METADATA};
use crate::cache_providers::{CacheManager, CacheManagerError, CacheValue};
use crate::decorators::{CacheKeyMetadata, CacheTTLMetadata};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Request {
    pub method: String,
    pub url: String,
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Response {
    pub headers: BTreeMap<String, String>,
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct HttpAdapter;

impl HttpAdapter {
    pub fn getRequestMethod(&self, request: &Request) -> String {
        request.method.clone()
    }

    pub fn getRequestUrl(&self, request: &Request) -> String {
        request.url.clone()
    }

    pub fn setHeader(&self, response: &mut Response, name: &str, value: &str) {
        response.headers.insert(name.to_string(), value.to_string());
    }
}

#[derive(Clone, Default)]
pub struct ExecutionContext {
    pub handler: String,
    pub class_name: String,
    pub request: Option<Request>,
    pub response: Option<Response>,
    pub http_adapter: Option<HttpAdapter>,
    cache_key_metadata: Option<CacheKeyMetadata>,
    cache_ttl_metadata: Option<CacheTTLMetadata>,
    class_cache_ttl_metadata: Option<CacheTTLMetadata>,
}

impl ExecutionContext {
    pub fn http(
        handler: impl Into<String>,
        class_name: impl Into<String>,
        request: Request,
    ) -> Self {
        Self {
            handler: handler.into(),
            class_name: class_name.into(),
            request: Some(request),
            response: Some(Response::default()),
            http_adapter: Some(HttpAdapter),
            ..Self::default()
        }
    }

    pub fn set_metadata(&mut self, key: &str, value: Metadata) {
        match (key, value) {
            (CACHE_KEY_METADATA, Metadata::CacheKey(value)) => {
                self.cache_key_metadata = Some(value)
            }
            (CACHE_TTL_METADATA, Metadata::CacheTTL(value)) => {
                self.cache_ttl_metadata = Some(value)
            }
            _ => {}
        }
    }

    pub fn set_class_cache_ttl(&mut self, value: CacheTTLMetadata) {
        self.class_cache_ttl_metadata = Some(value);
    }

    pub fn response(&self) -> Option<&Response> {
        self.response.as_ref()
    }
}

#[derive(Clone)]
pub enum Metadata {
    CacheKey(CacheKeyMetadata),
    CacheTTL(CacheTTLMetadata),
}

#[derive(Debug, Clone, PartialEq)]
pub enum InterceptorResponse {
    Cached(CacheValue),
    Handled(CacheValue),
    StreamableFile,
}

pub type CallHandler =
    Arc<dyn Fn() -> Result<InterceptorResponse, CacheManagerError> + Send + Sync + 'static>;

#[derive(Clone)]
pub struct CacheInterceptor {
    pub cacheManager: CacheManager,
    pub allowedMethods: Vec<String>,
}

impl CacheInterceptor {
    pub fn new(cacheManager: CacheManager) -> Self {
        Self {
            cacheManager,
            allowedMethods: vec!["GET".to_string()],
        }
    }

    pub async fn intercept(
        &self,
        context: &mut ExecutionContext,
        next: CallHandler,
    ) -> Result<InterceptorResponse, CacheManagerError> {
        let key = self.trackBy(context);
        let ttlValueOrFactory = context
            .cache_ttl_metadata
            .clone()
            .or_else(|| context.class_cache_ttl_metadata.clone());

        let Some(key) = key else {
            return next();
        };

        match self.cacheManager.get(&key).await {
            Ok(value) => {
                self.setHeadersWhenHttp(context, value.as_ref());

                if let Some(value) = value {
                    return Ok(InterceptorResponse::Cached(value));
                }

                let response = next()?;
                if !matches!(response, InterceptorResponse::StreamableFile) {
                    let ttl = ttlValueOrFactory.map(|ttl| ttl.resolve(context));
                    if let Some(value) = response_value(&response) {
                        if let Err(err) = self
                            .cacheManager
                            .set_value(
                                &key,
                                value.clone(),
                                ttl.map(std::time::Duration::from_millis),
                            )
                            .await
                        {
                            eprintln!(
                                "An error has occurred when inserting \"key: {key}\", \"value: {value}\""
                            );
                            eprintln!("{err}");
                        }
                    }
                }
                Ok(response)
            }
            Err(_) => next(),
        }
    }

    pub fn trackBy(&self, context: &ExecutionContext) -> Option<String> {
        let isHttpApp = context.http_adapter.is_some();
        let cacheMetadataOrFactory = context.cache_key_metadata.clone();

        if !isHttpApp || cacheMetadataOrFactory.is_some() {
            return cacheMetadataOrFactory.and_then(|metadata| metadata.resolve(context));
        }

        if !self.isRequestCacheable(context) {
            return None;
        }

        let request = context.request.as_ref()?;
        let http_adapter = context.http_adapter.as_ref()?;
        Some(http_adapter.getRequestUrl(request))
    }

    pub fn isRequestCacheable(&self, context: &ExecutionContext) -> bool {
        context
            .request
            .as_ref()
            .map(|req| self.allowedMethods.contains(&req.method))
            .unwrap_or(false)
    }

    pub fn setHeadersWhenHttp(&self, context: &mut ExecutionContext, value: Option<&CacheValue>) {
        let Some(http_adapter) = context.http_adapter.as_ref() else {
            return;
        };
        let Some(response) = context.response.as_mut() else {
            return;
        };
        http_adapter.setHeader(
            response,
            "X-Cache",
            if value.is_some() { "HIT" } else { "MISS" },
        );
    }
}

fn response_value(response: &InterceptorResponse) -> Option<&Value> {
    match response {
        InterceptorResponse::Cached(value) | InterceptorResponse::Handled(value) => Some(value),
        InterceptorResponse::StreamableFile => None,
    }
}

pub fn handled<T: Serialize>(value: T) -> Result<InterceptorResponse, CacheManagerError> {
    Ok(InterceptorResponse::Handled(serde_json::to_value(value)?))
}