kompact_component_derive/
lib.rs1#![recursion_limit = "128"]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use syn::{DeriveInput, parse_macro_input};
8
9use std::{collections::HashMap, iter::Iterator};
10
11#[proc_macro_derive(ComponentDefinition)]
19pub fn component_definition(input: TokenStream) -> TokenStream {
20 let ast = parse_macro_input!(input as DeriveInput);
22
23 let generated = impl_component_definition(&ast);
25
26 generated.into()
30}
31
32type PortEntry<'a> = (&'a syn::Field, PortField);
33
34#[allow(clippy::map_entry)]
35fn impl_component_definition(ast: &syn::DeriveInput) -> TokenStream2 {
36 let name = &ast.ident;
37 let name_str = format!("{}", name);
38 if let syn::Data::Struct(ref vdata) = ast.data {
39 let generics = &ast.generics;
40 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42 let fields = &vdata.fields;
43 let mut ports: Vec<PortEntry> = Vec::new();
44 let mut ctx_field: Option<&syn::Field> = None;
45 for field in fields.iter() {
46 let cf = identify_field(field);
47 match cf {
48 ComponentField::Ctx => {
49 ctx_field = Some(field);
50 }
51 ComponentField::Port(pf) => ports.push((field, pf)),
52 ComponentField::Other => (),
53 }
54 }
55 let (ctx_setup, ctx_access) = match ctx_field {
56 Some(f) => {
57 let id = &f.ident;
58 let setup = quote! { self.#id.initialise(self_component.clone()); };
59 let access = quote! { self.#id };
60 (setup, access)
61 }
62 None => panic!("No ComponentContext found for {:?}!", name),
63 };
64 let port_setup = ports
65 .iter()
66 .map(|&(f, _)| {
67 let id = &f.ident;
68 quote! { self.#id.set_parent(self_component.clone()); }
69 })
70 .collect::<Vec<_>>();
71 let port_handles_skip = ports
72 .iter()
73 .enumerate()
74 .map(|(i, &(f, ref t))| {
75 let id = &f.ident;
76 let handle = t.as_handle();
78 quote! {
79 if skip <= #i {
80 if count >= max_events {
81 return ExecuteResult::new(false, count, #i);
82 }
83 #[allow(unreachable_code)]
84 { if let Some(event) = self.#id.dequeue() {
86 let res = #handle
87 count += 1;
88 done_work = true;
89 match res {
90 Ok(Handled::Ok) => (),
91 Ok(Handled::BlockOn(blocking_future)) => {
92 self.ctx_mut().set_blocking(blocking_future);
93 return ExecuteResult::new(true, count, #i);
94 }
95 Ok(Handled::Shutdown) => {
96 return ExecuteResult::shutdown(count, #i);
97 }
98 Err(error) => {
99 return ExecuteResult::error(error, count, #i);
100 }
101 }
102 }
103 }
104 }
105 }
106 })
107 .collect::<Vec<_>>();
108 let port_handles = ports
109 .iter()
110 .enumerate()
111 .map(|(i, &(f, ref t))| {
112 let id = &f.ident;
113 let handle = t.as_handle();
115 quote! {
116 if count >= max_events {
117 return ExecuteResult::new(false, count, #i);
118 }
119 #[allow(unreachable_code)]
120 { if let Some(event) = self.#id.dequeue() {
122 let res = #handle
123 count += 1;
124 done_work = true;
125 match res {
126 Ok(Handled::Ok) => (),
127 Ok(Handled::BlockOn(blocking_future)) => {
128 self.ctx_mut().set_blocking(blocking_future);
129 return ExecuteResult::new(true, count, #i);
130 }
131 Ok(Handled::Shutdown) => {
132 return ExecuteResult::shutdown(count, #i);
133 }
134 Err(error) => {
135 return ExecuteResult::error(error, count, #i);
136 }
137 }
138 }
139 }
140 }
141 })
142 .collect::<Vec<_>>();
143 let exec = if port_handles.is_empty() {
144 quote! {
145 fn execute(&mut self, _max_events: usize, _skip: usize) -> ExecuteResult {
146 ExecuteResult::new(false, 0, 0)
147 }
148 }
149 } else {
150 quote! {
151 fn execute(&mut self, max_events: usize, skip: usize) -> ExecuteResult {
152 let mut count: usize = 0;
153 let mut done_work = true; #(#port_handles_skip)*
155 while done_work {
156 done_work = false;
157 #(#port_handles)*
158 }
159 ExecuteResult::new(false, count, 0)
160 }
161 }
162 };
163
164 let mut provided_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
165 let mut provided_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
166 let mut required_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
167 let mut required_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
168 for port in ports {
169 let port_field = port.1.clone();
170 match port_field {
171 PortField::Required(ty) => {
172 if let Some(port_list) = required_ports_non_unique.get_mut(&ty) {
173 port_list.push(port);
174 } else if required_ports_unique.contains_key(&ty) {
175 let other_entry = required_ports_unique.remove(&ty).unwrap();
176 required_ports_non_unique.insert(ty, vec![other_entry, port]);
177 } else {
178 required_ports_unique.insert(ty, port);
179 }
180 }
181 PortField::Provided(ty) => {
182 if let Some(port_list) = provided_ports_non_unique.get_mut(&ty) {
183 port_list.push(port);
184 } else if provided_ports_unique.contains_key(&ty) {
185 let other_entry = provided_ports_unique.remove(&ty).unwrap();
186 provided_ports_non_unique.insert(ty, vec![other_entry, port]);
187 } else {
188 provided_ports_unique.insert(ty, port);
189 }
190 }
191 }
192 }
193
194 let generate_provided_ref_impl = |p: &PortEntry| {
195 let (field, port_field) = p;
196 let id = &field.ident;
197 let ty = port_field.port_type();
198 quote! {
199 impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
200 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
201 self.#id.share()
202 }
203 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
204 self.#id.connect(req);
205 }
206 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
207 self.#id.disconnect_port(req);
208 }
209 }
210 }
211 };
212 let generate_required_ref_impl = |p: &PortEntry| {
213 let (field, port_field) = p;
214 let id = &field.ident;
215 let ty = port_field.port_type();
216 quote! {
217 impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
218 fn required_ref(&mut self) -> RequiredRef< #ty > {
219 self.#id.share()
220 }
221 fn connect_to_provided(&mut self, prov: ProvidedRef< #ty >) -> () {
222 self.#id.connect(prov);
223 }
224 fn disconnect(&mut self, prov: ProvidedRef< #ty >) -> () {
225 self.#id.disconnect_port(prov);
226 }
227 }
228 }
229 };
230 let generate_ambiguous_provided_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
231 let ids: Vec<String> = port_entries
232 .iter()
233 .map(|(field, _port_field)| {
234 let id = field.ident.as_ref().unwrap();
235 format!("{}", quote! {#id})
236 })
237 .collect();
238 let error_msg = format!(
239 "Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.",
240 quote! {#ty},
241 ids
242 );
243 quote! {
244 impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
245 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
246 compile_error!(#error_msg);
247 }
248 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
249 compile_error!(#error_msg);
250 }
251 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
252 compile_error!(#error_msg);
253 }
254 }
255 }
256 };
257 let generate_ambiguous_required_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
258 let ids: Vec<String> = port_entries
259 .iter()
260 .map(|(field, _port_field)| {
261 let id = field.ident.as_ref().unwrap();
262 format!("{}", quote! {#id})
263 })
264 .collect();
265 let error_msg = format!(
266 "Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.",
267 quote! {#ty},
268 ids
269 );
270 quote! {
271 impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
272 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
273 compile_error!(#error_msg);
274 }
275 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
276 compile_error!(#error_msg);
277 }
278 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
279 compile_error!(#error_msg);
280 }
281 }
282 }
283 };
284
285 let port_ref_impls = provided_ports_unique
286 .values()
287 .map(generate_provided_ref_impl)
288 .chain(
289 required_ports_unique
290 .values()
291 .map(generate_required_ref_impl),
292 )
293 .chain(
294 provided_ports_non_unique
295 .iter()
296 .map(|pair| generate_ambiguous_provided_ref_impl(pair.0, pair.1)),
297 )
298 .chain(
299 required_ports_non_unique
300 .iter()
301 .map(|pair| generate_ambiguous_required_ref_impl(pair.0, pair.1)),
302 )
303 .collect::<Vec<_>>();
304
305 fn make_match(f: &syn::Field, t: &syn::Type) -> TokenStream2 {
306 let f = &f.ident;
307 quote! {
308 id if id == ::std::any::TypeId::of::<#t>() =>
309 Some(&mut self.#f as &mut dyn ::std::any::Any),
310 }
311 }
312
313 let provided_matches: Vec<_> = provided_ports_unique
314 .iter()
315 .map(|(t, p)| make_match(p.0, t))
316 .collect();
317
318 let required_matches: Vec<_> = required_ports_unique
319 .iter()
320 .map(|(t, p)| make_match(p.0, t))
321 .collect();
322
323 quote! {
324 impl #impl_generics ComponentDefinition for #name #ty_generics #where_clause {
325 fn setup(&mut self, self_component: ::std::sync::Arc<Component<Self>>) -> () {
326 #ctx_setup
327 #(#port_setup)*
329 }
330 #exec
331 fn ctx_mut(&mut self) -> &mut ComponentContext<Self> {
332 &mut #ctx_access
333 }
334 fn ctx(&self) -> &ComponentContext<Self> {
335 &#ctx_access
336 }
337 fn type_name() -> &'static str {
338 #name_str
339 }
340 }
341 impl #impl_generics DynamicPortAccess for #name #ty_generics #where_clause {
342 fn get_provided_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
343 match port_id {
344 #(#provided_matches)*
345 _ => None,
346 }
347 }
348
349 fn get_required_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
350 match port_id {
351 #(#required_matches)*
352 _ => None,
353 }
354 }
355 }
356 #(#port_ref_impls)*
357 }
358 } else {
359 panic!("#[derive(ComponentDefinition)] is only defined for structs, not for enums!");
361 }
362}
363
364#[allow(clippy::large_enum_variant)]
365#[derive(Debug)]
366enum ComponentField {
367 Ctx,
368 Port(PortField),
369 Other,
370}
371
372#[derive(Debug, Clone)]
373enum PortField {
374 Required(syn::Type),
375 Provided(syn::Type),
376}
377
378impl PortField {
379 fn as_handle(&self) -> TokenStream2 {
380 match self {
381 PortField::Provided(ty) => quote! { Provide::<#ty>::handle(self, event); },
382 PortField::Required(ty) => quote! { Require::<#ty>::handle(self, event); },
383 }
384 }
385
386 fn port_type(&self) -> &syn::Type {
387 match self {
388 PortField::Provided(ty) => ty,
389 PortField::Required(ty) => ty,
390 }
391 }
392}
393
394const REQP: &str = "RequiredPort";
395const PROVP: &str = "ProvidedPort";
396const CTX: &str = "ComponentContext";
397const KOMPICS: &str = "kompact";
398
399fn identify_field(f: &syn::Field) -> ComponentField {
400 if let syn::Type::Path(ref patht) = f.ty {
401 let path = &patht.path;
402 let port_seg_opt = if path.segments.len() == 1 {
403 Some(&path.segments[0])
404 } else if path.segments.len() == 2 {
405 if path.segments[0].ident == KOMPICS {
406 Some(&path.segments[1])
407 } else {
408 None
410 }
411 } else {
412 None
414 };
415 if let Some(seg) = port_seg_opt {
416 if seg.ident == REQP {
417 ComponentField::Port(PortField::Required(extract_port_type(seg)))
418 } else if seg.ident == PROVP {
419 ComponentField::Port(PortField::Provided(extract_port_type(seg)))
420 } else if seg.ident == CTX {
421 ComponentField::Ctx
422 } else {
423 ComponentField::Other
425 }
426 } else {
427 ComponentField::Other
428 }
429 } else {
430 ComponentField::Other
431 }
432}
433
434fn extract_port_type(seg: &syn::PathSegment) -> syn::Type {
435 match seg.arguments {
436 syn::PathArguments::AngleBracketed(ref abppd) => {
437 match abppd.args.first().expect("Invalid type argument!") {
438 syn::GenericArgument::Type(ty) => ty.clone(),
439 _ => panic!("Wrong generic argument type in {:?}", seg),
440 }
441 }
442 _ => panic!("Wrong path parameter type! {:?}", seg),
443 }
444}