kompact_component_derive/
lib.rs1#![recursion_limit = "128"]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::quote;
7use syn::{parse_macro_input, DeriveInput};
8
9use std::{collections::HashMap, iter::Iterator};
10
11#[proc_macro_derive(ComponentDefinition)]
19pub fn component_definition(input: TokenStream) -> TokenStream {
20 let ast = parse_macro_input!(input as DeriveInput);
22
23 let gen = impl_component_definition(&ast);
25
26 gen.into()
30}
31
32type PortEntry<'a> = (&'a syn::Field, PortField);
33
34#[allow(clippy::map_entry)]
35fn impl_component_definition(ast: &syn::DeriveInput) -> TokenStream2 {
36 let name = &ast.ident;
37 let name_str = format!("{}", name);
38 if let syn::Data::Struct(ref vdata) = ast.data {
39 let generics = &ast.generics;
40 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42 let fields = &vdata.fields;
43 let mut ports: Vec<PortEntry> = Vec::new();
44 let mut ctx_field: Option<&syn::Field> = None;
45 for field in fields.iter() {
46 let cf = identify_field(field);
47 match cf {
48 ComponentField::Ctx => {
49 ctx_field = Some(field);
50 }
51 ComponentField::Port(pf) => ports.push((field, pf)),
52 ComponentField::Other => (),
53 }
54 }
55 let (ctx_setup, ctx_access) = match ctx_field {
56 Some(f) => {
57 let id = &f.ident;
58 let setup = quote! { self.#id.initialise(self_component.clone()); };
59 let access = quote! { self.#id };
60 (setup, access)
61 }
62 None => panic!("No ComponentContext found for {:?}!", name),
63 };
64 let port_setup = ports
65 .iter()
66 .map(|&(f, _)| {
67 let id = &f.ident;
68 quote! { self.#id.set_parent(self_component.clone()); }
69 })
70 .collect::<Vec<_>>();
71 let port_handles_skip = ports
72 .iter()
73 .enumerate()
74 .map(|(i, &(f, ref t))| {
75 let id = &f.ident;
76 let handle = t.as_handle();
78 quote! {
79 if skip <= #i {
80 if count >= max_events {
81 return ExecuteResult::new(false, count, #i);
82 }
83 #[allow(unreachable_code)]
84 { if let Some(event) = self.#id.dequeue() {
86 let res = #handle
87 count += 1;
88 done_work = true;
89 if let Handled::BlockOn(blocking_future) = res {
90 self.ctx_mut().set_blocking(blocking_future);
91 return ExecuteResult::new(true, count, #i);
92 }
93 }
94 }
95 }
96 }
97 })
98 .collect::<Vec<_>>();
99 let port_handles = ports
100 .iter()
101 .enumerate()
102 .map(|(i, &(f, ref t))| {
103 let id = &f.ident;
104 let handle = t.as_handle();
106 quote! {
107 if count >= max_events {
108 return ExecuteResult::new(false, count, #i);
109 }
110 #[allow(unreachable_code)]
111 { if let Some(event) = self.#id.dequeue() {
113 let res = #handle
114 count += 1;
115 done_work = true;
116 if let Handled::BlockOn(blocking_future) = res {
117 self.ctx_mut().set_blocking(blocking_future);
118 return ExecuteResult::new(true, count, #i);
119 }
120 }
121 }
122 }
123 })
124 .collect::<Vec<_>>();
125 let exec = if port_handles.is_empty() {
126 quote! {
127 fn execute(&mut self, _max_events: usize, _skip: usize) -> ExecuteResult {
128 ExecuteResult::new(false, 0, 0)
129 }
130 }
131 } else {
132 quote! {
133 fn execute(&mut self, max_events: usize, skip: usize) -> ExecuteResult {
134 let mut count: usize = 0;
135 let mut done_work = true; #(#port_handles_skip)*
137 while done_work {
138 done_work = false;
139 #(#port_handles)*
140 }
141 ExecuteResult::new(false, count, 0)
142 }
143 }
144 };
145
146 let mut provided_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
147 let mut provided_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
148 let mut required_ports_unique: HashMap<syn::Type, PortEntry> = HashMap::new();
149 let mut required_ports_non_unique: HashMap<syn::Type, Vec<PortEntry>> = HashMap::new();
150 for port in ports {
151 let port_field = port.1.clone();
152 match port_field {
153 PortField::Required(ty) => {
154 if let Some(port_list) = required_ports_non_unique.get_mut(&ty) {
155 port_list.push(port);
156 } else if required_ports_unique.contains_key(&ty) {
157 let other_entry = required_ports_unique.remove(&ty).unwrap();
158 required_ports_non_unique.insert(ty, vec![other_entry, port]);
159 } else {
160 required_ports_unique.insert(ty, port);
161 }
162 }
163 PortField::Provided(ty) => {
164 if let Some(port_list) = provided_ports_non_unique.get_mut(&ty) {
165 port_list.push(port);
166 } else if provided_ports_unique.contains_key(&ty) {
167 let other_entry = provided_ports_unique.remove(&ty).unwrap();
168 provided_ports_non_unique.insert(ty, vec![other_entry, port]);
169 } else {
170 provided_ports_unique.insert(ty, port);
171 }
172 }
173 }
174 }
175
176 let generate_provided_ref_impl = |p: &PortEntry| {
177 let (field, port_field) = p;
178 let id = &field.ident;
179 let ty = port_field.port_type();
180 quote! {
181 impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
182 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
183 self.#id.share()
184 }
185 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
186 self.#id.connect(req);
187 }
188 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
189 self.#id.disconnect_port(req);
190 }
191 }
192 }
193 };
194 let generate_required_ref_impl = |p: &PortEntry| {
195 let (field, port_field) = p;
196 let id = &field.ident;
197 let ty = port_field.port_type();
198 quote! {
199 impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
200 fn required_ref(&mut self) -> RequiredRef< #ty > {
201 self.#id.share()
202 }
203 fn connect_to_provided(&mut self, prov: ProvidedRef< #ty >) -> () {
204 self.#id.connect(prov);
205 }
206 fn disconnect(&mut self, prov: ProvidedRef< #ty >) -> () {
207 self.#id.disconnect_port(prov);
208 }
209 }
210 }
211 };
212 let generate_ambiguous_provided_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
213 let ids: Vec<String> = port_entries
214 .iter()
215 .map(|(field, _port_field)| {
216 let id = field.ident.as_ref().unwrap();
217 format!("{}", quote! {#id})
218 })
219 .collect();
220 let error_msg = format!("Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.", quote!{#ty}, ids);
221 quote! {
222 impl #impl_generics ProvideRef< #ty > for #name #ty_generics #where_clause {
223 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
224 compile_error!(#error_msg);
225 }
226 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
227 compile_error!(#error_msg);
228 }
229 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
230 compile_error!(#error_msg);
231 }
232 }
233 }
234 };
235 let generate_ambiguous_required_ref_impl = |ty: &syn::Type, port_entries: &[PortEntry]| {
236 let ids: Vec<String> = port_entries
237 .iter()
238 .map(|(field, _port_field)| {
239 let id = field.ident.as_ref().unwrap();
240 format!("{}", quote! {#id})
241 })
242 .collect();
243 let error_msg = format!("Ambiguous port type: There are multiple fields with type {} ({:?}). You cannot derive ComponentDefinition in these cases, as you must resolve the ambiguity manually.", quote!{#ty}, ids);
244 quote! {
245 impl #impl_generics RequireRef< #ty > for #name #ty_generics #where_clause {
246 fn provided_ref(&mut self) -> ProvidedRef< #ty > {
247 compile_error!(#error_msg);
248 }
249 fn connect_to_required(&mut self, req: RequiredRef< #ty >) -> () {
250 compile_error!(#error_msg);
251 }
252 fn disconnect(&mut self, req: RequiredRef< #ty >) -> () {
253 compile_error!(#error_msg);
254 }
255 }
256 }
257 };
258
259 let port_ref_impls = provided_ports_unique
260 .values()
261 .map(generate_provided_ref_impl)
262 .chain(
263 required_ports_unique
264 .values()
265 .map(generate_required_ref_impl),
266 )
267 .chain(
268 provided_ports_non_unique
269 .iter()
270 .map(|pair| generate_ambiguous_provided_ref_impl(pair.0, pair.1)),
271 )
272 .chain(
273 required_ports_non_unique
274 .iter()
275 .map(|pair| generate_ambiguous_required_ref_impl(pair.0, pair.1)),
276 )
277 .collect::<Vec<_>>();
278
279 fn make_match(f: &syn::Field, t: &syn::Type) -> TokenStream2 {
280 let f = &f.ident;
281 quote! {
282 id if id == ::std::any::TypeId::of::<#t>() =>
283 Some(&mut self.#f as &mut dyn ::std::any::Any),
284 }
285 }
286
287 let provided_matches: Vec<_> = provided_ports_unique
288 .iter()
289 .map(|(t, p)| make_match(p.0, t))
290 .collect();
291
292 let required_matches: Vec<_> = required_ports_unique
293 .iter()
294 .map(|(t, p)| make_match(p.0, t))
295 .collect();
296
297 quote! {
298 impl #impl_generics ComponentDefinition for #name #ty_generics #where_clause {
299 fn setup(&mut self, self_component: ::std::sync::Arc<Component<Self>>) -> () {
300 #ctx_setup
301 #(#port_setup)*
303 }
304 #exec
305 fn ctx_mut(&mut self) -> &mut ComponentContext<Self> {
306 &mut #ctx_access
307 }
308 fn ctx(&self) -> &ComponentContext<Self> {
309 &#ctx_access
310 }
311 fn type_name() -> &'static str {
312 #name_str
313 }
314 }
315 impl #impl_generics DynamicPortAccess for #name #ty_generics #where_clause {
316 fn get_provided_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
317 match port_id {
318 #(#provided_matches)*
319 _ => None,
320 }
321 }
322
323 fn get_required_port_as_any(&mut self, port_id: ::std::any::TypeId) -> Option<&mut dyn ::std::any::Any> {
324 match port_id {
325 #(#required_matches)*
326 _ => None,
327 }
328 }
329 }
330 #(#port_ref_impls)*
331 }
332 } else {
333 panic!("#[derive(ComponentDefinition)] is only defined for structs, not for enums!");
335 }
336}
337
338#[allow(clippy::large_enum_variant)]
339#[derive(Debug)]
340enum ComponentField {
341 Ctx,
342 Port(PortField),
343 Other,
344}
345
346#[derive(Debug, Clone)]
347enum PortField {
348 Required(syn::Type),
349 Provided(syn::Type),
350}
351
352impl PortField {
353 fn as_handle(&self) -> TokenStream2 {
354 match *self {
355 PortField::Provided(ref ty) => quote! { Provide::<#ty>::handle(self, event); },
356 PortField::Required(ref ty) => quote! { Require::<#ty>::handle(self, event); },
357 }
358 }
359
360 fn port_type(&self) -> &syn::Type {
361 match self {
362 PortField::Provided(ref ty) => ty,
363 PortField::Required(ref ty) => ty,
364 }
365 }
366}
367
368const REQP: &str = "RequiredPort";
369const PROVP: &str = "ProvidedPort";
370const CTX: &str = "ComponentContext";
371const KOMPICS: &str = "kompact";
372
373fn identify_field(f: &syn::Field) -> ComponentField {
374 if let syn::Type::Path(ref patht) = f.ty {
375 let path = &patht.path;
376 let port_seg_opt = if path.segments.len() == 1 {
377 Some(&path.segments[0])
378 } else if path.segments.len() == 2 {
379 if path.segments[0].ident == KOMPICS {
380 Some(&path.segments[1])
381 } else {
382 None
384 }
385 } else {
386 None
388 };
389 if let Some(seg) = port_seg_opt {
390 if seg.ident == REQP {
391 ComponentField::Port(PortField::Required(extract_port_type(seg)))
392 } else if seg.ident == PROVP {
393 ComponentField::Port(PortField::Provided(extract_port_type(seg)))
394 } else if seg.ident == CTX {
395 ComponentField::Ctx
396 } else {
397 ComponentField::Other
399 }
400 } else {
401 ComponentField::Other
402 }
403 } else {
404 ComponentField::Other
405 }
406}
407
408fn extract_port_type(seg: &syn::PathSegment) -> syn::Type {
409 match seg.arguments {
410 syn::PathArguments::AngleBracketed(ref abppd) => {
411 match abppd.args.first().expect("Invalid type argument!") {
412 syn::GenericArgument::Type(ty) => ty.clone(),
413 _ => panic!("Wrong generic argument type in {:?}", seg),
414 }
415 }
416 _ => panic!("Wrong path parameter type! {:?}", seg),
417 }
418}