apollo-router 2.13.1

A configurable, high-performance routing runtime for Apollo Federation 🚀
Documentation
use std::collections::HashSet;
use std::sync::Arc;

use futures::FutureExt;
use futures::StreamExt;
use futures::stream;
use itertools::Itertools;
use opentelemetry::StringValue;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use tower::BoxError;
use tracing::Instrument;

use super::plugin::StorageInterface;
use crate::plugins::response_cache::ErrorCode;
use crate::plugins::response_cache::plugin::INTERNAL_CACHE_TAG_PREFIX;
use crate::plugins::response_cache::plugin::RESPONSE_CACHE_VERSION;
use crate::plugins::response_cache::storage;
use crate::plugins::response_cache::storage::CacheStorage;
use crate::plugins::response_cache::storage::redis::Storage;

#[derive(Clone)]
pub(crate) struct Invalidation {
    pub(crate) storage: Arc<StorageInterface>,
}

#[derive(Error, Debug)]
pub(super) enum InvalidationError {
    #[error("error")]
    Misc(#[from] anyhow::Error),
    #[error("caching database error")]
    Storage(#[from] storage::Error),
    #[error("several errors")]
    Errors(#[from] InvalidationErrors),
}

impl ErrorCode for InvalidationError {
    fn code(&self) -> &'static str {
        match &self {
            InvalidationError::Misc(_) => "MISC",
            InvalidationError::Storage(error) => error.code(),
            InvalidationError::Errors(_) => "INVALIDATION_ERRORS",
        }
    }
}

#[derive(Debug)]
pub(crate) struct InvalidationErrors(Vec<InvalidationError>);

impl std::fmt::Display for InvalidationErrors {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "invalidation errors: [{}]",
            self.0.iter().map(|e| e.to_string()).join("; ")
        )
    }
}

impl std::error::Error for InvalidationErrors {}

impl Invalidation {
    pub(crate) async fn new(storage: Arc<StorageInterface>) -> Result<Self, BoxError> {
        Ok(Self { storage })
    }

    pub(crate) async fn invalidate(
        &self,
        requests: Vec<InvalidationRequest>,
    ) -> Result<u64, BoxError> {
        u64_counter_with_unit!(
            "apollo.router.operations.response_cache.invalidation.event",
            "Response cache received a batch of invalidation requests",
            "{request}",
            1u64
        );

        Ok(self
            .handle_request_batch(requests)
            .instrument(tracing::info_span!("cache.invalidation.batch"))
            .await?)
    }

    async fn handle_request(
        &self,
        storage: &Storage,
        request: &InvalidationRequest,
    ) -> Result<u64, InvalidationError> {
        let invalidation_key = request.invalidation_key();
        tracing::debug!(
            "got invalidation request: {request:?}, will invalidate: {}",
            invalidation_key
        );
        let (count, subgraphs) = match request {
            InvalidationRequest::Subgraph { subgraph } => {
                let count = storage
                    .invalidate_by_subgraph(subgraph, request.kind())
                    .await
                    .inspect_err(|err| {
                        u64_counter_with_unit!(
                            "apollo.router.operations.response_cache.invalidation.error",
                            "Errors when invalidating data in cache",
                            "{error}",
                            1,
                            "code" = err.code(),
                            "kind" = "subgraph",
                            "subgraph.name" = subgraph.clone()
                        );
                    })?;
                u64_counter_with_unit!(
                    "apollo.router.operations.response_cache.invalidation.entry",
                    "Response cache counter for invalidated entries",
                    "{entry}",
                    count,
                    "kind" = "subgraph",
                    "subgraph.name" = subgraph.clone()
                );
                (count, vec![subgraph.clone()])
            }
            InvalidationRequest::Type {
                subgraph,
                r#type: graphql_type,
            } => {
                let subgraph_counts = storage
                    .invalidate(
                        vec![invalidation_key],
                        vec![subgraph.clone()],
                        request.kind(),
                    )
                    .await
                    .inspect_err(|err| {
                        u64_counter_with_unit!(
                            "apollo.router.operations.response_cache.invalidation.error",
                            "Errors when invalidating data in cache",
                            "{error}",
                            1,
                            "code" = err.code(),
                            "kind" = "type",
                            "subgraph.name" = subgraph.clone(),
                            "graphql.type" = graphql_type.clone()
                        );
                    })?;
                let mut total_count = 0;
                for (subgraph_name, count) in subgraph_counts {
                    total_count += count;
                    u64_counter_with_unit!(
                        "apollo.router.operations.response_cache.invalidation.entry",
                        "Response cache counter for invalidated entries",
                        "{entry}",
                        count,
                        "kind" = "type",
                        "subgraph.name" = subgraph_name,
                        "graphql.type" = graphql_type.clone()
                    );
                }

                (total_count, vec![subgraph.clone()])
            }
            InvalidationRequest::CacheTag {
                subgraphs,
                cache_tag,
            } => {
                let subgraph_counts = storage
                    .invalidate(
                        vec![cache_tag.clone()],
                        subgraphs.clone().into_iter().collect(),
                        request.kind(),
                    )
                    .await
                    .inspect_err(|err| {
                        let subgraphs: opentelemetry::Array = subgraphs
                            .clone()
                            .into_iter()
                            .map(StringValue::from)
                            .collect::<Vec<StringValue>>()
                            .into();
                        u64_counter_with_unit!(
                            "apollo.router.operations.response_cache.invalidation.error",
                            "Errors when invalidating data in cache",
                            "{error}",
                            1,
                            "code" = err.code(),
                            "kind" = "cache_tag",
                            "subgraph.names" = opentelemetry::Value::Array(subgraphs),
                            "cache.tag" = cache_tag.clone()
                        );
                    })?;
                let mut total_count = 0;
                for (subgraph_name, count) in subgraph_counts {
                    total_count += count;
                    u64_counter_with_unit!(
                        "apollo.router.operations.response_cache.invalidation.entry",
                        "Response cache counter for invalidated entries",
                        "{entry}",
                        count,
                        "kind" = "cache_tag",
                        "subgraph.name" = subgraph_name,
                        "cache.tag" = cache_tag.clone()
                    );
                }

                (
                    total_count,
                    subgraphs.clone().into_iter().collect::<Vec<String>>(),
                )
            }
        };

        for subgraph in subgraphs {
            u64_histogram_with_unit!(
                "apollo.router.operations.response_cache.invalidation.request.entry",
                "Number of invalidated entries per invalidation request.",
                "{entry}",
                count,
                "subgraph.name" = subgraph
            );
        }

        Ok(count)
    }

    async fn handle_request_batch(
        &self,
        requests: Vec<InvalidationRequest>,
    ) -> Result<u64, InvalidationError> {
        let mut count = 0;
        let mut errors = Vec::new();
        let mut futures = Vec::new();
        for request in requests {
            let storages = match &request {
                InvalidationRequest::Subgraph { subgraph }
                | InvalidationRequest::Type { subgraph, .. } => match self.storage.get(subgraph) {
                    Some(s) => vec![s],
                    None => continue,
                },
                InvalidationRequest::CacheTag { subgraphs, .. } => subgraphs
                    .iter()
                    .filter_map(|subgraph| self.storage.get(subgraph))
                    .collect(),
            };

            for storage in storages {
                let request = request.clone();
                let f = async move {
                    self.handle_request(storage, &request)
                        .instrument(tracing::info_span!("cache.invalidation.request"))
                        .await
                };
                futures.push(f.boxed());
            }
        }
        let mut stream: stream::FuturesUnordered<_> = futures.into_iter().collect();
        while let Some(res) = stream.next().await {
            match res {
                Ok(c) => count += c,
                Err(err) => {
                    errors.push(err);
                }
            }
        }

        if !errors.is_empty() {
            Err(InvalidationErrors(errors).into())
        } else {
            Ok(count)
        }
    }
}

pub(super) type InvalidationKind = &'static str;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
pub(crate) enum InvalidationRequest {
    Subgraph {
        subgraph: String,
    },
    Type {
        subgraph: String,
        r#type: String,
    },
    CacheTag {
        subgraphs: HashSet<String>,
        cache_tag: String,
    },
}

impl InvalidationRequest {
    pub(crate) fn subgraph_names(&self) -> Vec<String> {
        match self {
            InvalidationRequest::Subgraph { subgraph }
            | InvalidationRequest::Type { subgraph, .. } => vec![subgraph.clone()],
            InvalidationRequest::CacheTag { subgraphs, .. } => {
                subgraphs.clone().into_iter().collect()
            }
        }
    }
    fn invalidation_key(&self) -> String {
        match self {
            InvalidationRequest::Subgraph { subgraph } => {
                format!("version:{RESPONSE_CACHE_VERSION}:subgraph:{subgraph}",)
            }
            InvalidationRequest::Type { subgraph, r#type } => {
                format!(
                    "{INTERNAL_CACHE_TAG_PREFIX}version:{RESPONSE_CACHE_VERSION}:subgraph:{subgraph}:type:{type}",
                )
            }
            InvalidationRequest::CacheTag { cache_tag, .. } => cache_tag.clone(),
        }
    }

    pub(super) fn kind(&self) -> InvalidationKind {
        match self {
            InvalidationRequest::Subgraph { .. } => "subgraph",
            InvalidationRequest::Type { .. } => "type",
            InvalidationRequest::CacheTag { .. } => "cache_tag",
        }
    }
}