Skip to main content

orlando_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Ident, ItemFn, ItemStruct, LitInt, LitStr, Path, Type};
4
5// ── #[grain(state = T)] ────────────────────────────────────────
6
7/// Define a grain type with its associated state.
8///
9/// ```ignore
10/// #[grain(state = CounterState)]
11/// struct Counter;
12///
13/// #[grain(state = CounterState, idle_timeout_secs = 60)]
14/// struct Counter;
15/// ```
16#[proc_macro_attribute]
17pub fn grain(attr: TokenStream, item: TokenStream) -> TokenStream {
18    let args = parse_macro_input!(attr as GrainArgs);
19    let item_struct = parse_macro_input!(item as ItemStruct);
20
21    let name = &item_struct.ident;
22    let state_type = &args.state;
23
24    if args.stateless_worker && args.reentrant {
25        return syn::Error::new(
26            item_struct.ident.span(),
27            "`stateless_worker` and `reentrant` are mutually exclusive — a stateless worker pool already provides concurrency",
28        )
29        .to_compile_error()
30        .into();
31    }
32
33    let idle_timeout = args.idle_timeout_secs.map(|secs| {
34        quote! {
35            fn idle_timeout() -> ::std::time::Duration {
36                ::std::time::Duration::from_secs(#secs)
37            }
38        }
39    });
40
41    let ask_timeout = args.ask_timeout_secs.map(|secs| {
42        quote! {
43            fn ask_timeout() -> ::std::time::Duration {
44                ::std::time::Duration::from_secs(#secs)
45            }
46        }
47    });
48
49    let grain_type_name = args.grain_name.map(|n| {
50        quote! {
51            fn grain_type_name() -> &'static str { #n }
52        }
53    });
54
55    let reentrant = if args.reentrant {
56        quote! {
57            fn reentrant() -> bool { true }
58        }
59    } else {
60        quote! {}
61    };
62
63    let placement_hint = args.placement.as_ref().map(|p| {
64        quote! {
65            fn placement_hint() -> Option<&'static str> { Some(#p) }
66        }
67    });
68
69    let storage_provider = args.storage.as_ref().map(|s| {
70        quote! {
71            fn storage_provider() -> &'static str { #s }
72        }
73    });
74
75    let stateless_worker_impl = if args.stateless_worker {
76        let max_act = args.max_activations.map(|n| {
77            quote! {
78                fn max_activations() -> usize { #n }
79            }
80        });
81        quote! {
82            impl ::orlando_core::StatelessWorker for #name {
83                #max_act
84            }
85        }
86    } else {
87        quote! {}
88    };
89
90    quote! {
91        #item_struct
92
93        #[::async_trait::async_trait]
94        impl ::orlando_core::Grain for #name {
95            type State = #state_type;
96            #idle_timeout
97            #ask_timeout
98            #grain_type_name
99            #reentrant
100            #placement_hint
101            #storage_provider
102        }
103
104        #stateless_worker_impl
105    }
106    .into()
107}
108
109struct GrainArgs {
110    state: Type,
111    idle_timeout_secs: Option<u64>,
112    stateless_worker: bool,
113    max_activations: Option<usize>,
114    reentrant: bool,
115    grain_name: Option<String>,
116    ask_timeout_secs: Option<u64>,
117    placement: Option<String>,
118    storage: Option<String>,
119}
120
121impl syn::parse::Parse for GrainArgs {
122    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
123        let mut state = None;
124        let mut idle_timeout_secs = None;
125        let mut stateless_worker = false;
126        let mut max_activations = None;
127        let mut reentrant = false;
128        let mut grain_name = None;
129        let mut ask_timeout_secs = None;
130        let mut placement = None;
131        let mut storage = None;
132
133        while !input.is_empty() {
134            let key: Ident = input.parse()?;
135
136            match key.to_string().as_str() {
137                "state" => {
138                    input.parse::<syn::Token![=]>()?;
139                    state = Some(input.parse::<Type>()?);
140                }
141                "idle_timeout_secs" => {
142                    input.parse::<syn::Token![=]>()?;
143                    let lit: LitInt = input.parse()?;
144                    idle_timeout_secs = Some(lit.base10_parse::<u64>()?);
145                }
146                "stateless_worker" => {
147                    stateless_worker = true;
148                }
149                "max_activations" => {
150                    input.parse::<syn::Token![=]>()?;
151                    let lit: LitInt = input.parse()?;
152                    max_activations = Some(lit.base10_parse::<usize>()?);
153                }
154                "reentrant" => {
155                    reentrant = true;
156                }
157                "name" => {
158                    input.parse::<syn::Token![=]>()?;
159                    let lit: LitStr = input.parse()?;
160                    grain_name = Some(lit.value());
161                }
162                "ask_timeout_secs" => {
163                    input.parse::<syn::Token![=]>()?;
164                    let lit: LitInt = input.parse()?;
165                    ask_timeout_secs = Some(lit.base10_parse::<u64>()?);
166                }
167                "placement" => {
168                    input.parse::<syn::Token![=]>()?;
169                    let lit: LitStr = input.parse()?;
170                    placement = Some(lit.value());
171                }
172                "storage" => {
173                    input.parse::<syn::Token![=]>()?;
174                    let lit: LitStr = input.parse()?;
175                    storage = Some(lit.value());
176                }
177                _ => {
178                    return Err(syn::Error::new(
179                        key.span(),
180                        format!("unknown attribute `{}`", key),
181                    ));
182                }
183            }
184
185            if !input.is_empty() {
186                input.parse::<syn::Token![,]>()?;
187            }
188        }
189
190        let state = state.ok_or_else(|| input.error("missing required `state` attribute"))?;
191        Ok(GrainArgs {
192            state,
193            idle_timeout_secs,
194            stateless_worker,
195            max_activations,
196            reentrant,
197            grain_name,
198            ask_timeout_secs,
199            placement,
200            storage,
201        })
202    }
203}
204
205// ── #[message(result = T)] ─────────────────────────────────────
206
207/// Define a message type with its result.
208///
209/// ```ignore
210/// #[message(result = i64)]
211/// struct GetCount;
212///
213/// // For cluster-capable messages (bincode only):
214/// #[message(result = i64, network)]
215/// #[derive(Serialize, Deserialize)]
216/// struct GetCount;
217///
218/// // For cluster + external client support (bincode + protobuf):
219/// #[message(result = CounterResult, network, proto)]
220/// #[derive(Serialize, Deserialize, prost::Message)]
221/// struct GetCount { ... }
222///
223/// // For versioned messages (rolling deploy safety):
224/// #[message(result = i64, network, version = 2)]
225/// #[derive(Serialize, Deserialize)]
226/// struct GetCountV2;
227/// ```
228#[proc_macro_attribute]
229pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
230    let args = parse_macro_input!(attr as MessageArgs);
231    let item_struct = parse_macro_input!(item as ItemStruct);
232
233    let name = &item_struct.ident;
234    let result_type = &args.result;
235
236    if args.proto && !args.network {
237        return syn::Error::new(
238            item_struct.ident.span(),
239            "`proto` requires `network` — protobuf encoding is only available for network-capable messages",
240        )
241        .to_compile_error()
242        .into();
243    }
244
245    if args.version.is_some() && !args.network {
246        return syn::Error::new(
247            item_struct.ident.span(),
248            "`version` requires `network` — versioning is only meaningful for network-capable messages",
249        )
250        .to_compile_error()
251        .into();
252    }
253
254    let network_impl = if args.network {
255        let name_str = name.to_string();
256
257        let version_method = args.version.map(|v| {
258            quote! {
259                fn message_version() -> u32 { #v }
260            }
261        });
262
263        let proto_methods = if args.proto {
264            quote! {
265                fn supports_proto() -> bool { true }
266
267                fn encode_proto(&self) -> Option<Vec<u8>> {
268                    use ::prost::Message;
269                    Some(::prost::Message::encode_to_vec(self))
270                }
271
272                fn decode_proto(bytes: &[u8]) -> Option<Self> {
273                    use ::prost::Message;
274                    <Self as ::prost::Message>::decode(bytes).ok()
275                }
276
277                fn encode_result_proto(result: &<Self as ::orlando_core::Message>::Result) -> Option<Vec<u8>> {
278                    use ::prost::Message;
279                    Some(::prost::Message::encode_to_vec(result))
280                }
281
282                fn decode_result_proto(bytes: &[u8]) -> Option<<Self as ::orlando_core::Message>::Result> {
283                    use ::prost::Message;
284                    <Self as ::orlando_core::Message>::Result::decode(bytes).ok()
285                }
286            }
287        } else {
288            quote! {}
289        };
290
291        quote! {
292            impl ::orlando_cluster::NetworkMessage for #name {
293                fn message_type_name() -> &'static str {
294                    #name_str
295                }
296                #version_method
297                #proto_methods
298            }
299        }
300    } else {
301        quote! {}
302    };
303
304    quote! {
305        #item_struct
306
307        impl ::orlando_core::Message for #name {
308            type Result = #result_type;
309        }
310
311        #network_impl
312    }
313    .into()
314}
315
316struct MessageArgs {
317    result: Type,
318    network: bool,
319    proto: bool,
320    version: Option<u32>,
321}
322
323impl syn::parse::Parse for MessageArgs {
324    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
325        let mut result = None;
326        let mut network = false;
327        let mut proto = false;
328        let mut version = None;
329
330        while !input.is_empty() {
331            let key: Ident = input.parse()?;
332
333            match key.to_string().as_str() {
334                "result" => {
335                    input.parse::<syn::Token![=]>()?;
336                    result = Some(input.parse::<Type>()?);
337                }
338                "network" => {
339                    network = true;
340                }
341                "proto" => {
342                    proto = true;
343                }
344                "version" => {
345                    input.parse::<syn::Token![=]>()?;
346                    let lit: LitInt = input.parse()?;
347                    version = Some(lit.base10_parse::<u32>()?);
348                }
349                _ => {
350                    return Err(syn::Error::new(
351                        key.span(),
352                        format!("unknown attribute `{}`", key),
353                    ));
354                }
355            }
356
357            if !input.is_empty() {
358                input.parse::<syn::Token![,]>()?;
359            }
360        }
361
362        let result = result.ok_or_else(|| input.error("missing required `result` attribute"))?;
363        Ok(MessageArgs {
364            result,
365            network,
366            proto,
367            version,
368        })
369    }
370}
371
372// ── #[grain_handler(GrainType)] ────────────────────────────────
373
374/// Define a handler for a grain + message combination.
375///
376/// ```ignore
377/// #[grain_handler(Counter)]
378/// async fn handle(state: &mut CounterState, msg: Increment, _ctx: &GrainContext) -> i64 {
379///     state.count += msg.amount;
380///     state.count
381/// }
382/// ```
383#[proc_macro_attribute]
384pub fn grain_handler(attr: TokenStream, item: TokenStream) -> TokenStream {
385    let grain_path = parse_macro_input!(attr as Path);
386    let func = parse_macro_input!(item as ItemFn);
387
388    let inputs = &func.sig.inputs;
389    let output = &func.sig.output;
390    let body = &func.block;
391    let attrs = &func.attrs;
392
393    let msg_type = match extract_msg_type(inputs) {
394        Some(ty) => ty,
395        None => {
396            return syn::Error::new_spanned(
397                &func.sig,
398                "grain_handler requires at least 2 parameters: (state: &mut State, msg: M, ...)",
399            )
400            .to_compile_error()
401            .into();
402        }
403    };
404
405    quote! {
406        #[::async_trait::async_trait]
407        impl ::orlando_core::GrainHandler<#msg_type> for #grain_path {
408            #(#attrs)*
409            async fn handle(#inputs) #output
410                #body
411        }
412    }
413    .into()
414}
415
416fn extract_msg_type(
417    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
418) -> Option<Type> {
419    let second = inputs.iter().nth(1)?;
420    match second {
421        syn::FnArg::Typed(pat_type) => Some((*pat_type.ty).clone()),
422        _ => None,
423    }
424}