1use convert_case::{Case, Casing};
2use proc_macro::{self, TokenStream};
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{punctuated::Punctuated, token::Comma, Token};
6use uuid::Uuid;
7
8#[cfg(test)]
9mod tests;
10
11fn skip_self(
12 arguments: &Punctuated<syn::FnArg, syn::token::Comma>,
13) -> Punctuated<syn::FnArg, syn::token::Comma> {
14 let mut output = Punctuated::new();
15 for arg in arguments {
16 if let syn::FnArg::Typed(pat_type) = &arg {
17 if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*pat_type.pat {
18 if ident != "self" {
19 output.push(arg.clone());
20 }
21 }
22 }
23 }
24
25 output
26}
27
28fn request(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
29 let items = item.items.iter().filter_map(|item| match item {
30 syn::TraitItem::Fn(method) => {
31 let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
32 let args = skip_self(&method.sig.inputs);
33 let item = quote! {
34 #ident { #args },
35 };
36 Some(item)
37 }
38 _ => None,
39 });
40
41 let output = quote! {
42 #[derive(Debug, serde::Serialize, serde::Deserialize)]
43 pub enum Request {
44 #(#items)*
45 }
46 };
47 output
48}
49
50fn extract_return_type(return_type: &syn::ReturnType) -> proc_macro2::TokenStream {
51 match return_type {
52 syn::ReturnType::Default => {
53 quote! {
54 ()
55 }
56 }
57 syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
58 syn::Type::ImplTrait(impl_trait) => {
59 let return_type = extract_stream_item_type(&impl_trait);
60 quote! {
61 #return_type
62 }
63 }
64 _ => {
65 quote! {
66 #return_type
67 }
68 }
69 },
70 }
71}
72
73fn is_stream(return_type: &syn::ReturnType) -> bool {
74 match return_type {
75 syn::ReturnType::Default => false,
76 syn::ReturnType::Type(_, return_type) => {
77 matches!(*return_type.to_owned(), syn::Type::ImplTrait(_))
78 }
79 }
80}
81
82fn response(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
83 let items = item.items.iter().filter_map(|item| match item {
84 syn::TraitItem::Fn(method) => {
85 let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
86 let return_type = extract_return_type(&method.sig.output);
87 let item = if is_stream(&method.sig.output) {
88 quote! { #ident(zzrpc::producer::StreamResponse<#return_type>), }
89 } else {
90 quote! { #ident(#return_type), }
91 };
92 Some(item)
93 }
94 _ => None,
95 });
96
97 let output = quote! {
98 #[derive(Debug, serde::Serialize, serde::Deserialize)]
99 pub enum Response {
100 #(#items)*
101 }
102 };
103 output
104}
105
106fn consumer_senders(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
107 let items = item.items.iter().filter_map(|item| match item {
108 syn::TraitItem::Fn(method) => {
109 let ident = format_ident!(
110 "{}__sender",
111 method.sig.ident.to_string().to_case(Case::Snake)
112 );
113 let return_type = extract_return_type(&method.sig.output);
114 let item = quote! {
115 #ident: std::sync::Arc<
116 zzrpc::futures::channel::mpsc::UnboundedSender<(
117 zzrpc::consumer::Message<Request>,
118 zzrpc::consumer::ResultSender<#return_type, Error>,
119 )>,
120 >,
121 };
122 Some(item)
123 }
124 _ => None,
125 });
126
127 let output = quote! {
128 #[derive(Debug)]
129 struct Senders<Error> {
130 #(#items)*
131 }
132 };
133 output
134}
135
136fn impl_consumer_state(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
137 let mut channels = vec![];
138 let mut senders = vec![];
139 let mut items = vec![];
140 let mut handlers = vec![];
141 let mut drainers = vec![];
142
143 for method in item.items.iter().filter_map(|item| match item {
144 syn::TraitItem::Fn(method) => Some(method),
145 _ => None,
146 }) {
147 let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
148 let ident_requests = format_ident!(
149 "{}__requests",
150 method.sig.ident.to_string().to_case(Case::Snake)
151 );
152 let ident_sender = format_ident!(
153 "{}__sender",
154 method.sig.ident.to_string().to_case(Case::Snake)
155 );
156 let ident_receiver = format_ident!(
157 "{}__receiver",
158 method.sig.ident.to_string().to_case(Case::Snake)
159 );
160 let return_type = extract_return_type(&method.sig.output);
161
162 channels.push(quote! {
163 let (#ident_sender, #ident_receiver) = zzrpc::futures::channel::mpsc::unbounded::<(
164 zzrpc::consumer::Message<Request>,
165 zzrpc::consumer::ResultSender<#return_type, Error>,
166 )>();
167 });
168
169 senders.push(quote! {
170 #ident_sender: std::sync::Arc::new(#ident_sender),
171 });
172
173 items.push(quote! {
174 #ident_requests: std::collections::HashMap::new(),
175 #ident_receiver,
176 });
177
178 handlers.push(if is_stream(&method.sig.output) {
179 quote! {
180 Response::#ident(result) => {
181 match result {
182 zzrpc::producer::StreamResponse::Open => {
183 if let Some((sender, _)) = self.#ident_requests.get_mut(&id) {
184 if let Some(sender) = sender.take() {
185 let _ = sender.send(Ok(()));
186 self.pending -= 1;
187 }
188 }
189 },
190 zzrpc::producer::StreamResponse::Item(item) => {
191 if let Some((_, sender)) = self.#ident_requests.get_mut(&id) {
192 let _ = sender.unbounded_send(item);
193 }
194 },
195 zzrpc::producer::StreamResponse::Closed => {
196 self.#ident_requests.remove(&id);
197 },
198 }
199 }
200 }
201 } else {
202 quote! {
203 Response::#ident(result) => {
204 if let Some(sender) = self.#ident_requests.remove(&id) {
205 let _ = sender.send(Ok(result));
206 self.pending -= 1;
207 }
208 }
209 }
210 });
211
212 if is_stream(&method.sig.output) {
213 drainers.push(quote! {
214 for (_id, (mut sender, _)) in self.#ident_requests.drain() {
215 if let Some(sender) = sender.take() {
216 let _ = sender.send(Err(shutdown_type.into()));
217 }
218 }
219 });
220 } else {
221 drainers.push(quote! {
222 for (_id, sender) in self.#ident_requests.drain() {
223 let _ = sender.send(Err(shutdown_type.into()));
224 }
225 });
226 }
227 }
228
229 let output = quote! {
230 impl<Error> ConsumerState<Error> {
231 fn new() -> (Senders<Error>, Self) {
232 #(#channels)*
233
234 let senders = Senders {
235 #(#senders)*
236 };
237
238 let state = ConsumerState {
239 pending: 0,
240 #(#items)*
241 };
242
243 (senders, state)
244 }
245
246 fn handle_message(
247 &mut self,
248 message: zzrpc::producer::Message<Response>,
249 ) -> Option<zzrpc::ShutdownType> {
250 match message {
251 zzrpc::producer::Message::Response { id, response } => {
252 match response {
253 #(#handlers)*
254 }
255 None
256 }
257 zzrpc::producer::Message::Aborted => Some(zzrpc::ShutdownType::Aborted),
258 zzrpc::producer::Message::Shutdown => Some(zzrpc::ShutdownType::Shutdown),
259 }
260 }
261
262 fn idle(&self) -> bool {
263 self.pending == 0
264 }
265
266 fn shutdown(&mut self, shutdown_type: zzrpc::ShutdownType) {
267 #(#drainers)*
268 }
269 }
270 };
271 output
272}
273
274fn consumer_state(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
275 let items = item.items.iter().filter_map(|item| match item {
276 syn::TraitItem::Fn(method) => {
277 let ident_requests = format_ident!("{}__requests", method.sig.ident.to_string().to_case(Case::Snake));
278 let ident_receiver = format_ident!("{}__receiver", method.sig.ident.to_string().to_case(Case::Snake));
279 let return_type = extract_return_type(&method.sig.output);
280
281 let sender = if is_stream(&method.sig.output) {
282 quote! {
283 (Option<zzrpc::futures::channel::oneshot::Sender<zzrpc::consumer::Result<(), Error>>>,
284 zzrpc::futures::channel::mpsc::UnboundedSender<#return_type>)
285 }
286 } else {
287 quote! {
288 zzrpc::futures::channel::oneshot::Sender<zzrpc::consumer::Result<#return_type, Error>>
289 }
290 };
291
292 let item = quote! {
293 #ident_requests: std::collections::HashMap<usize, #sender>,
294 #ident_receiver: zzrpc::futures::channel::mpsc::UnboundedReceiver<(
295 zzrpc::consumer::Message<Request>,
296 zzrpc::consumer::ResultSender<#return_type, Error>,
297 )>,
298 };
299 Some(item)
300 }
301 _ => None,
302 });
303
304 let impl_consumer_state = impl_consumer_state(item);
305
306 let output = quote! {
307 #[derive(Debug)]
308 struct ConsumerState<Error> {
309 pending: usize,
310 #(#items)*
311 }
312
313 #impl_consumer_state
314 };
315 output
316}
317
318fn impl_consume(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
319 let items = item.items.iter().filter_map(|item| match item {
320 syn::TraitItem::Fn(method) => {
321 let ident_option = format_ident!(
322 "{}__option",
323 method.sig.ident.to_string().to_case(Case::Snake)
324 );
325 let ident_receiver = format_ident!(
326 "{}__receiver",
327 method.sig.ident.to_string().to_case(Case::Snake)
328 );
329 let ident_requests = format_ident!(
330 "{}__requests",
331 method.sig.ident.to_string().to_case(Case::Snake)
332 );
333
334 let patterns = if is_stream(&method.sig.output) {
335 quote! {
336 zzrpc::consumer::ResultSender::Stream { result_sender, values_sender } => {
337 if let Err(error) = result {
338 let _ = result_sender.send(Err(error));
339 } else {
340 state.#ident_requests.insert(id, (Some(result_sender), values_sender));
341 state.pending += 1;
342 }
343 },
344 zzrpc::consumer::ResultSender::Abort => {
345 if let Some((sender, _)) = state.#ident_requests.remove(&id) {
346 if sender.is_some() {
347 state.pending -= 1;
348 }
349 }
350 },
351 _ => unreachable!("value sender got when stream sender expected"),
352 }
353 } else {
354 quote! {
355 zzrpc::consumer::ResultSender::Value(sender) => {
356 if let Err(error) = result {
357 let _ = sender.send(Err(error));
358 } else {
359 state.#ident_requests.insert(id, sender);
360 state.pending += 1;
361 }
362 },
363 zzrpc::consumer::ResultSender::Abort => {
364 if state.#ident_requests.remove(&id).is_some() {
365 state.pending -= 1;
366 }
367 },
368 _ => unreachable!("stream sender got when value sender expected"),
369 }
370 };
371
372 let item = quote! {
373 #ident_option = zzrpc::futures::StreamExt::next(&mut state.#ident_receiver) => {
374 if let Some((message, result_sender)) = #ident_option {
375 if timeout.is_some() {
376 timeout_future.reset();
377 }
378
379 let id = message.id;
380 let result = sender.send(message).await;
381 let result = result.map_err(|error| {
382 match error {
383 mezzenger::Error::Closed => zzrpc::Error::Closed,
384 mezzenger::Error::Other(error) => zzrpc::Error::Transport(error),
385 }
386 });
387 match result_sender {
388 #patterns
389 }
390 }
391 },
392 };
393 Some(item)
394 }
395 _ => None,
396 });
397
398 let implementation = quote! {
399 use zzrpc::futures::{FutureExt, SinkExt, StreamExt};
400
401 let zzrpc::consumer::Configuration {
402 shutdown,
403 mut receive_error_callback,
404 timeout,
405 ..
406 } = configuration;
407
408 let (drop_sender, mut drop_receiver) = zzrpc::futures::channel::oneshot::channel::<()>();
409 let drop_sender = Some(drop_sender);
410
411 let (senders, mut state) = ConsumerState::new();
412
413 zzrpc::spawn(async move {
414 let (mut sender, receiver) = transport.split();
415 let receiver = zzrpc::futures::StreamExt::fuse(receiver);
416
417 let timeout_future = if let Some(duration) = timeout {
418 zzrpc::Timeout::new(duration)
419 } else {
420 zzrpc::Timeout::never()
421 };
422 let shutdown = zzrpc::futures::FutureExt::fuse(shutdown);
423 zzrpc::futures::pin_mut!(receiver, timeout_future, shutdown);
424
425 loop {
426 zzrpc::futures::select! {
427 receive_option = zzrpc::futures::StreamExt::next(&mut receiver) => {
428 if let Some(receive_result) = receive_option {
429 if timeout.is_some() {
430 timeout_future.reset();
431 }
432
433 match receive_result {
434 Ok(message) => {
435 if let Some(shutdown_type) = state.handle_message(message) {
436 state.shutdown(shutdown_type);
437 break;
438 }
439 },
440 Err(error) => {
441 if let zzrpc::HandlingStrategy::Stop(shutdown_type) = receive_error_callback.on_receive_error(error) {
442 state.shutdown(shutdown_type);
443 break;
444 }
445 },
446 }
447 } else {
448 state.shutdown(zzrpc::ShutdownType::Closed);
449 break;
450 }
451 },
452 #(#items)*
453 _ = &mut timeout_future => {
454 if timeout.is_some() && !state.idle() {
455 state.shutdown(zzrpc::ShutdownType::Timeout);
456 break;
457 }
458 },
459 shutdown_type = &mut shutdown => {
460 state.shutdown(shutdown_type);
461 break;
462 }
463 _ = &mut drop_receiver => { break; }
464 }
465 }
466 });
467
468 Consumer {
469 id_counter: zzrpc::atomic_counter::ConsistentCounter::new(0),
470 senders,
471 drop_sender,
472 }
473 };
474
475 let output = quote! {
476 impl<Error> zzrpc::consumer::Consume<Consumer<Error>, Error> for Consumer<Error> {
477 type Request = Request;
478 type Response = Response;
479
480 #[cfg(not(target_arch = "wasm32"))]
481 fn consume_unreliable<Transport, Shutdown, ReceiveErrorCallback>(
482 transport: Transport,
483 configuration: zzrpc::consumer::Configuration<Shutdown, Error, ReceiveErrorCallback>,
484 ) -> Consumer<Error>
485 where
486 Transport: mezzenger::Transport<
487 zzrpc::producer::Message<Self::Response>,
488 zzrpc::consumer::Message<Self::Request>,
489 Error,
490 > + mezzenger::Reliable
491 + mezzenger::Order
492 + Send
493 + 'static,
494 Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + Send + 'static,
495 ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + Send + 'static,
496 Error: Send + 'static, {
497 #implementation
498 }
499
500 #[cfg(target_arch = "wasm32")]
501 fn consume_unreliable<Transport, Shutdown, ReceiveErrorCallback>(
502 transport: Transport,
503 configuration: zzrpc::consumer::Configuration<Shutdown, Error, ReceiveErrorCallback>,
504 ) -> Consumer<Error>
505 where
506 Transport: mezzenger::Transport<
507 zzrpc::producer::Message<Self::Response>,
508 zzrpc::consumer::Message<Self::Request>,
509 Error,
510 > + mezzenger::Reliable
511 + mezzenger::Order
512 + 'static,
513 Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + 'static,
514 ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + 'static,
515 Error: 'static, {
516 #implementation
517 }
518 }
519 };
520 output
521}
522
523fn pattern_arguments(
524 arguments: &Punctuated<syn::FnArg, syn::token::Comma>,
525) -> Punctuated<syn::Ident, Comma> {
526 let mut output = Punctuated::new();
527 for arg in arguments {
528 if let syn::FnArg::Typed(pat_type) = &arg {
529 if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*pat_type.pat {
530 if ident != "self" {
531 output.push(ident.clone());
532 }
533 }
534 }
535 }
536
537 output
538}
539
540fn impl_api(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
541 let ident = &item.ident;
542
543 let items = item.items.iter().filter_map(|item| match item {
544 syn::TraitItem::Fn(method) => {
545 let ident_request =
546 format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
547 let ident_sender = format_ident!(
548 "{}__sender",
549 method.sig.ident.to_string().to_case(Case::Snake)
550 );
551
552 let mut signature = method.sig.clone();
553 signature.asyncness = None;
554 match &signature.output {
555 syn::ReturnType::Default => {
556 signature.output = syn::parse2::<syn::ReturnType>(
557 quote!(-> zzrpc::ValueRequest<(), Request, Self::Error>),
558 )
559 .unwrap();
560 }
561 syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
562 syn::Type::ImplTrait(impl_trait) => {
563 let return_type = extract_stream_item_type(&impl_trait);
564 signature.output = syn::parse2::<syn::ReturnType>(
565 quote!(-> zzrpc::StreamRequest<#return_type, Request, Self::Error>),
566 )
567 .unwrap();
568 }
569 _ => {
570 signature.output = syn::parse2::<syn::ReturnType>(
571 quote!(-> zzrpc::ValueRequest<#return_type, Request, Self::Error>),
572 )
573 .unwrap();
574 }
575 },
576 };
577
578 let ident_request_future = if is_stream(&method.sig.output) {
579 quote!(zzrpc::StreamRequest)
580 } else {
581 quote!(zzrpc::ValueRequest)
582 };
583
584 let arguments = pattern_arguments(&method.sig.inputs);
585
586 let item = quote! {
587 #signature {
588 use zzrpc::atomic_counter::AtomicCounter;
589 let request = Request::#ident_request { #arguments };
590 #ident_request_future::new(
591 self.senders.#ident_sender.clone(),
592 self.id_counter.inc(),
593 request,
594 )
595 }
596 };
597 Some(item)
598 }
599 _ => None,
600 });
601
602 let output = quote! {
603 impl<Error> #ident for Consumer<Error> {
604 #(#items)*
605
606 type Request = Request;
607 type Error = Error;
608 }
609 };
610 output
611}
612
613fn consumer(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
614 let senders = consumer_senders(item);
615 let state = consumer_state(item);
616
617 let impl_consume = impl_consume(item);
618 let impl_api = impl_api(item);
619
620 let output = quote! {
621 #senders
622
623 #state
624
625 #[derive(Debug)]
626 pub struct Consumer<Error> {
627 id_counter: zzrpc::atomic_counter::ConsistentCounter,
628 senders: Senders<Error>,
629 drop_sender: Option<zzrpc::futures::channel::oneshot::Sender<()>>,
630 }
631
632 #impl_consume
633
634 #impl_api
635
636 impl<Error> Drop for Consumer<Error> {
637 fn drop(&mut self) {
638 if let Some(sender) = self.drop_sender.take() {
639 let _ = sender.send(());
640 }
641 }
642 }
643 };
644 output
645}
646
647fn impl_produce(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
648 let items = item.items.iter().filter_map(|item| match item {
649 syn::TraitItem::Fn(method) => {
650 let ident = method.sig.ident.clone();
651 let ident_request =
652 format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
653
654 let arguments = pattern_arguments(&method.sig.inputs);
655
656 let item = if is_stream(&method.sig.output) {
657 quote! {
658 Request::#ident_request { #arguments } => {
659 let me = me.clone();
660 let reply_sender = reply_sender.clone();
661 let remove_aborter_sender = remove_aborter_sender.clone();
662 spawn(async move {
663 let mut stream = zzrpc::futures::StreamExt::fuse(me.#ident(#arguments).await);
664
665 let response = zzrpc::producer::StreamResponse::Open;
666 let response = Response::#ident_request(response);
667 let message = zzrpc::producer::Message::Response { id, response };
668 let _ = reply_sender.unbounded_send(message);
669
670 loop {
671 select! {
672 result = zzrpc::futures::StreamExt::next(&mut stream) => {
673 if let Some(result) = result {
674 let response = zzrpc::producer::StreamResponse::Item(result);
675 let response = Response::#ident_request(response);
676 let message = zzrpc::producer::Message::Response { id, response };
677 let _ = reply_sender.unbounded_send(message);
678 } else {
679 let response = zzrpc::producer::StreamResponse::Closed;
680 let response = Response::#ident_request(response);
681 let message = zzrpc::producer::Message::Response { id, response };
682 let _ = reply_sender.unbounded_send(message);
683
684 let _ = remove_aborter_sender.unbounded_send(id);
685 break;
686 }
687 },
688 _ = abort_receiver => {
689 break;
690 },
691 };
692 }
693 });
694 },
695 }
696 } else {
697 quote! {
698 Request::#ident_request { #arguments } => {
699 let me = me.clone();
700 let reply_sender = reply_sender.clone();
701 let remove_aborter_sender = remove_aborter_sender.clone();
702 spawn(async move {
703 let task = zzrpc::futures::FutureExt::fuse(me.#ident(#arguments));
704 pin_mut!(task);
705 select! {
706 result = task => {
707 let response = Response::#ident_request(result);
708 let message = zzrpc::producer::Message::Response { id, response };
709 let _ = reply_sender.unbounded_send(message);
710 let _ = remove_aborter_sender.unbounded_send(id);
711 },
712 _ = abort_receiver => (),
713 };
714 });
715 },
716 }
717 };
718
719 Some(item)
720 }
721 _ => None,
722 });
723
724 let uuid = Uuid::new_v4();
725 let ident = format_ident!("__impl_produce_{}", uuid.simple().to_string());
726
727 let output = quote! {
728 #[macro_export]
729 macro_rules! #ident {
730 ($self:ident, $transport:ident, $configuration:ident) => {
731 zzrpc::spawn(async move {
732 use std::sync::Arc;
733 use std::collections::HashMap;
734 use futures::{
735 channel::{
736 mpsc::unbounded,
737 oneshot,
738 },
739 pin_mut, select, SinkExt, StreamExt,
740 };
741 use zzrpc::{ShutdownType, Timeout};
742
743 use zzrpc::spawn;
744
745 let me = Arc::new($self);
746
747 let (mut sender, receiver) = $transport.split();
748 let mut receiver = zzrpc::futures::StreamExt::fuse(receiver);
749
750 let (reply_sender, mut reply_receiver) = unbounded::<zzrpc::producer::Message<Response>>();
751
752 let mut aborters: HashMap<usize, oneshot::Sender<()>> = HashMap::new();
753 let (remove_aborter_sender, mut remove_aborter_receiver) = unbounded::<usize>();
754
755 let (stop_sender, stop_receiver) = oneshot::channel::<ShutdownType>();
756 let mut stop_sender = Some(stop_sender);
757
758 let zzrpc::producer::Configuration {
759 shutdown,
760 mut send_error_callback,
761 mut receive_error_callback,
762 timeout,
763 ..
764 } = $configuration;
765
766 spawn(async move {
767 while let Some(message) = zzrpc::futures::StreamExt::next(&mut reply_receiver).await {
768 match message {
769 zzrpc::producer::Message::Response { id, .. } => {
770 let result = sender.send(message).await;
771 if let Err(error) = result {
772 if let zzrpc::HandlingStrategy::Stop(shutdown_type) =
773 send_error_callback.on_send_error(id, error)
774 {
775 if let Some(stop_sender) = stop_sender.take() {
776 let _ = stop_sender.send(shutdown_type);
777 }
778 }
779 }
780 }
781 _ => {
782 let _ = sender.send(message).await;
783 }
784 }
785 }
786 });
787
788 let handle_message = |aborters: &mut HashMap<usize, oneshot::Sender<()>>, message: zzrpc::consumer::Message<Request>| {
789 let zzrpc::consumer::Message { id, payload } = message;
790 match payload {
791 zzrpc::consumer::Payload::Request(request) => {
792 let (abort_sender, mut abort_receiver) = oneshot::channel::<()>();
793 aborters.insert(id, abort_sender);
794 match request {
795 #(#items)*
796 }
797 }
798 zzrpc::consumer::Payload::Abort => {
799 if let Some(abort_sender) = aborters.remove(&id) {
800 let _ = abort_sender.send(());
801 }
802 }
803 }
804 };
805
806 let mut return_value = ShutdownType::Closed;
807 let mut handle_shutdown = |shutdown_type: ShutdownType| {
808 return_value = shutdown_type;
809 let message = match shutdown_type {
810 ShutdownType::Shutdown => zzrpc::producer::Message::Shutdown,
811 ShutdownType::Aborted => zzrpc::producer::Message::Aborted,
812 _ => {
813 return;
814 }
815 };
816 let _ = reply_sender.unbounded_send(message);
817 };
818
819 let timeout_future = if let Some(duration) = timeout {
820 Timeout::new(duration)
821 } else {
822 Timeout::never()
823 };
824 let shutdown = zzrpc::futures::FutureExt::fuse(shutdown);
825 pin_mut!(shutdown, timeout_future, stop_receiver);
826 loop {
827 select! {
828 receive_option = zzrpc::futures::StreamExt::next(&mut receiver) => {
829 if timeout.is_some() {
830 timeout_future.reset();
831 }
832
833 if let Some(receive_result) = receive_option {
834 match receive_result {
835 Ok(message) => handle_message(&mut aborters, message),
836 Err(error) => {
837 if let zzrpc::HandlingStrategy::Stop(shutdown_type) = receive_error_callback.on_receive_error(error) {
838 handle_shutdown(shutdown_type);
839 break;
840 }
841 },
842 }
843 } else {
844 handle_shutdown(ShutdownType::Closed);
845 break;
846 }
847 },
848 id_option = zzrpc::futures::StreamExt::next(&mut remove_aborter_receiver) => {
849 if let Some(id) = id_option {
850 aborters.remove(&id);
851 }
852 },
853 _ = &mut timeout_future => {
854 if timeout.is_some() {
855 if aborters.is_empty() {
856 handle_shutdown(ShutdownType::Timeout);
857 break;
858 } else {
859 timeout_future.reset();
860 }
861 }
862 },
863 shutdown_type = &mut stop_receiver => {
864 if let Ok(shutdown_type) = shutdown_type {
865 handle_shutdown(shutdown_type);
866 break;
867 }
868 },
869 shutdown_type = &mut shutdown => {
870 handle_shutdown(shutdown_type);
871 break;
872 },
873 }
874 }
875
876 for (_, aborter) in aborters.drain() {
877 let _ = aborter.send(());
878 }
879
880 return_value
881 })
882 }
883 }
884
885 pub use #ident as impl_produce;
886 };
887 output
888}
889
890fn extract_stream_item_type(impl_trait: &syn::TypeImplTrait) -> syn::Type {
891 if impl_trait.bounds.len() == 1 {
892 if let syn::TypeParamBound::Trait(bound) = &impl_trait.bounds[0] {
893 if bound.path.segments.len() == 1 {
894 let stream = &bound.path.segments[0];
895 if stream.ident == "Stream" {
896 if let syn::PathArguments::AngleBracketed(arguments) = &stream.arguments {
897 if arguments.args.len() == 1 {
898 let argument = &arguments.args[0];
899 if let syn::GenericArgument::AssocType(binding) = argument {
900 if binding.ident == "Item" {
901 return binding.ty.clone();
902 }
903 }
904 }
905 }
906 }
907 }
908 }
909 }
910 panic!("invalid stream request method return type");
911}
912
913fn modify_trait(mut item: syn::ItemTrait) -> syn::ItemTrait {
914 if item.generics.lt_token.is_some() {
915 panic!("generic traits are not supported");
916 }
917
918 item.items.push(
919 syn::parse2::<syn::TraitItem>(quote! {
920 type Request;
922 })
923 .unwrap(),
924 );
925
926 item.items.push(
927 syn::parse2::<syn::TraitItem>(quote! {
928 type Error;
930 })
931 .unwrap(),
932 );
933
934 let must_use = format_ident!("must_use");
935 let must_use = syn::Attribute {
936 pound_token: Token!(#)(Span::call_site()),
937 style: syn::AttrStyle::Outer,
938 bracket_token: syn::token::Bracket(Span::call_site()),
939 meta: syn::Meta::Path(syn::Path {
940 leading_colon: None,
941 segments: syn::punctuated::Punctuated::from_iter([syn::PathSegment::from(must_use)]),
942 }),
943 };
944
945 for method in item.items.iter_mut().filter_map(|item| {
946 if let syn::TraitItem::Fn(func) = item {
947 Some(func)
948 } else {
949 None
950 }
951 }) {
952 if method.sig.asyncness.is_none() {
953 panic!("all api methods should be marked as \"async\"")
954 }
955
956 if method.sig.generics.lt_token.is_some() {
957 panic!("generic methods are not supported")
958 }
959
960 method.sig.asyncness = None;
961 method.attrs.push(must_use.clone());
962
963 match &method.sig.output {
964 syn::ReturnType::Default => {
965 method.sig.output = syn::parse2::<syn::ReturnType>(
966 quote!(-> zzrpc::ValueRequest<(), Request, Self::Error>),
967 )
968 .unwrap();
969 }
970 syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
971 syn::Type::ImplTrait(impl_trait) => {
972 let return_type = extract_stream_item_type(&impl_trait);
973 method.sig.output = syn::parse2::<syn::ReturnType>(
974 quote!(-> zzrpc::StreamRequest<#return_type, Request, Self::Error>),
975 )
976 .unwrap();
977 }
978 _ => {
979 method.sig.output = syn::parse2::<syn::ReturnType>(
980 quote!(-> zzrpc::ValueRequest<#return_type, Request, Self::Error>),
981 )
982 .unwrap();
983 }
984 },
985 };
986 }
987
988 item
989}
990
991#[proc_macro_attribute]
992pub fn api(_attr: TokenStream, item: TokenStream) -> TokenStream {
993 if let Ok(item) = syn::parse2::<syn::ItemTrait>(item.into()) {
994 let item_modified = modify_trait(item.clone());
995
996 let request = request(&item);
997 let response = response(&item);
998 let consumer = consumer(&item);
999 let impl_produce = impl_produce(&item);
1000
1001 let output = quote! {
1002 #item_modified
1003
1004 #request
1005
1006 #response
1007
1008 #consumer
1009
1010 #impl_produce
1011 };
1012
1013 output.into()
1014 } else {
1015 panic!("expected a trait")
1016 }
1017}
1018
1019#[proc_macro_derive(Produce)]
1020pub fn produce(input: TokenStream) -> TokenStream {
1021 let syn::DeriveInput { ident, .. } = syn::parse_macro_input!(input);
1022 let output = quote! {
1023 impl zzrpc::producer::Produce for #ident {
1024 type Request = Request;
1025 type Response = Response;
1026
1027 #[cfg(not(target_arch = "wasm32"))]
1028 fn produce_unreliable<Transport, Error, Shutdown, SendErrorCallback, ReceiveErrorCallback>(
1029 self,
1030 transport: Transport,
1031 configuration: zzrpc::producer::Configuration<
1032 Shutdown,
1033 Error,
1034 SendErrorCallback,
1035 ReceiveErrorCallback,
1036 >,
1037 ) -> zzrpc::JoinHandle<zzrpc::ShutdownType> where
1038 Transport: mezzenger::Transport<
1039 zzrpc::consumer::Message<Self::Request>,
1040 zzrpc::producer::Message<Self::Response>,
1041 Error,
1042 > + mezzenger::Reliable
1043 + mezzenger::Order
1044 + Send
1045 + 'static,
1046 Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + Send + 'static,
1047 SendErrorCallback: zzrpc::SendErrorCallback<Error> + Send + 'static,
1048 ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + Send + 'static {
1049 impl_produce!(self, transport, configuration)
1050 }
1051
1052 #[cfg(target_arch = "wasm32")]
1053 fn produce_unreliable<Transport, Error, Shutdown, SendErrorCallback, ReceiveErrorCallback>(
1054 self,
1055 transport: Transport,
1056 configuration: zzrpc::producer::Configuration<
1057 Shutdown,
1058 Error,
1059 SendErrorCallback,
1060 ReceiveErrorCallback,
1061 >,
1062 ) -> zzrpc::JoinHandle<zzrpc::ShutdownType> where
1063 Transport: mezzenger::Transport<
1064 zzrpc::consumer::Message<Self::Request>,
1065 zzrpc::producer::Message<Self::Response>,
1066 Error,
1067 > + mezzenger::Reliable
1068 + mezzenger::Order
1069 + 'static,
1070 Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + 'static,
1071 SendErrorCallback: zzrpc::SendErrorCallback<Error> + 'static,
1072 ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + 'static {
1073 impl_produce!(self, transport, configuration)
1074 }
1075 }
1076 };
1077 output.into()
1078}