pywr_v1_schema_macros/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3
4/// A derive macro for Pywr nodes that implements `parameters` and `parameters_mut` methods.
5#[proc_macro_derive(PywrNode)]
6pub fn pywr_node_macro(input: TokenStream) -> TokenStream {
7 // Parse the input tokens into a syntax tree
8 let input = syn::parse_macro_input!(input as syn::DeriveInput);
9 impl_parameter_references_derive(&input)
10}
11
12/// A derive macro for Pywr parameters that implements `parameters`, `parameters_mut`,
13/// `resource_paths` and `update_resource_paths` methods.
14#[proc_macro_derive(PywrParameter)]
15pub fn pywr_parameter_macro(input: TokenStream) -> TokenStream {
16 // Parse the input tokens into a syntax tree
17 let input = syn::parse_macro_input!(input as syn::DeriveInput);
18
19 let mut expanded = impl_parameter_references_derive(&input);
20 expanded.extend(impl_parameter_resource_paths_derive(&input));
21
22 expanded
23}
24
25/// Generates a [`TokenStream`] containing the implementation of two methods, `parameters`
26/// and `parameters_mut`, for the given struct.
27///
28/// The `parameters` method returns a [`HashMap`] of parameter names to [`ParameterValueType`],
29/// and the `parameters_mut` method returns a [`HashMap`] of parameter names to [`ParameterValueTypeMut`].
30/// This is intended to be used for nodes and parameter structs in the Pywr schema.
31///
32/// Currently the implementation is limited to simple type definitions such as `Option<ParameterValue>` or `ParameterValue`.
33fn impl_parameter_references_derive(ast: &syn::DeriveInput) -> TokenStream {
34 // Name of the node type
35 let name = &ast.ident;
36
37 if let syn::Data::Struct(data) = &ast.data {
38 // Only apply this to structs
39
40 // Help struct for capturing parameter fields and whether they are optional.
41 struct ParamField {
42 field_name: syn::Ident,
43 optional: bool,
44 }
45
46 // Iterate through all fields of the struct. Try to find fields that reference
47 // parameters (e.g. `Option<ParameterValue>` or `ParameterValue`).
48 let parameter_fields: Vec<ParamField> = data
49 .fields
50 .iter()
51 .filter_map(|field| {
52 let field_ident = field.ident.as_ref()?;
53 // Identify optional fields
54 match type_to_ident(&field.ty) {
55 Some(PywrField::Optional(ident)) => {
56 // If optional and a parameter identifier then add to the list
57 is_parameter_ident(&ident).then_some(ParamField {
58 field_name: field_ident.clone(),
59 optional: true,
60 })
61 }
62 Some(PywrField::Required(ident)) => {
63 // Otherwise, if a parameter identifier then add to the list
64 is_parameter_ident(&ident).then_some(ParamField {
65 field_name: field_ident.clone(),
66 optional: false,
67 })
68 }
69 None => None, // All other fields are ignored.
70 }
71 })
72 .collect();
73
74 // Insert statements for non-mutable version
75 let inserts = parameter_fields
76 .iter()
77 .map(|param_field| {
78 let ident = ¶m_field.field_name;
79 let key = ident.to_string();
80 if param_field.optional {
81 quote! {
82 if let Some(p) = &self.#ident {
83 attributes.insert(#key, p.into());
84 }
85 }
86 } else {
87 quote! {
88 let #ident = &self.#ident;
89 attributes.insert(#key, #ident.into());
90 }
91 }
92 })
93 .collect::<Vec<_>>();
94
95 // Insert statements for mutable version
96 let inserts_mut = parameter_fields
97 .iter()
98 .map(|param_field| {
99 let ident = ¶m_field.field_name;
100 let key = ident.to_string();
101 if param_field.optional {
102 quote! {
103 if let Some(p) = &mut self.#ident {
104 attributes.insert(#key, p.into());
105 }
106 }
107 } else {
108 quote! {
109 let #ident = &mut self.#ident;
110 attributes.insert(#key, #ident.into());
111 }
112 }
113 })
114 .collect::<Vec<_>>();
115
116 // Create the two parameter methods using the insert statements
117 let expanded = quote! {
118 impl #name {
119 pub fn parameters(&self) -> HashMap<&str, ParameterValueType> {
120 let mut attributes = HashMap::new();
121 #(
122 #inserts
123 )*
124 attributes
125 }
126
127 pub fn parameters_mut(&mut self) -> HashMap<&str, ParameterValueTypeMut> {
128 let mut attributes = HashMap::new();
129 #(
130 #inserts_mut
131 )*
132 attributes
133 }
134 }
135 };
136
137 // Hand the output tokens back to the compiler.
138 TokenStream::from(expanded)
139 } else {
140 panic!("Only structs are supported for #[derive(PywrNode)] or #[derive(PywrParameter)]")
141 }
142}
143
144/// Generates a [`TokenStream`] containing the implementation `resource_paths`
145/// and `update_resource_paths` methods.
146fn impl_parameter_resource_paths_derive(ast: &syn::DeriveInput) -> TokenStream {
147 // Name of the node type
148 let name = &ast.ident;
149
150 if let syn::Data::Struct(data) = &ast.data {
151 // Helper struct to capture PathBuf fields
152 struct PathField {
153 field_name: syn::Ident,
154 ty: PathFieldType,
155 optional: bool,
156 }
157
158 let path_fields: Vec<PathField> = data
159 .fields
160 .iter()
161 .filter_map(|field| {
162 let field_ident = field.ident.as_ref()?;
163
164 // Identify optional fields
165 match type_to_ident(&field.ty) {
166 Some(PywrField::Optional(ident)) => {
167 // If optional and a path identifier then add to the list
168 ident_to_path_type(&ident).map(|field_type| PathField {
169 field_name: field_ident.clone(),
170 ty: field_type,
171 optional: true,
172 })
173 }
174 Some(PywrField::Required(ident)) => {
175 // If required, and a path identifier then add to the list
176 ident_to_path_type(&ident).map(|field_type| PathField {
177 field_name: field_ident.clone(),
178 ty: field_type,
179 optional: false,
180 })
181 }
182 None => None, // All other field types are ignored
183 }
184 })
185 .collect();
186
187 // Insert statements for non-mutable version
188 let inserts = path_fields
189 .iter()
190 .map(|param_field| {
191 let ident = ¶m_field.field_name;
192
193 match ¶m_field.ty {
194 PathFieldType::ExternalDataRef => {
195 if param_field.optional {
196 quote! {
197 if let Some(external) = &self.#ident {
198 resource_paths.push(external.url.clone());
199 }
200 }
201 } else {
202 quote! {
203 resource_paths.push(self.#ident.url.clone());
204 }
205 }
206 }
207 PathFieldType::PathBuf => {
208 if param_field.optional {
209 quote! {
210 if let Some(p) = &self.#ident {
211 resource_paths.push(p.clone());
212 }
213 }
214 } else {
215 quote! {
216 resource_paths.push(self.#ident.clone());
217 }
218 }
219 }
220 }
221 })
222 .collect::<Vec<_>>();
223
224 // Update statements for the `update_resource_paths` method
225 let updates = path_fields
226 .iter()
227 .map(|param_field| {
228 let ident = ¶m_field.field_name;
229
230 match ¶m_field.ty {
231 PathFieldType::ExternalDataRef => {
232 if param_field.optional {
233 quote! {
234 if let Some(external) = &mut self.#ident {
235 if let Some(new_path) = new_paths.get(&external.url) {
236 external.url = new_path.clone();
237 }
238 }
239 }
240 } else {
241 quote! {
242 if let Some(new_path) = new_paths.get(&self.#ident.url) {
243 self.#ident.url = new_path.clone();
244 }
245 }
246 }
247 }
248 PathFieldType::PathBuf => {
249 if param_field.optional {
250 quote! {
251 if let Some(path) = &mut self.#ident {
252 if let Some(new_path) = new_paths.get(path) {
253 *path = new_path.clone();
254 }
255 }
256 }
257 } else {
258 quote! {
259 if let Some(new_path) = new_paths.get(&self.#ident) {
260 self.#ident = new_path.clone();
261 }
262 }
263 }
264 }
265 }
266 })
267 .collect::<Vec<_>>();
268
269 // Create the two parameter methods using the insert statements
270 let expanded = quote! {
271 impl #name {
272 pub fn resource_paths(&self) -> Vec<PathBuf> {
273 let mut resource_paths = Vec::new();
274 #(
275 #inserts
276 )*
277 resource_paths
278 }
279
280 pub fn update_resource_paths(&mut self, new_paths: &HashMap<PathBuf, PathBuf>) {
281 #(
282 #updates
283 )*
284 }
285 }
286 };
287
288 // Hand the output tokens back to the compiler.
289 TokenStream::from(expanded)
290 } else {
291 panic!("Only structs are supported for #[derive(PywrNode)] or #[derive(PywrParameter)]")
292 }
293}
294
295enum PywrField {
296 Optional(syn::Ident),
297 Required(syn::Ident),
298}
299
300/// Returns the last segment of a type path as an identifier
301fn type_to_ident(ty: &syn::Type) -> Option<PywrField> {
302 match ty {
303 // Match type's that are a path and not a self type.
304 syn::Type::Path(type_path) if type_path.qself.is_none() => {
305 // Match on the last segment
306 match type_path.path.segments.last() {
307 Some(last_segment) => {
308 let ident = &last_segment.ident;
309
310 if ident == "Option" {
311 // The last segment is an Option, now we need to parse the argument
312 // I.e. the bit in inside the angle brackets.
313 let first_arg = match &last_segment.arguments {
314 syn::PathArguments::AngleBracketed(params) => params.args.first(),
315 _ => None,
316 };
317
318 // Find type arguments; ignore others
319 let arg_ty = match first_arg {
320 Some(syn::GenericArgument::Type(ty)) => Some(ty),
321 _ => None,
322 };
323
324 // Match on path types that are no self types.
325 let arg_type_path = match arg_ty {
326 Some(ty) => match ty {
327 syn::Type::Path(type_path) if type_path.qself.is_none() => {
328 Some(type_path)
329 }
330 _ => None,
331 },
332 None => None,
333 };
334
335 // Get the last segment of the path
336 let last_segment = match arg_type_path {
337 Some(type_path) => type_path.path.segments.last(),
338 None => None,
339 };
340
341 // Finally, if there's a last segment return this as an optional `PywrField`
342 match last_segment {
343 Some(last_segment) => {
344 let ident = &last_segment.ident;
345 Some(PywrField::Optional(ident.clone()))
346 }
347 None => None,
348 }
349 } else {
350 // Otherwise, assume this a simple required field
351 Some(PywrField::Required(ident.clone()))
352 }
353 }
354 None => None,
355 }
356 }
357 _ => None,
358 }
359}
360
361fn is_parameter_ident(ident: &syn::Ident) -> bool {
362 (ident == "ParameterValue") || (ident == "ParameterValues")
363}
364
365enum PathFieldType {
366 ExternalDataRef,
367 PathBuf,
368}
369
370fn ident_to_path_type(ident: &syn::Ident) -> Option<PathFieldType> {
371 if ident == "ExternalDataRef" {
372 Some(PathFieldType::ExternalDataRef)
373 } else if ident == "PathBuf" {
374 Some(PathFieldType::PathBuf)
375 } else {
376 None
377 }
378}