ez_impl/
proc_macros.rs

1use {
2    proc_macro2::{Delimiter, Ident, TokenStream, TokenTree},
3    quote::{quote_spanned, ToTokens},
4    syn::{
5        fold::Fold, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, Block,
6        ExprAsync, ExprClosure, ExprReturn, ImplItemMethod, ItemFn, Path, ReturnType, Visibility,
7    },
8};
9
10// We use this as a general purpose function representation because
11// its supported syntax seems to be a superset of other function types.
12type Function = ImplItemMethod;
13
14/// Returns the training block of this token stream, if it has one.
15fn trailing_block(tokens: &TokenStream) -> Result<Option<Block>, eyre::Report> {
16    let mut tokens = Vec::from_iter(tokens.clone());
17
18    if let Some(trailing) = tokens.last_mut() {
19        if let TokenTree::Group(group) = &trailing {
20            if group.delimiter() == Delimiter::Brace {
21                return Ok(Some(syn::parse2(trailing.into_token_stream())?));
22            }
23        }
24    }
25
26    Ok(None)
27}
28
29/// Wrap every return statement in `Ok`, but don't recur into nested
30/// functions/closures/async blocks.
31fn wrap_returns_in_ok(block: Block) -> Block {
32    struct Folder;
33    impl Fold for Folder {
34        fn fold_expr_return(&mut self, expr: ExprReturn) -> ExprReturn {
35            let inner = expr.expr.clone();
36            parse_quote_spanned! { expr.span() =>
37                return ::ez::__::Ok(#inner)
38            }
39        }
40
41        fn fold_item_fn(&mut self, item_fn: ItemFn) -> ItemFn {
42            item_fn
43        }
44
45        fn fold_expr_closure(&mut self, expr_closure: ExprClosure) -> ExprClosure {
46            expr_closure
47        }
48
49        fn fold_expr_async(&mut self, expr_async: ExprAsync) -> ExprAsync {
50            expr_async
51        }
52    }
53
54    Folder.fold_block(block)
55}
56
57/// If this token stream has a trailing block, import `throw!` and wrap every
58/// return value in `Ok`.
59fn tryify_trailing_block(tokens: TokenStream) -> Result<TokenStream, eyre::Report> {
60    let mut tokens = Vec::from_iter(tokens);
61
62    if let Some(last) = tokens.last_mut() {
63        if let proc_macro2::TokenTree::Group(group) = last {
64            if group.delimiter() == proc_macro2::Delimiter::Brace {
65                let block: syn::Block = syn::parse2(last.clone().into_token_stream())?;
66                let block = wrap_returns_in_ok(block);
67                *last = parse_quote_spanned! { block.span() => {
68                    #[allow(unused_imports)]
69                    use ::ez::throw;
70                    let _ez_inner = #block;
71                    #[allow(unreachable_code)]
72                    ::ez::__::Ok(_ez_inner)
73                } };
74            }
75        };
76    }
77
78    Ok(tokens.into_iter().collect())
79}
80
81// Wraps a `ReturnType` in a `Result` with the indicated `error_type`.
82fn wrap_return_with_result(return_type: ReturnType, error_type: Path) -> ReturnType {
83    match &return_type {
84        ReturnType::Default => {
85            parse_quote_spanned! { return_type.span() => -> ::ez::__::Result<(), #error_type> }
86        },
87        ReturnType::Type(_, t) => {
88            parse_quote_spanned! { return_type.span() => -> ::ez::__::Result<#t, #error_type> }
89        },
90    }
91}
92
93pub fn throws(
94    attribute_tokens: TokenStream,
95    function_tokens: TokenStream,
96) -> Result<TokenStream, eyre::Report> {
97    let error_type: Path = if attribute_tokens.is_empty() {
98        parse_quote_spanned! { attribute_tokens.span() => ::ez::Error }
99    } else {
100        syn::parse2(attribute_tokens)?
101    };
102
103    let function_tokens = tryify_trailing_block(function_tokens)?;
104
105    let mut function: Function = syn::parse2(function_tokens.into_iter().collect())?;
106
107    function.sig.output = wrap_return_with_result(function.sig.output, error_type);
108
109    Ok(function.into_token_stream())
110}
111
112fn panics(function_tokens: TokenStream) -> Result<TokenStream, eyre::Report> {
113    let function_tokens = tryify_trailing_block(function_tokens)?;
114
115    let mut function: Function = syn::parse2(function_tokens.into_iter().collect())?;
116
117    let block = function.block.clone();
118    function.block = parse_quote_spanned! {
119        function.block.span() => {
120            #[allow(unused_mut, clippy::needless_late_init)]
121            let mut _ez_inner;
122            _ez_inner = move || -> ::ez::__::Result<_, ::ez::__::ErrorPanicker> #block;
123            _ez_inner().unwrap()
124        }
125    };
126
127    Ok(function.into_token_stream())
128}
129
130pub fn try_throws(
131    attribute_tokens: TokenStream,
132    function_tokens: TokenStream,
133) -> Result<TokenStream, eyre::Report> {
134    let has_block = trailing_block(&function_tokens)?.is_some();
135    let source: Function = syn::parse2(function_tokens.clone())?;
136    let args = parameters_to_arguments(&source.sig.inputs);
137
138    let panicking_ident = source.sig.ident;
139    let throwing_ident = format!("try_{}", panicking_ident);
140    let throwing_ident = Ident::new(&throwing_ident, panicking_ident.span());
141
142    let throwing = throws(
143        attribute_tokens,
144        if !has_block {
145            let mut panicking: Function = syn::parse2(function_tokens.clone())?;
146            panicking.block = parse_quote_spanned! { function_tokens.span() => {
147                Self::#panicking_ident(#args)
148            } };
149            panicking.into_token_stream()
150        } else {
151            function_tokens.clone()
152        },
153    )?;
154    let mut throwing: Function = syn::parse2(throwing)?;
155    throwing.sig.ident = throwing_ident.clone();
156
157    let panicking = panics(if !has_block {
158        let mut panicking: Function = syn::parse2(function_tokens.clone())?;
159        panicking.block = parse_quote_spanned! { function_tokens.span() => {
160            Self::#throwing_ident(#args)?
161        } };
162        panicking.into_token_stream()
163    } else {
164        function_tokens
165    })?;
166
167    Ok(parse_quote_spanned! {
168        panicking.span() =>
169        #throwing
170        #panicking
171    })
172}
173
174fn parameters_to_arguments(
175    parameters: &Punctuated<syn::FnArg, syn::Token![,]>,
176) -> Punctuated<syn::Ident, syn::Token![,]> {
177    parameters
178        .iter()
179        .map(|arg| match arg {
180            syn::FnArg::Receiver(receiver) => syn::Ident::new("self", receiver.span()),
181            syn::FnArg::Typed(arg) => match &*arg.pat {
182                syn::Pat::Ident(pat) => pat.ident.clone(),
183                _ => panic!("unsupported pattern in arguments"),
184            },
185        })
186        .collect()
187}
188
189pub fn main(
190    attribute_tokens: TokenStream,
191    function_tokens: TokenStream,
192) -> Result<TokenStream, eyre::Report> {
193    if !attribute_tokens.is_empty() {
194        eyre::bail!("#[ez::main] macro takes no arguments");
195    };
196
197    let function_tokens = tryify_trailing_block(function_tokens)?;
198    let mut inner_function: ItemFn = syn::parse2(function_tokens)?;
199    let mut outer_function = inner_function.clone();
200
201    // inner function must always take two arguments
202    match inner_function.sig.inputs.len() {
203        0 => {
204            inner_function
205                .sig
206                .inputs
207                .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
208            inner_function
209                .sig
210                .inputs
211                .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
212        },
213        1 => {
214            inner_function
215                .sig
216                .inputs
217                .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
218        },
219        2 => {},
220        _ => {
221            return Ok(quote_spanned! {inner_function.sig.inputs.span()=>
222                compile_error!("#[ez::main] function must have at most 2 arguments (for example, `fn main(args: Vec<String>, env: Vec<(String, String)>)`).");
223            }.into_token_stream())
224        },
225    }
226
227    inner_function.sig.output = wrap_return_with_result(
228        inner_function.sig.output.clone(),
229        parse_quote_spanned! { inner_function.sig.output.span() => ::ez::Error },
230    );
231
232    outer_function.sig.inputs = Punctuated::new();
233    outer_function.sig.output =
234        parse_quote_spanned! { outer_function.sig.output.span() => -> Result<(), ::ez::Error> };
235    outer_function.sig.asyncness = None;
236
237    if inner_function.sig.asyncness.is_some() {
238        let block = inner_function.block.clone();
239        inner_function.block = parse_quote_spanned! { inner_function.block.span() => {
240            ::ez::__::tokio::runtime::Builder::new_current_thread()
241                .enable_all()
242                .build()?
243                .block_on(async #block)
244        } };
245
246        inner_function.sig.asyncness = None;
247    }
248
249    inner_function.vis = Visibility::Inherited;
250    let ident = inner_function.sig.ident.clone();
251
252    outer_function.block = parse_quote_spanned! { outer_function.block.span() => {
253        #inner_function
254        ::ez::__::entry_point(env!("CARGO_CRATE_NAME"), #ident)
255    } };
256
257    Ok(outer_function.to_token_stream())
258}