use crate::utils::{pluralize, to_pascal_case, to_snake_case};
use proc_macro2::TokenStream;
use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
use quote::{format_ident, quote};
pub fn generate_rust(schema: &ProtographSchema) -> String {
let traits = generate_traits(schema);
let dataloaders = generate_dataloaders(schema);
let graphql_types = generate_graphql_types(schema);
let input_types = generate_input_types(schema);
let query_type = generate_query_type(schema);
let mutation_type = generate_mutation_type(schema);
let schema_builder = generate_schema_builder(schema);
let output = quote! {
use async_graphql::*;
use async_graphql::dataloader::{DataLoader, Loader};
use std::collections::HashMap;
use std::sync::Arc;
#traits
#dataloaders
#graphql_types
#input_types
#query_type
#mutation_type
#schema_builder
};
output.to_string()
}
fn generate_input_types(schema: &ProtographSchema) -> TokenStream {
let input_types: Vec<TokenStream> = schema
.input_types
.iter()
.map(|(_, input)| {
let name = format_ident!("{}", &input.name);
let fields: Vec<TokenStream> = input
.fields
.iter()
.map(|f| {
let field_name = format_ident!("{}", to_snake_case(&f.name));
let field_type = graphql_type_to_rust(&f.field_type);
quote! { pub #field_name: #field_type }
})
.collect();
quote! {
#[derive(Clone, Debug, InputObject)]
pub struct #name {
#(#fields),*
}
}
})
.collect();
quote! { #(#input_types)* }
}
fn generate_traits(schema: &ProtographSchema) -> TokenStream {
let traits: Vec<TokenStream> = schema
.types
.iter()
.filter(|(_, t)| t.is_entity && !t.is_private)
.map(|(_, entity)| generate_service_trait(entity, schema))
.collect();
quote! {
#(#traits)*
}
}
fn generate_service_trait(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
let name = &entity.name;
let trait_name = format_ident!("{}Service", name);
let entity_type = format_ident!("{}", name);
let plural_name = pluralize(name);
let relationship_methods: Vec<TokenStream> = entity
.fields
.iter()
.filter_map(|f| generate_relationship_method(f, entity, schema))
.collect();
quote! {
#[async_trait::async_trait]
pub trait #trait_name: Send + Sync {
async fn get(&self, id: String) -> Result<Option<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
async fn batch_get(&self, ids: Vec<String>) -> Result<Vec<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
#(#relationship_methods)*
}
}
}
fn generate_relationship_method(
field: &protograph_core::Field,
parent: &EntityType,
schema: &ProtographSchema,
) -> Option<TokenStream> {
match &field.relationship {
Some(Relationship::HasMany { foreign_key }) => {
let related_type_name = field.field_type.base_type();
let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
let entity_type = format_ident!("{}", related_type_name);
let fk_name = format_ident!("{}s", to_snake_case(foreign_key));
Some(quote! {
async fn #method_name(
&self,
#fk_name: Vec<String>
) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
})
}
Some(Relationship::ManyToMany { junction_table, .. }) => {
let related_type_name = field.field_type.base_type();
let method_name = format_ident!("batch_get_via_{}", to_snake_case(junction_table));
let entity_type = format_ident!("{}", related_type_name);
Some(quote! {
async fn #method_name(
&self,
parent_ids: Vec<String>
) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
})
}
_ => None,
}
}
fn generate_dataloaders(schema: &ProtographSchema) -> TokenStream {
let entity_loaders: Vec<TokenStream> = schema
.types
.iter()
.filter(|(_, t)| t.is_entity && !t.is_private)
.map(|(_, entity)| generate_entity_loader(entity))
.collect();
let relationship_loaders: Vec<TokenStream> = schema
.types
.iter()
.filter(|(_, t)| t.is_entity)
.flat_map(|(_, entity)| {
entity
.fields
.iter()
.filter_map(|f| generate_relationship_loader(f, entity, schema))
})
.collect();
quote! {
#(#entity_loaders)*
#(#relationship_loaders)*
}
}
fn generate_entity_loader(entity: &EntityType) -> TokenStream {
let name = &entity.name;
let loader_name = format_ident!("{}Loader", name);
let service_trait = format_ident!("{}Service", name);
let entity_type = format_ident!("{}", name);
quote! {
pub struct #loader_name {
service: Arc<dyn #service_trait>,
}
impl #loader_name {
pub fn new(service: Arc<dyn #service_trait>) -> Self {
Self { service }
}
}
impl Loader<String> for #loader_name {
type Value = #entity_type;
type Error = Arc<dyn std::error::Error + Send + Sync>;
fn load(
&self,
keys: &[String]
) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
let service = self.service.clone();
let keys = keys.to_vec();
async move {
let entities = service.batch_get(keys).await
.map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)?;
Ok(entities.into_iter()
.map(|e| (e.id.clone(), e))
.collect())
}
}
}
}
}
fn generate_relationship_loader(
field: &protograph_core::Field,
parent: &EntityType,
_schema: &ProtographSchema,
) -> Option<TokenStream> {
match &field.relationship {
Some(Relationship::HasMany { foreign_key }) => {
let related_type_name = field.field_type.base_type();
let loader_name = format_ident!(
"{}By{}Loader",
pluralize(related_type_name),
to_pascal_case(foreign_key)
);
let service_trait = format_ident!("{}Service", parent.name);
let entity_type = format_ident!("{}", related_type_name);
let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
Some(quote! {
pub struct #loader_name {
service: Arc<dyn #service_trait>,
}
impl #loader_name {
pub fn new(service: Arc<dyn #service_trait>) -> Self {
Self { service }
}
}
impl Loader<String> for #loader_name {
type Value = Vec<#entity_type>;
type Error = Arc<dyn std::error::Error + Send + Sync>;
fn load(
&self,
keys: &[String]
) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
let service = self.service.clone();
let keys = keys.to_vec();
async move {
service.#method_name(keys).await
.map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)
}
}
}
})
}
_ => None,
}
}
fn generate_graphql_types(schema: &ProtographSchema) -> TokenStream {
let types: Vec<TokenStream> = schema
.types
.iter()
.filter(|(_, t)| !t.is_private)
.map(|(_, entity)| generate_graphql_type(entity, schema))
.collect();
quote! { #(#types)* }
}
fn generate_graphql_type(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
let name = format_ident!("{}", &entity.name);
let scalar_fields: Vec<TokenStream> = entity
.fields
.iter()
.filter(|f| !f.is_private && f.relationship.is_none())
.map(|f| generate_scalar_field(f))
.collect();
let relationship_fields: Vec<TokenStream> = entity
.fields
.iter()
.filter(|f| !f.is_private && f.relationship.is_some())
.map(|f| generate_relationship_field(f, entity, schema))
.collect();
quote! {
#[derive(Clone, Debug)]
pub struct #name {
pub id: String,
inner: HashMap<String, String>,
}
impl #name {
pub fn new(id: String) -> Self {
Self { id, inner: HashMap::new() }
}
pub fn with_field(mut self, key: &str, value: String) -> Self {
self.inner.insert(key.to_string(), value);
self
}
}
#[Object]
impl #name {
async fn id(&self) -> &str {
&self.id
}
#(#scalar_fields)*
#(#relationship_fields)*
}
}
}
fn generate_scalar_field(field: &protograph_core::Field) -> TokenStream {
let field_name = format_ident!("{}", to_snake_case(&field.name));
let graphql_name = &field.name;
let return_type = graphql_type_to_rust(&field.field_type);
if field.name == "id" {
return quote! {};
}
quote! {
#[graphql(name = #graphql_name)]
async fn #field_name(&self) -> #return_type {
self.inner.get(#graphql_name).cloned().unwrap_or_default()
}
}
}
fn generate_relationship_field(
field: &protograph_core::Field,
parent: &EntityType,
schema: &ProtographSchema,
) -> TokenStream {
let field_name = format_ident!("{}", to_snake_case(&field.name));
let graphql_name = &field.name;
match &field.relationship {
Some(Relationship::BelongsTo { foreign_key }) => {
let related_type = format_ident!("{}", field.field_type.base_type());
let loader_name = format_ident!("{}Loader", field.field_type.base_type());
let fk_field = to_snake_case(foreign_key);
quote! {
#[graphql(name = #graphql_name)]
async fn #field_name(&self, ctx: &Context<'_>) -> Result<Option<#related_type>> {
let loader = ctx.data::<DataLoader<#loader_name>>()?;
let fk = self.inner.get(#fk_field).cloned().unwrap_or_default();
loader.load_one(fk).await.map_err(|e| Error::new(e.to_string()))
}
}
}
Some(Relationship::HasMany { foreign_key }) => {
let related_type = format_ident!("{}", field.field_type.base_type());
let loader_name = format_ident!(
"{}By{}Loader",
pluralize(field.field_type.base_type()),
to_pascal_case(foreign_key)
);
quote! {
#[graphql(name = #graphql_name)]
async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
let loader = ctx.data::<DataLoader<#loader_name>>()?;
loader.load_one(self.id.clone()).await
.map_err(|e| Error::new(e.to_string()))?
.ok_or_else(|| Error::new("Not found"))
}
}
}
Some(Relationship::ManyToMany { junction_table, .. }) => {
let related_type = format_ident!("{}", field.field_type.base_type());
let loader_name = format_ident!(
"{}Via{}Loader",
pluralize(field.field_type.base_type()),
junction_table
);
quote! {
#[graphql(name = #graphql_name)]
async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
let loader = ctx.data::<DataLoader<#loader_name>>()?;
loader.load_one(self.id.clone()).await
.map_err(|e| Error::new(e.to_string()))?
.ok_or_else(|| Error::new("Not found"))
}
}
}
None => quote! {},
}
}
fn generate_query_type(schema: &ProtographSchema) -> TokenStream {
let query_methods: Vec<TokenStream> = schema
.query_fields
.iter()
.map(|f| generate_query_method(f, schema))
.collect();
quote! {
pub struct QueryRoot;
#[Object]
impl QueryRoot {
#(#query_methods)*
}
}
}
fn generate_query_method(
field: &protograph_core::QueryField,
schema: &ProtographSchema,
) -> TokenStream {
let method_name = format_ident!("{}", to_snake_case(&field.name));
let graphql_name = &field.name;
let return_type = graphql_type_to_rust(&field.return_type);
let base_type = field.return_type.base_type();
let args: Vec<TokenStream> = field
.arguments
.iter()
.map(|a| {
let arg_name = format_ident!("{}", to_snake_case(&a.name));
let arg_type = graphql_type_to_rust(&a.field_type);
quote! { #arg_name: #arg_type }
})
.collect();
let loader_name = format_ident!("{}Loader", base_type);
if field.return_type.is_list() {
quote! {
#[graphql(name = #graphql_name)]
async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
todo!("Implement query")
}
}
} else {
let id_arg = field.arguments.iter().find(|a| a.name == "id");
if id_arg.is_some() {
quote! {
#[graphql(name = #graphql_name)]
async fn #method_name(&self, ctx: &Context<'_>, id: ID) -> Result<Option<#return_type>> {
let loader = ctx.data::<DataLoader<#loader_name>>()?;
loader.load_one(id.to_string()).await.map_err(|e| Error::new(e.to_string()))
}
}
} else {
quote! {
#[graphql(name = #graphql_name)]
async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
todo!("Implement query")
}
}
}
}
}
fn generate_mutation_type(schema: &ProtographSchema) -> TokenStream {
if schema.mutation_fields.is_empty() {
return quote! {
pub struct MutationRoot;
#[Object]
impl MutationRoot {
async fn _placeholder(&self) -> bool {
true
}
}
};
}
let mutation_methods: Vec<TokenStream> = schema
.mutation_fields
.iter()
.map(|f| generate_mutation_method(f))
.collect();
quote! {
pub struct MutationRoot;
#[Object]
impl MutationRoot {
#(#mutation_methods)*
}
}
}
fn generate_mutation_method(field: &protograph_core::MutationField) -> TokenStream {
let method_name = format_ident!("{}", to_snake_case(&field.name));
let graphql_name = &field.name;
let return_type = graphql_type_to_rust(&field.return_type);
let args: Vec<TokenStream> = field
.arguments
.iter()
.map(|a| {
let arg_name = format_ident!("{}", to_snake_case(&a.name));
let arg_type = graphql_type_to_rust(&a.field_type);
quote! { #arg_name: #arg_type }
})
.collect();
quote! {
#[graphql(name = #graphql_name)]
async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
todo!("Implement mutation")
}
}
}
fn generate_schema_builder(schema: &ProtographSchema) -> TokenStream {
let loader_registrations: Vec<TokenStream> = schema
.types
.iter()
.filter(|(_, t)| t.is_entity && !t.is_private)
.map(|(name, _)| {
let loader_name = format_ident!("{}Loader", name);
let service_trait = format_ident!("{}Service", name);
let method_name = format_ident!("with_{}_loader", to_snake_case(name));
quote! {
pub fn #method_name(mut self, service: Arc<dyn #service_trait>) -> Self {
let loader = DataLoader::new(
#loader_name::new(service),
tokio::spawn
);
self.0 = self.0.data(loader);
self
}
}
})
.collect();
quote! {
pub struct ProtographSchemaBuilder(SchemaBuilder<QueryRoot, MutationRoot, EmptySubscription>);
impl ProtographSchemaBuilder {
pub fn new() -> Self {
Self(Schema::build(QueryRoot, MutationRoot, EmptySubscription))
}
#(#loader_registrations)*
pub fn finish(self) -> Schema<QueryRoot, MutationRoot, EmptySubscription> {
self.0.finish()
}
}
impl Default for ProtographSchemaBuilder {
fn default() -> Self {
Self::new()
}
}
}
}
fn graphql_type_to_rust(gql_type: &FieldType) -> TokenStream {
match gql_type {
FieldType::Named(name) => {
let ident = format_ident!(
"{}",
match name.as_str() {
"ID" => "ID",
"String" => "String",
"Int" => "i32",
"Float" => "f64",
"Boolean" => "bool",
other => other,
}
);
quote! { #ident }
}
FieldType::NonNull(inner) => graphql_type_to_rust(inner),
FieldType::List(inner) => {
let inner_type = graphql_type_to_rust(inner);
quote! { Vec<#inner_type> }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use protograph_core::parse_schema_file;
#[test]
fn test_generate_rust() {
let schema = r#"
type User @entity {
id: ID!
name: String!
posts: [Post!]! @hasMany(field: "authorId")
}
type Post @entity {
id: ID!
title: String!
author: User! @belongsTo(field: "authorId")
authorId: ID! @private
}
type Query {
user(id: ID!): User
users: [User!]!
}
"#;
let parsed = parse_schema_file(schema).unwrap();
let rust = generate_rust(&parsed);
assert!(rust.contains("pub trait UserService"));
assert!(rust.contains("pub trait PostService"));
assert!(rust.contains("pub struct UserLoader"));
assert!(rust.contains("pub struct QueryRoot"));
}
}