1mod attr;
2mod errors;
3mod format;
4mod method;
5mod utils;
6
7use crate::attr::PretendAttr;
8use crate::errors::{
9 ErrorsExt, Report, CODEGEN_FAILURE, INCONSISTENT_ASYNC, INCONSISTENT_ASYNC_ASYNC_HINT,
10 INCONSISTENT_ASYNC_NON_ASYNC_HINT, NO_METHOD, UNSUPPORTED_ATTR_SYNC,
11};
12use crate::method::{trait_item, trait_item_implem};
13use crate::utils::WithTokens;
14use proc_macro::TokenStream;
15use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
16use quote::quote;
17use syn::{parse_macro_input, Error, ItemTrait, Result, Signature, TraitItem};
18
19#[proc_macro_attribute]
20pub fn pretend(attr: TokenStream, item: TokenStream) -> TokenStream {
21 let attr = parse_macro_input!(attr as PretendAttr);
22 let item = parse_macro_input!(item as ItemTrait);
23 implement_pretend(attr, item)
24 .unwrap_or_else(Error::into_compile_error)
25 .into()
26}
27
28fn implement_pretend(attr: PretendAttr, item: ItemTrait) -> Result<TokenStream2> {
29 let name = &item.ident;
30 let vis = &item.vis;
31 let items = &item.items;
32 let attrs = &item.attrs;
33 let trait_items = items.iter().map(trait_item).collect::<Vec<_>>();
34
35 let kind = parse_client_kind(name, attr, items)?;
36 let methods = items
37 .iter()
38 .map(|item| trait_item_implem(item, &kind))
39 .collect::<Report<_>>()
40 .into_result(|| Error::new(Span::call_site(), CODEGEN_FAILURE))?;
41
42 let attr = async_trait_attr(&kind);
43 let client = client_implem(&kind);
44 let send_sync = send_sync_traits_impl(&kind);
45 let tokens = quote! {
46 #attr
47 #(#attrs)*
48 #vis trait #name {
49 #(#trait_items)*
50 }
51
52 #attr
53 impl<C, R, I> #name for pretend::Pretend<C, R, I>
54 where C: #client #send_sync,
55 R: pretend::resolver::ResolveUrl #send_sync,
56 I: pretend::interceptor::InterceptRequest #send_sync,
57 {
58 #(#methods)*
59 }
60 };
61 Ok(tokens)
62}
63
64enum ClientKind {
65 Async,
66 AsyncLocal,
67 Blocking,
68}
69
70fn parse_client_kind(name: &Ident, attr: PretendAttr, items: &[TraitItem]) -> Result<ClientKind> {
71 let asyncs = items.iter().filter_map(is_method_async).collect::<Vec<_>>();
72 let is_async = asyncs.iter().all(|item| item.value);
73 let is_not_async = asyncs.iter().all(|item| !item.value);
74
75 match (is_async, is_not_async) {
76 (true, false) => {
77 if attr.local {
78 Ok(ClientKind::AsyncLocal)
79 } else {
80 Ok(ClientKind::Async)
81 }
82 }
83 (false, true) => {
84 if attr.local {
85 Err(Error::new(Span::call_site(), UNSUPPORTED_ATTR_SYNC))
86 } else {
87 Ok(ClientKind::Blocking)
88 }
89 }
90 _ => {
91 if asyncs.is_empty() {
92 Err(Error::new_spanned(name, NO_METHOD))
93 } else {
94 let async_hints = asyncs
95 .iter()
96 .filter(|item| item.value)
97 .map(|item| Error::new_spanned(item.tokens, INCONSISTENT_ASYNC_ASYNC_HINT));
98
99 let non_async_hints = asyncs
100 .iter()
101 .filter(|item| !item.value)
102 .map(|item| Error::new_spanned(item.tokens, INCONSISTENT_ASYNC_NON_ASYNC_HINT));
103
104 let errors = async_hints.chain(non_async_hints).collect::<Vec<_>>();
105 errors.into_result(|| Error::new_spanned(name, INCONSISTENT_ASYNC))
106 }
107 }
108 }
109}
110
111fn is_method_async(item: &TraitItem) -> Option<WithTokens<bool, Signature>> {
112 match item {
113 TraitItem::Method(method) => {
114 let is_async = method.sig.asyncness.is_some();
115 Some(WithTokens::new(is_async, &method.sig))
116 }
117 _ => None,
118 }
119}
120
121fn async_trait_attr(kind: &ClientKind) -> TokenStream2 {
122 match kind {
123 ClientKind::Async => quote! {
124 #[pretend::client::async_trait]
125 },
126 ClientKind::AsyncLocal => quote! {
127 #[pretend::client::async_trait(?Send)]
128 },
129 ClientKind::Blocking => TokenStream2::new(),
130 }
131}
132
133fn client_implem(kind: &ClientKind) -> TokenStream2 {
134 match kind {
135 ClientKind::Async => quote! {
136 pretend::client::Client
137 },
138 ClientKind::AsyncLocal => quote! {
139 pretend::client::LocalClient
140 },
141 ClientKind::Blocking => quote! {
142 pretend::client::BlockingClient
143 },
144 }
145}
146
147fn send_sync_traits_impl(kind: &ClientKind) -> TokenStream2 {
148 match kind {
149 ClientKind::Async => quote! {
150 + Send + Sync
151 },
152 ClientKind::AsyncLocal => TokenStream2::new(),
153 ClientKind::Blocking => TokenStream2::new(),
154 }
155}