bluejay_typegen_codegen/
lib.rs

1use bluejay_core::{
2    definition::{prelude::*, SchemaDefinition, TypeDefinitionReference},
3    BuiltinScalarDefinition,
4};
5use bluejay_parser::{
6    ast::{
7        definition::{DefinitionDocument, SchemaDefinition as ParserSchemaDefinition},
8        Parse,
9    },
10    Error as ParserError,
11};
12use bluejay_validator::definition::BuiltinRulesValidator;
13use std::collections::{HashMap, HashSet};
14use syn::{parse_quote, spanned::Spanned};
15
16mod attributes;
17mod builtin_scalar;
18mod code_generator;
19mod enum_type_definition;
20mod executable_definition;
21mod input;
22mod input_object_type_definition;
23pub mod names;
24mod types;
25mod validation;
26
27use attributes::doc_string;
28pub use code_generator::CodeGenerator;
29use enum_type_definition::EnumTypeDefinitionBuilder;
30use executable_definition::generate_executable_definition;
31pub use executable_definition::{
32    ExecutableEnum, ExecutableField, ExecutableStruct, ExecutableType, WrappedExecutableType,
33};
34use input::DocumentInput;
35pub use input::Input;
36use input_object_type_definition::InputObjectTypeDefinitionBuilder;
37
38pub(crate) struct Config<'a, S: SchemaDefinition, C: CodeGenerator> {
39    borrow: bool,
40    schema_definition: &'a S,
41    custom_scalar_borrows: HashMap<String, bool>,
42    enums_as_str: HashSet<String>,
43    code_generator: &'a C,
44}
45
46impl<'a, S: SchemaDefinition, C: CodeGenerator> Config<'a, S, C> {
47    pub(crate) fn schema_definition(&self) -> &'a S {
48        self.schema_definition
49    }
50
51    pub(crate) fn borrow(&self) -> bool {
52        self.borrow
53    }
54
55    pub(crate) fn custom_scalar_borrows(&self, cstd: &S::CustomScalarTypeDefinition) -> bool {
56        *self
57            .custom_scalar_borrows
58            .get(&names::type_name(cstd.name()))
59            .expect("No type alias for custom scalar")
60    }
61
62    pub(crate) fn builtin_scalar_borrows(&self, bstd: BuiltinScalarDefinition) -> bool {
63        self.borrow && builtin_scalar::scalar_is_reference(bstd)
64    }
65
66    pub(crate) fn enum_as_str(&self, etd: &S::EnumTypeDefinition) -> bool {
67        self.enums_as_str.contains(etd.name())
68    }
69
70    pub(crate) fn code_generator(&self) -> &C {
71        self.code_generator
72    }
73}
74
75pub fn generate_schema(
76    input: Input,
77    module: &mut syn::ItemMod,
78    known_custom_scalar_types: HashMap<String, KnownCustomScalarType>,
79    code_generator: impl CodeGenerator,
80) -> syn::Result<()> {
81    let Input {
82        ref schema,
83        borrow,
84        enums_as_str,
85    } = input;
86
87    let borrow = borrow.is_some_and(|lit| lit.value());
88
89    let (schema_contents, schema_path) = schema.read_to_string_and_path()?;
90
91    let definition_document: DefinitionDocument = DefinitionDocument::parse(&schema_contents)
92        .map_err(|errors| {
93            map_parser_errors(schema, &schema_contents, schema_path.as_deref(), errors)
94        })?;
95    let schema_definition =
96        ParserSchemaDefinition::try_from(&definition_document).map_err(|errors| {
97            map_parser_errors(schema, &schema_contents, schema_path.as_deref(), errors)
98        })?;
99    let schema_errors: Vec<_> = BuiltinRulesValidator::validate(&schema_definition).collect();
100    if !schema_errors.is_empty() {
101        return Err(map_parser_errors(
102            schema,
103            &schema_contents,
104            schema_path.as_deref(),
105            schema_errors,
106        ));
107    }
108
109    let custom_scalar_borrows = custom_scalar_borrows(
110        module,
111        &schema_definition,
112        borrow,
113        known_custom_scalar_types,
114    )?;
115
116    let enums_as_str = validate_enums_as_str(enums_as_str, &schema_definition)?;
117
118    let config = Config {
119        schema_definition: &schema_definition,
120        borrow,
121        custom_scalar_borrows,
122        enums_as_str,
123        code_generator: &code_generator,
124    };
125
126    if let Some((_, items)) = module.content.take() {
127        let new_items = process_module_items(&config, items)?;
128        module.content = Some((syn::token::Brace::default(), new_items));
129    } else {
130        let new_items = process_module_items(&config, Vec::new())?;
131        module.content = Some((syn::token::Brace::default(), new_items));
132    }
133
134    if let Some(description) = schema_definition.description() {
135        module.attrs.push(doc_string(description));
136    }
137
138    Ok(())
139}
140
141fn custom_scalar_borrows(
142    module: &mut syn::ItemMod,
143    schema_definition: &impl SchemaDefinition,
144    borrow: bool,
145    known_custom_scalar_types: HashMap<String, KnownCustomScalarType>,
146) -> syn::Result<HashMap<String, bool>> {
147    let items = module
148        .content
149        .as_ref()
150        .map(|(_, items)| items.as_slice())
151        .unwrap_or_default();
152
153    let type_aliases = items
154        .iter()
155        .filter_map(|item| match item {
156            syn::Item::Type(ty) => Some(ty),
157            _ => None,
158        })
159        .collect::<Vec<_>>();
160
161    type_aliases.iter().try_for_each(|type_alias| {
162        let generics = &type_alias.generics;
163
164        if let Some(type_param) = generics.type_params().next() {
165            return Err(syn::Error::new(
166                type_param.span(),
167                "Type aliases for custom scalars must not contain type parameters",
168            ));
169        }
170
171        if let Some(const_param) = generics.const_params().next() {
172            return Err(syn::Error::new(
173                const_param.span(),
174                "Type aliases for custom scalars must not contain const parameters",
175            ));
176        }
177
178        if !borrow {
179            if let Some(lifetime_param) = generics.lifetimes().next() {
180                return Err(syn::Error::new(
181                    lifetime_param.span(),
182                    "Type aliases for custom scalars cannot contain lifetime parameters when `borrow` is set to true",
183                ));
184            }
185        } else if let Some(lifetime_param) = generics.lifetimes().nth(1) {
186            return Err(syn::Error::new(
187                lifetime_param.span(),
188                "Type aliases for custom scalars must contain at most one lifetime parameter",
189            ));
190        }
191
192        let name = type_alias.ident.to_string();
193
194        if !schema_definition.type_definitions().any(|type_definition| {
195            matches!(type_definition, TypeDefinitionReference::CustomScalar(cstd) if names::type_name(cstd.name()) == name)
196        }) {
197            return Err(syn::Error::new(
198                type_alias.ident.span(),
199                format!("No custom scalar definition named {name}"),
200            ));
201        }
202
203        Ok(())
204    })?;
205
206    let mut custom_scalars: HashMap<String, bool> = type_aliases
207        .into_iter()
208        .map(|type_alias| {
209            (
210                type_alias.ident.to_string(),
211                type_alias.generics.lifetimes().next().is_some(),
212            )
213        })
214        .collect();
215
216    schema_definition
217        .type_definitions()
218        .try_for_each(|td| match td {
219            TypeDefinitionReference::CustomScalar(cstd) => {
220                let name = names::type_name(cstd.name());
221                #[allow(clippy::map_entry)]
222                if custom_scalars.contains_key(&name) {
223                    Ok(())
224                } else if let Some(known_custom_scalar_type) = known_custom_scalar_types.get(&name)
225                {
226                    let (ty, lifetime): (_, Option<syn::Generics>) =
227                        match known_custom_scalar_type.type_for_borrowed.as_ref() {
228                            Some(ty) if borrow => (ty, Some(parse_quote! { <'a> })),
229                            _ => (&known_custom_scalar_type.type_for_owned, None),
230                        };
231                    let ident = quote::format_ident!("{}", name);
232                    let alias: syn::ItemType = parse_quote! {
233                        pub type #ident #lifetime = #ty;
234                    };
235                    if let Some((_, items)) = module.content.as_mut() {
236                        items.push(alias.into());
237                    }
238                    custom_scalars.insert(
239                        name,
240                        borrow && known_custom_scalar_type.type_for_borrowed.is_some(),
241                    );
242                    Ok(())
243                } else {
244                    Err(syn::Error::new(
245                        module.span(),
246                        format!("Missing type alias for custom scalar {name}"),
247                    ))
248                }
249            }
250            _ => Ok(()),
251        })?;
252
253    Ok(custom_scalars)
254}
255
256fn validate_enums_as_str(
257    enums_as_str: syn::punctuated::Punctuated<syn::LitStr, syn::Token![,]>,
258    schema_definition: &impl SchemaDefinition,
259) -> syn::Result<HashSet<String>> {
260    let mut enum_names = HashSet::new();
261    enums_as_str.iter().try_for_each(|lit| {
262        let name: String = lit.value();
263        if matches!(
264            schema_definition.get_type_definition(&name),
265            Some(TypeDefinitionReference::Enum(_))
266        ) {
267            if enum_names.insert(name.clone()) {
268                Ok(())
269            } else {
270                Err(syn::Error::new(
271                    lit.span(),
272                    format!("Duplicate enum definition named {name}"),
273                ))
274            }
275        } else {
276            Err(syn::Error::new(
277                lit.span(),
278                format!("No enum definition named {name}"),
279            ))
280        }
281    })?;
282    Ok(enum_names)
283}
284
285fn process_module_items<S: SchemaDefinition, C: CodeGenerator>(
286    config: &Config<S, C>,
287    items: Vec<syn::Item>,
288) -> syn::Result<Vec<syn::Item>> {
289    config
290        .schema_definition
291        .type_definitions()
292        .filter_map(|type_definition| match type_definition {
293            TypeDefinitionReference::Enum(etd) if !config.enum_as_str(etd) => Some(
294                EnumTypeDefinitionBuilder::<S, C>::build(etd, config.code_generator()),
295            ),
296            TypeDefinitionReference::InputObject(iotd) => {
297                Some(InputObjectTypeDefinitionBuilder::build(iotd, config))
298            }
299            _ => None,
300        })
301        .flatten()
302        .map(Ok)
303        .chain(
304            items
305                .into_iter()
306                .map(|item| process_module_item(config, item)),
307        )
308        .collect()
309}
310
311fn process_module_item<S: SchemaDefinition, C: CodeGenerator>(
312    config: &Config<S, C>,
313    item: syn::Item,
314) -> syn::Result<syn::Item> {
315    if let syn::Item::Mod(mut module) = item {
316        if let Some((attribute, &mut [])) = module.attrs.split_first_mut() {
317            if matches!(attribute.style, syn::AttrStyle::Inner(_)) {
318                Err(syn::Error::new(
319                    attribute.span(),
320                    "Expected an outer attribute",
321                ))
322            } else if let syn::Meta::List(list) = &mut attribute.meta {
323                if list.path.is_ident("query") {
324                    if !matches!(list.delimiter, syn::MacroDelimiter::Bracket(_)) {
325                        let items = generate_executable_definition(
326                            config,
327                            std::mem::take(&mut list.tokens),
328                        )?;
329                        module.content = Some((syn::token::Brace::default(), items));
330                        module.attrs = Vec::new();
331                        Ok(module.into())
332                    } else {
333                        Err(syn::Error::new(
334                            list.delimiter.span().open(),
335                            "Expected brackets",
336                        ))
337                    }
338                } else {
339                    Err(syn::Error::new(list.path.span(), "Expected `query`"))
340                }
341            } else {
342                Err(syn::Error::new(
343                    attribute.meta.span(),
344                    "Expected a list meta attribute, e.g. `#[query(...)]`",
345                ))
346            }
347        } else {
348            Err(syn::Error::new(
349                module.span(),
350                "Expected a single `#[query(...)]` attribute",
351            ))
352        }
353    } else if matches!(item, syn::Item::Type(_)) {
354        Ok(item)
355    } else {
356        Err(syn::Error::new(item.span(), "Expected a module"))
357    }
358}
359
360fn map_parser_errors<E: Into<ParserError>>(
361    span: &impl syn::spanned::Spanned,
362    schema_contents: &str,
363    schema_path: Option<&str>,
364    errors: impl IntoIterator<Item = E>,
365) -> syn::Error {
366    syn::Error::new(
367        span.span(),
368        ParserError::format_errors(schema_contents, schema_path, errors),
369    )
370}
371
372#[derive(Clone)]
373pub struct KnownCustomScalarType {
374    pub type_for_owned: syn::Type,
375    pub type_for_borrowed: Option<syn::Type>,
376}