use std::iter::once;
use apollo_compiler::Name;
use apollo_compiler::schema;
use serde::Deserialize;
use serde::Serialize;
use serde::de::Error as _;
use super::query::parse_hir_value;
use crate::configuration::mode::Mode;
use crate::json_ext::Value;
use crate::json_ext::ValueExt;
use crate::spec::Schema;
#[derive(Debug)]
pub(crate) struct InvalidValue;
#[derive(thiserror::Error, displaydoc::Display, Debug, Clone, Serialize, Eq, PartialEq)]
pub(crate) struct InvalidInputValue(pub(crate) String);
fn describe_json_value(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "boolean",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "map",
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FieldType(pub(crate) schema::Type);
pub(crate) enum JsonValuePath<'a> {
Variable {
name: &'a str,
},
ObjectKey {
key: &'a str,
parent: &'a JsonValuePath<'a>,
},
ArrayItem {
index: usize,
parent: &'a JsonValuePath<'a>,
},
}
impl Serialize for FieldType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct BorrowedFieldType<'a>(&'a schema::Type);
impl Serialize for BorrowedFieldType<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
enum NestedBorrowed<'a> {
Named(&'a str),
NonNullNamed(&'a str),
List(BorrowedFieldType<'a>),
NonNullList(BorrowedFieldType<'a>),
}
match &self.0 {
schema::Type::Named(name) => NestedBorrowed::Named(name),
schema::Type::NonNullNamed(name) => NestedBorrowed::NonNullNamed(name),
schema::Type::List(ty) => NestedBorrowed::List(BorrowedFieldType(ty)),
schema::Type::NonNullList(ty) => {
NestedBorrowed::NonNullList(BorrowedFieldType(ty))
}
}
.serialize(serializer)
}
}
BorrowedFieldType(&self.0).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for FieldType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
enum WithoutLocation {
Named(String),
NonNullNamed(String),
List(FieldType),
NonNullList(FieldType),
}
Ok(match WithoutLocation::deserialize(deserializer)? {
WithoutLocation::Named(name) => FieldType(schema::Type::Named(
name.try_into().map_err(D::Error::custom)?,
)),
WithoutLocation::NonNullNamed(name) => FieldType(
schema::Type::Named(name.try_into().map_err(D::Error::custom)?).non_null(),
),
WithoutLocation::List(ty) => FieldType(ty.0.list()),
WithoutLocation::NonNullList(ty) => FieldType(ty.0.list().non_null()),
})
}
}
impl std::fmt::Display for FieldType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
fn validate_input_value(
ty: &schema::Type,
value: Option<&Value>,
schema: &Schema,
path: &JsonValuePath<'_>,
strict_variable_validation: Mode,
) -> Result<(), InvalidInputValue> {
let fmt_path = |var_path: &JsonValuePath<'_>| match var_path {
JsonValuePath::Variable { .. } => format!("variable `{var_path}`"),
_ => format!("input value at `{var_path}`"),
};
let Some(value) = value else {
if ty.is_non_null() {
return Err(InvalidInputValue(format!(
"missing {}: for required GraphQL type `{ty}`",
fmt_path(path),
)));
} else {
return Ok(());
}
};
let invalid = || {
InvalidInputValue(format!(
"invalid {}: found JSON {} for GraphQL type `{ty}`",
fmt_path(path),
describe_json_value(value)
))
};
if value.is_null() {
if ty.is_non_null() {
return Err(invalid());
} else {
return Ok(());
}
}
let type_name = match ty {
schema::Type::Named(name) | schema::Type::NonNullNamed(name) => name,
schema::Type::List(inner_type) | schema::Type::NonNullList(inner_type) => {
if let Value::Array(vec) = value {
for (i, x) in vec.iter().enumerate() {
let path = JsonValuePath::ArrayItem {
index: i,
parent: path,
};
validate_input_value(
inner_type,
Some(x),
schema,
&path,
strict_variable_validation,
)?
}
return Ok(());
} else {
return validate_input_value(
inner_type,
Some(value),
schema,
path,
strict_variable_validation,
);
}
}
};
let from_bool = |condition| {
if condition { Ok(()) } else { Err(invalid()) }
};
match type_name.as_str() {
"String" => return from_bool(value.is_string()),
"Int" => return from_bool(value.is_valid_int_input()),
"Float" => return from_bool(value.is_valid_float_input()),
"ID" => return from_bool(value.is_valid_id_input()),
"Boolean" => return from_bool(value.is_boolean()),
_ => {}
}
let type_def = schema
.supergraph_schema()
.types
.get(type_name)
.ok_or_else(invalid)?;
match (type_def, value) {
(schema::ExtendedType::Scalar(_), _) => Ok(()),
(schema::ExtendedType::Enum(def), Value::String(s)) => {
from_bool(def.values.contains_key(s.as_str()))
}
(schema::ExtendedType::Enum(_), _) => Err(invalid()),
(schema::ExtendedType::InputObject(def), Value::Object(obj)) => {
let unknown_field = |field_name| {
let path_string = JsonValuePath::ObjectKey {
key: field_name,
parent: path,
};
InvalidInputValue(format!(
"unknown field {} found for GraphQL type `{def}`",
fmt_path(&path_string),
))
};
let mut unknown_input_fields = obj
.keys()
.map(|k| k.as_str())
.filter(|&k| !def.fields.contains_key(k));
if let Some(unknown_input_field) = unknown_input_fields.next() {
match strict_variable_validation {
Mode::Enforce => {
return Err(unknown_field(unknown_input_field));
}
Mode::Measure => {
let unknown_fields: Vec<&str> = once(unknown_input_field)
.chain(unknown_input_fields)
.collect();
tracing::warn!(variables = ?unknown_fields, "encountered unexpected variable(s)");
}
}
}
def.fields.values().try_for_each(|field| {
let path = JsonValuePath::ObjectKey {
key: &field.name,
parent: path,
};
match obj.get(field.name.as_str()) {
Some(&Value::Null) | None => {
let default = field
.default_value
.as_ref()
.and_then(|v| parse_hir_value(v));
validate_input_value(
&field.ty,
default.as_ref(),
schema,
&path,
strict_variable_validation,
)
}
value => validate_input_value(
&field.ty,
value,
schema,
&path,
strict_variable_validation,
),
}
})
}
_ => Err(invalid()),
}
}
impl FieldType {
pub(crate) fn new_named(name: Name) -> Self {
Self(schema::Type::Named(name))
}
pub(crate) fn validate_input_value(
&self,
value: Option<&Value>,
schema: &Schema,
path: &JsonValuePath<'_>,
strict_variable_validation: Mode,
) -> Result<(), InvalidInputValue> {
validate_input_value(&self.0, value, schema, path, strict_variable_validation)
}
pub(crate) fn is_non_null(&self) -> bool {
self.0.is_non_null()
}
}
impl From<&'_ schema::Type> for FieldType {
fn from(ty: &'_ schema::Type) -> Self {
Self(ty.clone())
}
}
impl std::fmt::Display for JsonValuePath<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Variable { name } => {
f.write_str("$")?;
f.write_str(name)
}
Self::ObjectKey { key, parent } => {
parent.fmt(f)?;
f.write_str(".")?;
f.write_str(key)
}
Self::ArrayItem { index, parent } => {
parent.fmt(f)?;
write!(f, "[{index}]")
}
}
}
}
#[test]
fn test_field_type_serialization() {
let ty = FieldType(apollo_compiler::ty!([ID]!));
assert_eq!(
serde_json::from_str::<FieldType>(&serde_json::to_string(&ty).unwrap()).unwrap(),
ty
)
}