hive-router 0.0.51

GraphQL router/gateway for Federation
use std::sync::Arc;

use graphql_tools::validation::utils::ValidationError;
use hive_router_internal::telemetry::TelemetryContext;
use hive_router_query_planner::planner::plan_nodes::QueryPlan;
use moka::future::Cache;
use moka::Entry;

use crate::pipeline::normalize::GraphQLNormalizationPayload;
use crate::pipeline::parser::ParseCacheEntry;

#[derive(Clone)]
pub struct CacheState {
    pub parse_cache: Cache<u64, ParseCacheEntry>,
    pub validate_cache: Cache<u64, Arc<Vec<ValidationError>>>,
    pub normalize_cache: Cache<u64, Arc<GraphQLNormalizationPayload>>,
    pub plan_cache: Cache<u64, Arc<QueryPlan>>,
}

#[derive(Clone, Copy, Debug)]
pub enum CacheHitMiss {
    Hit,
    Miss,
    Error,
}

impl CacheHitMiss {}

pub trait EntryResultHitMissExt<V, E> {
    fn into_result_with_hit_miss(self, on_hit_miss: impl FnOnce(CacheHitMiss)) -> Result<V, E>;
}

pub trait EntryValueHitMissExt<V> {
    fn into_value_with_hit_miss(self, on_hit_miss: impl FnOnce(CacheHitMiss)) -> V;
}

impl<K, V, E> EntryResultHitMissExt<V, E> for Result<Entry<K, V>, E> {
    fn into_result_with_hit_miss(self, on_hit_miss: impl FnOnce(CacheHitMiss)) -> Result<V, E> {
        match self {
            Ok(entry) => {
                let hit_miss = if entry.is_fresh() {
                    CacheHitMiss::Miss
                } else {
                    CacheHitMiss::Hit
                };
                on_hit_miss(hit_miss);
                Ok(entry.into_value())
            }
            Err(err) => {
                on_hit_miss(CacheHitMiss::Error);
                Err(err)
            }
        }
    }
}

impl<K, V> EntryValueHitMissExt<V> for Entry<K, V> {
    fn into_value_with_hit_miss(self, on_hit_miss: impl FnOnce(CacheHitMiss)) -> V {
        let hit_miss = if self.is_fresh() {
            CacheHitMiss::Miss
        } else {
            CacheHitMiss::Hit
        };
        on_hit_miss(hit_miss);
        self.into_value()
    }
}

impl CacheState {
    pub fn new() -> Self {
        Self {
            parse_cache: Cache::new(1000),
            validate_cache: Cache::new(1000),
            normalize_cache: Cache::new(1000),
            plan_cache: Cache::new(1000),
        }
    }

    pub fn on_schema_change(&self) {
        self.plan_cache.invalidate_all();
        self.validate_cache.invalidate_all();
        self.normalize_cache.invalidate_all();
    }
}

pub fn register_cache_size_observers(
    telemetry_context: Arc<TelemetryContext>,
    cache_state: Arc<CacheState>,
) {
    let metrics = &telemetry_context.metrics.cache;

    let parse_cache = Arc::clone(&cache_state);
    metrics
        .parse
        .observe_size_with(move || parse_cache.parse_cache.entry_count());

    let normalize_cache = Arc::clone(&cache_state);
    metrics
        .normalize
        .observe_size_with(move || normalize_cache.normalize_cache.entry_count());

    let validate_cache = Arc::clone(&cache_state);
    metrics
        .validate
        .observe_size_with(move || validate_cache.validate_cache.entry_count());

    let plan_cache = Arc::clone(&cache_state);
    metrics
        .plan
        .observe_size_with(move || plan_cache.plan_cache.entry_count());
}