kompact_component_derive/
lib.rs

1#![recursion_limit = "128"]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use syn::{parse_macro_input, DeriveInput};
8
9use std::{collections::HashMap, iter::Iterator};
10
11/// A macro to derive fair [ComponentDefinition](ComponentDefinition) implementations
12///
13/// Implementations will set up ports and the component context correctly, and
14/// during execution check ports in a fair round-robin manner.
15///
16/// Using this macro will also derive implementations of [ProvideRef](ProvideRef)
17/// or [RequireRef](RequireRef) for each declared port.
18#[proc_macro_derive(ComponentDefinition)]
19pub fn component_definition(input: TokenStream) -> TokenStream {
20    // Parse the input stream
21    let ast = parse_macro_input!(input as DeriveInput);
22
23    // Build the impl
24    let gen = impl_component_definition(&ast);
25
26    //println!("Derived code:\n{}", gen);
27
28    // Return the generated impl
29    gen.into()
30}
31
32type PortEntry<'a> = (&'a syn::Field, PortField);
33
34#[allow(clippy::map_entry)]
35fn impl_component_definition(ast: &syn::DeriveInput) -> TokenStream2 {
36    let name = &ast.ident;
37    let name_str = format!("{}", name);
38    if let syn::Data::Struct(ref vdata) = ast.data {
39        let generics = &ast.generics;
40        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42        let fields = &vdata.fields;
43        let mut ports: Vec<PortEntry> = Vec::new();
44        let mut ctx_field: Option<&syn::Field> = None;
45        for field in fields.iter() {
46            let cf = identify_field(field);
47            match cf {
48                ComponentField::Ctx => {
49                    ctx_field = Some(field);
50                }
51                ComponentField::Port(pf) => ports.push((field, pf)),
52                ComponentField::Other => (),
53            }
54        }
55        let (ctx_setup, ctx_access) = match ctx_field {
56            Some(f) => {
57                let id = &f.ident;
58                let setup = quote! { self.#id.initialise(self_component.clone()); };
59                let access = quote! { self.#id };
60                (setup, access)
61            }
62            None => panic!("No ComponentContext found for {:?}!", name),
63        };
64        let port_setup = ports
65            .iter()
66            .map(|&(f, _)| {
67                let id = &f.ident;
68                quote! { self.#id.set_parent(self_component.clone()); }
69            })
70            .collect::<Vec<_>>();
71        let port_handles_skip = ports
72            .iter()
73            .enumerate()
74            .map(|(i, &(f, ref t))| {
75                let id = &f.ident;
76                //let ref ty = f.ty;
77                let handle = t.as_handle();
78                quote! {
79                    if skip <= #i {
80                        if count >= max_events {
81                            return ExecuteResult::new(false, count, #i);
82                        }
83                        #[allow(unreachable_code)]
84                        { // can be Never type, which is ok, so just suppress this warning
85                            if let Some(event) = self.#id.dequeue() {
86                                let res = #handle
87                                count += 1;
88                                done_work = true;
89                                if let Handled::BlockOn(blocking_future) = res {
90                                    self.ctx_mut().set_blocking(blocking_future);
91                                    return ExecuteResult::new(true, count, #i);
92                                }
93                            }
94                        }
95                    }
96                }
97            })
98            .collect::<Vec<_>>();
99        let port_handles = ports
100            .iter()
101            .enumerate()
102            .map(|(i, &(f, ref t))| {
103                let id = &f.ident;
104                //let ref ty = f.ty;
105                let handle = t.as_handle();
106                quote! {
107                    if count >= max_events {
108                        return ExecuteResult::new(false, count, #i);
109                    }
110                    #[allow(unreachable_code)]
111                    { // can be Never type, which is ok, so just suppress this warning
112                        if let Some(event) = self.#id.dequeue() {
113                            let res = #handle
114                            count += 1;
115                            done_work = true;
116                            if let Handled::BlockOn(blocking_future) = res {
117                                self.ctx_mut().set_blocking(blocking_future);
118                                return ExecuteResult::new(true, count, #i);
119                            }
120                        }
121                    }
122                }
123            })
124            .collect::<Vec<_>>();
125        let exec = if port_handles.is_empty() {
126            quote! {
127                fn execute(&mut self, _max_events: usize, _skip: usize) -> ExecuteResult {
128                    ExecuteResult::new(false, 0, 0)
129                }
130            }
131        } else {
132            quote! {
133                fn execute(&mut self, max_events: usize, skip: usize) -> ExecuteResult {
134                    let mut count: usize = 0;
135                    let mut done_work = true; // might skip queues that have work
136                    #(#port_handles_skip)*
137                    while done_work {
138                        done_work = false;
139                        #(#port_handles)*
140                    }
141                    ExecuteResult::new(false, count, 0)
142                }
143            }
144        };
145
146        let mut provided_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
147        let mut provided_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
148        let mut required_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
149        let mut required_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
150        for port in ports {
151            let port_field = port.1.clone();
152            match port_field {
153                PortField::Required(ty) => {
154                    if let Some(port_list) = required_ports_non_unique.get_mut(&ty) {
155                        port_list.push(port);
156                    } else if required_ports_unique.contains_key(&ty) {
157                        let other_entry = required_ports_unique.remove(&ty).unwrap();
158                        required_ports_non_unique.insert(ty, vec![other_entry, port]);
159                    } else {
160                        required_ports_unique.insert(ty, port);
161                    }
162                }
163                PortField::Provided(ty) => {
164                    if let Some(port_list) = provided_ports_non_unique.get_mut(&ty) {
165                        port_list.push(port);
166                    } else if provided_ports_unique.contains_key(&ty) {
167                        let other_entry = provided_ports_unique.remove(&ty).unwrap();
168                        provided_ports_non_unique.insert(ty, vec![other_entry, port]);
169                    } else {
170                        provided_ports_unique.insert(ty, port);
171                    }
172                }
173            }
174        }
175
176        let generate_provided_ref_impl = |p: &PortEntry| {
177            let (field, port_field) = p;
178            let id = &field.ident;
179            let ty = port_field.port_type();
180            quote! {
181                impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
182                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
183                        self.#id.share()
184                    }
185                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
186                        self.#id.connect(req);
187                    }
188                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
189                        self.#id.disconnect_port(req);
190                    }
191                }
192            }
193        };
194        let generate_required_ref_impl = |p: &PortEntry| {
195            let (field, port_field) = p;
196            let id = &field.ident;
197            let ty = port_field.port_type();
198            quote! {
199                impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
200                    fn required_ref(&mut self) -> RequiredRef< #ty > {
201                        self.#id.share()
202                    }
203                    fn connect_to_provided(&mut self, prov: ProvidedRef< #ty >) -> () {
204                        self.#id.connect(prov);
205                    }
206                    fn disconnect(&mut self, prov: ProvidedRef< #ty >) -> () {
207                        self.#id.disconnect_port(prov);
208                    }
209                }
210            }
211        };
212        let generate_ambiguous_provided_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
213            let ids: Vec<String> = port_entries
214                .iter()
215                .map(|(field, _port_field)| {
216                    let id = field.ident.as_ref().unwrap();
217                    format!("{}", quote! {#id})
218                })
219                .collect();
220            let error_msg = format!("Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.", quote!{#ty}, ids);
221            quote! {
222                impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
223                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
224                        compile_error!(#error_msg);
225                    }
226                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
227                        compile_error!(#error_msg);
228                    }
229                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
230                        compile_error!(#error_msg);
231                    }
232                }
233            }
234        };
235        let generate_ambiguous_required_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
236            let ids: Vec<String> = port_entries
237                .iter()
238                .map(|(field, _port_field)| {
239                    let id = field.ident.as_ref().unwrap();
240                    format!("{}", quote! {#id})
241                })
242                .collect();
243            let error_msg = format!("Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.", quote!{#ty}, ids);
244            quote! {
245                impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
246                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
247                        compile_error!(#error_msg);
248                    }
249                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
250                        compile_error!(#error_msg);
251                    }
252                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
253                        compile_error!(#error_msg);
254                    }
255                }
256            }
257        };
258
259        let port_ref_impls = provided_ports_unique
260            .values()
261            .map(generate_provided_ref_impl)
262            .chain(
263                required_ports_unique
264                    .values()
265                    .map(generate_required_ref_impl),
266            )
267            .chain(
268                provided_ports_non_unique
269                    .iter()
270                    .map(|pair| generate_ambiguous_provided_ref_impl(pair.0, pair.1)),
271            )
272            .chain(
273                required_ports_non_unique
274                    .iter()
275                    .map(|pair| generate_ambiguous_required_ref_impl(pair.0, pair.1)),
276            )
277            .collect::<Vec<_>>();
278
279        fn make_match(f: &syn::Field, t: &syn::Type) -> TokenStream2 {
280            let f = &f.ident;
281            quote! {
282                id if id == ::std::any::TypeId::of::<#t>() =>
283                    Some(&mut self.#f as &mut dyn ::std::any::Any),
284            }
285        }
286
287        let provided_matches: Vec<_> = provided_ports_unique
288            .iter()
289            .map(|(t, p)| make_match(p.0, t))
290            .collect();
291
292        let required_matches: Vec<_> = required_ports_unique
293            .iter()
294            .map(|(t, p)| make_match(p.0, t))
295            .collect();
296
297        quote! {
298            impl #impl_generics ComponentDefinition for #name #ty_generics #where_clause {
299                fn setup(&mut self, self_component: ::std::sync::Arc<Component<Self>>) -> () {
300                    #ctx_setup
301                    //println!("Setting up ports");
302                    #(#port_setup)*
303                }
304                #exec
305                fn ctx_mut(&mut self) -> &mut ComponentContext<Self> {
306                    &mut #ctx_access
307                }
308                fn ctx(&self) -> &ComponentContext<Self> {
309                    &#ctx_access
310                }
311                fn type_name() -> &'static str {
312                    #name_str
313                }
314            }
315            impl #impl_generics DynamicPortAccess for #name #ty_generics #where_clause {
316                fn get_provided_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
317                    match port_id {
318                        #(#provided_matches)*
319                        _ => None,
320                    }
321                }
322
323                fn get_required_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
324                    match port_id {
325                        #(#required_matches)*
326                        _ => None,
327                    }
328                }
329            }
330            #(#port_ref_impls)*
331        }
332    } else {
333        //Nope. This is an Enum. We cannot handle these!
334        panic!("#[derive(ComponentDefinition)] is only defined for structs, not for enums!");
335    }
336}
337
338#[allow(clippy::large_enum_variant)]
339#[derive(Debug)]
340enum ComponentField {
341    Ctx,
342    Port(PortField),
343    Other,
344}
345
346#[derive(Debug, Clone)]
347enum PortField {
348    Required(syn::Type),
349    Provided(syn::Type),
350}
351
352impl PortField {
353    fn as_handle(&self) -> TokenStream2 {
354        match *self {
355            PortField::Provided(ref ty) => quote! { Provide::<#ty>::handle(self, event); },
356            PortField::Required(ref ty) => quote! { Require::<#ty>::handle(self, event); },
357        }
358    }
359
360    fn port_type(&self) -> &syn::Type {
361        match self {
362            PortField::Provided(ref ty) => ty,
363            PortField::Required(ref ty) => ty,
364        }
365    }
366}
367
368const REQP: &str = "RequiredPort";
369const PROVP: &str = "ProvidedPort";
370const CTX: &str = "ComponentContext";
371const KOMPICS: &str = "kompact";
372
373fn identify_field(f: &syn::Field) -> ComponentField {
374    if let syn::Type::Path(ref patht) = f.ty {
375        let path = &patht.path;
376        let port_seg_opt = if path.segments.len() == 1 {
377            Some(&path.segments[0])
378        } else if path.segments.len() == 2 {
379            if path.segments[0].ident == KOMPICS {
380                Some(&path.segments[1])
381            } else {
382                //println!("Module is not 'kompact': {:?}", path);
383                None
384            }
385        } else {
386            //println!("Path too long for port: {:?}", path);
387            None
388        };
389        if let Some(seg) = port_seg_opt {
390            if seg.ident == REQP {
391                ComponentField::Port(PortField::Required(extract_port_type(seg)))
392            } else if seg.ident == PROVP {
393                ComponentField::Port(PortField::Provided(extract_port_type(seg)))
394            } else if seg.ident == CTX {
395                ComponentField::Ctx
396            } else {
397                //println!("Not a port: {:?}", path);
398                ComponentField::Other
399            }
400        } else {
401            ComponentField::Other
402        }
403    } else {
404        ComponentField::Other
405    }
406}
407
408fn extract_port_type(seg: &syn::PathSegment) -> syn::Type {
409    match seg.arguments {
410        syn::PathArguments::AngleBracketed(ref abppd) => {
411            match abppd.args.first().expect("Invalid type argument!") {
412                syn::GenericArgument::Type(ty) => ty.clone(),
413                _ => panic!("Wrong generic argument type in {:?}", seg),
414            }
415        }
416        _ => panic!("Wrong path parameter type! {:?}", seg),
417    }
418}