protograph_codegen/
rust_gen.rs

1use crate::utils::{pluralize, to_pascal_case, to_snake_case};
2use proc_macro2::TokenStream;
3use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
4use quote::{format_ident, quote};
5
6pub fn generate_rust(schema: &ProtographSchema) -> String {
7    let traits = generate_traits(schema);
8    let dataloaders = generate_dataloaders(schema);
9    let graphql_types = generate_graphql_types(schema);
10    let input_types = generate_input_types(schema);
11    let query_type = generate_query_type(schema);
12    let mutation_type = generate_mutation_type(schema);
13    let schema_builder = generate_schema_builder(schema);
14
15    let output = quote! {
16        use async_graphql::*;
17        use async_graphql::dataloader::{DataLoader, Loader};
18        use std::collections::HashMap;
19        use std::sync::Arc;
20
21        #traits
22        #dataloaders
23        #graphql_types
24        #input_types
25        #query_type
26        #mutation_type
27        #schema_builder
28    };
29
30    output.to_string()
31}
32
33fn generate_input_types(schema: &ProtographSchema) -> TokenStream {
34    let input_types: Vec<TokenStream> = schema
35        .input_types
36        .iter()
37        .map(|(_, input)| {
38            let name = format_ident!("{}", &input.name);
39            let fields: Vec<TokenStream> = input
40                .fields
41                .iter()
42                .map(|f| {
43                    let field_name = format_ident!("{}", to_snake_case(&f.name));
44                    let field_type = graphql_type_to_rust(&f.field_type);
45                    quote! { pub #field_name: #field_type }
46                })
47                .collect();
48
49            quote! {
50                #[derive(Clone, Debug, InputObject)]
51                pub struct #name {
52                    #(#fields),*
53                }
54            }
55        })
56        .collect();
57
58    quote! { #(#input_types)* }
59}
60
61fn generate_traits(schema: &ProtographSchema) -> TokenStream {
62    let traits: Vec<TokenStream> = schema
63        .types
64        .iter()
65        .filter(|(_, t)| t.is_entity && !t.is_private)
66        .map(|(_, entity)| generate_service_trait(entity, schema))
67        .collect();
68
69    quote! {
70        #(#traits)*
71    }
72}
73
74fn generate_service_trait(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
75    let name = &entity.name;
76    let trait_name = format_ident!("{}Service", name);
77    let entity_type = format_ident!("{}", name);
78    let plural_name = pluralize(name);
79
80    let relationship_methods: Vec<TokenStream> = entity
81        .fields
82        .iter()
83        .filter_map(|f| generate_relationship_method(f, entity, schema))
84        .collect();
85
86    quote! {
87        #[async_trait::async_trait]
88        pub trait #trait_name: Send + Sync {
89            async fn get(&self, id: String) -> Result<Option<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
90
91            async fn batch_get(&self, ids: Vec<String>) -> Result<Vec<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
92
93            #(#relationship_methods)*
94        }
95    }
96}
97
98fn generate_relationship_method(
99    field: &protograph_core::Field,
100    parent: &EntityType,
101    schema: &ProtographSchema,
102) -> Option<TokenStream> {
103    match &field.relationship {
104        Some(Relationship::HasMany { foreign_key }) => {
105            let related_type_name = field.field_type.base_type();
106            let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
107            let entity_type = format_ident!("{}", related_type_name);
108            let fk_name = format_ident!("{}s", to_snake_case(foreign_key));
109
110            Some(quote! {
111                async fn #method_name(
112                    &self,
113                    #fk_name: Vec<String>
114                ) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
115            })
116        }
117        Some(Relationship::ManyToMany { junction_table, .. }) => {
118            let related_type_name = field.field_type.base_type();
119            let method_name = format_ident!("batch_get_via_{}", to_snake_case(junction_table));
120            let entity_type = format_ident!("{}", related_type_name);
121
122            Some(quote! {
123                async fn #method_name(
124                    &self,
125                    parent_ids: Vec<String>
126                ) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
127            })
128        }
129        _ => None,
130    }
131}
132
133fn generate_dataloaders(schema: &ProtographSchema) -> TokenStream {
134    let entity_loaders: Vec<TokenStream> = schema
135        .types
136        .iter()
137        .filter(|(_, t)| t.is_entity && !t.is_private)
138        .map(|(_, entity)| generate_entity_loader(entity))
139        .collect();
140
141    let relationship_loaders: Vec<TokenStream> = schema
142        .types
143        .iter()
144        .filter(|(_, t)| t.is_entity)
145        .flat_map(|(_, entity)| {
146            entity
147                .fields
148                .iter()
149                .filter_map(|f| generate_relationship_loader(f, entity, schema))
150        })
151        .collect();
152
153    quote! {
154        #(#entity_loaders)*
155        #(#relationship_loaders)*
156    }
157}
158
159fn generate_entity_loader(entity: &EntityType) -> TokenStream {
160    let name = &entity.name;
161    let loader_name = format_ident!("{}Loader", name);
162    let service_trait = format_ident!("{}Service", name);
163    let entity_type = format_ident!("{}", name);
164
165    quote! {
166        pub struct #loader_name {
167            service: Arc<dyn #service_trait>,
168        }
169
170        impl #loader_name {
171            pub fn new(service: Arc<dyn #service_trait>) -> Self {
172                Self { service }
173            }
174        }
175
176        impl Loader<String> for #loader_name {
177            type Value = #entity_type;
178            type Error = Arc<dyn std::error::Error + Send + Sync>;
179
180            fn load(
181                &self,
182                keys: &[String]
183            ) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
184                let service = self.service.clone();
185                let keys = keys.to_vec();
186                async move {
187                    let entities = service.batch_get(keys).await
188                        .map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)?;
189
190                    Ok(entities.into_iter()
191                        .map(|e| (e.id.clone(), e))
192                        .collect())
193                }
194            }
195        }
196    }
197}
198
199fn generate_relationship_loader(
200    field: &protograph_core::Field,
201    parent: &EntityType,
202    _schema: &ProtographSchema,
203) -> Option<TokenStream> {
204    match &field.relationship {
205        Some(Relationship::HasMany { foreign_key }) => {
206            let related_type_name = field.field_type.base_type();
207            let loader_name = format_ident!(
208                "{}By{}Loader",
209                pluralize(related_type_name),
210                to_pascal_case(foreign_key)
211            );
212            let service_trait = format_ident!("{}Service", parent.name);
213            let entity_type = format_ident!("{}", related_type_name);
214            let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
215
216            Some(quote! {
217                pub struct #loader_name {
218                    service: Arc<dyn #service_trait>,
219                }
220
221                impl #loader_name {
222                    pub fn new(service: Arc<dyn #service_trait>) -> Self {
223                        Self { service }
224                    }
225                }
226
227                impl Loader<String> for #loader_name {
228                    type Value = Vec<#entity_type>;
229                    type Error = Arc<dyn std::error::Error + Send + Sync>;
230
231                    fn load(
232                        &self,
233                        keys: &[String]
234                    ) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
235                        let service = self.service.clone();
236                        let keys = keys.to_vec();
237                        async move {
238                            service.#method_name(keys).await
239                                .map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)
240                        }
241                    }
242                }
243            })
244        }
245        _ => None,
246    }
247}
248
249fn generate_graphql_types(schema: &ProtographSchema) -> TokenStream {
250    let types: Vec<TokenStream> = schema
251        .types
252        .iter()
253        .filter(|(_, t)| !t.is_private)
254        .map(|(_, entity)| generate_graphql_type(entity, schema))
255        .collect();
256
257    quote! { #(#types)* }
258}
259
260fn generate_graphql_type(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
261    let name = format_ident!("{}", &entity.name);
262
263    let scalar_fields: Vec<TokenStream> = entity
264        .fields
265        .iter()
266        .filter(|f| !f.is_private && f.relationship.is_none())
267        .map(|f| generate_scalar_field(f))
268        .collect();
269
270    let relationship_fields: Vec<TokenStream> = entity
271        .fields
272        .iter()
273        .filter(|f| !f.is_private && f.relationship.is_some())
274        .map(|f| generate_relationship_field(f, entity, schema))
275        .collect();
276
277    quote! {
278        #[derive(Clone, Debug)]
279        pub struct #name {
280            pub id: String,
281            inner: HashMap<String, String>,
282        }
283
284        impl #name {
285            pub fn new(id: String) -> Self {
286                Self { id, inner: HashMap::new() }
287            }
288
289            pub fn with_field(mut self, key: &str, value: String) -> Self {
290                self.inner.insert(key.to_string(), value);
291                self
292            }
293        }
294
295        #[Object]
296        impl #name {
297            async fn id(&self) -> &str {
298                &self.id
299            }
300
301            #(#scalar_fields)*
302            #(#relationship_fields)*
303        }
304    }
305}
306
307fn generate_scalar_field(field: &protograph_core::Field) -> TokenStream {
308    let field_name = format_ident!("{}", to_snake_case(&field.name));
309    let graphql_name = &field.name;
310    let return_type = graphql_type_to_rust(&field.field_type);
311
312    if field.name == "id" {
313        return quote! {};
314    }
315
316    quote! {
317        #[graphql(name = #graphql_name)]
318        async fn #field_name(&self) -> #return_type {
319            self.inner.get(#graphql_name).cloned().unwrap_or_default()
320        }
321    }
322}
323
324fn generate_relationship_field(
325    field: &protograph_core::Field,
326    parent: &EntityType,
327    schema: &ProtographSchema,
328) -> TokenStream {
329    let field_name = format_ident!("{}", to_snake_case(&field.name));
330    let graphql_name = &field.name;
331
332    match &field.relationship {
333        Some(Relationship::BelongsTo { foreign_key }) => {
334            let related_type = format_ident!("{}", field.field_type.base_type());
335            let loader_name = format_ident!("{}Loader", field.field_type.base_type());
336            let fk_field = to_snake_case(foreign_key);
337
338            quote! {
339                #[graphql(name = #graphql_name)]
340                async fn #field_name(&self, ctx: &Context<'_>) -> Result<Option<#related_type>> {
341                    let loader = ctx.data::<DataLoader<#loader_name>>()?;
342                    let fk = self.inner.get(#fk_field).cloned().unwrap_or_default();
343                    loader.load_one(fk).await.map_err(|e| Error::new(e.to_string()))
344                }
345            }
346        }
347        Some(Relationship::HasMany { foreign_key }) => {
348            let related_type = format_ident!("{}", field.field_type.base_type());
349            let loader_name = format_ident!(
350                "{}By{}Loader",
351                pluralize(field.field_type.base_type()),
352                to_pascal_case(foreign_key)
353            );
354
355            quote! {
356                #[graphql(name = #graphql_name)]
357                async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
358                    let loader = ctx.data::<DataLoader<#loader_name>>()?;
359                    loader.load_one(self.id.clone()).await
360                        .map_err(|e| Error::new(e.to_string()))?
361                        .ok_or_else(|| Error::new("Not found"))
362                }
363            }
364        }
365        Some(Relationship::ManyToMany { junction_table, .. }) => {
366            let related_type = format_ident!("{}", field.field_type.base_type());
367            let loader_name = format_ident!(
368                "{}Via{}Loader",
369                pluralize(field.field_type.base_type()),
370                junction_table
371            );
372
373            quote! {
374                #[graphql(name = #graphql_name)]
375                async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
376                    let loader = ctx.data::<DataLoader<#loader_name>>()?;
377                    loader.load_one(self.id.clone()).await
378                        .map_err(|e| Error::new(e.to_string()))?
379                        .ok_or_else(|| Error::new("Not found"))
380                }
381            }
382        }
383        None => quote! {},
384    }
385}
386
387fn generate_query_type(schema: &ProtographSchema) -> TokenStream {
388    let query_methods: Vec<TokenStream> = schema
389        .query_fields
390        .iter()
391        .map(|f| generate_query_method(f, schema))
392        .collect();
393
394    quote! {
395        pub struct QueryRoot;
396
397        #[Object]
398        impl QueryRoot {
399            #(#query_methods)*
400        }
401    }
402}
403
404fn generate_query_method(
405    field: &protograph_core::QueryField,
406    schema: &ProtographSchema,
407) -> TokenStream {
408    let method_name = format_ident!("{}", to_snake_case(&field.name));
409    let graphql_name = &field.name;
410    let return_type = graphql_type_to_rust(&field.return_type);
411    let base_type = field.return_type.base_type();
412
413    let args: Vec<TokenStream> = field
414        .arguments
415        .iter()
416        .map(|a| {
417            let arg_name = format_ident!("{}", to_snake_case(&a.name));
418            let arg_type = graphql_type_to_rust(&a.field_type);
419            quote! { #arg_name: #arg_type }
420        })
421        .collect();
422
423    let loader_name = format_ident!("{}Loader", base_type);
424
425    if field.return_type.is_list() {
426        quote! {
427            #[graphql(name = #graphql_name)]
428            async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
429                todo!("Implement query")
430            }
431        }
432    } else {
433        let id_arg = field.arguments.iter().find(|a| a.name == "id");
434        if id_arg.is_some() {
435            quote! {
436                #[graphql(name = #graphql_name)]
437                async fn #method_name(&self, ctx: &Context<'_>, id: ID) -> Result<Option<#return_type>> {
438                    let loader = ctx.data::<DataLoader<#loader_name>>()?;
439                    loader.load_one(id.to_string()).await.map_err(|e| Error::new(e.to_string()))
440                }
441            }
442        } else {
443            quote! {
444                #[graphql(name = #graphql_name)]
445                async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
446                    todo!("Implement query")
447                }
448            }
449        }
450    }
451}
452
453fn generate_mutation_type(schema: &ProtographSchema) -> TokenStream {
454    if schema.mutation_fields.is_empty() {
455        return quote! {
456            pub struct MutationRoot;
457
458            #[Object]
459            impl MutationRoot {
460                async fn _placeholder(&self) -> bool {
461                    true
462                }
463            }
464        };
465    }
466
467    let mutation_methods: Vec<TokenStream> = schema
468        .mutation_fields
469        .iter()
470        .map(|f| generate_mutation_method(f))
471        .collect();
472
473    quote! {
474        pub struct MutationRoot;
475
476        #[Object]
477        impl MutationRoot {
478            #(#mutation_methods)*
479        }
480    }
481}
482
483fn generate_mutation_method(field: &protograph_core::MutationField) -> TokenStream {
484    let method_name = format_ident!("{}", to_snake_case(&field.name));
485    let graphql_name = &field.name;
486    let return_type = graphql_type_to_rust(&field.return_type);
487
488    let args: Vec<TokenStream> = field
489        .arguments
490        .iter()
491        .map(|a| {
492            let arg_name = format_ident!("{}", to_snake_case(&a.name));
493            let arg_type = graphql_type_to_rust(&a.field_type);
494            quote! { #arg_name: #arg_type }
495        })
496        .collect();
497
498    quote! {
499        #[graphql(name = #graphql_name)]
500        async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
501            todo!("Implement mutation")
502        }
503    }
504}
505
506fn generate_schema_builder(schema: &ProtographSchema) -> TokenStream {
507    let loader_registrations: Vec<TokenStream> = schema
508        .types
509        .iter()
510        .filter(|(_, t)| t.is_entity && !t.is_private)
511        .map(|(name, _)| {
512            let loader_name = format_ident!("{}Loader", name);
513            let service_trait = format_ident!("{}Service", name);
514            let method_name = format_ident!("with_{}_loader", to_snake_case(name));
515
516            quote! {
517                pub fn #method_name(mut self, service: Arc<dyn #service_trait>) -> Self {
518                    let loader = DataLoader::new(
519                        #loader_name::new(service),
520                        tokio::spawn
521                    );
522                    self.0 = self.0.data(loader);
523                    self
524                }
525            }
526        })
527        .collect();
528
529    quote! {
530        pub struct ProtographSchemaBuilder(SchemaBuilder<QueryRoot, MutationRoot, EmptySubscription>);
531
532        impl ProtographSchemaBuilder {
533            pub fn new() -> Self {
534                Self(Schema::build(QueryRoot, MutationRoot, EmptySubscription))
535            }
536
537            #(#loader_registrations)*
538
539            pub fn finish(self) -> Schema<QueryRoot, MutationRoot, EmptySubscription> {
540                self.0.finish()
541            }
542        }
543
544        impl Default for ProtographSchemaBuilder {
545            fn default() -> Self {
546                Self::new()
547            }
548        }
549    }
550}
551
552fn graphql_type_to_rust(gql_type: &FieldType) -> TokenStream {
553    match gql_type {
554        FieldType::Named(name) => {
555            let ident = format_ident!(
556                "{}",
557                match name.as_str() {
558                    "ID" => "ID",
559                    "String" => "String",
560                    "Int" => "i32",
561                    "Float" => "f64",
562                    "Boolean" => "bool",
563                    other => other,
564                }
565            );
566            quote! { #ident }
567        }
568        FieldType::NonNull(inner) => graphql_type_to_rust(inner),
569        FieldType::List(inner) => {
570            let inner_type = graphql_type_to_rust(inner);
571            quote! { Vec<#inner_type> }
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use protograph_core::parse_schema_file;
580
581    #[test]
582    fn test_generate_rust() {
583        let schema = r#"
584            type User @entity {
585                id: ID!
586                name: String!
587                posts: [Post!]! @hasMany(field: "authorId")
588            }
589
590            type Post @entity {
591                id: ID!
592                title: String!
593                author: User! @belongsTo(field: "authorId")
594                authorId: ID! @private
595            }
596
597            type Query {
598                user(id: ID!): User
599                users: [User!]!
600            }
601        "#;
602
603        let parsed = parse_schema_file(schema).unwrap();
604        let rust = generate_rust(&parsed);
605
606        assert!(rust.contains("pub trait UserService"));
607        assert!(rust.contains("pub trait PostService"));
608        assert!(rust.contains("pub struct UserLoader"));
609        assert!(rust.contains("pub struct QueryRoot"));
610    }
611}