Skip to main content

dbt_yaml_schemars_derive/
lib.rs

1#![forbid(unsafe_code)]
2
3#[macro_use]
4extern crate quote;
5#[macro_use]
6extern crate syn;
7extern crate proc_macro;
8
9mod ast;
10mod attr;
11mod metadata;
12mod regex_syntax;
13mod schema_exprs;
14
15use ast::*;
16use proc_macro2::TokenStream;
17use syn::spanned::Spanned;
18
19#[proc_macro_derive(JsonSchema, attributes(schemars, serde, validate))]
20pub fn derive_json_schema_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let input = parse_macro_input!(input as syn::DeriveInput);
22    derive_json_schema(input, false, false)
23        .unwrap_or_else(syn::Error::into_compile_error)
24        .into()
25}
26
27#[proc_macro_derive(DbtSchema, attributes(schemars, serde, validate))]
28pub fn derive_dbt_schema_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    let input = parse_macro_input!(input as syn::DeriveInput);
30    derive_json_schema(input, false, true)
31        .unwrap_or_else(syn::Error::into_compile_error)
32        .into()
33}
34
35#[proc_macro_derive(JsonSchema_repr, attributes(schemars, serde))]
36pub fn derive_json_schema_repr_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
37    let input = parse_macro_input!(input as syn::DeriveInput);
38    derive_json_schema(input, true, false)
39        .unwrap_or_else(syn::Error::into_compile_error)
40        .into()
41}
42
43fn derive_json_schema(
44    mut input: syn::DeriveInput,
45    repr: bool,
46    dbt: bool,
47) -> syn::Result<TokenStream> {
48    attr::process_serde_attrs(&mut input)?;
49
50    let mut cont = Container::from_ast(&input)?;
51    add_trait_bounds(&mut cont);
52
53    let crate_alias = cont.attrs.crate_name.as_ref().map(|path| {
54        quote_spanned! {path.span()=>
55            use #path as schemars;
56        }
57    });
58
59    let type_name = &cont.ident;
60    let (impl_generics, ty_generics, where_clause) = cont.generics.split_for_impl();
61
62    if let Some(transparent_field) = cont.transparent_field() {
63        let (ty, type_def) = schema_exprs::type_for_field_schema(transparent_field);
64        return Ok(quote! {
65            const _: () = {
66                #crate_alias
67                #type_def
68
69                #[automatically_derived]
70                impl #impl_generics schemars::JsonSchema for #type_name #ty_generics #where_clause {
71                    fn is_referenceable() -> bool {
72                        <#ty as schemars::JsonSchema>::is_referenceable()
73                    }
74
75                    fn schema_name() -> std::string::String {
76                        <#ty as schemars::JsonSchema>::schema_name()
77                    }
78
79                    fn schema_id() -> std::borrow::Cow<'static, str> {
80                        <#ty as schemars::JsonSchema>::schema_id()
81                    }
82
83                    fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
84                        <#ty as schemars::JsonSchema>::json_schema(generator)
85                    }
86
87                    fn _schemars_private_non_optional_json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
88                        <#ty as schemars::JsonSchema>::_schemars_private_non_optional_json_schema(generator)
89                    }
90
91                    fn _schemars_private_is_option() -> bool {
92                        <#ty as schemars::JsonSchema>::_schemars_private_is_option()
93                    }
94                };
95            };
96        });
97    }
98
99    let mut schema_base_name = cont.name().to_string();
100
101    if !cont.attrs.is_renamed {
102        if let Some(path) = cont.serde_attrs.remote() {
103            if let Some(segment) = path.segments.last() {
104                schema_base_name = segment.ident.to_string();
105            }
106        }
107    }
108
109    // FIXME improve handling of generic type params which may not implement JsonSchema
110    let type_params: Vec<_> = cont.generics.type_params().map(|ty| &ty.ident).collect();
111    let const_params: Vec<_> = cont.generics.const_params().map(|c| &c.ident).collect();
112    let params: Vec<_> = type_params.iter().chain(const_params.iter()).collect();
113
114    let (schema_name, schema_id) = if params.is_empty()
115        || (cont.attrs.is_renamed && !schema_base_name.contains('{'))
116    {
117        (
118            quote! {
119                #schema_base_name.to_owned()
120            },
121            quote! {
122                std::borrow::Cow::Borrowed(std::concat!(
123                    std::module_path!(),
124                    "::",
125                    #schema_base_name
126                ))
127            },
128        )
129    } else if cont.attrs.is_renamed {
130        let mut schema_name_fmt = schema_base_name;
131        for tp in &params {
132            schema_name_fmt.push_str(&format!("{{{tp}:.0}}"));
133        }
134        (
135            quote! {
136                format!(#schema_name_fmt #(,#type_params=#type_params::schema_name())* #(,#const_params=#const_params)*)
137            },
138            quote! {
139                std::borrow::Cow::Owned(
140                    format!(
141                        std::concat!(
142                            std::module_path!(),
143                            "::",
144                            #schema_name_fmt
145                        )
146                        #(,#type_params=#type_params::schema_id())*
147                        #(,#const_params=#const_params)*
148                    )
149                )
150            },
151        )
152    } else {
153        let mut schema_name_fmt = schema_base_name;
154        schema_name_fmt.push_str("_for_{}");
155        schema_name_fmt.push_str(&"_and_{}".repeat(params.len() - 1));
156        (
157            quote! {
158                format!(#schema_name_fmt #(,#type_params::schema_name())* #(,#const_params)*)
159            },
160            quote! {
161                std::borrow::Cow::Owned(
162                    format!(
163                        std::concat!(
164                            std::module_path!(),
165                            "::",
166                            #schema_name_fmt
167                        )
168                        #(,#type_params::schema_id())*
169                        #(,#const_params)*
170                    )
171                )
172            },
173        )
174    };
175
176    let schema_expr = if repr {
177        schema_exprs::expr_for_repr(&cont)?
178    } else {
179        schema_exprs::expr_for_container(&cont)
180    };
181    let schema_expr = if dbt {
182        quote! {
183            dbt_yaml::maybe_transformable::maybe_transformable(#schema_expr)
184        }
185    } else {
186        schema_expr
187    };
188
189    Ok(quote! {
190        const _: () = {
191            #crate_alias
192
193            #[automatically_derived]
194            #[allow(unused_braces)]
195            impl #impl_generics schemars::JsonSchema for #type_name #ty_generics #where_clause {
196                fn schema_name() -> std::string::String {
197                    #schema_name
198                }
199
200                fn schema_id() -> std::borrow::Cow<'static, str> {
201                    #schema_id
202                }
203
204                fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
205                    #schema_expr
206                }
207            };
208        };
209    })
210}
211
212fn add_trait_bounds(cont: &mut Container) {
213    if let Some(bounds) = cont.serde_attrs.ser_bound() {
214        let where_clause = cont.generics.make_where_clause();
215        where_clause.predicates.extend(bounds.iter().cloned());
216    } else {
217        // No explicit trait bounds specified, assume the Rust convention of adding the trait to each type parameter
218        // TODO consider also adding trait bound to associated types when used as fields - I think Serde does this?
219        for param in &mut cont.generics.params {
220            if let syn::GenericParam::Type(ref mut type_param) = *param {
221                type_param.bounds.push(parse_quote!(schemars::JsonSchema));
222            }
223        }
224    }
225}