Skip to main content

combadge_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    parse, parse_macro_input, Field, Fields, FnArg, GenericArgument, Ident, ImplItem, Index,
7    ItemImpl, ItemStruct, ItemTrait, LitInt, Pat, PathArguments, ReturnType, TraitItem, Type,
8    TypeParamBound, Visibility,
9};
10
11fn parse_count(item: TokenStream) -> usize {
12    let Ok(count) = parse::<LitInt>(item.clone().into()) else {
13        panic!("expected an integer literal");
14    };
15
16    let Ok(count) = count.base10_parse::<usize>() else {
17        panic!("failed to parse {count} as usize");
18    };
19
20    if count == 0 {
21        panic!("must generate at least 1 variable");
22    }
23
24    if count > 26 {
25        panic!("can only generate up to 26 variables without running out of letters");
26    }
27
28    count
29}
30
31fn build_variables(count: usize) -> (Vec<Ident>, Vec<Ident>) {
32    let type_name = (0..count)
33        .map(|i| char::from(b'A' + i as u8))
34        .collect::<Vec<_>>();
35
36    let variable_name = type_name
37        .iter()
38        .map(|t| format_ident!("{}", t.to_ascii_lowercase()))
39        .collect::<Vec<_>>();
40
41    let type_name = type_name
42        .iter()
43        .map(|t| format_ident!("{}", t))
44        .collect::<Vec<_>>();
45
46    (type_name, variable_name)
47}
48
49#[proc_macro]
50pub fn build_call_traits(item: TokenStream) -> TokenStream {
51    let max_count = parse_count(item);
52
53    let mut call_traits = quote! {};
54    for count in 1..=max_count {
55        let (type_name, variable_name) = build_variables(count);
56        let trait_name = format_ident!("Call{}", count);
57
58        call_traits = quote! {
59            #call_traits
60
61            pub trait #trait_name<#(#type_name),*, Return> {
62                fn call(&self, #(#variable_name: #type_name),*) -> AsyncReturnWithError<Return>;
63            }
64
65            impl<#(#type_name),*, Return: 'static> #trait_name<#(#type_name),*, Return> for Callback<(#(#type_name),*,), Return>
66            where
67                <((#(#type_name),*,), Return) as CallbackTypes>::Local: Fn(#(#type_name),*) -> Return,
68                <((#(#type_name),*,), Return) as CallbackTypes>::AsyncLocal: Fn(#(#type_name),*) -> AsyncReturn<Return>,
69                <((#(#type_name),*,), Return) as CallbackTypes>::Remote: Fn(#(#type_name),*) -> AsyncReturnWithError<Return>,
70            {
71                fn call(&self, #(#variable_name: #type_name),*) -> AsyncReturnWithError<Return> {
72                    if let Some(remote) = &self.remote {
73                        remote(#(#variable_name),*)
74                    } else if let Some(local) = &self.local {
75                        let response = local(#(#variable_name),*);
76                        Box::pin(async { Ok(response) })
77                    } else if let Some(async_local) = &self.async_local {
78                        let result = async_local(#(#variable_name),*);
79                        Box::pin(async move {
80                            let result = result.await;
81                            Ok(result)
82                        })
83                    } else {
84                        Box::pin(async {
85                            Err(Error::CallbackFailed {
86                                error: String::from("callbacks (both remote and local) not found"),
87                            })
88                        })
89                    }
90                }
91            }
92        }
93    }
94
95    call_traits.into()
96}
97
98#[proc_macro]
99pub fn build_callback_from_closure(item: TokenStream) -> TokenStream {
100    let max_count = parse_count(item);
101
102    let mut callback_from_closure = quote! {};
103    for count in 1..=max_count {
104        let (type_name, _) = build_variables(count);
105
106        callback_from_closure = quote! {
107            #callback_from_closure
108
109            impl<#(#type_name),*, Return> From<Box<dyn Fn(#(#type_name),*) -> Return>> for Callback<(#(#type_name),*,), Return> {
110                fn from(callback: Box<dyn Fn(#(#type_name),*) -> Return>) -> Self {
111                    Self {
112                        local: Some(callback),
113                        async_local: None,
114                        remote: None,
115                    }
116                }
117            }
118
119            impl<#(#type_name),*, Return> From<Box<dyn Fn(#(#type_name),*) -> AsyncReturn<Return>>> for Callback<(#(#type_name),*,), Return> {
120                fn from(callback: Box<dyn Fn(#(#type_name),*) -> AsyncReturn<Return>>) -> Self {
121                    Self {
122                        local: None,
123                        async_local: Some(callback),
124                        remote: None,
125                    }
126                }
127            }
128
129            impl<#(#type_name),*, Return> From<Box<dyn Fn(#(#type_name),*) -> AsyncReturnWithError<Return>>> for Callback<(#(#type_name),*,), Return> {
130                fn from(callback: Box<dyn Fn(#(#type_name),*) -> AsyncReturnWithError<Return>>) -> Self {
131                    Self {
132                        local: None,
133                        async_local: None,
134                        remote: Some(callback),
135                    }
136                }
137            }
138        }
139    }
140
141    callback_from_closure.into()
142}
143
144#[proc_macro]
145pub fn build_callback_types(item: TokenStream) -> TokenStream {
146    let max_count = parse_count(item);
147
148    let mut callback_types = quote! {};
149    for count in 1..=max_count {
150        let (type_name, _) = build_variables(count);
151
152        callback_types = quote! {
153            #callback_types
154
155            impl<#(#type_name),*, Return> CallbackTypes for ((#(#type_name),*,), Return) {
156                type Local = Box<dyn Fn(#(#type_name),*) -> Return>;
157                type AsyncLocal = Box<dyn Fn(#(#type_name),*) -> AsyncReturn<Return>>;
158                type Remote = Box<dyn Fn(#(#type_name),*) -> AsyncReturnWithError<Return>>;
159            }
160        }
161    }
162
163    callback_types.into()
164}
165
166#[proc_macro]
167pub fn build_post_tuple(item: TokenStream) -> TokenStream {
168    let max_count = parse_count(item);
169
170    let mut post_tuple = quote! {};
171    for count in 1..=max_count {
172        let (type_name, _) = build_variables(count);
173        let index = (0..count).map(Index::from).collect::<Vec<_>>();
174
175        post_tuple = quote! {
176            #post_tuple
177
178            impl<#(#type_name),*> PostTuple<(#(#type_name),*,)> for Message {
179                fn post_tuple(&mut self, tuple: (#(#type_name),*,)) -> Result<(), Error> {
180                    #(
181                        self.post(tuple.#index)?;
182                    )*
183                    Ok(())
184                }
185            }
186        }
187    }
188
189    post_tuple.into()
190}
191
192#[proc_macro]
193pub fn build_post_for_tuple(item: TokenStream) -> TokenStream {
194    let max_count = parse_count(item);
195
196    let mut post = quote! {};
197    for count in 1..=max_count {
198        let (type_name, _) = build_variables(count);
199        let index = (0..count).map(Index::from).collect::<Vec<_>>();
200
201        post = quote! {
202            #post
203
204            impl<#(#type_name),*> Post for (#(#type_name),*,)
205            where
206                #(#type_name: Post),*
207            {
208                const POSTABLE: bool = true #(&& <#type_name as Post>::POSTABLE)*;
209
210                fn from_js_value(value: JsValue) -> Result<Self, Error> {
211                    let array: Array = value.dyn_into().map_err(|error| Error::DeserializeFailed {
212                        type_name: String::from(type_name::<(#(#type_name),*,)>()),
213                        error: format!("{error:?}"),
214                    })?;
215                    Ok((#(
216                        #type_name::from_js_value(array.get(#index))?
217                    ),*,))
218                }
219
220                fn to_js_value(self) -> Result<JsValue, Error> {
221                    let array = Array::from_iter([
222                        #(#type_name::to_js_value(self.#index)?),*
223                    ].into_iter());
224                    Ok(array.into())
225                }
226            }
227        }
228    }
229
230    post.into()
231}
232
233#[proc_macro]
234pub fn build_transfer_for_tuple(item: TokenStream) -> TokenStream {
235    let max_count = parse_count(item);
236
237    let mut transfer = quote! {};
238    for count in 1..=max_count {
239        let (type_name, _) = build_variables(count);
240        let index = (0..count).map(Index::from).collect::<Vec<_>>();
241
242        transfer = quote! {
243            #transfer
244
245            impl<#(#type_name),*> Transfer for (#(#type_name),*,)
246            where
247                #(#type_name: Transfer),*
248            {
249                fn get_transferable(js_value: &JsValue) -> Option<Array> {
250                    let as_array: &Array = js_value.dyn_ref()?;
251                    let mut transferable = Array::new();
252                    #(
253                        if let Some(array) = #type_name::get_transferable(&as_array.get(#index)) {
254                            transferable.extend(array.into_iter());
255                        }
256                    )*
257
258                    if transferable.length() == 0 {
259                        None
260                    } else {
261                        Some(transferable)
262                    }
263                }
264            }
265        }
266    }
267
268    transfer.into()
269}
270
271#[proc_macro]
272pub fn build_responder(item: TokenStream) -> TokenStream {
273    let max_count = parse_count(item);
274
275    let mut responder = quote! {};
276    for count in 1..=max_count {
277        let (type_name, variable_name) = build_variables(count);
278
279        responder = quote! {
280            #responder
281
282            impl<#(#type_name),*, Return> Responder for Box<dyn Fn(#(#type_name),*) -> Return> {
283                default fn respond(&self, arguments_: Array, port_: MessagePort) -> Result<(), Error> {
284                    #(
285                        let #variable_name: #type_name = Post::from_js_value(arguments_.shift())?;
286                    )*
287                    let result = Post::to_js_value(self(#(#variable_name),*))?;
288
289                    if let Some(transferable) = <Return as Transfer>::get_transferable(&result) {
290                        port_.post_message_with_transferable(&result, &transferable)
291                            .map_err(|error| Error::PostFailed {
292                                error: format!("failed to respond in Responder: {error:?}"),
293                            })?;
294                    } else {
295                        port_.post_message(&result)
296                            .map_err(|error| Error::PostFailed {
297                                error: format!("failed to respond in Responder: {error:?}"),
298                            })?;
299                    }
300
301                    Ok(())
302                }
303            }
304
305            impl<#(#type_name),*, Return: 'static> Responder for Box<dyn Fn(#(#type_name),*) -> Box<dyn Future<Output = Return>>> {
306                fn respond(&self, arguments_: Array, port_: MessagePort) -> Result<(), Error> {
307                    #(
308                        let #variable_name: #type_name = Post::from_js_value(arguments_.shift())?;
309                    )*
310                    let result = self(#(#variable_name),*);
311                    let future_result = async move {
312                        let result = Box::into_pin(result).await;
313                        let value = match Post::to_js_value(result) {
314                            Ok(value) => value,
315                            Err(error) => {
316                                crate::log_error!("error while converting to JsValue in future: {error:?}");
317                                return;
318                            }
319                        };
320
321                        if let Err(error) = Return::get_transferable(&value).map_or_else(
322                            || port_.post_message(&value),
323                            |transferable| port_.post_message_with_transferable(&value, &Array::of1(&value))
324                        ) {
325                            crate::log_error!("error while posting async: {error:?}");
326                        }
327                    };
328                    spawn_local(future_result);
329                    Ok(())
330                }
331            }
332        }
333    }
334
335    responder.into()
336}
337
338#[proc_macro]
339pub fn build_to_closure(item: TokenStream) -> TokenStream {
340    let max_count = parse_count(item);
341
342    let mut to_closure = quote! {};
343    for count in 1..=max_count {
344        let (type_name, variable_name) = build_variables(count);
345
346        to_closure = quote! {
347            #to_closure
348
349            impl<#(#type_name: 'static),*, Return: 'static> ToClosure for CallbackClient<(#(#type_name),*,), Return> {
350                type Output = Box<dyn Fn(#(#type_name),*) -> AsyncReturnWithError<Return>>;
351                fn to_closure(self) -> Box<dyn Fn(#(#type_name),*) -> AsyncReturnWithError<Return>> {
352                    Box::new(move |#(#variable_name),*| self.call((#(#variable_name),*,)))
353                }
354            }
355        }
356    }
357
358    to_closure.into()
359}
360
361fn parse_named_fields<'a>(fields: impl Iterator<Item = &'a Field>) -> (Vec<Ident>, Vec<Type>) {
362    fields
363        .map(|field| (field.ident.clone().unwrap(), field.ty.clone()))
364        .unzip()
365}
366
367#[proc_macro_derive(Post)]
368pub fn derive_post(item: TokenStream) -> TokenStream {
369    let item_struct: ItemStruct = parse_macro_input!(item);
370
371    let postable = match &item_struct.fields {
372        Fields::Named(fields) => {
373            let (_, field_type) = parse_named_fields(fields.named.iter());
374            quote! {
375                const POSTABLE: bool = true #(&& <#field_type as Post>::POSTABLE)*;
376            }
377        }
378        _ => unimplemented!(),
379    };
380
381    let from_js_value = match &item_struct.fields {
382        Fields::Named(fields) => {
383            let (field_name, field_type) = parse_named_fields(fields.named.iter());
384            let index = (0..field_name.len()).map(Index::from).collect::<Vec<_>>();
385            quote! {
386                fn from_js_value(value: combadge::reexports::wasm_bindgen::JsValue) -> std::result::Result<Self, combadge::Error> {
387                    let array: combadge::reexports::js_sys::Array = value.dyn_into().map_err(|error| combadge::Error::DeserializeFailed {
388                        type_name: String::from(std::any::type_name::<Self>()),
389                        error: format!("{error:?}"),
390                    })?;
391                    Ok(Self {
392                        #(
393                            #field_name: <#field_type as Post>::from_js_value(array.get(#index))?
394                        ),*
395                    })
396                }
397            }
398        }
399        _ => unimplemented!(),
400    };
401
402    let to_js_value = match item_struct.fields {
403        Fields::Named(fields) => {
404            let (field_name, field_type) = parse_named_fields(fields.named.iter());
405            quote! {
406                fn to_js_value(self) -> std::result::Result<combadge::reexports::wasm_bindgen::JsValue, combadge::Error> {
407                    Ok(combadge::reexports::js_sys::Array::from_iter([
408                        #(<#field_type as Post>::to_js_value(self.#field_name)?),*
409                    ].into_iter()).into())
410                }
411            }
412        }
413        _ => unimplemented!(),
414    };
415
416    let struct_name = item_struct.ident;
417    quote! {
418        impl Post for #struct_name {
419            #postable
420            #from_js_value
421            #to_js_value
422        }
423    }
424    .into()
425}
426
427#[proc_macro_derive(Transfer)]
428pub fn derive_transfer(item: TokenStream) -> TokenStream {
429    let item_struct: ItemStruct = parse_macro_input!(item);
430
431    let get_transferable = match item_struct.fields {
432        Fields::Named(fields) => {
433            let (field_name, field_type) = parse_named_fields(fields.named.iter());
434            let index = (0..field_name.len()).map(Index::from).collect::<Vec<_>>();
435            quote! {
436                fn get_transferable(value: &combadge::reexports::wasm_bindgen::JsValue) -> Option<combadge::reexports::js_sys::Array> {
437                    let as_array: &combadge::reexports::js_sys::Array = value.dyn_ref()?;
438                    let mut transferable = combadge::reexports::js_sys::Array::new();
439                    #(
440                        if let Some(array) = #field_type::get_transferable(&as_array.get(#index)) {
441                            transferable.extend(array)
442                        }
443                    )*
444                    if transferable.length() == 0 {
445                        None
446                    } else {
447                        Some(transferable)
448                    }
449                }
450            }
451        }
452        _ => unimplemented!(),
453    };
454
455    let struct_name = item_struct.ident;
456    quote! {
457        impl Transfer for #struct_name {
458            #get_transferable
459        }
460    }
461    .into()
462}
463
464#[proc_macro_attribute]
465pub fn combadge(_attr: TokenStream, item: TokenStream) -> TokenStream {
466    let item: ItemTrait = parse_macro_input!(item);
467    let trait_name = item.ident.clone();
468
469    let functions = item
470        .items
471        .iter()
472        .filter_map(|item| match item {
473            TraitItem::Fn(f) => Some(f),
474            _ => None,
475        })
476        .collect::<Vec<_>>();
477
478    let name = functions
479        .iter()
480        .map(|function| function.sig.ident.clone())
481        .collect::<Vec<_>>();
482
483    let name_string = name.iter().map(|name| name.to_string()).collect::<Vec<_>>();
484
485    let argument = functions
486        .iter()
487        .map(|function| function.sig.inputs.iter().collect::<Vec<_>>())
488        .collect::<Vec<_>>();
489
490    let non_receiver = argument
491        .iter()
492        .enumerate()
493        .map(|(index, arguments)| {
494            let non_receiver = arguments
495                .iter()
496                .filter_map(|arg| match arg {
497                    FnArg::Receiver(_) => None,
498                    FnArg::Typed(typed) => Some(typed.clone()),
499                })
500                .collect::<Vec<_>>();
501
502            if non_receiver.len() == arguments.len() {
503                panic!(
504                    "expected {} to have a receiver (self parameter)",
505                    name[index]
506                )
507            }
508
509            non_receiver
510        })
511        .collect::<Vec<_>>();
512
513    let non_receiver_name = non_receiver
514        .iter()
515        .map(|non_receiver| {
516            non_receiver
517                .iter()
518                .filter_map(|item| match item.pat.as_ref() {
519                    Pat::Ident(ident) => Some(ident.clone()),
520                    _ => None,
521                })
522                .collect::<Vec<_>>()
523        })
524        .collect::<Vec<_>>();
525
526    let non_receiver_type = non_receiver
527        .iter()
528        .map(|non_receiver| {
529            non_receiver
530                .iter()
531                .map(|item| item.ty.clone())
532                .collect::<Vec<_>>()
533        })
534        .collect::<Vec<_>>();
535
536    let output = functions
537        .iter()
538        .map(|function| function.sig.output.clone())
539        .collect::<Vec<_>>();
540
541    let internal_type = output
542        .iter()
543        .map(|output| match output {
544            ReturnType::Default => quote! { () },
545            ReturnType::Type(_, t) => match t.as_ref() {
546                Type::Path(path) => {
547                    if path.path.segments.len() > 1
548                        || path.path.segments.get(0).unwrap().ident != "Box"
549                    {
550                        return quote! { #t };
551                    }
552                    let segment = path.path.segments.get(0).unwrap();
553                    match &segment.arguments {
554                        PathArguments::AngleBracketed(arguments) => {
555                            if arguments.args.len() > 1 {
556                                return quote! { #t };
557                            }
558                            let argument = arguments.args.get(0).unwrap();
559                            match argument {
560                                GenericArgument::Type(generic_type) => match generic_type {
561                                    Type::TraitObject(trait_) => {
562                                        if trait_.dyn_token.is_none() || trait_.bounds.len() > 1 {
563                                            return quote! { #t };
564                                        }
565
566                                        match trait_.bounds.get(0).unwrap() {
567                                            TypeParamBound::Trait(bound) => {
568                                                if bound.path.segments.len() > 1 {
569                                                    return quote! { #t };
570                                                }
571
572                                                let segment = bound.path.segments.get(0).unwrap();
573                                                if segment.ident != "Future" {
574                                                    return quote! { #t };
575                                                }
576
577                                                if let PathArguments::AngleBracketed(arguments) =
578                                                    &segment.arguments
579                                                {
580                                                    if arguments.args.len() > 1 {
581                                                        return quote! { #t };
582                                                    }
583
584                                                    match arguments.args.get(0).unwrap() {
585                                                        GenericArgument::AssocType(assoc) => {
586                                                            if assoc.ident != "Output" {
587                                                                return quote! { #t };
588                                                            }
589
590                                                            let generic_type = &assoc.ty;
591                                                            quote! { #generic_type }
592                                                        }
593                                                        _ => quote! { #t },
594                                                    }
595                                                } else {
596                                                    quote! { #t }
597                                                }
598                                            }
599                                            _ => quote! { #t },
600                                        }
601                                    }
602                                    _ => quote! { #t },
603                                },
604                                _ => quote! { #t},
605                            }
606                        }
607                        _ => quote! { #t },
608                    }
609                }
610                _ => quote! { #t },
611            },
612        })
613        .collect::<Vec<_>>();
614
615    let client_name = format_ident!("{}Client", item.ident);
616    let client = quote! {
617        #[derive(Clone, Debug)]
618        pub struct #client_name<P: ::combadge::Port + 'static> {
619            client: std::rc::Rc<std::cell::RefCell<::combadge::Client::<P>>>,
620        }
621
622        impl<P: ::combadge::Port + 'static> #client_name<P> {
623            pub fn new(port: P) -> Self {
624                Self { client: ::combadge::Client::new(port) }
625            }
626
627            #(
628                #[expect(clippy::future_not_send)]
629                pub fn #name(#(#argument),*) -> impl std::future::Future<Output = Result<#internal_type, ::combadge::Error>> {
630                    use ::combadge::reexports::futures::future::FutureExt;
631                    use ::combadge::reexports::futures::future::TryFutureExt;
632                    const _: () = assert!(<#internal_type as ::combadge::Post>::POSTABLE);
633
634                    let message_ = Ok(::combadge::Message::new(#name_string));
635                    #(
636                        const _: () = assert!(<#non_receiver_type as ::combadge::Post>::POSTABLE);
637                        let message_ = message_.and_then(|mut message_| {
638                            message_.post(#non_receiver_name)?;
639                            Ok(message_)
640                        });
641                    )*
642
643                    let server_ready_ = match self
644                        .client
645                        .try_borrow_mut()
646                        .map_err(|_| ::combadge::Error::ClientUnavailable)
647                    {
648                        Ok(mut client) => client.wait_for_server().map(|()| Ok(())).left_future(),
649                        Err(error) => async { Err(error) }.right_future(),
650                    };
651
652                    let client_clone = self.client.clone();
653                    server_ready_.then(move |result_| {
654                        let message_ = result_.and(message_);
655                        async { message_ }.and_then(move |message_| {
656                            let client = client_clone
657                                .try_borrow_mut()
658                                .map_err(|_| ::combadge::Error::ClientUnavailable);
659                            let message_ = client.map(|mut client| client.send_message::<#internal_type>(message_));
660                            async { message_ }.try_flatten().map(|result_| {
661                                let result_: Result<#internal_type, ::combadge::Error> = result_.map(std::convert::Into::into);
662                                result_
663                            })
664                        })
665                    })
666                }
667            )*
668        }
669    };
670
671    let server_name = format_ident!("{}Server", item.ident);
672    let server = quote! {
673        pub struct #server_name<P: ::combadge::Port + 'static> {
674            server: std::rc::Rc<std::cell::RefCell<::combadge::Server<P>>>,
675        }
676
677        impl<P: ::combadge::Port + 'static> #server_name<P> {
678            pub fn create<L: #trait_name + 'static>(mut local: L, port: P) {
679                let dispatch = Box::new(move |procedure: &str, data| {
680                    match procedure {
681                        #(
682                            #name_string => Self::#name(&mut local, data),
683                        )*
684                        _ => Err(::combadge::Error::UnknownProcedure{ name: String::from(procedure) })
685                    }
686                });
687
688                ::combadge::Server::create(port, dispatch);
689            }
690
691            #(
692                fn #name(local_: &mut dyn #trait_name, data_: ::combadge::reexports::js_sys::Array) -> Result<(), ::combadge::Error> {
693                    use ::combadge::reexports::wasm_bindgen_futures::spawn_local;
694
695                    #(
696                        const _: () = assert!(<#non_receiver_type as ::combadge::Post>::POSTABLE);
697                        let #non_receiver = ::combadge::Post::from_js_value(data_.shift())?;
698                    )*
699                    let result_ = local_.#name(#(#non_receiver_name),*);
700                    let port_: ::combadge::reexports::web_sys::MessagePort = data_.shift().into();
701                    let async_result_ = ::combadge::MaybeAsync::to_maybe_async(result_);
702                    let future_result_ = async move {
703                        let result_: #internal_type = Box::into_pin(async_result_).await;
704                        let value_ = match ::combadge::Post::to_js_value(result_) {
705                            Ok(value_) => value_,
706                            Err(error_) => {
707                                ::combadge::log_error!("error while converting to JsValue in future: {error_:?}");
708                                return;
709                            }
710                        };
711
712                        if let Err(error_) = <#internal_type as ::combadge::Transfer>::get_transferable(&value_).map_or_else(
713                            || port_.post_message(&value_),
714                            |transferable| port_.post_message_with_transferable(&value_, &transferable))
715                        {
716                            ::combadge::log_error!("error while posting {value_:?} {} in {} async: {error_:?}", std::any::type_name::<#internal_type>(), #name_string);
717                        }
718                    };
719                    spawn_local(future_result_);
720                    Ok(())
721                }
722            )*
723        }
724    };
725
726    let result: TokenStream = quote! {
727        #item
728        #client
729        #server
730    }
731    .into();
732
733    // println!("{}", prettyplease::unparse(&parse(result.clone()).unwrap()));
734
735    result
736}
737
738#[proc_macro_attribute]
739pub fn proxy(_attr: TokenStream, item: TokenStream) -> TokenStream {
740    let item_impl: ItemImpl = parse_macro_input!(item);
741    let Type::Path(path) = &*item_impl.self_ty else {
742        panic!("proxy expected to find a path in impl");
743    };
744
745    if path.qself.is_some() {
746        panic!("can't proxy an impl with a qualified type");
747    }
748
749    if path.path.segments.len() > 1 {
750        panic!("can't proxy an impl with a multi-segment path")
751    }
752
753    let struct_name = path.path.segments.get(0).unwrap().ident.clone();
754    let trait_name = format_ident!("{}Proxy", struct_name);
755    let local_name = format_ident!("{}Local", struct_name);
756    let client_name = format_ident!("{}Client", trait_name);
757    let server_name = format_ident!("{}Server", trait_name);
758
759    let functions = item_impl
760        .items
761        .iter()
762        .filter_map(|item| match item {
763            ImplItem::Fn(f) => {
764                if matches!(f.vis, Visibility::Public(_)) {
765                    Some(f)
766                } else {
767                    None
768                }
769            }
770            _ => None,
771        })
772        .collect::<Vec<_>>();
773
774    let argument = functions
775        .iter()
776        .map(|function| function.sig.inputs.iter().collect::<Vec<_>>())
777        .collect::<Vec<_>>();
778
779    let name = functions
780        .iter()
781        .map(|function| function.sig.ident.clone())
782        .collect::<Vec<_>>();
783
784    let non_receiver = argument
785        .iter()
786        .enumerate()
787        .map(|(index, arguments)| {
788            let non_receiver = arguments
789                .iter()
790                .filter_map(|arg| match arg {
791                    FnArg::Receiver(_) => None,
792                    FnArg::Typed(typed) => Some(typed.clone()),
793                })
794                .collect::<Vec<_>>();
795
796            if non_receiver.len() == arguments.len() {
797                panic!(
798                    "expected {} to have a receiver (self parameter)",
799                    name[index]
800                )
801            }
802
803            non_receiver
804        })
805        .collect::<Vec<_>>();
806
807    let non_receiver_name = non_receiver
808        .iter()
809        .map(|non_receiver| {
810            non_receiver
811                .iter()
812                .filter_map(|item| match item.pat.as_ref() {
813                    Pat::Ident(ident) => Some(ident.clone()),
814                    _ => None,
815                })
816                .collect::<Vec<_>>()
817        })
818        .collect::<Vec<_>>();
819
820    let output = functions
821        .iter()
822        .map(|function| function.sig.output.clone())
823        .collect::<Vec<_>>();
824
825    let return_type = output
826        .iter()
827        .map(|output| match output {
828            ReturnType::Default => quote! { () },
829            ReturnType::Type(_, t) => quote! { #t },
830        })
831        .collect::<Vec<_>>();
832
833    let name = functions
834        .iter()
835        .map(|function| function.sig.ident.clone())
836        .collect::<Vec<_>>();
837
838    let argument = functions
839        .iter()
840        .map(|function| function.sig.inputs.iter().collect::<Vec<_>>())
841        .collect::<Vec<_>>();
842
843    quote! {
844        #item_impl
845
846        #[combadge]
847        trait #trait_name {
848            #(
849                fn #name(#(#argument),*) -> #return_type;
850            )*
851        }
852
853        struct #local_name {
854            local: #struct_name
855        }
856
857        impl #local_name {
858            fn new(local: #struct_name) -> Self {
859                Self { local }
860            }
861        }
862
863        impl #trait_name for #local_name {
864            #(
865                fn #name(#(#argument),*) -> #return_type {
866                    self.local.#name(#(#non_receiver_name),*)
867                }
868            )*
869        }
870
871        impl ::combadge::AsHandle<#struct_name> for #struct_name {
872            type Client = #client_name<::combadge::reexports::web_sys::MessagePort>;
873            type Server = #server_name<::combadge::reexports::web_sys::MessagePort>;
874
875            fn into_client(port: ::combadge::reexports::web_sys::MessagePort) -> Self::Client {
876                Self::Client::new(port)
877            }
878
879            fn create_server(local: #struct_name, port: ::combadge::reexports::web_sys::MessagePort)  {
880                Self::Server::create(#local_name::new(local), port);
881            }
882        }
883
884    }
885    .into()
886}