1use crate::utils::{pluralize, to_pascal_case, to_snake_case};
2use proc_macro2::TokenStream;
3use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
4use quote::{format_ident, quote};
5
6pub fn generate_rust(schema: &ProtographSchema) -> String {
7 let traits = generate_traits(schema);
8 let dataloaders = generate_dataloaders(schema);
9 let graphql_types = generate_graphql_types(schema);
10 let input_types = generate_input_types(schema);
11 let query_type = generate_query_type(schema);
12 let mutation_type = generate_mutation_type(schema);
13 let schema_builder = generate_schema_builder(schema);
14
15 let output = quote! {
16 use async_graphql::*;
17 use async_graphql::dataloader::{DataLoader, Loader};
18 use std::collections::HashMap;
19 use std::sync::Arc;
20
21 #traits
22 #dataloaders
23 #graphql_types
24 #input_types
25 #query_type
26 #mutation_type
27 #schema_builder
28 };
29
30 output.to_string()
31}
32
33fn generate_input_types(schema: &ProtographSchema) -> TokenStream {
34 let input_types: Vec<TokenStream> = schema
35 .input_types
36 .iter()
37 .map(|(_, input)| {
38 let name = format_ident!("{}", &input.name);
39 let fields: Vec<TokenStream> = input
40 .fields
41 .iter()
42 .map(|f| {
43 let field_name = format_ident!("{}", to_snake_case(&f.name));
44 let field_type = graphql_type_to_rust(&f.field_type);
45 quote! { pub #field_name: #field_type }
46 })
47 .collect();
48
49 quote! {
50 #[derive(Clone, Debug, InputObject)]
51 pub struct #name {
52 #(#fields),*
53 }
54 }
55 })
56 .collect();
57
58 quote! { #(#input_types)* }
59}
60
61fn generate_traits(schema: &ProtographSchema) -> TokenStream {
62 let traits: Vec<TokenStream> = schema
63 .types
64 .iter()
65 .filter(|(_, t)| t.is_entity && !t.is_private)
66 .map(|(_, entity)| generate_service_trait(entity, schema))
67 .collect();
68
69 quote! {
70 #(#traits)*
71 }
72}
73
74fn generate_service_trait(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
75 let name = &entity.name;
76 let trait_name = format_ident!("{}Service", name);
77 let entity_type = format_ident!("{}", name);
78 let plural_name = pluralize(name);
79
80 let relationship_methods: Vec<TokenStream> = entity
81 .fields
82 .iter()
83 .filter_map(|f| generate_relationship_method(f, entity, schema))
84 .collect();
85
86 quote! {
87 #[async_trait::async_trait]
88 pub trait #trait_name: Send + Sync {
89 async fn get(&self, id: String) -> Result<Option<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
90
91 async fn batch_get(&self, ids: Vec<String>) -> Result<Vec<#entity_type>, Box<dyn std::error::Error + Send + Sync>>;
92
93 #(#relationship_methods)*
94 }
95 }
96}
97
98fn generate_relationship_method(
99 field: &protograph_core::Field,
100 parent: &EntityType,
101 schema: &ProtographSchema,
102) -> Option<TokenStream> {
103 match &field.relationship {
104 Some(Relationship::HasMany { foreign_key }) => {
105 let related_type_name = field.field_type.base_type();
106 let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
107 let entity_type = format_ident!("{}", related_type_name);
108 let fk_name = format_ident!("{}s", to_snake_case(foreign_key));
109
110 Some(quote! {
111 async fn #method_name(
112 &self,
113 #fk_name: Vec<String>
114 ) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
115 })
116 }
117 Some(Relationship::ManyToMany { junction_table, .. }) => {
118 let related_type_name = field.field_type.base_type();
119 let method_name = format_ident!("batch_get_via_{}", to_snake_case(junction_table));
120 let entity_type = format_ident!("{}", related_type_name);
121
122 Some(quote! {
123 async fn #method_name(
124 &self,
125 parent_ids: Vec<String>
126 ) -> Result<HashMap<String, Vec<#entity_type>>, Box<dyn std::error::Error + Send + Sync>>;
127 })
128 }
129 _ => None,
130 }
131}
132
133fn generate_dataloaders(schema: &ProtographSchema) -> TokenStream {
134 let entity_loaders: Vec<TokenStream> = schema
135 .types
136 .iter()
137 .filter(|(_, t)| t.is_entity && !t.is_private)
138 .map(|(_, entity)| generate_entity_loader(entity))
139 .collect();
140
141 let relationship_loaders: Vec<TokenStream> = schema
142 .types
143 .iter()
144 .filter(|(_, t)| t.is_entity)
145 .flat_map(|(_, entity)| {
146 entity
147 .fields
148 .iter()
149 .filter_map(|f| generate_relationship_loader(f, entity, schema))
150 })
151 .collect();
152
153 quote! {
154 #(#entity_loaders)*
155 #(#relationship_loaders)*
156 }
157}
158
159fn generate_entity_loader(entity: &EntityType) -> TokenStream {
160 let name = &entity.name;
161 let loader_name = format_ident!("{}Loader", name);
162 let service_trait = format_ident!("{}Service", name);
163 let entity_type = format_ident!("{}", name);
164
165 quote! {
166 pub struct #loader_name {
167 service: Arc<dyn #service_trait>,
168 }
169
170 impl #loader_name {
171 pub fn new(service: Arc<dyn #service_trait>) -> Self {
172 Self { service }
173 }
174 }
175
176 impl Loader<String> for #loader_name {
177 type Value = #entity_type;
178 type Error = Arc<dyn std::error::Error + Send + Sync>;
179
180 fn load(
181 &self,
182 keys: &[String]
183 ) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
184 let service = self.service.clone();
185 let keys = keys.to_vec();
186 async move {
187 let entities = service.batch_get(keys).await
188 .map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)?;
189
190 Ok(entities.into_iter()
191 .map(|e| (e.id.clone(), e))
192 .collect())
193 }
194 }
195 }
196 }
197}
198
199fn generate_relationship_loader(
200 field: &protograph_core::Field,
201 parent: &EntityType,
202 _schema: &ProtographSchema,
203) -> Option<TokenStream> {
204 match &field.relationship {
205 Some(Relationship::HasMany { foreign_key }) => {
206 let related_type_name = field.field_type.base_type();
207 let loader_name = format_ident!(
208 "{}By{}Loader",
209 pluralize(related_type_name),
210 to_pascal_case(foreign_key)
211 );
212 let service_trait = format_ident!("{}Service", parent.name);
213 let entity_type = format_ident!("{}", related_type_name);
214 let method_name = format_ident!("batch_get_by_{}", to_snake_case(foreign_key));
215
216 Some(quote! {
217 pub struct #loader_name {
218 service: Arc<dyn #service_trait>,
219 }
220
221 impl #loader_name {
222 pub fn new(service: Arc<dyn #service_trait>) -> Self {
223 Self { service }
224 }
225 }
226
227 impl Loader<String> for #loader_name {
228 type Value = Vec<#entity_type>;
229 type Error = Arc<dyn std::error::Error + Send + Sync>;
230
231 fn load(
232 &self,
233 keys: &[String]
234 ) -> impl std::future::Future<Output = Result<HashMap<String, Self::Value>, Self::Error>> + Send {
235 let service = self.service.clone();
236 let keys = keys.to_vec();
237 async move {
238 service.#method_name(keys).await
239 .map_err(|e| Arc::from(e) as Arc<dyn std::error::Error + Send + Sync>)
240 }
241 }
242 }
243 })
244 }
245 _ => None,
246 }
247}
248
249fn generate_graphql_types(schema: &ProtographSchema) -> TokenStream {
250 let types: Vec<TokenStream> = schema
251 .types
252 .iter()
253 .filter(|(_, t)| !t.is_private)
254 .map(|(_, entity)| generate_graphql_type(entity, schema))
255 .collect();
256
257 quote! { #(#types)* }
258}
259
260fn generate_graphql_type(entity: &EntityType, schema: &ProtographSchema) -> TokenStream {
261 let name = format_ident!("{}", &entity.name);
262
263 let scalar_fields: Vec<TokenStream> = entity
264 .fields
265 .iter()
266 .filter(|f| !f.is_private && f.relationship.is_none())
267 .map(|f| generate_scalar_field(f))
268 .collect();
269
270 let relationship_fields: Vec<TokenStream> = entity
271 .fields
272 .iter()
273 .filter(|f| !f.is_private && f.relationship.is_some())
274 .map(|f| generate_relationship_field(f, entity, schema))
275 .collect();
276
277 quote! {
278 #[derive(Clone, Debug)]
279 pub struct #name {
280 pub id: String,
281 inner: HashMap<String, String>,
282 }
283
284 impl #name {
285 pub fn new(id: String) -> Self {
286 Self { id, inner: HashMap::new() }
287 }
288
289 pub fn with_field(mut self, key: &str, value: String) -> Self {
290 self.inner.insert(key.to_string(), value);
291 self
292 }
293 }
294
295 #[Object]
296 impl #name {
297 async fn id(&self) -> &str {
298 &self.id
299 }
300
301 #(#scalar_fields)*
302 #(#relationship_fields)*
303 }
304 }
305}
306
307fn generate_scalar_field(field: &protograph_core::Field) -> TokenStream {
308 let field_name = format_ident!("{}", to_snake_case(&field.name));
309 let graphql_name = &field.name;
310 let return_type = graphql_type_to_rust(&field.field_type);
311
312 if field.name == "id" {
313 return quote! {};
314 }
315
316 quote! {
317 #[graphql(name = #graphql_name)]
318 async fn #field_name(&self) -> #return_type {
319 self.inner.get(#graphql_name).cloned().unwrap_or_default()
320 }
321 }
322}
323
324fn generate_relationship_field(
325 field: &protograph_core::Field,
326 parent: &EntityType,
327 schema: &ProtographSchema,
328) -> TokenStream {
329 let field_name = format_ident!("{}", to_snake_case(&field.name));
330 let graphql_name = &field.name;
331
332 match &field.relationship {
333 Some(Relationship::BelongsTo { foreign_key }) => {
334 let related_type = format_ident!("{}", field.field_type.base_type());
335 let loader_name = format_ident!("{}Loader", field.field_type.base_type());
336 let fk_field = to_snake_case(foreign_key);
337
338 quote! {
339 #[graphql(name = #graphql_name)]
340 async fn #field_name(&self, ctx: &Context<'_>) -> Result<Option<#related_type>> {
341 let loader = ctx.data::<DataLoader<#loader_name>>()?;
342 let fk = self.inner.get(#fk_field).cloned().unwrap_or_default();
343 loader.load_one(fk).await.map_err(|e| Error::new(e.to_string()))
344 }
345 }
346 }
347 Some(Relationship::HasMany { foreign_key }) => {
348 let related_type = format_ident!("{}", field.field_type.base_type());
349 let loader_name = format_ident!(
350 "{}By{}Loader",
351 pluralize(field.field_type.base_type()),
352 to_pascal_case(foreign_key)
353 );
354
355 quote! {
356 #[graphql(name = #graphql_name)]
357 async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
358 let loader = ctx.data::<DataLoader<#loader_name>>()?;
359 loader.load_one(self.id.clone()).await
360 .map_err(|e| Error::new(e.to_string()))?
361 .ok_or_else(|| Error::new("Not found"))
362 }
363 }
364 }
365 Some(Relationship::ManyToMany { junction_table, .. }) => {
366 let related_type = format_ident!("{}", field.field_type.base_type());
367 let loader_name = format_ident!(
368 "{}Via{}Loader",
369 pluralize(field.field_type.base_type()),
370 junction_table
371 );
372
373 quote! {
374 #[graphql(name = #graphql_name)]
375 async fn #field_name(&self, ctx: &Context<'_>) -> Result<Vec<#related_type>> {
376 let loader = ctx.data::<DataLoader<#loader_name>>()?;
377 loader.load_one(self.id.clone()).await
378 .map_err(|e| Error::new(e.to_string()))?
379 .ok_or_else(|| Error::new("Not found"))
380 }
381 }
382 }
383 None => quote! {},
384 }
385}
386
387fn generate_query_type(schema: &ProtographSchema) -> TokenStream {
388 let query_methods: Vec<TokenStream> = schema
389 .query_fields
390 .iter()
391 .map(|f| generate_query_method(f, schema))
392 .collect();
393
394 quote! {
395 pub struct QueryRoot;
396
397 #[Object]
398 impl QueryRoot {
399 #(#query_methods)*
400 }
401 }
402}
403
404fn generate_query_method(
405 field: &protograph_core::QueryField,
406 schema: &ProtographSchema,
407) -> TokenStream {
408 let method_name = format_ident!("{}", to_snake_case(&field.name));
409 let graphql_name = &field.name;
410 let return_type = graphql_type_to_rust(&field.return_type);
411 let base_type = field.return_type.base_type();
412
413 let args: Vec<TokenStream> = field
414 .arguments
415 .iter()
416 .map(|a| {
417 let arg_name = format_ident!("{}", to_snake_case(&a.name));
418 let arg_type = graphql_type_to_rust(&a.field_type);
419 quote! { #arg_name: #arg_type }
420 })
421 .collect();
422
423 let loader_name = format_ident!("{}Loader", base_type);
424
425 if field.return_type.is_list() {
426 quote! {
427 #[graphql(name = #graphql_name)]
428 async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
429 todo!("Implement query")
430 }
431 }
432 } else {
433 let id_arg = field.arguments.iter().find(|a| a.name == "id");
434 if id_arg.is_some() {
435 quote! {
436 #[graphql(name = #graphql_name)]
437 async fn #method_name(&self, ctx: &Context<'_>, id: ID) -> Result<Option<#return_type>> {
438 let loader = ctx.data::<DataLoader<#loader_name>>()?;
439 loader.load_one(id.to_string()).await.map_err(|e| Error::new(e.to_string()))
440 }
441 }
442 } else {
443 quote! {
444 #[graphql(name = #graphql_name)]
445 async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
446 todo!("Implement query")
447 }
448 }
449 }
450 }
451}
452
453fn generate_mutation_type(schema: &ProtographSchema) -> TokenStream {
454 if schema.mutation_fields.is_empty() {
455 return quote! {
456 pub struct MutationRoot;
457
458 #[Object]
459 impl MutationRoot {
460 async fn _placeholder(&self) -> bool {
461 true
462 }
463 }
464 };
465 }
466
467 let mutation_methods: Vec<TokenStream> = schema
468 .mutation_fields
469 .iter()
470 .map(|f| generate_mutation_method(f))
471 .collect();
472
473 quote! {
474 pub struct MutationRoot;
475
476 #[Object]
477 impl MutationRoot {
478 #(#mutation_methods)*
479 }
480 }
481}
482
483fn generate_mutation_method(field: &protograph_core::MutationField) -> TokenStream {
484 let method_name = format_ident!("{}", to_snake_case(&field.name));
485 let graphql_name = &field.name;
486 let return_type = graphql_type_to_rust(&field.return_type);
487
488 let args: Vec<TokenStream> = field
489 .arguments
490 .iter()
491 .map(|a| {
492 let arg_name = format_ident!("{}", to_snake_case(&a.name));
493 let arg_type = graphql_type_to_rust(&a.field_type);
494 quote! { #arg_name: #arg_type }
495 })
496 .collect();
497
498 quote! {
499 #[graphql(name = #graphql_name)]
500 async fn #method_name(&self, ctx: &Context<'_>, #(#args),*) -> Result<#return_type> {
501 todo!("Implement mutation")
502 }
503 }
504}
505
506fn generate_schema_builder(schema: &ProtographSchema) -> TokenStream {
507 let loader_registrations: Vec<TokenStream> = schema
508 .types
509 .iter()
510 .filter(|(_, t)| t.is_entity && !t.is_private)
511 .map(|(name, _)| {
512 let loader_name = format_ident!("{}Loader", name);
513 let service_trait = format_ident!("{}Service", name);
514 let method_name = format_ident!("with_{}_loader", to_snake_case(name));
515
516 quote! {
517 pub fn #method_name(mut self, service: Arc<dyn #service_trait>) -> Self {
518 let loader = DataLoader::new(
519 #loader_name::new(service),
520 tokio::spawn
521 );
522 self.0 = self.0.data(loader);
523 self
524 }
525 }
526 })
527 .collect();
528
529 quote! {
530 pub struct ProtographSchemaBuilder(SchemaBuilder<QueryRoot, MutationRoot, EmptySubscription>);
531
532 impl ProtographSchemaBuilder {
533 pub fn new() -> Self {
534 Self(Schema::build(QueryRoot, MutationRoot, EmptySubscription))
535 }
536
537 #(#loader_registrations)*
538
539 pub fn finish(self) -> Schema<QueryRoot, MutationRoot, EmptySubscription> {
540 self.0.finish()
541 }
542 }
543
544 impl Default for ProtographSchemaBuilder {
545 fn default() -> Self {
546 Self::new()
547 }
548 }
549 }
550}
551
552fn graphql_type_to_rust(gql_type: &FieldType) -> TokenStream {
553 match gql_type {
554 FieldType::Named(name) => {
555 let ident = format_ident!(
556 "{}",
557 match name.as_str() {
558 "ID" => "ID",
559 "String" => "String",
560 "Int" => "i32",
561 "Float" => "f64",
562 "Boolean" => "bool",
563 other => other,
564 }
565 );
566 quote! { #ident }
567 }
568 FieldType::NonNull(inner) => graphql_type_to_rust(inner),
569 FieldType::List(inner) => {
570 let inner_type = graphql_type_to_rust(inner);
571 quote! { Vec<#inner_type> }
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use protograph_core::parse_schema_file;
580
581 #[test]
582 fn test_generate_rust() {
583 let schema = r#"
584 type User @entity {
585 id: ID!
586 name: String!
587 posts: [Post!]! @hasMany(field: "authorId")
588 }
589
590 type Post @entity {
591 id: ID!
592 title: String!
593 author: User! @belongsTo(field: "authorId")
594 authorId: ID! @private
595 }
596
597 type Query {
598 user(id: ID!): User
599 users: [User!]!
600 }
601 "#;
602
603 let parsed = parse_schema_file(schema).unwrap();
604 let rust = generate_rust(&parsed);
605
606 assert!(rust.contains("pub trait UserService"));
607 assert!(rust.contains("pub trait PostService"));
608 assert!(rust.contains("pub struct UserLoader"));
609 assert!(rust.contains("pub struct QueryRoot"));
610 }
611}