use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::Deref;
use hive_router_config::persisted_documents::{
PersistedDocumentExtractorConfig, PersistedDocumentUrlTemplate, PersistedDocumentsConfig,
};
use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams;
use ntex::web::HttpRequest;
use sonic_rs::OwnedLazyValue;
use thiserror::Error;
use crate::pipeline::persisted_documents::extract::extractors::apollo::{
ApolloExtractor, APOLLO_HASH_PATH,
};
use crate::pipeline::persisted_documents::extract::extractors::document_id::{
DocumentIdExtractor, DOCUMENT_ID_FIELD,
};
use crate::pipeline::persisted_documents::extract::extractors::json_path::JsonPathExtractor;
use crate::pipeline::persisted_documents::extract::extractors::url_path_param::UrlPathParamExtractor;
use crate::pipeline::persisted_documents::extract::extractors::url_query_param::{
QueryParams, UrlQueryParamExtractor,
};
use super::super::types::PersistedDocumentId;
pub struct HttpRequestContext<'a> {
pub(crate) path: &'a str,
pub(crate) query: Option<QueryParams<'a>>,
}
pub struct DocumentIdResolverInput<'a> {
pub graphql_params: &'a GraphQLParams,
pub document_id: Option<&'a str>,
pub nonstandard_json_fields: Option<&'a HashMap<String, OwnedLazyValue>>,
pub request_context: &'a HttpRequestContext<'a>,
}
impl<'a> From<&'a HttpRequest> for HttpRequestContext<'a> {
fn from(req: &'a HttpRequest) -> Self {
Self::from_parts(req.uri().path(), req.uri().query())
}
}
impl<'a> HttpRequestContext<'a> {
pub fn from_parts(path: &'a str, query: Option<&'a str>) -> Self {
Self {
path,
query: query.map(QueryParams::new),
}
}
}
pub struct DocumentIdResolver {
graphql_endpoint: GraphQLEndpointPath,
state: ResolverState,
}
#[derive(Debug, Error)]
pub enum PersistedDocumentExtractError {
#[error("url_path_param.template must contain ':id' segment: {template}")]
MissingIdParam { template: String },
#[error("failed to compile url_path_param.template: {0}")]
MatcherCompile(String),
}
enum ResolverState {
Disabled,
Enabled(ActivePlan),
}
struct ActivePlan {
selectors: Vec<Box<dyn DocumentIdSourceExtractor>>,
requires_nonstandard_json_fields: bool,
depends_on_graphql_path: bool,
}
pub(super) trait DocumentIdSourceExtractor: Send + Sync {
fn extract(&self, ctx: &ExtractionContext<'_>) -> Option<PersistedDocumentId>;
}
#[derive(Debug)]
struct GraphQLEndpointPath(String);
impl Deref for GraphQLEndpointPath {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl AsRef<str> for GraphQLEndpointPath {
fn as_ref(&self) -> &str {
&self.0
}
}
pub(super) struct ExtractionContext<'a> {
pub(crate) graphql_params: &'a GraphQLParams,
document_id: Option<&'a str>,
pub(crate) nonstandard_json_fields: Option<&'a HashMap<String, OwnedLazyValue>>,
relative_path: Option<&'a str>,
pub(crate) request_context: &'a HttpRequestContext<'a>,
}
impl<'a> ExtractionContext<'a> {
fn new(input: DocumentIdResolverInput<'a>, graphql_endpoint: &GraphQLEndpointPath) -> Self {
Self {
graphql_params: input.graphql_params,
document_id: input.document_id,
nonstandard_json_fields: input.nonstandard_json_fields,
relative_path: graphql_endpoint.relative_path(input.request_context.path()),
request_context: input.request_context,
}
}
pub(super) fn document_id(&self) -> Option<&'a str> {
self.document_id
}
pub(super) fn relative_path(&self) -> Option<&'a str> {
self.relative_path
}
pub(super) fn query_param(&self, name: &str) -> Option<Cow<'a, str>> {
self.request_context.query_param(name)
}
}
impl DocumentIdResolver {
pub fn from_config(
config: &PersistedDocumentsConfig,
graphql_endpoint: &str,
) -> Result<Self, PersistedDocumentExtractError> {
let graphql_endpoint = GraphQLEndpointPath::from(graphql_endpoint);
if !config.enabled {
return Ok(Self {
graphql_endpoint,
state: ResolverState::Disabled,
});
}
let configured_selectors = match config.selectors.as_ref() {
Some(selectors) => selectors.clone(),
None => PersistedDocumentsConfig::default_selectors(),
};
let mut selectors = Vec::with_capacity(configured_selectors.len());
let mut requires_nonstandard_json_fields = false;
let mut depends_on_graphql_path = false;
for selector_config in &configured_selectors {
let (extractor, requires_nonstandard_fields, depends_on_url_path) =
build_extractor(selector_config)?;
requires_nonstandard_json_fields |= requires_nonstandard_fields;
depends_on_graphql_path |= depends_on_url_path;
selectors.push(extractor);
}
Ok(Self {
graphql_endpoint,
state: ResolverState::Enabled(ActivePlan {
selectors,
requires_nonstandard_json_fields,
depends_on_graphql_path,
}),
})
}
#[inline]
pub fn is_enabled(&self) -> bool {
matches!(self.state, ResolverState::Enabled(_))
}
#[inline]
pub fn requires_nonstandard_json_fields(&self) -> bool {
match &self.state {
ResolverState::Disabled => false,
ResolverState::Enabled(active_plan) => active_plan.requires_nonstandard_json_fields,
}
}
pub fn depends_on_graphql_path(&self) -> bool {
match &self.state {
ResolverState::Disabled => false,
ResolverState::Enabled(active_plan) => active_plan.depends_on_graphql_path,
}
}
pub fn resolve_document_id(
&self,
input: DocumentIdResolverInput<'_>,
) -> Option<PersistedDocumentId> {
let active_plan = match &self.state {
ResolverState::Disabled => return None,
ResolverState::Enabled(active_plan) => active_plan,
};
let ctx = ExtractionContext::new(input, &self.graphql_endpoint);
for selector in &active_plan.selectors {
if let Some(persisted_document_id) = selector.extract(&ctx) {
return Some(persisted_document_id);
}
}
None
}
}
fn build_extractor(
extractor_config: &PersistedDocumentExtractorConfig,
) -> Result<(Box<dyn DocumentIdSourceExtractor>, bool, bool), PersistedDocumentExtractError> {
match extractor_config {
PersistedDocumentExtractorConfig::JsonPath { path } => {
if path.as_str() == DOCUMENT_ID_FIELD {
return Ok((Box::new(DocumentIdExtractor), false, false));
}
if path.as_str() == APOLLO_HASH_PATH {
return Ok((Box::new(ApolloExtractor), false, false));
}
let segments = path
.as_str()
.split('.')
.map(|s| s.to_string())
.collect::<Vec<_>>();
let requires_extra = JsonPathExtractor::requires_nonstandard_json_fields(&segments);
Ok((
Box::new(JsonPathExtractor { segments }),
requires_extra,
false,
))
}
PersistedDocumentExtractorConfig::UrlQueryParam { name } => Ok((
Box::new(UrlQueryParamExtractor {
name: name.as_str().to_string(),
}),
false,
false,
)),
PersistedDocumentExtractorConfig::UrlPathParam { template } => {
let extractor: UrlPathParamExtractor = template.try_into()?;
Ok((Box::new(extractor), false, true))
}
}
}
impl From<&str> for GraphQLEndpointPath {
fn from(endpoint: &str) -> Self {
if endpoint.is_empty() || endpoint == "/" {
return Self("/".to_string());
}
let with_leading_slash = if endpoint.starts_with('/') {
endpoint.to_string()
} else {
format!("/{endpoint}")
};
Self(with_leading_slash.trim_end_matches('/').to_string())
}
}
impl GraphQLEndpointPath {
fn relative_path<'a>(&self, request_path: &'a str) -> Option<&'a str> {
let suffix = if self.as_ref() == "/" {
request_path
} else {
let suffix = request_path.strip_prefix(self.as_ref())?;
if !suffix.is_empty() && !suffix.starts_with('/') {
return None;
}
suffix
};
Some(suffix)
}
}
impl TryFrom<&PersistedDocumentUrlTemplate> for UrlPathParamExtractor {
type Error = PersistedDocumentExtractError;
fn try_from(template: &PersistedDocumentUrlTemplate) -> Result<Self, Self::Error> {
UrlPathParamExtractor::try_from_template(template)
}
}