Skip to main content

kompact_component_derive/
lib.rs

1#![recursion_limit = "128"]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use syn::{DeriveInput, parse_macro_input};
8
9use std::{collections::HashMap, iter::Iterator};
10
11/// A macro to derive fair [ComponentDefinition](ComponentDefinition) implementations
12///
13/// Implementations will set up ports and the component context correctly, and
14/// during execution check ports in a fair round-robin manner.
15///
16/// Using this macro will also derive implementations of [ProvideRef](ProvideRef)
17/// or [RequireRef](RequireRef) for each declared port.
18#[proc_macro_derive(ComponentDefinition)]
19pub fn component_definition(input: TokenStream) -> TokenStream {
20    // Parse the input stream
21    let ast = parse_macro_input!(input as DeriveInput);
22
23    // Build the impl
24    let generated = impl_component_definition(&ast);
25
26    //println!("Derived code:\n{}", generated);
27
28    // Return the generated impl
29    generated.into()
30}
31
32type PortEntry<'a> = (&'a syn::Field, PortField);
33
34#[allow(clippy::map_entry)]
35fn impl_component_definition(ast: &syn::DeriveInput) -> TokenStream2 {
36    let name = &ast.ident;
37    let name_str = format!("{}", name);
38    if let syn::Data::Struct(ref vdata) = ast.data {
39        let generics = &ast.generics;
40        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42        let fields = &vdata.fields;
43        let mut ports: Vec<PortEntry> = Vec::new();
44        let mut ctx_field: Option<&syn::Field> = None;
45        for field in fields.iter() {
46            let cf = identify_field(field);
47            match cf {
48                ComponentField::Ctx => {
49                    ctx_field = Some(field);
50                }
51                ComponentField::Port(pf) => ports.push((field, pf)),
52                ComponentField::Other => (),
53            }
54        }
55        let (ctx_setup, ctx_access) = match ctx_field {
56            Some(f) => {
57                let id = &f.ident;
58                let setup = quote! { self.#id.initialise(self_component.clone()); };
59                let access = quote! { self.#id };
60                (setup, access)
61            }
62            None => panic!("No ComponentContext found for {:?}!", name),
63        };
64        let port_setup = ports
65            .iter()
66            .map(|&(f, _)| {
67                let id = &f.ident;
68                quote! { self.#id.set_parent(self_component.clone()); }
69            })
70            .collect::<Vec<_>>();
71        let port_handles_skip = ports
72            .iter()
73            .enumerate()
74            .map(|(i, &(f, ref t))| {
75                let id = &f.ident;
76                //let ref ty = f.ty;
77                let handle = t.as_handle();
78                quote! {
79                    if skip <= #i {
80                        if count >= max_events {
81                            return ExecuteResult::new(false, count, #i);
82                        }
83                        #[allow(unreachable_code)]
84                        { // can be Never type, which is ok, so just suppress this warning
85                            if let Some(event) = self.#id.dequeue() {
86                                let res = #handle
87                                count += 1;
88                                done_work = true;
89                                match res {
90                                    Ok(Handled::Ok) => (),
91                                    Ok(Handled::BlockOn(blocking_future)) => {
92                                        self.ctx_mut().set_blocking(blocking_future);
93                                        return ExecuteResult::new(true, count, #i);
94                                    }
95                                    Ok(Handled::Shutdown) => {
96                                        return ExecuteResult::shutdown(count, #i);
97                                    }
98                                    Err(error) => {
99                                        return ExecuteResult::error(error, count, #i);
100                                    }
101                                }
102                            }
103                        }
104                    }
105                }
106            })
107            .collect::<Vec<_>>();
108        let port_handles = ports
109            .iter()
110            .enumerate()
111            .map(|(i, &(f, ref t))| {
112                let id = &f.ident;
113                //let ref ty = f.ty;
114                let handle = t.as_handle();
115                quote! {
116                    if count >= max_events {
117                        return ExecuteResult::new(false, count, #i);
118                    }
119                    #[allow(unreachable_code)]
120                    { // can be Never type, which is ok, so just suppress this warning
121                        if let Some(event) = self.#id.dequeue() {
122                            let res = #handle
123                            count += 1;
124                            done_work = true;
125                            match res {
126                                Ok(Handled::Ok) => (),
127                                Ok(Handled::BlockOn(blocking_future)) => {
128                                    self.ctx_mut().set_blocking(blocking_future);
129                                    return ExecuteResult::new(true, count, #i);
130                                }
131                                Ok(Handled::Shutdown) => {
132                                    return ExecuteResult::shutdown(count, #i);
133                                }
134                                Err(error) => {
135                                    return ExecuteResult::error(error, count, #i);
136                                }
137                            }
138                        }
139                    }
140                }
141            })
142            .collect::<Vec<_>>();
143        let exec = if port_handles.is_empty() {
144            quote! {
145                fn execute(&mut self, _max_events: usize, _skip: usize) -> ExecuteResult {
146                    ExecuteResult::new(false, 0, 0)
147                }
148            }
149        } else {
150            quote! {
151                fn execute(&mut self, max_events: usize, skip: usize) -> ExecuteResult {
152                    let mut count: usize = 0;
153                    let mut done_work = true; // might skip queues that have work
154                    #(#port_handles_skip)*
155                    while done_work {
156                        done_work = false;
157                        #(#port_handles)*
158                    }
159                    ExecuteResult::new(false, count, 0)
160                }
161            }
162        };
163
164        let mut provided_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
165        let mut provided_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
166        let mut required_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
167        let mut required_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
168        for port in ports {
169            let port_field = port.1.clone();
170            match port_field {
171                PortField::Required(ty) => {
172                    if let Some(port_list) = required_ports_non_unique.get_mut(&ty) {
173                        port_list.push(port);
174                    } else if required_ports_unique.contains_key(&ty) {
175                        let other_entry = required_ports_unique.remove(&ty).unwrap();
176                        required_ports_non_unique.insert(ty, vec![other_entry, port]);
177                    } else {
178                        required_ports_unique.insert(ty, port);
179                    }
180                }
181                PortField::Provided(ty) => {
182                    if let Some(port_list) = provided_ports_non_unique.get_mut(&ty) {
183                        port_list.push(port);
184                    } else if provided_ports_unique.contains_key(&ty) {
185                        let other_entry = provided_ports_unique.remove(&ty).unwrap();
186                        provided_ports_non_unique.insert(ty, vec![other_entry, port]);
187                    } else {
188                        provided_ports_unique.insert(ty, port);
189                    }
190                }
191            }
192        }
193
194        let generate_provided_ref_impl = |p: &PortEntry| {
195            let (field, port_field) = p;
196            let id = &field.ident;
197            let ty = port_field.port_type();
198            quote! {
199                impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
200                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
201                        self.#id.share()
202                    }
203                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
204                        self.#id.connect(req);
205                    }
206                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
207                        self.#id.disconnect_port(req);
208                    }
209                }
210            }
211        };
212        let generate_required_ref_impl = |p: &PortEntry| {
213            let (field, port_field) = p;
214            let id = &field.ident;
215            let ty = port_field.port_type();
216            quote! {
217                impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
218                    fn required_ref(&mut self) -> RequiredRef< #ty > {
219                        self.#id.share()
220                    }
221                    fn connect_to_provided(&mut self, prov: ProvidedRef< #ty >) -> () {
222                        self.#id.connect(prov);
223                    }
224                    fn disconnect(&mut self, prov: ProvidedRef< #ty >) -> () {
225                        self.#id.disconnect_port(prov);
226                    }
227                }
228            }
229        };
230        let generate_ambiguous_provided_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
231            let ids: Vec<String> = port_entries
232                .iter()
233                .map(|(field, _port_field)| {
234                    let id = field.ident.as_ref().unwrap();
235                    format!("{}", quote! {#id})
236                })
237                .collect();
238            let error_msg = format!(
239                "Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.",
240                quote! {#ty},
241                ids
242            );
243            quote! {
244                impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
245                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
246                        compile_error!(#error_msg);
247                    }
248                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
249                        compile_error!(#error_msg);
250                    }
251                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
252                        compile_error!(#error_msg);
253                    }
254                }
255            }
256        };
257        let generate_ambiguous_required_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
258            let ids: Vec<String> = port_entries
259                .iter()
260                .map(|(field, _port_field)| {
261                    let id = field.ident.as_ref().unwrap();
262                    format!("{}", quote! {#id})
263                })
264                .collect();
265            let error_msg = format!(
266                "Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.",
267                quote! {#ty},
268                ids
269            );
270            quote! {
271                impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
272                    fn provided_ref(&mut self) -> ProvidedRef< #ty > {
273                        compile_error!(#error_msg);
274                    }
275                    fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
276                        compile_error!(#error_msg);
277                    }
278                    fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
279                        compile_error!(#error_msg);
280                    }
281                }
282            }
283        };
284
285        let port_ref_impls = provided_ports_unique
286            .values()
287            .map(generate_provided_ref_impl)
288            .chain(
289                required_ports_unique
290                    .values()
291                    .map(generate_required_ref_impl),
292            )
293            .chain(
294                provided_ports_non_unique
295                    .iter()
296                    .map(|pair| generate_ambiguous_provided_ref_impl(pair.0, pair.1)),
297            )
298            .chain(
299                required_ports_non_unique
300                    .iter()
301                    .map(|pair| generate_ambiguous_required_ref_impl(pair.0, pair.1)),
302            )
303            .collect::<Vec<_>>();
304
305        fn make_match(f: &syn::Field, t: &syn::Type) -> TokenStream2 {
306            let f = &f.ident;
307            quote! {
308                id if id == ::std::any::TypeId::of::<#t>() =>
309                    Some(&mut self.#f as &mut dyn ::std::any::Any),
310            }
311        }
312
313        let provided_matches: Vec<_> = provided_ports_unique
314            .iter()
315            .map(|(t, p)| make_match(p.0, t))
316            .collect();
317
318        let required_matches: Vec<_> = required_ports_unique
319            .iter()
320            .map(|(t, p)| make_match(p.0, t))
321            .collect();
322
323        quote! {
324            impl #impl_generics ComponentDefinition for #name #ty_generics #where_clause {
325                fn setup(&mut self, self_component: ::std::sync::Arc<Component<Self>>) -> () {
326                    #ctx_setup
327                    //println!("Setting up ports");
328                    #(#port_setup)*
329                }
330                #exec
331                fn ctx_mut(&mut self) -> &mut ComponentContext<Self> {
332                    &mut #ctx_access
333                }
334                fn ctx(&self) -> &ComponentContext<Self> {
335                    &#ctx_access
336                }
337                fn type_name() -> &'static str {
338                    #name_str
339                }
340            }
341            impl #impl_generics DynamicPortAccess for #name #ty_generics #where_clause {
342                fn get_provided_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
343                    match port_id {
344                        #(#provided_matches)*
345                        _ => None,
346                    }
347                }
348
349                fn get_required_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
350                    match port_id {
351                        #(#required_matches)*
352                        _ => None,
353                    }
354                }
355            }
356            #(#port_ref_impls)*
357        }
358    } else {
359        //Nope. This is an Enum. We cannot handle these!
360        panic!("#[derive(ComponentDefinition)] is only defined for structs, not for enums!");
361    }
362}
363
364#[allow(clippy::large_enum_variant)]
365#[derive(Debug)]
366enum ComponentField {
367    Ctx,
368    Port(PortField),
369    Other,
370}
371
372#[derive(Debug, Clone)]
373enum PortField {
374    Required(syn::Type),
375    Provided(syn::Type),
376}
377
378impl PortField {
379    fn as_handle(&self) -> TokenStream2 {
380        match self {
381            PortField::Provided(ty) => quote! { Provide::<#ty>::handle(self, event); },
382            PortField::Required(ty) => quote! { Require::<#ty>::handle(self, event); },
383        }
384    }
385
386    fn port_type(&self) -> &syn::Type {
387        match self {
388            PortField::Provided(ty) => ty,
389            PortField::Required(ty) => ty,
390        }
391    }
392}
393
394const REQP: &str = "RequiredPort";
395const PROVP: &str = "ProvidedPort";
396const CTX: &str = "ComponentContext";
397const KOMPICS: &str = "kompact";
398
399fn identify_field(f: &syn::Field) -> ComponentField {
400    if let syn::Type::Path(ref patht) = f.ty {
401        let path = &patht.path;
402        let port_seg_opt = if path.segments.len() == 1 {
403            Some(&path.segments[0])
404        } else if path.segments.len() == 2 {
405            if path.segments[0].ident == KOMPICS {
406                Some(&path.segments[1])
407            } else {
408                //println!("Module is not 'kompact': {:?}", path);
409                None
410            }
411        } else {
412            //println!("Path too long for port: {:?}", path);
413            None
414        };
415        if let Some(seg) = port_seg_opt {
416            if seg.ident == REQP {
417                ComponentField::Port(PortField::Required(extract_port_type(seg)))
418            } else if seg.ident == PROVP {
419                ComponentField::Port(PortField::Provided(extract_port_type(seg)))
420            } else if seg.ident == CTX {
421                ComponentField::Ctx
422            } else {
423                //println!("Not a port: {:?}", path);
424                ComponentField::Other
425            }
426        } else {
427            ComponentField::Other
428        }
429    } else {
430        ComponentField::Other
431    }
432}
433
434fn extract_port_type(seg: &syn::PathSegment) -> syn::Type {
435    match seg.arguments {
436        syn::PathArguments::AngleBracketed(ref abppd) => {
437            match abppd.args.first().expect("Invalid type argument!") {
438                syn::GenericArgument::Type(ty) => ty.clone(),
439                _ => panic!("Wrong generic argument type in {:?}", seg),
440            }
441        }
442        _ => panic!("Wrong path parameter type! {:?}", seg),
443    }
444}