cynic-codegen 3.13.2

Procedural macro code generation for cynic - a code first GraphQL client for Rust
Documentation
use std::collections::HashSet;

use proc_macro2::Span;

use crate::{
    RenameAll,
    error::Errors,
    idents::RenamableFieldIdent,
    input_object_derive::input::{InputObjectDeriveField, OneOfDeriveVariant},
    schema::types::{InputObjectType, InputValue},
    suggestions::{FieldSuggestionError, guess_field},
};

pub(super) trait PairingCandidate {
    fn graphql_ident(&self, rename: RenameAll) -> RenamableFieldIdent;
}

impl PairingCandidate for InputObjectDeriveField {
    fn graphql_ident(&self, rename_rule: RenameAll) -> RenamableFieldIdent {
        InputObjectDeriveField::graphql_ident(self, rename_rule)
    }
}

impl PairingCandidate for OneOfDeriveVariant {
    fn graphql_ident(&self, rename_rule: RenameAll) -> RenamableFieldIdent {
        OneOfDeriveVariant::graphql_ident(self, rename_rule)
    }
}

pub(super) fn pair_fields<'a, T>(
    fields: &'a [T],
    input_object_def: InputObjectType<'a>,
    rename_all: RenameAll,
    require_all_fields: bool,
    struct_span: &Span,
) -> Result<Vec<(&'a T, InputValue<'a>)>, Errors>
where
    T: PairingCandidate,
{
    let mut result = Vec::new();
    let mut unknown_fields = Vec::new();

    for field in fields {
        let ident = field.graphql_ident(rename_all);
        match input_object_def.field(&ident) {
            Some(schema_field) => result.push((field, schema_field)),
            None => unknown_fields.push(field),
        }
    }

    let required_fields = if require_all_fields {
        input_object_def.fields.iter().collect::<HashSet<_>>()
    } else {
        input_object_def
            .fields
            .iter()
            .filter(|f| f.is_required())
            .collect::<HashSet<_>>()
    };

    let provided_fields = result
        .iter()
        .map(|(_, field)| field)
        .cloned()
        .collect::<HashSet<_>>();

    let missing_fields = required_fields
        .difference(&provided_fields)
        .collect::<Vec<_>>();

    if missing_fields.is_empty() && unknown_fields.is_empty() {
        return Ok(result.into_iter().map(|(l, r)| (l, r.clone())).collect());
    }

    let field_candidates = input_object_def
        .fields
        .iter()
        .map(|f| f.name.as_str())
        .collect::<Vec<_>>();

    let mut errors = unknown_fields
        .into_iter()
        .map(|field| {
            let field_name = &field.graphql_ident(rename_all);
            let graphql_name = field_name.graphql_name();
            let expected_field = graphql_name.as_str();
            let suggested_field = guess_field(field_candidates.iter().copied(), expected_field);
            syn::Error::new(
                field_name.span(),
                FieldSuggestionError {
                    expected_field,
                    graphql_type_name: input_object_def.name.as_ref(),
                    suggested_field,
                },
            )
        })
        .map(Errors::from)
        .collect::<Errors>();

    if !missing_fields.is_empty() {
        let missing_fields_string = missing_fields
            .into_iter()
            .map(|f| f.name.as_str().to_string())
            .collect::<Vec<_>>()
            .join(", ");

        errors.push(syn::Error::new(
            *struct_span,
            format!(
                "This InputObject is missing these fields: {}",
                missing_fields_string
            ),
        ))
    }

    Err(errors)
}

#[cfg(test)]
mod test {
    use assert_matches::assert_matches;

    use crate::{
        input_object_derive::input::InputObjectDeriveField,
        schema::{Schema, SchemaInput},
    };

    use super::*;

    static SCHEMA: &str = r#"
        input TestType {
            field_one: String!,
            field_two: String
        }
        "#;

    #[test]
    fn test_join_fields_when_all_required() {
        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
        let input_object = schema.lookup("TestType").unwrap();

        let fields = vec![InputObjectDeriveField {
            ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
            ty: syn::parse_quote! { String },
            rename: None,
            skip_serializing_if: None,
        }];

        let result = pair_fields(
            &fields,
            input_object,
            RenameAll::None,
            true,
            &Span::call_site(),
        );

        assert_matches!(result, Err(_))
    }

    #[test]
    fn test_join_fields_when_required_field_missing() {
        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
        let input_object = schema.lookup("TestType").unwrap();

        let fields = vec![InputObjectDeriveField {
            ident: Some(proc_macro2::Ident::new("field_two", Span::call_site())),
            ty: syn::parse_quote! { String },
            rename: None,
            skip_serializing_if: None,
        }];

        let result = pair_fields(
            &fields,
            input_object,
            RenameAll::None,
            false,
            &Span::call_site(),
        );

        assert_matches!(result, Err(_))
    }

    #[test]
    fn test_join_fields_when_not_required() {
        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
        let input_object = schema.lookup::<InputObjectType<'_>>("TestType").unwrap();

        let fields = vec![InputObjectDeriveField {
            ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
            ty: syn::parse_quote! { String },
            rename: None,
            skip_serializing_if: None,
        }];

        let result = pair_fields(
            &fields,
            input_object.clone(),
            RenameAll::None,
            false,
            &Span::call_site(),
        );

        assert_matches!(result, Ok(_));

        let (rust_field_ref, input_field_ref) = result.unwrap().into_iter().next().unwrap();
        assert!(std::ptr::eq(rust_field_ref, fields.first().unwrap()));
        assert_eq!(&input_field_ref, input_object.fields.first().unwrap());
    }
}