injoint_macros/
lib.rs

1/// This file contains the implementation of the injoint codegen.
2extern crate proc_macro;
3use crate::utils::snake_to_camel;
4use proc_macro::TokenStream;
5use quote::{quote, ToTokens};
6use syn::punctuated::Punctuated;
7use syn::{
8    parse_macro_input, DeriveInput, FnArg, Ident, ImplItem, ImplItemFn, ItemImpl, ItemStruct,
9    PatType, Signature, Token, Type,
10};
11
12mod utils;
13
14/// This macro derives the `Broadcastable` trait for a struct.
15#[proc_macro_derive(Broadcastable)]
16pub fn derive_broadcastable(input: TokenStream) -> TokenStream {
17    let input = parse_macro_input!(input as DeriveInput);
18
19    let struct_name = input.ident;
20
21    let expanded = quote! {
22        impl injoint::utils::types::Broadcastable for #struct_name {}
23    };
24
25    TokenStream::from(expanded)
26}
27
28/// This macro is used to annotate a struct as a reducer.
29#[proc_macro_attribute]
30pub fn reducer_struct(attr: TokenStream, item: TokenStream) -> TokenStream {
31    let input: ItemStruct = parse_macro_input!(item);
32
33    let reducer_name = input.clone().ident;
34
35    let args: Vec<Ident> =
36        parse_macro_input!(attr with Punctuated::<Ident, Token![,]>::parse_terminated)
37            .into_iter()
38            .collect();
39
40    let state_struct = args[0].clone();
41
42    let expanded = quote! {
43        impl injoint::utils::types::Broadcastable for #state_struct {}
44        impl injoint::utils::types::Broadcastable for #reducer_name {}
45        #[derive(serde::Serialize)]
46        #input
47    };
48
49    TokenStream::from(expanded)
50}
51
52/// This macro is used to annotate a struct as a reducer and generate the necessary
53/// code for the reducer actions.
54///
55/// It generates an enum for the actions, implements the `Dispatchable` trait,
56/// and provides methods for dispatching actions.
57///
58/// The macro takes the state struct as an argument and generates the
59/// necessary code for the reducer actions.
60#[proc_macro_attribute]
61pub fn reducer_actions(attr: TokenStream, item: TokenStream) -> TokenStream {
62    let input: ItemImpl = parse_macro_input!(item);
63
64    let implementation = input.clone();
65
66    let args: Vec<Ident> =
67        parse_macro_input!(attr with Punctuated::<Ident, Token![,]>::parse_terminated)
68            .into_iter()
69            .collect();
70
71    let state_struct = args[0].clone();
72
73    let reducer_name = match *input.self_ty {
74        Type::Path(ref type_path) => &type_path.path.segments.last().unwrap().ident,
75        _ => panic!("Invalid impl"),
76    };
77    let reducer_span = reducer_name.span();
78
79    let methods = input
80        .items
81        .iter()
82        .filter_map(|item| match item {
83            ImplItem::Fn(x) => Some(x.clone()),
84            _ => None,
85        })
86        .collect::<Vec<ImplItemFn>>();
87
88    fn parse_action_name(sig: &Signature) -> proc_macro2::TokenStream {
89        let span = sig.ident.clone().span();
90        let name = Ident::new(
91            &format!("Action{}", snake_to_camel(&sig.ident.to_string())),
92            span,
93        );
94        let expanded = quote! {#name};
95
96        expanded
97    }
98
99    fn parse_action_args(sig: &Signature) -> Vec<&PatType> {
100        let mut args = sig
101            .inputs
102            .iter()
103            .filter_map(|arg| match arg {
104                FnArg::Typed(item) => Some(item),
105                _ => None,
106            })
107            .collect::<Vec<_>>();
108
109        args.remove(0); // remove "client_id: u64" arg
110
111        args
112    }
113
114    fn parse_action_arg_types(sig: &Signature) -> Vec<proc_macro2::TokenStream> {
115        let span = sig.ident.clone().span();
116        parse_action_args(sig)
117            .iter()
118            .map(|item| {
119                Ident::new(&item.ty.clone().to_token_stream().to_string(), span).to_token_stream()
120            })
121            .collect::<Vec<_>>()
122    }
123
124    fn parse_action_arg_names(sig: &Signature) -> Vec<proc_macro2::TokenStream> {
125        let span = sig.ident.clone().span();
126        parse_action_args(sig)
127            .iter()
128            .map(|item| {
129                Ident::new(&item.pat.clone().to_token_stream().to_string(), span).to_token_stream()
130            })
131            .collect::<Vec<_>>()
132    }
133
134    let actions = methods
135        .clone()
136        .iter()
137        .map(|&ref method| {
138            let name = parse_action_name(&method.sig);
139            let args = parse_action_arg_types(&method.sig);
140
141            quote! {
142                #name(#(#args),*)
143            }
144        })
145        .collect::<Vec<_>>();
146
147    let action_enum_name = Ident::new(&format!("Action{}", reducer_name), reducer_span);
148
149    let action_names = methods
150        .clone()
151        .iter()
152        .map(|method| {
153            let enum_name = &action_enum_name.clone();
154            let action_name = parse_action_name(&method.sig);
155            // let action_name_str =
156            //     Ident::new(&format!("{}", action_name), action_name.span()).to_token_stream();
157            let action_name_str = &format!("{}", action_name);
158            let args = parse_action_arg_names(&method.sig);
159
160            let result = quote! {
161                #enum_name::#action_name(#(#args),*) => String::from(#action_name_str)
162            };
163
164            result
165        })
166        .collect::<Vec<_>>();
167
168    let action_handlers = methods
169        .clone()
170        .iter()
171        .map(|method| {
172            let enum_name = &action_enum_name.clone();
173            let action_name = parse_action_name(&method.sig);
174            let method_name = method.sig.ident.clone();
175            let args = parse_action_arg_names(&method.sig);
176
177            let result = quote! {
178                #enum_name::#action_name(#(#args),*) => self.#method_name(client_id, #(#args),*).await?
179            };
180
181            result
182        })
183        .collect::<Vec<_>>();
184
185    let enum_name = &action_enum_name.clone();
186
187    let expanded = quote! {
188        #implementation
189
190        #[derive(serde::Deserialize, Debug)]
191        #[serde(tag = "type", content = "data")]
192        enum #enum_name {
193            #(#actions),*
194        }
195
196        impl injoint::utils::types::Receivable for #enum_name {}
197
198        impl injoint::dispatcher::Dispatchable for #reducer_name {
199            type Action = #enum_name;
200            type State = #state_struct;
201
202            fn get_state(&self) -> #state_struct {
203                self.state.clone()
204            }
205
206            async fn dispatch(
207                &mut self,
208                client_id: u64,
209                action:  #enum_name,
210            ) -> Result<injoint::dispatcher::ActionResponse<#state_struct>, String> {
211                let name = match &action {
212                    #(#action_names),*
213                };
214
215                let msg = match action {
216                    #(#action_handlers),*
217                };
218
219                Ok(injoint::dispatcher::ActionResponse {
220                    status: name,
221                    state: self.state.clone(),
222                    author: client_id,
223                    data: msg,
224                })
225            }
226
227            async fn extern_dispatch(
228                &mut self,
229                client_id: u64,
230                action: &str,
231            ) -> Result<injoint::dispatcher::ActionResponse<#state_struct>, String> {
232                let action: #enum_name = serde_json::from_str(action).unwrap();
233                self.dispatch(client_id, action).await
234            }
235        }
236    };
237
238    TokenStream::from(expanded)
239}