Skip to main content

reflow_actor_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::punctuated::Punctuated;
4use syn::{
5    Expr, ItemFn, LitBool, LitInt, LitStr, Token, parse::Parse, parse::ParseStream,
6    parse_macro_input,
7};
8
9/// Delivery semantics for a port connection.
10#[derive(Debug, Clone, PartialEq, Default)]
11enum PortDelivery {
12    /// Block if channel full. Messages never dropped. (default)
13    #[default]
14    Reliable,
15    /// try_send — drop if channel full. For ticks, signals.
16    Latest,
17    /// Write to shared FramePool, send slot index. For large binary data.
18    Pool(String),
19}
20
21#[derive(Debug, Clone)]
22struct PortDef {
23    name: String,
24    delivery: PortDelivery,
25}
26
27#[derive(Debug, Default)]
28struct PortsDefinition {
29    capacity: Option<usize>,
30    ports: Vec<String>,
31    port_defs: Vec<PortDef>,
32}
33
34/// Parse a single port entry: `name` or `name: latest` or `name: pool("pool_name")`
35fn parse_port_entry(input: ParseStream) -> syn::Result<PortDef> {
36    let name = input.parse::<syn::Ident>()?.to_string();
37
38    // Check for `: delivery` annotation
39    let delivery = if input.peek(Token![:]) && !input.peek2(Token![:]) {
40        input.parse::<Token![:]>()?;
41        let kind = input.parse::<syn::Ident>()?;
42        match kind.to_string().as_str() {
43            "latest" => PortDelivery::Latest,
44            "reliable" => PortDelivery::Reliable,
45            "pool" => {
46                // Parse pool("name")
47                let content;
48                syn::parenthesized!(content in input);
49                let pool_name = content.parse::<syn::LitStr>()?;
50                PortDelivery::Pool(pool_name.value())
51            }
52            other => {
53                return Err(syn::Error::new(
54                    kind.span(),
55                    format!(
56                        "Unknown port delivery kind '{}'. Expected 'latest', 'reliable', or 'pool(\"name\")'",
57                        other
58                    ),
59                ));
60            }
61        }
62    } else {
63        PortDelivery::Reliable
64    };
65
66    Ok(PortDef { name, delivery })
67}
68
69impl Parse for PortsDefinition {
70    fn parse(input: ParseStream) -> syn::Result<Self> {
71        // Parse the capacity in angle brackets, default to 0 if not provided
72        let mut capacity = None;
73        if input.peek(syn::token::Colon) {
74            input.parse::<syn::token::Colon>()?;
75            input.parse::<syn::token::Colon>()?;
76
77            let _lt = input.parse::<Token![<]>()?;
78            capacity = Some(input.parse::<LitInt>()?.base10_parse()?);
79            let _gt = input.parse::<Token![>]>()?;
80        }
81
82        // Parse port entries in parentheses
83        let content;
84        syn::parenthesized!(content in input);
85
86        let mut port_defs = Vec::new();
87        while !content.is_empty() {
88            port_defs.push(parse_port_entry(&content)?);
89            if !content.is_empty() {
90                content.parse::<Token![,]>()?;
91            }
92        }
93
94        let ports = port_defs.iter().map(|p| p.name.clone()).collect();
95
96        Ok(PortsDefinition {
97            capacity,
98            ports,
99            port_defs,
100        })
101    }
102}
103
104struct ActorArgs {
105    name: Option<syn::Ident>,
106    _state: Option<syn::Ident>,
107    inports: PortsDefinition,
108    outports: PortsDefinition,
109    await_all_inports: bool,
110    await_inports: Vec<String>,
111}
112
113impl Parse for ActorArgs {
114    fn parse(input: ParseStream) -> syn::Result<Self> {
115        let mut name = None;
116        let mut inports = PortsDefinition::default();
117        let mut outports = PortsDefinition::default();
118        let mut _state = None;
119        let mut await_all_inports = false;
120        let mut await_inports: Vec<String> = Vec::new();
121
122        // Parse optional struct name
123        if !input.peek(syn::token::Paren) {
124            name = Some(input.parse::<syn::Ident>()?);
125            if !input.is_empty() {
126                input.parse::<Token![,]>()?;
127            }
128        }
129
130        // Parse inports and outports
131        while !input.is_empty() {
132            let ident = input.parse::<syn::Ident>()?;
133
134            match ident.to_string().as_str() {
135                "state" => {
136                    let content;
137                    syn::parenthesized!(content in input);
138                    let state_ident = content.parse::<syn::Ident>()?;
139                    _state = Some(state_ident);
140                }
141                "inports" => {
142                    let port_def = input.parse::<PortsDefinition>()?;
143                    inports = port_def;
144                }
145                "outports" => {
146                    let port_def = input.parse::<PortsDefinition>()?;
147                    outports = port_def;
148                }
149                "await_all_inports" => {
150                    await_all_inports = true;
151                }
152                "await_inports" => {
153                    // Parse: await_inports(port1, port2, port3)
154                    let content;
155                    syn::parenthesized!(content in input);
156                    let ports = Punctuated::<syn::Ident, Token![,]>::parse_terminated(&content)?;
157                    await_inports = ports.into_iter().map(|i| i.to_string()).collect();
158                }
159                _ => {
160                    return Err(syn::Error::new(
161                        ident.span(),
162                        "Expected 'inports', 'outports', 'await_all_inports', or 'await_inports'",
163                    ));
164                }
165            }
166
167            if !input.is_empty() {
168                input.parse::<Token![,]>()?;
169            }
170        }
171
172        Ok(ActorArgs {
173            name,
174            _state,
175            inports,
176            outports,
177            await_all_inports,
178            await_inports,
179        })
180    }
181}
182
183#[derive(Debug, Clone)]
184struct DisplayPort {
185    name: String,
186    data_type: String,
187}
188
189#[derive(Debug, Default)]
190struct DisplayPortList {
191    ports: Vec<DisplayPort>,
192}
193
194impl Parse for DisplayPortList {
195    fn parse(input: ParseStream) -> syn::Result<Self> {
196        let content;
197        syn::parenthesized!(content in input);
198        let mut ports = Vec::new();
199
200        while !content.is_empty() {
201            let name = content.parse::<syn::Ident>()?.to_string();
202            content.parse::<Token![=]>()?;
203            let data_type = content.parse::<LitStr>()?.value();
204            ports.push(DisplayPort { name, data_type });
205            if !content.is_empty() {
206                content.parse::<Token![,]>()?;
207            }
208        }
209
210        Ok(Self { ports })
211    }
212}
213
214#[derive(Default)]
215struct DisplayComponentArgs {
216    element: Option<String>,
217    bundle_id: Option<String>,
218    source: Option<Expr>,
219    shadow: Option<bool>,
220    observed_props: Vec<String>,
221    width: Option<String>,
222}
223
224impl Parse for DisplayComponentArgs {
225    fn parse(input: ParseStream) -> syn::Result<Self> {
226        let content;
227        syn::parenthesized!(content in input);
228        let mut display = Self::default();
229
230        while !content.is_empty() {
231            let key = content.parse::<syn::Ident>()?;
232            match key.to_string().as_str() {
233                "element" => {
234                    content.parse::<Token![=]>()?;
235                    display.element = Some(content.parse::<LitStr>()?.value());
236                }
237                "bundle_id" => {
238                    content.parse::<Token![=]>()?;
239                    display.bundle_id = Some(content.parse::<LitStr>()?.value());
240                }
241                "source" => {
242                    content.parse::<Token![=]>()?;
243                    display.source = Some(content.parse::<Expr>()?);
244                }
245                "shadow" => {
246                    content.parse::<Token![=]>()?;
247                    display.shadow = Some(content.parse::<LitBool>()?.value);
248                }
249                "observed_props" => {
250                    let props;
251                    syn::parenthesized!(props in content);
252                    let parsed = Punctuated::<LitStr, Token![,]>::parse_terminated(&props)?;
253                    display.observed_props = parsed.into_iter().map(|prop| prop.value()).collect();
254                }
255                "width" => {
256                    content.parse::<Token![=]>()?;
257                    display.width = Some(content.parse::<LitStr>()?.value());
258                }
259                other => {
260                    return Err(syn::Error::new(
261                        key.span(),
262                        format!(
263                            "Unknown display key '{}'. Expected element, bundle_id, source, shadow, observed_props, or width",
264                            other
265                        ),
266                    ));
267                }
268            }
269
270            if !content.is_empty() {
271                content.parse::<Token![,]>()?;
272            }
273        }
274
275        Ok(display)
276    }
277}
278
279struct ActorDisplayArgs {
280    actor: Option<syn::Ident>,
281    id: String,
282    title: String,
283    subtitle: Option<String>,
284    category: String,
285    subcategory: Option<String>,
286    description: String,
287    icon: String,
288    variant: Option<String>,
289    inputs: DisplayPortList,
290    outputs: DisplayPortList,
291    display: Option<DisplayComponentArgs>,
292}
293
294impl Default for ActorDisplayArgs {
295    fn default() -> Self {
296        Self {
297            actor: None,
298            id: String::new(),
299            title: String::new(),
300            subtitle: None,
301            category: "reflow".to_string(),
302            subcategory: None,
303            description: String::new(),
304            icon: "cpu".to_string(),
305            variant: None,
306            inputs: DisplayPortList::default(),
307            outputs: DisplayPortList::default(),
308            display: None,
309        }
310    }
311}
312
313impl Parse for ActorDisplayArgs {
314    fn parse(input: ParseStream) -> syn::Result<Self> {
315        let mut args = Self::default();
316
317        while !input.is_empty() {
318            let key = input.parse::<syn::Ident>()?;
319            match key.to_string().as_str() {
320                "actor" => {
321                    input.parse::<Token![=]>()?;
322                    args.actor = Some(input.parse::<syn::Ident>()?);
323                }
324                "id" | "template_id" => {
325                    input.parse::<Token![=]>()?;
326                    args.id = input.parse::<LitStr>()?.value();
327                }
328                "title" => {
329                    input.parse::<Token![=]>()?;
330                    args.title = input.parse::<LitStr>()?.value();
331                }
332                "subtitle" => {
333                    input.parse::<Token![=]>()?;
334                    args.subtitle = Some(input.parse::<LitStr>()?.value());
335                }
336                "category" => {
337                    input.parse::<Token![=]>()?;
338                    args.category = input.parse::<LitStr>()?.value();
339                }
340                "subcategory" => {
341                    input.parse::<Token![=]>()?;
342                    args.subcategory = Some(input.parse::<LitStr>()?.value());
343                }
344                "description" => {
345                    input.parse::<Token![=]>()?;
346                    args.description = input.parse::<LitStr>()?.value();
347                }
348                "icon" => {
349                    input.parse::<Token![=]>()?;
350                    args.icon = input.parse::<LitStr>()?.value();
351                }
352                "variant" => {
353                    input.parse::<Token![=]>()?;
354                    args.variant = Some(input.parse::<LitStr>()?.value());
355                }
356                "inputs" => {
357                    args.inputs = input.parse::<DisplayPortList>()?;
358                }
359                "outputs" => {
360                    args.outputs = input.parse::<DisplayPortList>()?;
361                }
362                "display" => {
363                    args.display = Some(input.parse::<DisplayComponentArgs>()?);
364                }
365                other => {
366                    return Err(syn::Error::new(
367                        key.span(),
368                        format!(
369                            "Unknown actor_display key '{}'. Expected actor, id, title, subtitle, category, subcategory, description, icon, variant, inputs, outputs, or display",
370                            other
371                        ),
372                    ));
373                }
374            }
375
376            if !input.is_empty() {
377                input.parse::<Token![,]>()?;
378            }
379        }
380
381        if args.id.is_empty() {
382            return Err(input.error("actor_display requires id = \"tpl_...\""));
383        }
384        if args.title.is_empty() {
385            return Err(input.error("actor_display requires title = \"...\""));
386        }
387        if args.description.is_empty() {
388            return Err(input.error("actor_display requires description = \"...\""));
389        }
390
391        Ok(args)
392    }
393}
394
395#[proc_macro_attribute]
396pub fn actor(attr: TokenStream, item: TokenStream) -> TokenStream {
397    let args = parse_macro_input!(attr as ActorArgs);
398    let input_fn = parse_macro_input!(item as ItemFn);
399    let fn_name = &input_fn.sig.ident;
400    let fn_vis = &input_fn.vis;
401
402    // Create struct name from either provided name or function name
403    let struct_name = match args.name {
404        Some(name) => name,
405        None => format_ident!(
406            "{}Actor",
407            fn_name
408                .to_string()
409                .chars()
410                .next()
411                .unwrap()
412                .to_uppercase()
413                .to_string()
414                + &fn_name.to_string()[1..]
415        ),
416    };
417    // Generate port initialization code
418    let init_inports = args.inports.ports.iter().map(|port| {
419        let name = port;
420        quote! {
421            String::from(#name)
422        }
423    });
424
425    let init_outports = args.outports.ports.iter().map(|port| {
426        let name = port;
427        quote! {
428            String::from(#name)
429        }
430    });
431
432    let out_ports_cap = args.outports.capacity;
433    let _in_ports_cap = args.inports.capacity;
434    let await_all_inports = args.await_all_inports;
435    let await_inports_list = &args.await_inports;
436    let _has_selective_await = !await_inports_list.is_empty();
437
438    let out_ports_channel = if let Some(out_ports_cap) = out_ports_cap {
439        if out_ports_cap < 1 {
440            panic!("Outports capacity must be greater than 0");
441        }
442        quote! {flume::bounded(#out_ports_cap)}
443    } else {
444        quote! {flume::unbounded()}
445    };
446    // Actor inport channel is always unbounded — per-connector forwarder
447    // channels handle backpressure via bounded(64) + delivery semantics.
448    // The inport is just a merge point for all connectors, not a throttle.
449    let in_ports_channel = quote! {flume::unbounded()};
450
451    // Re-generate port name iterators for trait methods
452    let inport_names_iter = args.inports.ports.iter().map(|port| {
453        quote! { String::from(#port) }
454    });
455    let outport_names_iter = args.outports.ports.iter().map(|port| {
456        quote! { String::from(#port) }
457    });
458
459    // Generate port delivery metadata entries
460    let all_port_defs: Vec<&PortDef> = args
461        .inports
462        .port_defs
463        .iter()
464        .chain(args.outports.port_defs.iter())
465        .collect();
466    let port_delivery_entries = all_port_defs.iter().filter_map(|pd| {
467        match &pd.delivery {
468            PortDelivery::Reliable => None, // default, no entry needed
469            PortDelivery::Latest => {
470                let name = &pd.name;
471                Some(quote! { m.insert(#name.to_string(), "latest".to_string()); })
472            }
473            PortDelivery::Pool(pool_name) => {
474                let name = &pd.name;
475                let pool = pool_name.as_str();
476                Some(quote! { m.insert(#name.to_string(), format!("pool:{}", #pool)); })
477            }
478        }
479    });
480
481    let expanded = quote! {
482
483        // Keep the original function
484        #input_fn
485
486        #fn_vis struct #struct_name {
487            inports_channel: Port,
488            outports_channel: Port,
489        }
490
491        impl #struct_name {
492            pub fn new() -> Self {
493                Self {
494                    // NOTE: channels are intentionally cross-assigned —
495                    // inports get the outport-declared capacity and vice
496                    // versa.  Fixing this swap requires updating every
497                    // actor's capacity declarations first (many use
498                    // outports::<1> which would deadlock with bounded(1)).
499                    inports_channel: #out_ports_channel,
500                    outports_channel: #in_ports_channel,
501                }
502            }
503
504            /// Get a list of available input ports
505            pub fn input_ports(&self) -> Vec<String> {
506                vec![#(#init_inports),*]
507            }
508
509            /// Get a list of available output ports
510            pub fn output_ports(&self) -> Vec<String> {
511                vec![#(#init_outports),*]
512            }
513        }
514
515        impl Clone for #struct_name {
516            fn clone(&self) -> Self {
517                Self {
518                    inports_channel: self.inports_channel.clone(),
519                    outports_channel: self.outports_channel.clone(),
520                }
521            }
522        }
523
524        impl Actor for #struct_name {
525
526            fn get_behavior(&self) -> ActorBehavior {
527                Box::new(|context: ActorContext| {
528                    Box::pin(async move {
529                        #fn_name(context).await
530                    })
531                })
532            }
533
534            fn get_outports(&self) -> Port {
535                self.outports_channel.clone()
536            }
537
538            fn get_inports(&self) -> Port {
539                self.inports_channel.clone()
540            }
541
542            fn inport_names(&self) -> Vec<String> {
543                vec![#(#inport_names_iter),*]
544            }
545
546            fn outport_names(&self) -> Vec<String> {
547                vec![#(#outport_names_iter),*]
548            }
549
550            fn await_all_inports(&self) -> bool {
551                #await_all_inports
552            }
553
554            fn required_inports(&self) -> Vec<String> {
555                vec![#(String::from(#await_inports_list)),*]
556            }
557
558            fn port_delivery(&self) -> std::collections::HashMap<String, String> {
559                let mut m = std::collections::HashMap::new();
560                #(#port_delivery_entries)*
561                m
562            }
563
564            fn create_instance(&self) -> std::sync::Arc<dyn Actor> {
565                std::sync::Arc::new(Self::new())
566            }
567
568            // create_process() and create_state() use the trait defaults
569            // via ActorProcess. Override only for non-MemoryState state types.
570        }
571    };
572
573    TokenStream::from(expanded)
574}
575
576#[proc_macro_attribute]
577pub fn actor_display(attr: TokenStream, item: TokenStream) -> TokenStream {
578    let args = parse_macro_input!(attr as ActorDisplayArgs);
579    let input_fn = parse_macro_input!(item as ItemFn);
580    let fn_name = &input_fn.sig.ident;
581    let fn_vis = &input_fn.vis;
582    let template_fn = format_ident!("{}_template", fn_name);
583
584    let actor_struct = args.actor.unwrap_or_else(|| {
585        format_ident!(
586            "{}Actor",
587            fn_name
588                .to_string()
589                .chars()
590                .next()
591                .unwrap()
592                .to_uppercase()
593                .to_string()
594                + &fn_name.to_string()[1..]
595        )
596    });
597
598    let id = args.id;
599    let title = args.title;
600    let subtitle = args.subtitle;
601    let category = args.category;
602    let subcategory = args.subcategory;
603    let description = args.description;
604    let icon = args.icon;
605    let variant = args.variant;
606
607    let subtitle_tokens = option_string_tokens(subtitle);
608    let subcategory_tokens = option_string_tokens(subcategory);
609    let variant_tokens = option_string_tokens(variant);
610
611    let input_ports = args.inputs.ports.iter().map(|port| {
612        let name = &port.name;
613        let label = label_from_port_name(name);
614        let data_type = &port.data_type;
615        quote! {
616            ::reflow_network::template::Port {
617                id: #name.to_string(),
618                label: #label.to_string(),
619                port_type: ::reflow_network::template::PortType::Input,
620                position: ::reflow_network::template::PortPosition::Left,
621                data_type: Some(#data_type.to_string()),
622                required: None,
623                multiple: None,
624            }
625        }
626    });
627
628    let output_ports = args.outputs.ports.iter().map(|port| {
629        let name = &port.name;
630        let label = label_from_port_name(name);
631        let data_type = &port.data_type;
632        quote! {
633            ::reflow_network::template::Port {
634                id: #name.to_string(),
635                label: #label.to_string(),
636                port_type: ::reflow_network::template::PortType::Output,
637                position: ::reflow_network::template::PortPosition::Right,
638                data_type: Some(#data_type.to_string()),
639                required: None,
640                multiple: None,
641            }
642        }
643    });
644
645    let ports = input_ports.chain(output_ports).collect::<Vec<_>>();
646
647    let display = match args.display {
648        Some(display) => {
649            let element = display.element.unwrap_or_default();
650            let bundle_id_tokens = option_string_tokens(display.bundle_id);
651            let source = display.source;
652            let shadow_tokens = match display.shadow {
653                Some(value) => quote! { Some(#value) },
654                None => quote! { None },
655            };
656            let observed_props = display.observed_props;
657            let width_tokens = option_string_tokens(display.width);
658
659            let source_tokens = match source {
660                Some(source) => quote! { Some((#source).to_string()) },
661                None => quote! { None },
662            };
663
664            quote! {
665                Some(::reflow_network::template::DisplayComponent {
666                    element: #element.to_string(),
667                    bundle_id: #bundle_id_tokens,
668                    source: #source_tokens,
669                    shadow: #shadow_tokens,
670                    observed_props: Some(vec![#(#observed_props.to_string()),*]),
671                    width: #width_tokens,
672                })
673            }
674        }
675        None => quote! { None },
676    };
677
678    let expanded = quote! {
679        #input_fn
680
681        #fn_vis fn #template_fn(
682            version: &Option<String>,
683            capabilities: &Option<Vec<String>>,
684        ) -> ::reflow_network::template::NodeTemplate {
685            ::reflow_network::template::NodeTemplate {
686                id: #id.to_string(),
687                type_name: #id.to_string(),
688                title: #title.to_string(),
689                subtitle: #subtitle_tokens,
690                category: #category.to_string(),
691                subcategory: #subcategory_tokens,
692                description: #description.to_string(),
693                icon: #icon.to_string(),
694                variant: #variant_tokens,
695                shape: Some(::reflow_network::template::NodeShape::Rectangle),
696                size: Some(::reflow_network::template::NodeSize::Medium),
697                ports: vec![#(#ports),*],
698                properties: None,
699                property_rules: None,
700                runtime: Some(::reflow_network::template::RuntimeRequirements {
701                    executor: "reflow".to_string(),
702                    version: version.clone(),
703                    required_env_vars: None,
704                    capabilities: capabilities.clone(),
705                }),
706                display: #display,
707            }
708        }
709
710        impl #actor_struct {
711            pub fn actor_template(
712                version: &Option<String>,
713                capabilities: &Option<Vec<String>>,
714            ) -> ::reflow_network::template::NodeTemplate {
715                #template_fn(version, capabilities)
716            }
717        }
718    };
719
720    TokenStream::from(expanded)
721}
722
723fn label_from_port_name(name: &str) -> String {
724    name.split('_')
725        .map(|part| {
726            let mut chars = part.chars();
727            match chars.next() {
728                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
729                None => String::new(),
730            }
731        })
732        .collect::<Vec<_>>()
733        .join(" ")
734}
735
736fn option_string_tokens(value: Option<String>) -> proc_macro2::TokenStream {
737    match value {
738        Some(value) => quote! { Some(#value.to_string()) },
739        None => quote! { None },
740    }
741}