protograph-core 0.1.0

Core types and SDL parsing for protograph
Documentation
use crate::ast::*;
use crate::directives::*;
use async_graphql_parser::types::{
    BaseType, FieldDefinition, InputValueDefinition, ObjectType, Type, TypeDefinition,
    TypeKind, TypeSystemDefinition,
};
use async_graphql_parser::{parse_schema, Positioned};
use async_graphql_value::ConstValue;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum ParseError {
    #[error("Failed to parse GraphQL schema: {0}")]
    GraphQL(String),
    #[error("Invalid directive argument: {0}")]
    InvalidDirective(String),
}

pub fn parse_schema_file(content: &str) -> Result<ProtographSchema, ParseError> {
    let document =
        parse_schema(content).map_err(|e| ParseError::GraphQL(e.to_string()))?;

    let mut schema = ProtographSchema::default();

    for definition in document.definitions {
        match definition {
            TypeSystemDefinition::Type(type_def) => {
                let type_name = type_def.node.name.node.to_string();
                let directives = &type_def.node.directives;

                match &type_def.node.kind {
                    TypeKind::Object(obj) => {
                        if type_name == "Query" {
                            schema.query_fields = parse_query_fields(obj);
                        } else if type_name == "Mutation" {
                            schema.mutation_fields = parse_mutation_fields(obj);
                        } else {
                            let entity = parse_object_type(&type_name, directives, obj)?;
                            schema.types.insert(entity.name.clone(), entity);
                        }
                    }
                    TypeKind::InputObject(input) => {
                        let input_type = parse_input_type(&type_name, input);
                        schema.input_types.insert(input_type.name.clone(), input_type);
                    }
                    TypeKind::Enum(en) => {
                        let enum_type = EnumType {
                            name: type_name.clone(),
                            values: en
                                .values
                                .iter()
                                .map(|v| v.node.value.node.to_string())
                                .collect(),
                        };
                        schema.enums.insert(enum_type.name.clone(), enum_type);
                    }
                    _ => {}
                }
            }
            _ => {}
        }
    }

    Ok(schema)
}

fn parse_object_type(
    name: &str,
    directives: &[Positioned<async_graphql_parser::types::ConstDirective>],
    obj: &ObjectType,
) -> Result<EntityType, ParseError> {
    let is_entity = has_directive(directives, DIRECTIVE_ENTITY);
    let is_private = has_directive(directives, DIRECTIVE_PRIVATE);

    let mut fields = Vec::new();
    for field_def in &obj.fields {
        let field = parse_field(&field_def.node)?;
        fields.push(field);
    }

    Ok(EntityType {
        name: name.to_string(),
        is_entity,
        is_private,
        fields,
    })
}

fn parse_field(field: &FieldDefinition) -> Result<Field, ParseError> {
    let name = field.name.node.to_string();
    let field_type = convert_type(&field.ty.node);
    let is_private = has_directive(&field.directives, DIRECTIVE_PRIVATE);
    let relationship = parse_relationship_directive(&field.directives)?;

    Ok(Field {
        name,
        field_type,
        is_private,
        relationship,
    })
}

fn parse_input_type(
    name: &str,
    input: &async_graphql_parser::types::InputObjectType,
) -> InputType {
    let fields = input
        .fields
        .iter()
        .map(|f| InputField {
            name: f.node.name.node.to_string(),
            field_type: convert_type(&f.node.ty.node),
        })
        .collect();

    InputType { name: name.to_string(), fields }
}

fn parse_query_fields(obj: &ObjectType) -> Vec<QueryField> {
    obj.fields
        .iter()
        .map(|f| QueryField {
            name: f.node.name.node.to_string(),
            arguments: parse_arguments(&f.node.arguments),
            return_type: convert_type(&f.node.ty.node),
        })
        .collect()
}

fn parse_mutation_fields(obj: &ObjectType) -> Vec<MutationField> {
    obj.fields
        .iter()
        .map(|f| MutationField {
            name: f.node.name.node.to_string(),
            arguments: parse_arguments(&f.node.arguments),
            return_type: convert_type(&f.node.ty.node),
        })
        .collect()
}

fn parse_arguments(args: &[Positioned<InputValueDefinition>]) -> Vec<InputField> {
    args.iter()
        .map(|a| InputField {
            name: a.node.name.node.to_string(),
            field_type: convert_type(&a.node.ty.node),
        })
        .collect()
}

fn convert_type(ty: &Type) -> FieldType {
    match &ty.base {
        BaseType::Named(name) => {
            let base = FieldType::Named(name.to_string());
            if ty.nullable {
                base
            } else {
                FieldType::NonNull(Box::new(base))
            }
        }
        BaseType::List(inner) => {
            let inner_type = convert_type(inner);
            let list = FieldType::List(Box::new(inner_type));
            if ty.nullable {
                list
            } else {
                FieldType::NonNull(Box::new(list))
            }
        }
    }
}

fn has_directive(
    directives: &[Positioned<async_graphql_parser::types::ConstDirective>],
    name: &str,
) -> bool {
    directives.iter().any(|d| d.node.name.node == name)
}

fn parse_relationship_directive(
    directives: &[Positioned<async_graphql_parser::types::ConstDirective>],
) -> Result<Option<Relationship>, ParseError> {
    for directive in directives {
        let name = directive.node.name.node.as_str();
        match name {
            DIRECTIVE_BELONGS_TO => {
                let field = get_directive_arg(&directive.node, ARG_FIELD)?;
                return Ok(Some(Relationship::BelongsTo { foreign_key: field }));
            }
            DIRECTIVE_HAS_MANY => {
                let field = get_directive_arg(&directive.node, ARG_FIELD)?;
                return Ok(Some(Relationship::HasMany { foreign_key: field }));
            }
            DIRECTIVE_MANY_TO_MANY => {
                let through = get_directive_arg(&directive.node, ARG_THROUGH)?;
                let field = get_directive_arg(&directive.node, ARG_FIELD)?;
                return Ok(Some(Relationship::ManyToMany {
                    junction_table: through,
                    foreign_key: field,
                }));
            }
            _ => {}
        }
    }
    Ok(None)
}

fn get_directive_arg(
    directive: &async_graphql_parser::types::ConstDirective,
    arg_name: &str,
) -> Result<String, ParseError> {
    for (name, value) in &directive.arguments {
        if name.node == arg_name {
            if let ConstValue::String(s) = &value.node {
                return Ok(s.clone());
            }
        }
    }
    Err(ParseError::InvalidDirective(format!(
        "Missing required argument '{}' on @{}",
        arg_name, directive.name.node
    )))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_simple_schema() {
        let schema = r#"
            type User @entity {
                id: ID!
                name: String!
                email: String! @private
            }

            type Post @entity {
                id: ID!
                title: String!
                author: User! @belongsTo(field: "authorId")
                authorId: ID! @private
            }

            type Query {
                user(id: ID!): User
                users: [User!]!
            }
        "#;

        let result = parse_schema_file(schema).unwrap();

        assert!(result.types.contains_key("User"));
        assert!(result.types.contains_key("Post"));
        assert!(result.types.get("User").unwrap().is_entity);
        assert_eq!(result.query_fields.len(), 2);

        let post = result.types.get("Post").unwrap();
        let author_field = post.fields.iter().find(|f| f.name == "author").unwrap();
        assert!(matches!(
            &author_field.relationship,
            Some(Relationship::BelongsTo { foreign_key }) if foreign_key == "authorId"
        ));
    }
}