1extern crate proc_macro;
2extern crate proc_macro2;
3
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::quote;
7use syn::{
8 parse_macro_input, Fields, FnArg, Ident, ImplItem, ItemImpl, ItemStruct, ItemTrait, Pat,
9 ReturnType, TraitItem, Type,
10};
11
12#[proc_macro_attribute]
13pub fn interface(attr: TokenStream, defn: TokenStream) -> TokenStream {
14 let mut input = parse_macro_input!(defn as ItemTrait);
15 let user_error = parse_macro_input!(attr as Type);
16
17 let mut client_functions = quote! {};
18 let mut dispatcher_cases = quote! {};
19
20 for mut item in &mut input.items {
21 match &mut item {
22 TraitItem::Method(ref mut method) => {
23 let signature = &mut method.sig;
24 let original_result = &signature.output;
25 match original_result {
26 ReturnType::Type(_, original_result) => {
27 let replacement = TokenStream::from(
28 quote! { -> ::bearings::Result<#original_result, #user_error> },
29 );
30 signature.output = parse_macro_input!(replacement as ReturnType);
31 }
32 _ => {
33 panic!("can't handle a function with default return type in a class");
34 }
35 }
36
37 let mut arguments = quote! {};
38 let mut argument_tuple = quote! {};
39 let mut argument_expansion = quote! {};
40
41 let mut i: u32 = 0;
42
43 for arg in &signature.inputs {
44 match arg {
45 FnArg::Receiver(_) => (),
46 FnArg::Typed(arg) => match &*arg.pat {
47 Pat::Ident(name) => {
48 arguments.extend(quote! {
49 &#name,
50 });
51
52 let ty = &arg.ty;
53 argument_tuple.extend(quote! {
54 #ty,
55 });
56
57 let idx = syn::Index::from(i as usize);
58 argument_expansion.extend(quote! {
59 arguments.#idx,
60 });
61 i += 1;
62 }
63 _ => {
64 panic!("only an identifier is allowed as a name of a class method argument");
65 }
66 },
67 }
68 }
69
70 let name = &signature.ident;
71 let return_type = match &signature.output {
72 ReturnType::Default => {
73 panic!("can't handle a method with a default return type");
74 }
75 ReturnType::Type(_, ty) => ty,
76 };
77
78 client_functions.extend(quote! {
79 #signature {
80 let id = {
81 let mut id_guard = self.state.id.lock().await;
82 let id = *id_guard;
83 *id_guard += 1;
84 id
85 };
86
87 let call = ::serde_json::to_string(&::bearings::Message::<_, #user_error>::Call(::bearings::FunctionCall{
88 id: id,
89 uuid: self.uuid.clone(),
90 member: self.member.to_string(),
91 method: stringify!(#name).to_string(),
92 arguments: (#arguments),
93 }))?;
94
95 {
96 let mut map = self.state.awaiters.lock().await;
97 map.insert(id, ::tokio::sync::Mutex::from(::bearings::Awaiter::Empty));
98 }
99
100 let mut w = self.state.w.lock().await;
101 use ::tokio::io::AsyncWriteExt;
102 w.write_all(format!("{}\0", call).as_bytes()).await?;
103 w.flush().await?;
104
105 ::bearings::ReplyFuture::<
106 <#return_type as ::std::iter::IntoIterator>::Item,
107 T,
108 #user_error
109 >::new(self.state.clone(), id).await
110 }
111 });
112
113 dispatcher_cases.extend(quote! {
114 stringify!(#name) => {
115 let arguments: (#argument_tuple) = ::serde_json::from_value(call.arguments)?;
116 Ok(::bearings::Message::<(), #user_error>::Return(
117 ::bearings::ReturnValue{
118 id: call.id,
119 result: ::serde_json::value::Value::from({
120 let mut result = object.lock().await;
121 let result = result.#name(#argument_expansion);
122 result.await?
123 })
124 }
125 ))
126 }
127 });
128 }
129 _ => {
130 panic!("only methods are allowed inside a class trait");
131 }
132 }
133 }
134
135 let name = &input.ident;
136 let error_name = syn::Ident::new(&format!("{}Error", name), Span::call_site());
137 let client_name = syn::Ident::new(&format!("{}Client", name), Span::call_site());
138 let dispatcher_name = syn::Ident::new(&format!("{}Dispatcher", name), Span::call_site());
139
140 let expanded = quote! {
141 #[::bearings::async_trait]
142 #input
143
144 struct #dispatcher_name {
145 }
146
147 impl #dispatcher_name {
148 async fn invoke_method<'a>(
149 object: &::tokio::sync::Mutex<Box<dyn #name + Send + 'a>>,
150 call: ::bearings::FunctionCall<serde_json::value::Value>,
151 ) -> ::bearings::Result<::bearings::Message<(), #user_error>, #user_error> {
152 match &call.method[..] {
153 #dispatcher_cases
154 _ => Err(::bearings::Error::UnknownMethod(call.member, call.method))
155 }
156 }
157 }
158
159 struct #client_name<T: Send + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite> {
160 uuid: ::uuid::Uuid,
161 member: &'static str,
162 state: ::bearings::StatePtr<T, #user_error>,
163 }
164
165 #[::bearings::async_trait]
166 impl<T: Send + Unpin + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite> #name for #client_name<T> {
167 #client_functions
168 }
169
170 type #error_name = #user_error;
171 };
172
173 TokenStream::from(expanded)
174}
175
176#[proc_macro_attribute]
177pub fn class(attr: TokenStream, defn: TokenStream) -> TokenStream {
178 let mut input = parse_macro_input!(defn as ItemImpl);
179 let user_error = parse_macro_input!(attr as Type);
180
181 for item in &mut input.items {
182 match item {
183 ImplItem::Method(ref mut method) => {
184 let signature = &mut method.sig;
185 let original_result = &signature.output;
186 match original_result {
187 ReturnType::Type(_, original_result) => {
188 let replacement = TokenStream::from(
189 quote! { -> ::bearings::Result<#original_result, #user_error> },
190 );
191 signature.output = parse_macro_input!(replacement as ReturnType);
192 }
193 _ => {
194 panic!("can't handle a function with default return type in a class");
195 }
196 }
197 }
198
199 _ => panic!("unsupported class element"),
200 }
201 }
202
203 let expanded = quote! {
204 #[::bearings::async_trait]
205 #input
206 };
207
208 TokenStream::from(expanded)
209}
210
211#[proc_macro_attribute]
212pub fn object(attr: TokenStream, defn: TokenStream) -> TokenStream {
213 let input = parse_macro_input!(defn as ItemStruct);
214 let user_error = parse_macro_input!(attr as Type);
215
216 let mut fields = quote!();
217 let mut parameters = quote!();
218 let mut arguments = quote!();
219 let mut init = quote!();
220 let mut client_init = quote!();
221 let mut member_dispatch = quote!();
222
223 let mut i: u32 = 0;
224
225 match input.fields {
226 Fields::Named(ref named) => {
227 for field in named.named.iter() {
228 let name = field.ident.as_ref().unwrap();
229 let ty = &field.ty;
230
231 fields.extend(quote! {
232 #name: ::tokio::sync::Mutex<Box<dyn #ty + Send + 'a>>,
233 });
234
235 let param_type = syn::Ident::new(&format!("T{}", i), Span::call_site());
236 i += 1;
237
238 parameters.extend(quote! {
239 #param_type: #ty + Send + 'a,
240 });
241
242 arguments.extend(quote! {
243 #name: #param_type,
244 });
245
246 init.extend(quote! {
247 #name: ::tokio::sync::Mutex::from(Box::from(#name) as Box<dyn #ty + Send + 'a>),
248 });
249
250 let (client_type, dispatcher_type) = match ty {
251 Type::Path(path) => {
252 let mut client = path.clone();
253 let mut dispatcher = path.clone();
254
255 let mut last = client.path.segments.pop().unwrap().into_value();
256 last.ident =
257 Ident::new(&format!("{}Client", last.ident), last.ident.span());
258 client.path.segments.push_value(last);
259
260 let mut last = dispatcher.path.segments.pop().unwrap().into_value();
261 last.ident =
262 Ident::new(&format!("{}Dispatcher", last.ident), last.ident.span());
263 dispatcher.path.segments.push_value(last);
264
265 (client, dispatcher)
266 }
267 _ => {
268 panic!("the type of a field of an object structure must be a previously defined class");
269 }
270 };
271
272 client_init.extend(quote! {
273 #name: ::tokio::sync::Mutex::from(Box::from(#client_type {
274 uuid: uuid.clone(),
275 member: stringify!(#name),
276 state: state.clone()
277 }) as Box<dyn #ty + Send + 'a>),
278 });
279
280 member_dispatch.extend(quote! {
281 stringify!(#name) => #dispatcher_type::invoke_method(&self.#name, call).await,
282 });
283 }
284 }
285
286 _ => unimplemented!(),
287 }
288
289 let name = &input.ident;
290 let expanded = quote! {
291 struct #name<'a> {
292 __: std::marker::PhantomData<&'a ()>,
293
294 #fields
295 }
296
297 impl<'a> #name<'a> {
298 pub fn new<#parameters>(#arguments) -> Self {
299 Self{
300 __: <_>::default(),
301 #init
302 }
303 }
304
305 fn uuid() -> ::uuid::Uuid {
306 ::uuid::Uuid::new_v5(&::uuid::Uuid::nil(), stringify!(#name).as_bytes())
307 }
308 }
309
310 #[::bearings::async_trait]
311 impl<'a> ::bearings::Object<#user_error> for #name<'a> {
312 fn uuid() -> ::uuid::Uuid {
313 Self::uuid()
314 }
315
316 async fn invoke(
317 &self,
318 call: ::bearings::FunctionCall<::serde_json::value::Value>,
319 ) -> ::bearings::Result<::bearings::Message<(), #user_error>, #user_error> {
320 assert_eq!(Self::uuid(), call.uuid);
321
322 match &call.member[..] {
323 #member_dispatch
324 _ => Err(::bearings::Error::UnknownMember(call.member))
325 }
326 }
327 }
328
329 impl<'a> ::bearings::ObjectClient<'a, #user_error> for #name<'a> {
330 fn build<T: 'a + Send + Unpin + ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite>(
331 state: ::bearings::StatePtr<T, #user_error>
332 ) -> Self {
333 let uuid = Self::uuid();
334 Self {
335 __: <_>::default(),
336 #client_init
337 }
338 }
339 }
340
341 unsafe impl Sync for #name<'_> {}
342 };
343
344 TokenStream::from(expanded)
345}