bearings_proc/
lib.rs

1extern crate proc_macro;
2extern crate proc_macro2;
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8    parse_macro_input, Fields, FnArg, Ident, ImplItem, ItemImpl, ItemStruct, ItemTrait, Pat,
9    ReturnType, TraitItem, Type,
10};
11
12#[proc_macro_attribute]
13pub fn interface(attr: TokenStream, defn: TokenStream) -> TokenStream {
14    let mut input = parse_macro_input!(defn as ItemTrait);
15    let user_error = parse_macro_input!(attr as Type);
16
17    let mut client_functions = quote! {};
18    let mut dispatcher_cases = quote! {};
19
20    for mut item in &mut input.items {
21        match &mut item {
22            TraitItem::Method(ref mut method) => {
23                let signature = &mut method.sig;
24                let original_result = &signature.output;
25                match original_result {
26                    ReturnType::Type(_, original_result) => {
27                        let replacement = TokenStream::from(
28                            quote! { -> ::bearings::Result<#original_result, #user_error> },
29                        );
30                        signature.output = parse_macro_input!(replacement as ReturnType);
31                    }
32                    _ => {
33                        panic!("can't handle a function with default return type in a class");
34                    }
35                }
36
37                let mut arguments = quote! {};
38                let mut argument_tuple = quote! {};
39                let mut argument_expansion = quote! {};
40
41                let mut i: u32 = 0;
42
43                for arg in &signature.inputs {
44                    match arg {
45                        FnArg::Receiver(_) => (),
46                        FnArg::Typed(arg) => match &*arg.pat {
47                            Pat::Ident(name) => {
48                                arguments.extend(quote! {
49                                    &#name,
50                                });
51
52                                let ty = &arg.ty;
53                                argument_tuple.extend(quote! {
54                                    #ty,
55                                });
56
57                                let idx = syn::Index::from(i as usize);
58                                argument_expansion.extend(quote! {
59                                    arguments.#idx,
60                                });
61                                i += 1;
62                            }
63                            _ => {
64                                panic!("only an identifier is allowed as a name of a class method argument");
65                            }
66                        },
67                    }
68                }
69
70                let name = &signature.ident;
71                let return_type = match &signature.output {
72                    ReturnType::Default => {
73                        panic!("can't handle a method with a default return type");
74                    }
75                    ReturnType::Type(_, ty) => ty,
76                };
77
78                client_functions.extend(quote! {
79                    #signature {
80                        let id = {
81                            let mut id_guard = self.state.id.lock().await;
82                            let id = *id_guard;
83                            *id_guard += 1;
84                            id
85                        };
86
87                        let call = ::serde_json::to_string(&::bearings::Message::<_, #user_error>::Call(::bearings::FunctionCall{
88                            id: id,
89                            uuid: self.uuid.clone(),
90                            member: self.member.to_string(),
91                            method: stringify!(#name).to_string(),
92                            arguments: (#arguments),
93                        }))?;
94
95                        {
96                            let mut map = self.state.awaiters.lock().await;
97                            map.insert(id, ::tokio::sync::Mutex::from(::bearings::Awaiter::Empty));
98                        }
99
100                        let mut w = self.state.w.lock().await;
101                        use ::tokio::io::AsyncWriteExt;
102                        w.write_all(format!("{}\0", call).as_bytes()).await?;
103                        w.flush().await?;
104
105                        ::bearings::ReplyFuture::<
106                            <#return_type as ::std::iter::IntoIterator>::Item,
107                            T,
108                            #user_error
109                        >::new(self.state.clone(), id).await
110                    }
111                });
112
113                dispatcher_cases.extend(quote! {
114                    stringify!(#name) => {
115                        let arguments: (#argument_tuple) = ::serde_json::from_value(call.arguments)?;
116                        Ok(::bearings::Message::<(), #user_error>::Return(
117                            ::bearings::ReturnValue{
118                                id: call.id,
119                                result: ::serde_json::value::Value::from({
120                                    let mut result = object.lock().await;
121                                    let result = result.#name(#argument_expansion);
122                                    result.await?
123                                })
124                            }
125                        ))
126                    }
127                });
128            }
129            _ => {
130                panic!("only methods are allowed inside a class trait");
131            }
132        }
133    }
134
135    let name = &input.ident;
136    let error_name = syn::Ident::new(&format!("{}Error", name), Span::call_site());
137    let client_name = syn::Ident::new(&format!("{}Client", name), Span::call_site());
138    let dispatcher_name = syn::Ident::new(&format!("{}Dispatcher", name), Span::call_site());
139
140    let expanded = quote! {
141        #[::bearings::async_trait]
142        #input
143
144        struct #dispatcher_name {
145        }
146
147        impl #dispatcher_name {
148            async fn invoke_method<'a>(
149                object: &::tokio::sync::Mutex<Box<dyn #name + Send + 'a>>,
150                call: ::bearings::FunctionCall<serde_json::value::Value>,
151            ) -> ::bearings::Result<::bearings::Message<(), #user_error>, #user_error> {
152                match &call.method[..] {
153                    #dispatcher_cases
154                    _ => Err(::bearings::Error::UnknownMethod(call.member, call.method))
155                }
156            }
157        }
158
159        struct #client_name<T: Send + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite> {
160            uuid: ::uuid::Uuid,
161            member: &'static str,
162            state: ::bearings::StatePtr<T, #user_error>,
163        }
164
165        #[::bearings::async_trait]
166        impl<T: Send + Unpin + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite> #name for #client_name<T> {
167            #client_functions
168        }
169
170        type #error_name = #user_error;
171    };
172
173    TokenStream::from(expanded)
174}
175
176#[proc_macro_attribute]
177pub fn class(attr: TokenStream, defn: TokenStream) -> TokenStream {
178    let mut input = parse_macro_input!(defn as ItemImpl);
179    let user_error = parse_macro_input!(attr as Type);
180
181    for item in &mut input.items {
182        match item {
183            ImplItem::Method(ref mut method) => {
184                let signature = &mut method.sig;
185                let original_result = &signature.output;
186                match original_result {
187                    ReturnType::Type(_, original_result) => {
188                        let replacement = TokenStream::from(
189                            quote! { -> ::bearings::Result<#original_result, #user_error> },
190                        );
191                        signature.output = parse_macro_input!(replacement as ReturnType);
192                    }
193                    _ => {
194                        panic!("can't handle a function with default return type in a class");
195                    }
196                }
197            }
198
199            _ => panic!("unsupported class element"),
200        }
201    }
202
203    let expanded = quote! {
204        #[::bearings::async_trait]
205        #input
206    };
207
208    TokenStream::from(expanded)
209}
210
211#[proc_macro_attribute]
212pub fn object(attr: TokenStream, defn: TokenStream) -> TokenStream {
213    let input = parse_macro_input!(defn as ItemStruct);
214    let user_error = parse_macro_input!(attr as Type);
215
216    let mut fields = quote!();
217    let mut parameters = quote!();
218    let mut arguments = quote!();
219    let mut init = quote!();
220    let mut client_init = quote!();
221    let mut member_dispatch = quote!();
222
223    let mut i: u32 = 0;
224
225    match input.fields {
226        Fields::Named(ref named) => {
227            for field in named.named.iter() {
228                let name = field.ident.as_ref().unwrap();
229                let ty = &field.ty;
230
231                fields.extend(quote! {
232                    #name: ::tokio::sync::Mutex<Box<dyn #ty + Send + 'a>>,
233                });
234
235                let param_type = syn::Ident::new(&format!("T{}", i), Span::call_site());
236                i += 1;
237
238                parameters.extend(quote! {
239                    #param_type: #ty + Send + 'a,
240                });
241
242                arguments.extend(quote! {
243                    #name: #param_type,
244                });
245
246                init.extend(quote! {
247                    #name: ::tokio::sync::Mutex::from(Box::from(#name) as Box<dyn #ty + Send + 'a>),
248                });
249
250                let (client_type, dispatcher_type) = match ty {
251                    Type::Path(path) => {
252                        let mut client = path.clone();
253                        let mut dispatcher = path.clone();
254
255                        let mut last = client.path.segments.pop().unwrap().into_value();
256                        last.ident =
257                            Ident::new(&format!("{}Client", last.ident), last.ident.span());
258                        client.path.segments.push_value(last);
259
260                        let mut last = dispatcher.path.segments.pop().unwrap().into_value();
261                        last.ident =
262                            Ident::new(&format!("{}Dispatcher", last.ident), last.ident.span());
263                        dispatcher.path.segments.push_value(last);
264
265                        (client, dispatcher)
266                    }
267                    _ => {
268                        panic!("the type of a field of an object structure must be a previously defined class");
269                    }
270                };
271
272                client_init.extend(quote! {
273                    #name: ::tokio::sync::Mutex::from(Box::from(#client_type {
274                        uuid: uuid.clone(),
275                        member: stringify!(#name),
276                        state: state.clone()
277                    }) as Box<dyn #ty + Send + 'a>),
278                });
279
280                member_dispatch.extend(quote! {
281                    stringify!(#name) => #dispatcher_type::invoke_method(&self.#name, call).await,
282                });
283            }
284        }
285
286        _ => unimplemented!(),
287    }
288
289    let name = &input.ident;
290    let expanded = quote! {
291        struct #name<'a> {
292            __: std::marker::PhantomData<&'a ()>,
293
294            #fields
295        }
296
297        impl<'a> #name<'a> {
298            pub fn new<#parameters>(#arguments) -> Self {
299                Self{
300                    __: <_>::default(),
301                    #init
302                }
303            }
304
305            fn uuid() -> ::uuid::Uuid {
306                ::uuid::Uuid::new_v5(&::uuid::Uuid::nil(), stringify!(#name).as_bytes())
307            }
308        }
309
310        #[::bearings::async_trait]
311        impl<'a> ::bearings::Object<#user_error> for #name<'a> {
312            fn uuid() -> ::uuid::Uuid {
313                Self::uuid()
314            }
315
316            async fn invoke(
317                &self,
318                call: ::bearings::FunctionCall<::serde_json::value::Value>,
319            ) -> ::bearings::Result<::bearings::Message<(), #user_error>, #user_error> {
320                assert_eq!(Self::uuid(), call.uuid);
321
322                match &call.member[..] {
323                    #member_dispatch
324                    _ => Err(::bearings::Error::UnknownMember(call.member))
325                }
326            }
327        }
328
329        impl<'a> ::bearings::ObjectClient<'a, #user_error> for #name<'a> {
330            fn build<T: 'a + Send + Unpin + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite>(
331                state: ::bearings::StatePtr<T, #user_error>
332            ) -> Self {
333                let uuid = Self::uuid();
334                Self {
335                    __: <_>::default(),
336                    #client_init
337                }
338            }
339        }
340
341        unsafe impl Sync for #name<'_> {}
342    };
343
344    TokenStream::from(expanded)
345}