boomerang_derive/reaction/
mod.rs

1use std::{collections::HashMap, hash::Hash};
2
3use darling::{
4    ast::{self},
5    util, FromDeriveInput, FromField, FromMeta,
6};
7use quote::{quote, ToTokens};
8use syn::{Expr, GenericParam, Generics, Ident, Type};
9
10mod reaction_field_inner;
11mod trigger_inner;
12
13use reaction_field_inner::ReactionFieldInner;
14use trigger_inner::TriggerInner;
15
16const INPUT_REF: &str = "InputRef";
17const OUTPUT_REF: &str = "OutputRef";
18const ACTION: &str = "Action";
19const ACTION_REF: &str = "ActionRef";
20const PHYSICAL_ACTION_REF: &str = "PhysicalActionRef";
21
22#[derive(Debug, Eq, PartialEq, Hash)]
23pub enum TriggerAttr {
24    Startup,
25    Shutdown,
26    Action(Expr),
27    Port(Expr),
28}
29
30impl FromMeta for TriggerAttr {
31    fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
32        (match *item {
33            syn::Meta::Path(ref path) => path.segments.first().map_or_else(
34                || Err(darling::Error::unsupported_shape("something wierd")),
35                |path| match path.ident.to_string().as_ref() {
36                    "startup" => Ok(TriggerAttr::Startup),
37                    "shutdown" => Ok(TriggerAttr::Shutdown),
38                    __other => Err(darling::Error::unknown_field_with_alts(
39                        __other,
40                        &["startup", "shutdown"],
41                    )
42                    .with_span(&path.ident)),
43                },
44            ),
45            syn::Meta::List(ref value) => {
46                let meta: syn::Meta = syn::parse2(value.tokens.clone())?;
47                Self::from_meta(&meta)
48            }
49            syn::Meta::NameValue(ref value) => value
50                .path
51                .segments
52                .first()
53                .map(|path| match path.ident.to_string().as_ref() {
54                    "action" => {
55                        let value = darling::FromMeta::from_expr(&value.value)?;
56                        Ok(TriggerAttr::Action(value))
57                    }
58                    "port" => {
59                        let value = darling::FromMeta::from_expr(&value.value)?;
60                        Ok(TriggerAttr::Port(value))
61                    }
62                    __other => Err(darling::Error::unknown_field_with_alts(
63                        __other,
64                        &["action", "timer", "port"],
65                    )
66                    .with_span(&path.ident)),
67                })
68                .expect("oopsie"),
69        })
70        .map_err(|e| e.with_span(item))
71    }
72
73    fn from_string(value: &str) -> darling::Result<Self> {
74        let value = darling::FromMeta::from_string(value)?;
75        Ok(TriggerAttr::Port(value))
76    }
77}
78
79#[derive(Clone, Debug, FromField)]
80#[darling(attributes(reaction), forward_attrs(doc, cfg, allow))]
81pub struct ReactionField {
82    ident: Option<Ident>,
83    ty: Type,
84    triggers: Option<bool>,
85    effects: Option<bool>,
86    uses: Option<bool>,
87    path: Option<Expr>,
88}
89
90fn parse_bound(item: &syn::Meta) -> Result<syn::GenericParam, darling::Error> {
91    match item {
92        syn::Meta::NameValue(syn::MetaNameValue { value, .. }) => match value {
93            syn::Expr::Lit(syn::ExprLit {
94                lit: syn::Lit::Str(lit_str),
95                ..
96            }) => syn::parse_str(lit_str.value().as_str())
97                .map_err(|e| darling::Error::custom(format!("Failed to parse bound: {}", e))),
98
99            _ => Err(darling::Error::unsupported_shape(
100                "Only string literals are supported",
101            )),
102        },
103        _ => Err(darling::Error::unsupported_shape(
104            "Only name value pairs are supported",
105        )),
106    }
107}
108
109#[derive(Debug, FromDeriveInput)]
110#[darling(attributes(reaction), supports(struct_named, struct_unit))]
111pub struct ReactionReceiver {
112    ident: Ident,
113    generics: Generics,
114    data: ast::Data<util::Ignored, ReactionField>,
115
116    /// Type of the reactor
117    reactor: syn::Type,
118
119    #[darling(default, multiple, rename = "bound", with = "parse_bound")]
120    bounds: Vec<syn::GenericParam>,
121
122    /// Connection definitions
123    #[darling(default, multiple)]
124    triggers: Vec<TriggerAttr>,
125}
126
127pub struct Reaction {
128    ident: Ident,
129    generics: Generics,
130    combined_generics: Generics,
131    reactor: Type,
132    fields: Vec<ReactionFieldInner>,
133    inner: TriggerInner,
134    /// Whether the reaction has a startup trigger
135    trigger_startup: bool,
136    /// Whether the reaction has a shutdown trigger
137    trigger_shutdown: bool,
138}
139
140impl TryFrom<ReactionReceiver> for Reaction {
141    type Error = darling::Error;
142
143    fn try_from(value: ReactionReceiver) -> Result<Self, Self::Error> {
144        // Combine the bounds with the generics
145        let mut combined_generics = value.generics.clone();
146        combined_generics
147            .params
148            .extend(value.bounds.iter().cloned().map(GenericParam::from));
149
150        let inner = TriggerInner::new(&value, &combined_generics)?;
151
152        let fields = value
153            .data
154            .take_struct()
155            .ok_or(darling::Error::unsupported_shape(
156                "Only structs are supported",
157            ))?;
158
159        let inner_fields: Vec<ReactionFieldInner> = fields
160            .into_iter()
161            .map(TryFrom::try_from)
162            .collect::<Result<_, _>>()?;
163
164        let mut fields_map: HashMap<_, (usize, ReactionFieldInner)> = inner_fields
165            .into_iter()
166            .enumerate()
167            .map(|(idx, mut field)| {
168                if let ReactionFieldInner::FieldDefined {
169                    ref mut uses,
170                    triggers,
171                    path,
172                    ..
173                } = &mut field
174                {
175                    // If the field is a trigger, then it implies use
176                    if *triggers {
177                        *uses = true;
178                    }
179                    (path.clone(), (idx, field))
180                } else {
181                    panic!("Unexpected reaction field");
182                }
183            })
184            .collect();
185
186        let mut last_idx = fields_map.len();
187
188        // Update/apply the struct_fields with any triggers clauses
189        for trigger in value.triggers.iter() {
190            match trigger {
191                TriggerAttr::Action(path) => {
192                    fields_map
193                        .entry(path.clone())
194                        .and_modify(|(_idx, field)| {
195                            if let ReactionFieldInner::FieldDefined {
196                                ref mut triggers, ..
197                            } = field
198                            {
199                                *triggers = true;
200                            } else {
201                                panic!("Trigger action path already used");
202                            }
203                        })
204                        .or_insert_with(|| {
205                            last_idx += 1;
206                            (
207                                last_idx,
208                                ReactionFieldInner::TriggerAction {
209                                    action: path.clone(),
210                                },
211                            )
212                        });
213                }
214
215                TriggerAttr::Port(path) => {
216                    fields_map
217                        .entry(path.clone())
218                        .and_modify(|(_idx, field)| {
219                            if let ReactionFieldInner::FieldDefined {
220                                ref mut triggers, ..
221                            } = field
222                            {
223                                *triggers = true;
224                            } else {
225                                panic!("Trigger port path already used");
226                            }
227                        })
228                        .or_insert_with(|| {
229                            last_idx += 1;
230                            (
231                                last_idx,
232                                ReactionFieldInner::TriggerPort { port: path.clone() },
233                            )
234                        });
235                }
236
237                _ => {}
238            }
239        }
240
241        let trigger_startup = value
242            .triggers
243            .iter()
244            .any(|t| matches!(t, TriggerAttr::Startup));
245        let trigger_shutdown = value
246            .triggers
247            .iter()
248            .any(|t| matches!(t, TriggerAttr::Shutdown));
249
250        let mut idx_fields: Vec<_> = fields_map.into_values().collect();
251        idx_fields.sort_by_key(|(idx, _)| *idx);
252        let fields = idx_fields.into_iter().map(|(_, field)| field).collect();
253
254        Ok(Self {
255            ident: value.ident,
256            generics: value.generics,
257            combined_generics,
258            reactor: value.reactor,
259            fields,
260            inner,
261            trigger_startup,
262            trigger_shutdown,
263        })
264    }
265}
266
267impl ToTokens for Reaction {
268    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
269        let ident = &self.ident;
270        let reactor = &self.reactor;
271        let struct_fields = &self.fields;
272        let trigger_inner = &self.inner;
273
274        // We use impl_generics from `combined_generics` to allow additional bounds to be added, but type and where come
275        // from the original generics
276        let (impl_generics, _, _) = self.combined_generics.split_for_impl();
277        let (_, type_generics, where_clause) = self.generics.split_for_impl();
278        let inner_type_generics = {
279            let g = self
280                .combined_generics
281                .const_params()
282                .map(|ty| &ty.ident)
283                .chain(self.combined_generics.type_params().map(|ty| &ty.ident));
284            quote! { ::<#(#g),*> }
285        };
286
287        let trigger_startup = if self.trigger_startup {
288            quote! {
289                let mut __reaction = __reaction.with_action(
290                    __startup_action,
291                    0,
292                    ::boomerang::builder::TriggerMode::TriggersOnly
293                )?;
294            }
295        } else {
296            quote! {}
297        };
298
299        let trigger_shutdown = if self.trigger_shutdown {
300            quote! {
301                let mut __reaction = __reaction.with_action(
302                    __shutdown_action,
303                    0,
304                    ::boomerang::builder::TriggerMode::TriggersOnly
305                )?;
306            }
307        } else {
308            quote! {}
309        };
310
311        tokens.extend(quote! {
312            #[automatically_derived]
313            impl #impl_generics ::boomerang::builder::Reaction<#reactor> for #ident #type_generics #where_clause {
314                fn build<'builder>(
315                    name: &str,
316                    reactor: &#reactor,
317                    builder: &'builder mut ::boomerang::builder::ReactorBuilderState,
318                ) -> Result<
319                    ::boomerang::builder::ReactionBuilderState<'builder>,
320                    ::boomerang::builder::BuilderError
321                >
322                {
323                    #trigger_inner
324                    let __startup_action = builder.get_startup_action();
325                let __shutdown_action = builder.get_shutdown_action();
326                    let mut __reaction = builder.add_reaction(
327                        name,
328                        Box::new(__trigger_inner #inner_type_generics)
329                    );
330
331                    #trigger_startup
332                    #trigger_shutdown
333                    #(#struct_fields;)*
334                    Ok(__reaction)
335                }
336            }
337        });
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use syn::{parse_quote, DeriveInput};
344
345    use super::*;
346
347    #[test]
348    fn test_struct_attrs() {
349        let input = r#"
350#[derive(Reaction)]
351#[reaction(
352    reactor = "Inner::Count<T>",
353    bound = "T: runtime::PortData",
354    bound = "const N: usize",
355    triggers(action = "x"),
356    triggers(port = "child.y"),
357    triggers(startup),
358    triggers(shutdown),
359)]
360struct ReactionT;"#;
361        let parsed: DeriveInput = syn::parse_str(input).unwrap();
362        let receiver = ReactionReceiver::from_derive_input(&parsed).unwrap();
363        assert_eq!(receiver.reactor, parse_quote! {Inner::Count<T>});
364        assert_eq!(
365            receiver.bounds,
366            vec![
367                parse_quote! {T: runtime::PortData},
368                parse_quote! {const N: usize}
369            ]
370        );
371        assert_eq!(
372            receiver.triggers.iter().collect::<Vec<_>>(),
373            vec![
374                &TriggerAttr::Action(parse_quote! {x}),
375                &TriggerAttr::Port(parse_quote! {child.y}),
376                &TriggerAttr::Startup,
377                &TriggerAttr::Shutdown
378            ]
379        );
380    }
381
382    #[test]
383    fn test_port_fields() {
384        let input = r#"
385#[derive(Reaction)]
386#[reaction(reactor = "Foo")]
387struct ReactionT<'a> {
388    ref_port: runtime::InputRef<'a, ()>,
389    mut_port: runtime::OutputRef<'a, ()>,
390    #[reaction(uses)]
391    uses_only_port: runtime::InputRef<'a, ()>,
392    #[reaction(path = "child.y.z")]
393    renamed_port: runtime::OutputRef<'a, u32>,
394}"#;
395
396        let parsed = syn::parse_str(input).unwrap();
397        let receiver = ReactionReceiver::from_derive_input(&parsed).unwrap();
398        let reaction = Reaction::try_from(receiver).unwrap();
399        assert_eq!(
400            reaction.fields[0],
401            ReactionFieldInner::FieldDefined {
402                elem: parse_quote! {runtime::InputRef<'a, ()>},
403                triggers: true,
404                effects: false,
405                uses: true,
406                path: parse_quote! {ref_port},
407            },
408        );
409        assert_eq!(
410            reaction.fields[1],
411            ReactionFieldInner::FieldDefined {
412                elem: parse_quote! {runtime::OutputRef<'a, ()>},
413                triggers: false,
414                effects: true,
415                uses: false,
416                path: parse_quote! {mut_port},
417            },
418        );
419        assert_eq!(
420            reaction.fields[2],
421            ReactionFieldInner::FieldDefined {
422                elem: parse_quote! {runtime::InputRef<'a, ()>},
423                triggers: false,
424                effects: false,
425                uses: true,
426                path: parse_quote! {uses_only_port},
427            },
428        );
429        assert_eq!(
430            reaction.fields[3],
431            ReactionFieldInner::FieldDefined {
432                elem: parse_quote! {runtime::OutputRef<'a, u32>},
433                triggers: false,
434                effects: true,
435                uses: false,
436                path: parse_quote! {child.y.z},
437            }
438        );
439    }
440
441    #[test]
442    fn test_action_fields() {
443        let input = r#"
444#[derive(Reaction)]
445#[reaction(reactor = "Foo")]
446struct ReactionT<'a> {
447    #[reaction(triggers)]
448    raw_action: &'a runtime::Action,
449}"#;
450        let parsed = syn::parse_str(input).unwrap();
451        let receiver = ReactionReceiver::from_derive_input(&parsed).unwrap();
452        let reaction = Reaction::try_from(receiver).unwrap();
453        assert_eq!(
454            reaction.fields[0],
455            ReactionFieldInner::FieldDefined {
456                elem: parse_quote! {runtime::Action},
457                triggers: true,
458                effects: false,
459                uses: true,
460                path: parse_quote! {raw_action},
461            }
462        );
463    }
464}