mdmodels_macro/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use lazy_static::lazy_static;
5use mdmodels::datamodel::DataModel;
6use proc_macro::TokenStream;
7use quote::quote;
8use std::collections::{BTreeMap, HashMap};
9use std::{error::Error, path::Path};
10use syn::{parse_macro_input, LitStr};
11
12// Static variables
13const FORBIDDEN_NAMES: [&str; 9] = [
14    "type", "struct", "enum", "use", "crate", "mod", "fn", "impl", "trait",
15];
16
17// Lazy static initialization for type mappings
18lazy_static! {
19    static ref TYPE_MAPPINGS: HashMap<&'static str, &'static str> = {
20        let mut m = HashMap::new();
21        m.insert("integer", "i64");
22        m.insert("float", "f64");
23        m.insert("string", "String");
24        m.insert("boolean", "bool");
25        m.insert("bytes", "Vec<u8>");
26        m.insert("date", "String");
27        m.insert("datetime", "String");
28        m
29    };
30}
31
32/// Procedural macro to generate structs from markdown models
33///
34/// # Arguments
35/// * `input` - A TokenStream representing the input markdown file path
36///
37/// # Returns
38/// A TokenStream containing the generated Rust code for the structs and enums
39#[proc_macro]
40pub fn parse_mdmodel(input: TokenStream) -> TokenStream {
41    // Get the current working directory
42    let dir = std::env::var("CARGO_MANIFEST_DIR").map_or_else(
43        |_| std::env::current_dir().unwrap(),
44        |s| Path::new(&s).to_path_buf(),
45    );
46
47    // Parse the input TokenStream as a literal string
48    let input = parse_macro_input!(input as LitStr).value();
49    let path = dir.join(input);
50
51    // Parse the DataModel from the specified path
52    let model = DataModel::from_markdown(&path)
53        .unwrap_or_else(|_| panic!("Failed to parse the markdown model at path: {:?}", path));
54    let mut structs = vec![];
55
56    // Iterate through the objects in the model
57    for object in model.objects {
58        if is_reserved(&object.name) {
59            panic!("Reserved keyword used as object name: {}", object.name);
60        }
61
62        let struct_name = syn::Ident::new(&object.name, proc_macro2::Span::call_site());
63        let mut fields = vec![quote! {
64            #[serde(skip_serializing_if = "Option::is_none")]
65            #[builder(default)]
66            pub additional_properties: Option<std::collections::HashMap<String, serde_json::Value>>
67        }];
68        let mut getters = vec![];
69        let mut setters = vec![];
70
71        // Iterate through the attributes of each object
72        for attribute in object.attributes {
73            let field_name = syn::Ident::new(&attribute.name, proc_macro2::Span::call_site());
74            let field_type = get_data_type(&attribute.dtypes[0])
75                .unwrap_or_else(|_| panic!("Unknown data type: {}", attribute.dtypes[0]));
76
77            let wrapped_type = wrap_dtype(attribute.is_array, attribute.required, field_type);
78            let builder_attr =
79                get_builder_attr(attribute.is_array, attribute.required, &attribute.name);
80            let serde_attr = get_serde_attr(attribute.is_array, attribute.required);
81
82            fields.push(quote! {
83                #builder_attr
84                #serde_attr
85                pub #field_name: #wrapped_type
86            });
87
88            let getter_name = syn::Ident::new(
89                format!("get_{}", attribute.name).as_str(),
90                proc_macro2::Span::call_site(),
91            );
92
93            let setter_name = syn::Ident::new(
94                format!("set_{}", attribute.name).as_str(),
95                proc_macro2::Span::call_site(),
96            );
97
98            getters.push(quote! {
99                pub fn #getter_name(&self) -> &#wrapped_type {
100                    &self.#field_name
101                }
102            });
103
104            setters.push(quote! {
105                pub fn #setter_name(&mut self, value: #wrapped_type) -> &mut Self {
106                    self.#field_name = value;
107                    self
108                }
109            });
110        }
111
112        // Generate the struct definition with pyclass and constructor
113        let struct_def = quote! {
114            #[derive(Builder, Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
115            pub struct #struct_name {
116                #(#fields),*
117            }
118
119            impl #struct_name {
120                pub fn new() -> Self {
121                    Self::default()
122                }
123
124                #(#getters)*
125                #(#setters)*
126
127            }
128        };
129
130        structs.push(struct_def);
131    }
132
133    // Iterate through enumerations
134    let mut enums = vec![];
135    for enum_ in model.enums {
136        if is_reserved(&enum_.name) {
137            panic!("Reserved keyword used as enum name: {}", enum_.name);
138        }
139        enums.push(generate_enum(&enum_.mappings, &enum_.name))
140    }
141
142    // Combine all generated structs into a single TokenStream
143    let expanded = quote! {
144        use derive_builder::Builder;
145        use serde;
146        use schemars;
147
148        #(#structs)*
149        #(#enums)*
150    };
151
152    TokenStream::from(expanded)
153}
154
155/// Enumeration for data types
156enum DataTypes {
157    BaseType(syn::Type),
158    ComplexType(syn::Ident),
159}
160
161/// Function to get the data type from the type mappings
162///
163/// # Arguments
164/// * `dtype` - A string slice representing the data type
165///
166/// # Returns
167/// A Result containing either a DataTypes enum or an error
168fn get_data_type(dtype: &str) -> Result<DataTypes, Box<dyn Error>> {
169    match TYPE_MAPPINGS.get(dtype) {
170        Some(t) => {
171            let field_type: syn::Type = syn::parse_str(t)?;
172            Ok(DataTypes::BaseType(field_type))
173        }
174        None => {
175            let field_type: syn::Ident = syn::Ident::new(dtype, proc_macro2::Span::call_site());
176            Ok(DataTypes::ComplexType(field_type))
177        }
178    }
179}
180
181/// Function to wrap data types based on their properties (array, required)
182///
183/// # Arguments
184/// * `is_array` - A boolean indicating if the type is an array
185/// * `required` - A boolean indicating if the type is required
186/// * `dtype` - A DataTypes enum representing the data type
187///
188/// # Returns
189/// A TokenStream representing the wrapped data type
190fn wrap_dtype(is_array: bool, required: bool, dtype: DataTypes) -> proc_macro2::TokenStream {
191    match dtype {
192        DataTypes::BaseType(base_type) => {
193            if required && !is_array {
194                quote! { #base_type }
195            } else if !required && !is_array {
196                quote! { Option<#base_type> }
197            } else if required && is_array {
198                quote! { Vec<#base_type> }
199            } else {
200                quote! { Option<Vec<#base_type>> }
201            }
202        }
203        DataTypes::ComplexType(complex_type) => {
204            if required && !is_array {
205                quote! { #complex_type }
206            } else if !required && !is_array {
207                quote! { Option<#complex_type> }
208            } else {
209                quote! { Vec<#complex_type> }
210            }
211        }
212    }
213}
214
215/// Function to generate builder attributes for struct fields
216///
217/// # Arguments
218/// * `is_array` - A boolean indicating if the field is an array
219/// * `required` - A boolean indicating if the field is required
220/// * `name` - A string slice representing the field name
221///
222/// # Returns
223/// A TokenStream representing the builder attributes
224fn get_builder_attr(is_array: bool, required: bool, name: &str) -> proc_macro2::TokenStream {
225    let mut setter_args = vec![];
226
227    if !required {
228        setter_args.push(quote! { strip_option });
229    }
230
231    if is_array {
232        let add_name = syn::Ident::new(&format!("to_{}", name), proc_macro2::Span::call_site());
233        setter_args.push(quote! { each(name = #add_name, into) });
234    }
235
236    let setter_args = quote! { #(#setter_args),* };
237
238    quote! {
239        #[builder(default, setter(into, #setter_args))]
240    }
241}
242
243/// Function to generate serde attributes for struct fields
244///
245/// # Arguments
246/// * `is_array` - A boolean indicating if the field is an array
247/// * `required` - A boolean indicating if the field is required
248///
249/// # Returns
250/// A TokenStream representing the serde attributes
251fn get_serde_attr(is_array: bool, required: bool) -> proc_macro2::TokenStream {
252    if !required && !is_array {
253        quote! { #[serde(skip_serializing_if = "Option::is_none")] }
254    } else if is_array {
255        quote! { #[serde(default)] }
256    } else {
257        quote! {}
258    }
259}
260
261/// Function to generate Rust code for enums
262///
263/// # Arguments
264/// * `mappings` - A reference to a BTreeMap of enum variant mappings
265/// * `name` - A string slice representing the enum name
266///
267/// # Returns
268/// A TokenStream containing the generated enum code
269fn generate_enum(mappings: &BTreeMap<String, String>, name: &str) -> proc_macro2::TokenStream {
270    let enum_name = syn::Ident::new(name, proc_macro2::Span::call_site());
271    let mut variants = vec![];
272    let mut index = 0;
273
274    for (key, value) in mappings {
275        let variant_name = syn::Ident::new(&to_camel(key), proc_macro2::Span::call_site());
276        let variant_value = syn::LitStr::new(value, proc_macro2::Span::call_site());
277
278        if index == 0 {
279            variants.push(quote! {
280                #[default]
281                #[serde(rename = #variant_value)]
282                #variant_name
283            });
284            index += 1;
285        } else {
286            variants.push(quote! {
287                #[serde(rename = #variant_value)]
288                #variant_name
289            });
290        }
291    }
292
293    quote! {
294        #[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
295        pub enum #enum_name {
296            #(#variants),*
297        }
298    }
299}
300
301/// Checks if an object or enum name is a reserved keyword
302fn is_reserved(name: &str) -> bool {
303    FORBIDDEN_NAMES.contains(&name)
304}
305
306/// Function to convert a string to upper camel case
307fn to_camel(name: &str) -> String {
308    name.to_case(Case::UpperCamel)
309}