pywr_v1_schema_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3
4/// A derive macro for Pywr nodes that implements `parameters` and `parameters_mut` methods.
5#[proc_macro_derive(PywrNode)]
6pub fn pywr_node_macro(input: TokenStream) -> TokenStream {
7    // Parse the input tokens into a syntax tree
8    let input = syn::parse_macro_input!(input as syn::DeriveInput);
9    impl_parameter_references_derive(&input)
10}
11
12/// A derive macro for Pywr parameters that implements `parameters`, `parameters_mut`,
13/// `resource_paths` and `update_resource_paths` methods.
14#[proc_macro_derive(PywrParameter)]
15pub fn pywr_parameter_macro(input: TokenStream) -> TokenStream {
16    // Parse the input tokens into a syntax tree
17    let input = syn::parse_macro_input!(input as syn::DeriveInput);
18
19    let mut expanded = impl_parameter_references_derive(&input);
20    expanded.extend(impl_parameter_resource_paths_derive(&input));
21
22    expanded
23}
24
25/// Generates a [`TokenStream`] containing the implementation of two methods, `parameters`
26/// and `parameters_mut`, for the given struct.
27///
28/// The `parameters` method returns a [`HashMap`] of parameter names to [`ParameterValueType`],
29/// and the `parameters_mut` method returns a [`HashMap`] of parameter names to [`ParameterValueTypeMut`].
30/// This is intended to be used for nodes and parameter structs in the Pywr schema.
31///
32/// Currently the implementation is limited to simple type definitions such as `Option<ParameterValue>` or `ParameterValue`.
33fn impl_parameter_references_derive(ast: &syn::DeriveInput) -> TokenStream {
34    // Name of the node type
35    let name = &ast.ident;
36
37    if let syn::Data::Struct(data) = &ast.data {
38        // Only apply this to structs
39
40        // Help struct for capturing parameter fields and whether they are optional.
41        struct ParamField {
42            field_name: syn::Ident,
43            optional: bool,
44        }
45
46        // Iterate through all fields of the struct. Try to find fields that reference
47        // parameters (e.g. `Option<ParameterValue>` or `ParameterValue`).
48        let parameter_fields: Vec<ParamField> = data
49            .fields
50            .iter()
51            .filter_map(|field| {
52                let field_ident = field.ident.as_ref()?;
53                // Identify optional fields
54                match type_to_ident(&field.ty) {
55                    Some(PywrField::Optional(ident)) => {
56                        // If optional and a parameter identifier then add to the list
57                        is_parameter_ident(&ident).then_some(ParamField {
58                            field_name: field_ident.clone(),
59                            optional: true,
60                        })
61                    }
62                    Some(PywrField::Required(ident)) => {
63                        // Otherwise, if a parameter identifier then add to the list
64                        is_parameter_ident(&ident).then_some(ParamField {
65                            field_name: field_ident.clone(),
66                            optional: false,
67                        })
68                    }
69                    None => None, // All other fields are ignored.
70                }
71            })
72            .collect();
73
74        // Insert statements for non-mutable version
75        let inserts = parameter_fields
76            .iter()
77            .map(|param_field| {
78                let ident = &param_field.field_name;
79                let key = ident.to_string();
80                if param_field.optional {
81                    quote! {
82                        if let Some(p) = &self.#ident {
83                            attributes.insert(#key, p.into());
84                        }
85                    }
86                } else {
87                    quote! {
88                        let #ident = &self.#ident;
89                        attributes.insert(#key, #ident.into());
90                    }
91                }
92            })
93            .collect::<Vec<_>>();
94
95        // Insert statements for mutable version
96        let inserts_mut = parameter_fields
97            .iter()
98            .map(|param_field| {
99                let ident = &param_field.field_name;
100                let key = ident.to_string();
101                if param_field.optional {
102                    quote! {
103                        if let Some(p) = &mut self.#ident {
104                            attributes.insert(#key, p.into());
105                        }
106                    }
107                } else {
108                    quote! {
109                        let #ident = &mut self.#ident;
110                        attributes.insert(#key, #ident.into());
111                    }
112                }
113            })
114            .collect::<Vec<_>>();
115
116        // Create the two parameter methods using the insert statements
117        let expanded = quote! {
118            impl #name {
119                pub fn parameters(&self) -> HashMap<&str, ParameterValueType> {
120                    let mut attributes = HashMap::new();
121                    #(
122                        #inserts
123                    )*
124                    attributes
125                }
126
127                pub fn parameters_mut(&mut self) -> HashMap<&str, ParameterValueTypeMut> {
128                    let mut attributes = HashMap::new();
129                    #(
130                        #inserts_mut
131                    )*
132                    attributes
133                }
134            }
135        };
136
137        // Hand the output tokens back to the compiler.
138        TokenStream::from(expanded)
139    } else {
140        panic!("Only structs are supported for #[derive(PywrNode)] or #[derive(PywrParameter)]")
141    }
142}
143
144/// Generates a [`TokenStream`] containing the implementation `resource_paths`
145/// and `update_resource_paths` methods.
146fn impl_parameter_resource_paths_derive(ast: &syn::DeriveInput) -> TokenStream {
147    // Name of the node type
148    let name = &ast.ident;
149
150    if let syn::Data::Struct(data) = &ast.data {
151        // Helper struct to capture PathBuf fields
152        struct PathField {
153            field_name: syn::Ident,
154            ty: PathFieldType,
155            optional: bool,
156        }
157
158        let path_fields: Vec<PathField> = data
159            .fields
160            .iter()
161            .filter_map(|field| {
162                let field_ident = field.ident.as_ref()?;
163
164                // Identify optional fields
165                match type_to_ident(&field.ty) {
166                    Some(PywrField::Optional(ident)) => {
167                        // If optional and a path identifier then add to the list
168                        ident_to_path_type(&ident).map(|field_type| PathField {
169                            field_name: field_ident.clone(),
170                            ty: field_type,
171                            optional: true,
172                        })
173                    }
174                    Some(PywrField::Required(ident)) => {
175                        // If required, and a path identifier then add to the list
176                        ident_to_path_type(&ident).map(|field_type| PathField {
177                            field_name: field_ident.clone(),
178                            ty: field_type,
179                            optional: false,
180                        })
181                    }
182                    None => None, // All other field types are ignored
183                }
184            })
185            .collect();
186
187        // Insert statements for non-mutable version
188        let inserts = path_fields
189            .iter()
190            .map(|param_field| {
191                let ident = &param_field.field_name;
192
193                match &param_field.ty {
194                    PathFieldType::ExternalDataRef => {
195                        if param_field.optional {
196                            quote! {
197                                if let Some(external) = &self.#ident {
198                                    resource_paths.push(external.url.clone());
199                                }
200                            }
201                        } else {
202                            quote! {
203                                resource_paths.push(self.#ident.url.clone());
204                            }
205                        }
206                    }
207                    PathFieldType::PathBuf => {
208                        if param_field.optional {
209                            quote! {
210                                if let Some(p) = &self.#ident {
211                                    resource_paths.push(p.clone());
212                                }
213                            }
214                        } else {
215                            quote! {
216                                resource_paths.push(self.#ident.clone());
217                            }
218                        }
219                    }
220                }
221            })
222            .collect::<Vec<_>>();
223
224        // Update statements for the `update_resource_paths` method
225        let updates = path_fields
226            .iter()
227            .map(|param_field| {
228                let ident = &param_field.field_name;
229
230                match &param_field.ty {
231                    PathFieldType::ExternalDataRef => {
232                        if param_field.optional {
233                            quote! {
234                                if let Some(external) = &mut self.#ident {
235                                    if let Some(new_path) = new_paths.get(&external.url) {
236                                        external.url = new_path.clone();
237                                    }
238                                }
239                            }
240                        } else {
241                            quote! {
242                                if let Some(new_path) = new_paths.get(&self.#ident.url) {
243                                    self.#ident.url = new_path.clone();
244                                }
245                            }
246                        }
247                    }
248                    PathFieldType::PathBuf => {
249                        if param_field.optional {
250                            quote! {
251                                if let Some(path) = &mut self.#ident {
252                                    if let Some(new_path) = new_paths.get(path) {
253                                        *path = new_path.clone();
254                                    }
255                                }
256                            }
257                        } else {
258                            quote! {
259                                if let Some(new_path) = new_paths.get(&self.#ident) {
260                                    self.#ident = new_path.clone();
261                                }
262                            }
263                        }
264                    }
265                }
266            })
267            .collect::<Vec<_>>();
268
269        // Create the two parameter methods using the insert statements
270        let expanded = quote! {
271            impl #name {
272                pub fn resource_paths(&self) -> Vec<PathBuf> {
273                    let mut resource_paths = Vec::new();
274                    #(
275                        #inserts
276                    )*
277                    resource_paths
278                }
279
280                pub fn update_resource_paths(&mut self, new_paths: &HashMap<PathBuf, PathBuf>) {
281                    #(
282                        #updates
283                    )*
284                }
285            }
286        };
287
288        // Hand the output tokens back to the compiler.
289        TokenStream::from(expanded)
290    } else {
291        panic!("Only structs are supported for #[derive(PywrNode)] or #[derive(PywrParameter)]")
292    }
293}
294
295enum PywrField {
296    Optional(syn::Ident),
297    Required(syn::Ident),
298}
299
300/// Returns the last segment of a type path as an identifier
301fn type_to_ident(ty: &syn::Type) -> Option<PywrField> {
302    match ty {
303        // Match type's that are a path and not a self type.
304        syn::Type::Path(type_path) if type_path.qself.is_none() => {
305            // Match on the last segment
306            match type_path.path.segments.last() {
307                Some(last_segment) => {
308                    let ident = &last_segment.ident;
309
310                    if ident == "Option" {
311                        // The last segment is an Option, now we need to parse the argument
312                        // I.e. the bit in inside the angle brackets.
313                        let first_arg = match &last_segment.arguments {
314                            syn::PathArguments::AngleBracketed(params) => params.args.first(),
315                            _ => None,
316                        };
317
318                        // Find type arguments; ignore others
319                        let arg_ty = match first_arg {
320                            Some(syn::GenericArgument::Type(ty)) => Some(ty),
321                            _ => None,
322                        };
323
324                        // Match on path types that are no self types.
325                        let arg_type_path = match arg_ty {
326                            Some(ty) => match ty {
327                                syn::Type::Path(type_path) if type_path.qself.is_none() => {
328                                    Some(type_path)
329                                }
330                                _ => None,
331                            },
332                            None => None,
333                        };
334
335                        // Get the last segment of the path
336                        let last_segment = match arg_type_path {
337                            Some(type_path) => type_path.path.segments.last(),
338                            None => None,
339                        };
340
341                        // Finally, if there's a last segment return this as an optional `PywrField`
342                        match last_segment {
343                            Some(last_segment) => {
344                                let ident = &last_segment.ident;
345                                Some(PywrField::Optional(ident.clone()))
346                            }
347                            None => None,
348                        }
349                    } else {
350                        // Otherwise, assume this a simple required field
351                        Some(PywrField::Required(ident.clone()))
352                    }
353                }
354                None => None,
355            }
356        }
357        _ => None,
358    }
359}
360
361fn is_parameter_ident(ident: &syn::Ident) -> bool {
362    (ident == "ParameterValue") || (ident == "ParameterValues")
363}
364
365enum PathFieldType {
366    ExternalDataRef,
367    PathBuf,
368}
369
370fn ident_to_path_type(ident: &syn::Ident) -> Option<PathFieldType> {
371    if ident == "ExternalDataRef" {
372        Some(PathFieldType::ExternalDataRef)
373    } else if ident == "PathBuf" {
374        Some(PathFieldType::PathBuf)
375    } else {
376        None
377    }
378}