Skip to main content

apollo_router/spec/
field_type.rs

1use std::iter::once;
2
3use apollo_compiler::Name;
4use apollo_compiler::schema;
5use serde::Deserialize;
6use serde::Serialize;
7use serde::de::Error as _;
8
9use super::query::parse_hir_value;
10use crate::configuration::mode::Mode;
11use crate::json_ext::Value;
12use crate::json_ext::ValueExt;
13use crate::spec::Schema;
14
15#[derive(Debug)]
16pub(crate) struct InvalidValue;
17
18/// {0}
19#[derive(thiserror::Error, displaydoc::Display, Debug, Clone, Serialize, Eq, PartialEq)]
20pub(crate) struct InvalidInputValue(pub(crate) String);
21
22fn describe_json_value(value: &Value) -> &'static str {
23    match value {
24        Value::Null => "null",
25        Value::Bool(_) => "boolean",
26        Value::Number(_) => "number",
27        Value::String(_) => "string",
28        Value::Array(_) => "array",
29        Value::Object(_) => "map",
30    }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub(crate) struct FieldType(pub(crate) schema::Type);
35
36/// A path within a JSON object that doesn’t need heap allocation in the happy path
37pub(crate) enum JsonValuePath<'a> {
38    Variable {
39        name: &'a str,
40    },
41    ObjectKey {
42        key: &'a str,
43        parent: &'a JsonValuePath<'a>,
44    },
45    ArrayItem {
46        index: usize,
47        parent: &'a JsonValuePath<'a>,
48    },
49}
50
51// schema::Type does not implement Serialize or Deserialize,
52// and <https://serde.rs/remote-derive.html> seems not to work for recursive types.
53// Instead have explicit `impl`s that are based on derived impl of purpose-built types.
54
55impl Serialize for FieldType {
56    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
57    where
58        S: serde::Serializer,
59    {
60        struct BorrowedFieldType<'a>(&'a schema::Type);
61
62        impl Serialize for BorrowedFieldType<'_> {
63            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
64            where
65                S: serde::Serializer,
66            {
67                #[derive(Serialize)]
68                enum NestedBorrowed<'a> {
69                    Named(&'a str),
70                    NonNullNamed(&'a str),
71                    List(BorrowedFieldType<'a>),
72                    NonNullList(BorrowedFieldType<'a>),
73                }
74                match &self.0 {
75                    schema::Type::Named(name) => NestedBorrowed::Named(name),
76                    schema::Type::NonNullNamed(name) => NestedBorrowed::NonNullNamed(name),
77                    schema::Type::List(ty) => NestedBorrowed::List(BorrowedFieldType(ty)),
78                    schema::Type::NonNullList(ty) => {
79                        NestedBorrowed::NonNullList(BorrowedFieldType(ty))
80                    }
81                }
82                .serialize(serializer)
83            }
84        }
85
86        BorrowedFieldType(&self.0).serialize(serializer)
87    }
88}
89
90impl<'de> Deserialize<'de> for FieldType {
91    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92    where
93        D: serde::Deserializer<'de>,
94    {
95        #[derive(Deserialize)]
96        enum WithoutLocation {
97            Named(String),
98            NonNullNamed(String),
99            List(FieldType),
100            NonNullList(FieldType),
101        }
102        Ok(match WithoutLocation::deserialize(deserializer)? {
103            WithoutLocation::Named(name) => FieldType(schema::Type::Named(
104                name.try_into().map_err(D::Error::custom)?,
105            )),
106            WithoutLocation::NonNullNamed(name) => FieldType(
107                schema::Type::Named(name.try_into().map_err(D::Error::custom)?).non_null(),
108            ),
109            WithoutLocation::List(ty) => FieldType(ty.0.list()),
110            WithoutLocation::NonNullList(ty) => FieldType(ty.0.list().non_null()),
111        })
112    }
113}
114
115impl std::fmt::Display for FieldType {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        self.0.fmt(f)
118    }
119}
120
121/// This function currently stops at the first error it finds.
122/// It may be nicer to return a `Vec` of errors, but its size should be limited
123/// in case e.g. every item of a large array is invalid.
124fn validate_input_value(
125    ty: &schema::Type,
126    value: Option<&Value>,
127    schema: &Schema,
128    path: &JsonValuePath<'_>,
129    strict_variable_validation: Mode,
130) -> Result<(), InvalidInputValue> {
131    let fmt_path = |var_path: &JsonValuePath<'_>| match var_path {
132        JsonValuePath::Variable { .. } => format!("variable `{var_path}`"),
133        _ => format!("input value at `{var_path}`"),
134    };
135    let Some(value) = value else {
136        if ty.is_non_null() {
137            return Err(InvalidInputValue(format!(
138                "missing {}: for required GraphQL type `{ty}`",
139                fmt_path(path),
140            )));
141        } else {
142            return Ok(());
143        }
144    };
145    let invalid = || {
146        InvalidInputValue(format!(
147            "invalid {}: found JSON {} for GraphQL type `{ty}`",
148            fmt_path(path),
149            describe_json_value(value)
150        ))
151    };
152    if value.is_null() {
153        if ty.is_non_null() {
154            return Err(invalid());
155        } else {
156            return Ok(());
157        }
158    }
159    let type_name = match ty {
160        schema::Type::Named(name) | schema::Type::NonNullNamed(name) => name,
161        schema::Type::List(inner_type) | schema::Type::NonNullList(inner_type) => {
162            if let Value::Array(vec) = value {
163                for (i, x) in vec.iter().enumerate() {
164                    let path = JsonValuePath::ArrayItem {
165                        index: i,
166                        parent: path,
167                    };
168                    validate_input_value(
169                        inner_type,
170                        Some(x),
171                        schema,
172                        &path,
173                        strict_variable_validation,
174                    )?
175                }
176                return Ok(());
177            } else {
178                // For coercion from single value to list
179                return validate_input_value(
180                    inner_type,
181                    Some(value),
182                    schema,
183                    path,
184                    strict_variable_validation,
185                );
186            }
187        }
188    };
189    let from_bool = |condition| {
190        if condition { Ok(()) } else { Err(invalid()) }
191    };
192    match type_name.as_str() {
193        "String" => return from_bool(value.is_string()),
194        // Spec: https://spec.graphql.org/June2018/#sec-Int
195        "Int" => return from_bool(value.is_valid_int_input()),
196        // Spec: https://spec.graphql.org/draft/#sec-Float.Input-Coercion
197        "Float" => return from_bool(value.is_valid_float_input()),
198        // "The ID scalar type represents a unique identifier, often used to refetch an object
199        // or as the key for a cache. The ID type is serialized in the same way as a String;
200        // however, it is not intended to be human-readable. While it is often numeric, it
201        // should always serialize as a String."
202        //
203        // In practice it seems Int works too
204        "ID" => return from_bool(value.is_valid_id_input()),
205        "Boolean" => return from_bool(value.is_boolean()),
206        _ => {}
207    }
208    let type_def = schema
209        .supergraph_schema()
210        .types
211        .get(type_name)
212        // Should never happen in a valid schema
213        .ok_or_else(invalid)?;
214    match (type_def, value) {
215        // Custom scalar: accept any JSON value
216        (schema::ExtendedType::Scalar(_), _) => Ok(()),
217
218        (schema::ExtendedType::Enum(def), Value::String(s)) => {
219            from_bool(def.values.contains_key(s.as_str()))
220        }
221        (schema::ExtendedType::Enum(_), _) => Err(invalid()),
222
223        (schema::ExtendedType::InputObject(def), Value::Object(obj)) => {
224            // Check for extra/unknown fields in obj vs def
225            let unknown_field = |field_name| {
226                let path_string = JsonValuePath::ObjectKey {
227                    key: field_name,
228                    parent: path,
229                };
230                InvalidInputValue(format!(
231                    "unknown field {} found for GraphQL type `{def}`",
232                    fmt_path(&path_string),
233                ))
234            };
235
236            let mut unknown_input_fields = obj
237                .keys()
238                .map(|k| k.as_str())
239                .filter(|&k| !def.fields.contains_key(k));
240            if let Some(unknown_input_field) = unknown_input_fields.next() {
241                match strict_variable_validation {
242                    Mode::Enforce => {
243                        return Err(unknown_field(unknown_input_field));
244                    }
245                    Mode::Measure => {
246                        let unknown_fields: Vec<&str> = once(unknown_input_field)
247                            .chain(unknown_input_fields)
248                            .collect();
249                        // NB: warning will be attached to the span via trace id, so you can figure out
250                        //  operation name from parent span
251                        tracing::warn!(variables = ?unknown_fields, "encountered unexpected variable(s)");
252                    }
253                }
254            }
255
256            // Validate all fields present on def
257            def.fields.values().try_for_each(|field| {
258                let path = JsonValuePath::ObjectKey {
259                    key: &field.name,
260                    parent: path,
261                };
262                match obj.get(field.name.as_str()) {
263                    Some(&Value::Null) | None => {
264                        let default = field
265                            .default_value
266                            .as_ref()
267                            .and_then(|v| parse_hir_value(v));
268                        validate_input_value(
269                            &field.ty,
270                            default.as_ref(),
271                            schema,
272                            &path,
273                            strict_variable_validation,
274                        )
275                    }
276                    value => validate_input_value(
277                        &field.ty,
278                        value,
279                        schema,
280                        &path,
281                        strict_variable_validation,
282                    ),
283                }
284            })
285        }
286        _ => Err(invalid()),
287    }
288}
289
290impl FieldType {
291    pub(crate) fn new_named(name: Name) -> Self {
292        Self(schema::Type::Named(name))
293    }
294
295    // This function validates input values according to the graphql specification.
296    // Each of the values are validated against the "input coercion" rules.
297    pub(crate) fn validate_input_value(
298        &self,
299        value: Option<&Value>,
300        schema: &Schema,
301        path: &JsonValuePath<'_>,
302        strict_variable_validation: Mode,
303    ) -> Result<(), InvalidInputValue> {
304        validate_input_value(&self.0, value, schema, path, strict_variable_validation)
305    }
306
307    pub(crate) fn is_non_null(&self) -> bool {
308        self.0.is_non_null()
309    }
310}
311
312impl From<&'_ schema::Type> for FieldType {
313    fn from(ty: &'_ schema::Type) -> Self {
314        Self(ty.clone())
315    }
316}
317
318impl std::fmt::Display for JsonValuePath<'_> {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        match self {
321            Self::Variable { name } => {
322                f.write_str("$")?;
323                f.write_str(name)
324            }
325            Self::ObjectKey { key, parent } => {
326                parent.fmt(f)?;
327                f.write_str(".")?;
328                f.write_str(key)
329            }
330            Self::ArrayItem { index, parent } => {
331                parent.fmt(f)?;
332                write!(f, "[{index}]")
333            }
334        }
335    }
336}
337
338/// Make sure custom Serialize and Deserialize impls are compatible with each other
339#[test]
340fn test_field_type_serialization() {
341    let ty = FieldType(apollo_compiler::ty!([ID]!));
342    assert_eq!(
343        serde_json::from_str::<FieldType>(&serde_json::to_string(&ty).unwrap()).unwrap(),
344        ty
345    )
346}