1use std::iter::once;
2
3use apollo_compiler::Name;
4use apollo_compiler::schema;
5use serde::Deserialize;
6use serde::Serialize;
7use serde::de::Error as _;
8
9use super::query::parse_hir_value;
10use crate::configuration::mode::Mode;
11use crate::json_ext::Value;
12use crate::json_ext::ValueExt;
13use crate::spec::Schema;
14
15#[derive(Debug)]
16pub(crate) struct InvalidValue;
17
18#[derive(thiserror::Error, displaydoc::Display, Debug, Clone, Serialize, Eq, PartialEq)]
20pub(crate) struct InvalidInputValue(pub(crate) String);
21
22fn describe_json_value(value: &Value) -> &'static str {
23 match value {
24 Value::Null => "null",
25 Value::Bool(_) => "boolean",
26 Value::Number(_) => "number",
27 Value::String(_) => "string",
28 Value::Array(_) => "array",
29 Value::Object(_) => "map",
30 }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub(crate) struct FieldType(pub(crate) schema::Type);
35
36pub(crate) enum JsonValuePath<'a> {
38 Variable {
39 name: &'a str,
40 },
41 ObjectKey {
42 key: &'a str,
43 parent: &'a JsonValuePath<'a>,
44 },
45 ArrayItem {
46 index: usize,
47 parent: &'a JsonValuePath<'a>,
48 },
49}
50
51impl Serialize for FieldType {
56 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
57 where
58 S: serde::Serializer,
59 {
60 struct BorrowedFieldType<'a>(&'a schema::Type);
61
62 impl Serialize for BorrowedFieldType<'_> {
63 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
64 where
65 S: serde::Serializer,
66 {
67 #[derive(Serialize)]
68 enum NestedBorrowed<'a> {
69 Named(&'a str),
70 NonNullNamed(&'a str),
71 List(BorrowedFieldType<'a>),
72 NonNullList(BorrowedFieldType<'a>),
73 }
74 match &self.0 {
75 schema::Type::Named(name) => NestedBorrowed::Named(name),
76 schema::Type::NonNullNamed(name) => NestedBorrowed::NonNullNamed(name),
77 schema::Type::List(ty) => NestedBorrowed::List(BorrowedFieldType(ty)),
78 schema::Type::NonNullList(ty) => {
79 NestedBorrowed::NonNullList(BorrowedFieldType(ty))
80 }
81 }
82 .serialize(serializer)
83 }
84 }
85
86 BorrowedFieldType(&self.0).serialize(serializer)
87 }
88}
89
90impl<'de> Deserialize<'de> for FieldType {
91 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92 where
93 D: serde::Deserializer<'de>,
94 {
95 #[derive(Deserialize)]
96 enum WithoutLocation {
97 Named(String),
98 NonNullNamed(String),
99 List(FieldType),
100 NonNullList(FieldType),
101 }
102 Ok(match WithoutLocation::deserialize(deserializer)? {
103 WithoutLocation::Named(name) => FieldType(schema::Type::Named(
104 name.try_into().map_err(D::Error::custom)?,
105 )),
106 WithoutLocation::NonNullNamed(name) => FieldType(
107 schema::Type::Named(name.try_into().map_err(D::Error::custom)?).non_null(),
108 ),
109 WithoutLocation::List(ty) => FieldType(ty.0.list()),
110 WithoutLocation::NonNullList(ty) => FieldType(ty.0.list().non_null()),
111 })
112 }
113}
114
115impl std::fmt::Display for FieldType {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 self.0.fmt(f)
118 }
119}
120
121fn validate_input_value(
125 ty: &schema::Type,
126 value: Option<&Value>,
127 schema: &Schema,
128 path: &JsonValuePath<'_>,
129 strict_variable_validation: Mode,
130) -> Result<(), InvalidInputValue> {
131 let fmt_path = |var_path: &JsonValuePath<'_>| match var_path {
132 JsonValuePath::Variable { .. } => format!("variable `{var_path}`"),
133 _ => format!("input value at `{var_path}`"),
134 };
135 let Some(value) = value else {
136 if ty.is_non_null() {
137 return Err(InvalidInputValue(format!(
138 "missing {}: for required GraphQL type `{ty}`",
139 fmt_path(path),
140 )));
141 } else {
142 return Ok(());
143 }
144 };
145 let invalid = || {
146 InvalidInputValue(format!(
147 "invalid {}: found JSON {} for GraphQL type `{ty}`",
148 fmt_path(path),
149 describe_json_value(value)
150 ))
151 };
152 if value.is_null() {
153 if ty.is_non_null() {
154 return Err(invalid());
155 } else {
156 return Ok(());
157 }
158 }
159 let type_name = match ty {
160 schema::Type::Named(name) | schema::Type::NonNullNamed(name) => name,
161 schema::Type::List(inner_type) | schema::Type::NonNullList(inner_type) => {
162 if let Value::Array(vec) = value {
163 for (i, x) in vec.iter().enumerate() {
164 let path = JsonValuePath::ArrayItem {
165 index: i,
166 parent: path,
167 };
168 validate_input_value(
169 inner_type,
170 Some(x),
171 schema,
172 &path,
173 strict_variable_validation,
174 )?
175 }
176 return Ok(());
177 } else {
178 return validate_input_value(
180 inner_type,
181 Some(value),
182 schema,
183 path,
184 strict_variable_validation,
185 );
186 }
187 }
188 };
189 let from_bool = |condition| {
190 if condition { Ok(()) } else { Err(invalid()) }
191 };
192 match type_name.as_str() {
193 "String" => return from_bool(value.is_string()),
194 "Int" => return from_bool(value.is_valid_int_input()),
196 "Float" => return from_bool(value.is_valid_float_input()),
198 "ID" => return from_bool(value.is_valid_id_input()),
205 "Boolean" => return from_bool(value.is_boolean()),
206 _ => {}
207 }
208 let type_def = schema
209 .supergraph_schema()
210 .types
211 .get(type_name)
212 .ok_or_else(invalid)?;
214 match (type_def, value) {
215 (schema::ExtendedType::Scalar(_), _) => Ok(()),
217
218 (schema::ExtendedType::Enum(def), Value::String(s)) => {
219 from_bool(def.values.contains_key(s.as_str()))
220 }
221 (schema::ExtendedType::Enum(_), _) => Err(invalid()),
222
223 (schema::ExtendedType::InputObject(def), Value::Object(obj)) => {
224 let unknown_field = |field_name| {
226 let path_string = JsonValuePath::ObjectKey {
227 key: field_name,
228 parent: path,
229 };
230 InvalidInputValue(format!(
231 "unknown field {} found for GraphQL type `{def}`",
232 fmt_path(&path_string),
233 ))
234 };
235
236 let mut unknown_input_fields = obj
237 .keys()
238 .map(|k| k.as_str())
239 .filter(|&k| !def.fields.contains_key(k));
240 if let Some(unknown_input_field) = unknown_input_fields.next() {
241 match strict_variable_validation {
242 Mode::Enforce => {
243 return Err(unknown_field(unknown_input_field));
244 }
245 Mode::Measure => {
246 let unknown_fields: Vec<&str> = once(unknown_input_field)
247 .chain(unknown_input_fields)
248 .collect();
249 tracing::warn!(variables = ?unknown_fields, "encountered unexpected variable(s)");
252 }
253 }
254 }
255
256 def.fields.values().try_for_each(|field| {
258 let path = JsonValuePath::ObjectKey {
259 key: &field.name,
260 parent: path,
261 };
262 match obj.get(field.name.as_str()) {
263 Some(&Value::Null) | None => {
264 let default = field
265 .default_value
266 .as_ref()
267 .and_then(|v| parse_hir_value(v));
268 validate_input_value(
269 &field.ty,
270 default.as_ref(),
271 schema,
272 &path,
273 strict_variable_validation,
274 )
275 }
276 value => validate_input_value(
277 &field.ty,
278 value,
279 schema,
280 &path,
281 strict_variable_validation,
282 ),
283 }
284 })
285 }
286 _ => Err(invalid()),
287 }
288}
289
290impl FieldType {
291 pub(crate) fn new_named(name: Name) -> Self {
292 Self(schema::Type::Named(name))
293 }
294
295 pub(crate) fn validate_input_value(
298 &self,
299 value: Option<&Value>,
300 schema: &Schema,
301 path: &JsonValuePath<'_>,
302 strict_variable_validation: Mode,
303 ) -> Result<(), InvalidInputValue> {
304 validate_input_value(&self.0, value, schema, path, strict_variable_validation)
305 }
306
307 pub(crate) fn is_non_null(&self) -> bool {
308 self.0.is_non_null()
309 }
310}
311
312impl From<&'_ schema::Type> for FieldType {
313 fn from(ty: &'_ schema::Type) -> Self {
314 Self(ty.clone())
315 }
316}
317
318impl std::fmt::Display for JsonValuePath<'_> {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 match self {
321 Self::Variable { name } => {
322 f.write_str("$")?;
323 f.write_str(name)
324 }
325 Self::ObjectKey { key, parent } => {
326 parent.fmt(f)?;
327 f.write_str(".")?;
328 f.write_str(key)
329 }
330 Self::ArrayItem { index, parent } => {
331 parent.fmt(f)?;
332 write!(f, "[{index}]")
333 }
334 }
335 }
336}
337
338#[test]
340fn test_field_type_serialization() {
341 let ty = FieldType(apollo_compiler::ty!([ID]!));
342 assert_eq!(
343 serde_json::from_str::<FieldType>(&serde_json::to_string(&ty).unwrap()).unwrap(),
344 ty
345 )
346}