1use {
2 proc_macro2::{Delimiter, Ident, TokenStream, TokenTree},
3 quote::{quote_spanned, ToTokens},
4 syn::{
5 fold::Fold, parse_quote_spanned, punctuated::Punctuated, spanned::Spanned, Block,
6 ExprAsync, ExprClosure, ExprReturn, ImplItemMethod, ItemFn, Path, ReturnType, Visibility,
7 },
8};
9
10type Function = ImplItemMethod;
13
14fn trailing_block(tokens: &TokenStream) -> Result<Option<Block>, eyre::Report> {
16 let mut tokens = Vec::from_iter(tokens.clone());
17
18 if let Some(trailing) = tokens.last_mut() {
19 if let TokenTree::Group(group) = &trailing {
20 if group.delimiter() == Delimiter::Brace {
21 return Ok(Some(syn::parse2(trailing.into_token_stream())?));
22 }
23 }
24 }
25
26 Ok(None)
27}
28
29fn wrap_returns_in_ok(block: Block) -> Block {
32 struct Folder;
33 impl Fold for Folder {
34 fn fold_expr_return(&mut self, expr: ExprReturn) -> ExprReturn {
35 let inner = expr.expr.clone();
36 parse_quote_spanned! { expr.span() =>
37 return ::ez::__::Ok(#inner)
38 }
39 }
40
41 fn fold_item_fn(&mut self, item_fn: ItemFn) -> ItemFn {
42 item_fn
43 }
44
45 fn fold_expr_closure(&mut self, expr_closure: ExprClosure) -> ExprClosure {
46 expr_closure
47 }
48
49 fn fold_expr_async(&mut self, expr_async: ExprAsync) -> ExprAsync {
50 expr_async
51 }
52 }
53
54 Folder.fold_block(block)
55}
56
57fn tryify_trailing_block(tokens: TokenStream) -> Result<TokenStream, eyre::Report> {
60 let mut tokens = Vec::from_iter(tokens);
61
62 if let Some(last) = tokens.last_mut() {
63 if let proc_macro2::TokenTree::Group(group) = last {
64 if group.delimiter() == proc_macro2::Delimiter::Brace {
65 let block: syn::Block = syn::parse2(last.clone().into_token_stream())?;
66 let block = wrap_returns_in_ok(block);
67 *last = parse_quote_spanned! { block.span() => {
68 #[allow(unused_imports)]
69 use ::ez::throw;
70 let _ez_inner = #block;
71 #[allow(unreachable_code)]
72 ::ez::__::Ok(_ez_inner)
73 } };
74 }
75 };
76 }
77
78 Ok(tokens.into_iter().collect())
79}
80
81fn wrap_return_with_result(return_type: ReturnType, error_type: Path) -> ReturnType {
83 match &return_type {
84 ReturnType::Default => {
85 parse_quote_spanned! { return_type.span() => -> ::ez::__::Result<(), #error_type> }
86 },
87 ReturnType::Type(_, t) => {
88 parse_quote_spanned! { return_type.span() => -> ::ez::__::Result<#t, #error_type> }
89 },
90 }
91}
92
93pub fn throws(
94 attribute_tokens: TokenStream,
95 function_tokens: TokenStream,
96) -> Result<TokenStream, eyre::Report> {
97 let error_type: Path = if attribute_tokens.is_empty() {
98 parse_quote_spanned! { attribute_tokens.span() => ::ez::Error }
99 } else {
100 syn::parse2(attribute_tokens)?
101 };
102
103 let function_tokens = tryify_trailing_block(function_tokens)?;
104
105 let mut function: Function = syn::parse2(function_tokens.into_iter().collect())?;
106
107 function.sig.output = wrap_return_with_result(function.sig.output, error_type);
108
109 Ok(function.into_token_stream())
110}
111
112fn panics(function_tokens: TokenStream) -> Result<TokenStream, eyre::Report> {
113 let function_tokens = tryify_trailing_block(function_tokens)?;
114
115 let mut function: Function = syn::parse2(function_tokens.into_iter().collect())?;
116
117 let block = function.block.clone();
118 function.block = parse_quote_spanned! {
119 function.block.span() => {
120 #[allow(unused_mut, clippy::needless_late_init)]
121 let mut _ez_inner;
122 _ez_inner = move || -> ::ez::__::Result<_, ::ez::__::ErrorPanicker> #block;
123 _ez_inner().unwrap()
124 }
125 };
126
127 Ok(function.into_token_stream())
128}
129
130pub fn try_throws(
131 attribute_tokens: TokenStream,
132 function_tokens: TokenStream,
133) -> Result<TokenStream, eyre::Report> {
134 let has_block = trailing_block(&function_tokens)?.is_some();
135 let source: Function = syn::parse2(function_tokens.clone())?;
136 let args = parameters_to_arguments(&source.sig.inputs);
137
138 let panicking_ident = source.sig.ident;
139 let throwing_ident = format!("try_{}", panicking_ident);
140 let throwing_ident = Ident::new(&throwing_ident, panicking_ident.span());
141
142 let throwing = throws(
143 attribute_tokens,
144 if !has_block {
145 let mut panicking: Function = syn::parse2(function_tokens.clone())?;
146 panicking.block = parse_quote_spanned! { function_tokens.span() => {
147 Self::#panicking_ident(#args)
148 } };
149 panicking.into_token_stream()
150 } else {
151 function_tokens.clone()
152 },
153 )?;
154 let mut throwing: Function = syn::parse2(throwing)?;
155 throwing.sig.ident = throwing_ident.clone();
156
157 let panicking = panics(if !has_block {
158 let mut panicking: Function = syn::parse2(function_tokens.clone())?;
159 panicking.block = parse_quote_spanned! { function_tokens.span() => {
160 Self::#throwing_ident(#args)?
161 } };
162 panicking.into_token_stream()
163 } else {
164 function_tokens
165 })?;
166
167 Ok(parse_quote_spanned! {
168 panicking.span() =>
169 #throwing
170 #panicking
171 })
172}
173
174fn parameters_to_arguments(
175 parameters: &Punctuated<syn::FnArg, syn::Token![,]>,
176) -> Punctuated<syn::Ident, syn::Token![,]> {
177 parameters
178 .iter()
179 .map(|arg| match arg {
180 syn::FnArg::Receiver(receiver) => syn::Ident::new("self", receiver.span()),
181 syn::FnArg::Typed(arg) => match &*arg.pat {
182 syn::Pat::Ident(pat) => pat.ident.clone(),
183 _ => panic!("unsupported pattern in arguments"),
184 },
185 })
186 .collect()
187}
188
189pub fn main(
190 attribute_tokens: TokenStream,
191 function_tokens: TokenStream,
192) -> Result<TokenStream, eyre::Report> {
193 if !attribute_tokens.is_empty() {
194 eyre::bail!("#[ez::main] macro takes no arguments");
195 };
196
197 let function_tokens = tryify_trailing_block(function_tokens)?;
198 let mut inner_function: ItemFn = syn::parse2(function_tokens)?;
199 let mut outer_function = inner_function.clone();
200
201 match inner_function.sig.inputs.len() {
203 0 => {
204 inner_function
205 .sig
206 .inputs
207 .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
208 inner_function
209 .sig
210 .inputs
211 .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
212 },
213 1 => {
214 inner_function
215 .sig
216 .inputs
217 .push(parse_quote_spanned! { inner_function.sig.inputs.span() => _: ::ez::__::IteratorDropper });
218 },
219 2 => {},
220 _ => {
221 return Ok(quote_spanned! {inner_function.sig.inputs.span()=>
222 compile_error!("#[ez::main] function must have at most 2 arguments (for example, `fn main(args: Vec<String>, env: Vec<(String, String)>)`).");
223 }.into_token_stream())
224 },
225 }
226
227 inner_function.sig.output = wrap_return_with_result(
228 inner_function.sig.output.clone(),
229 parse_quote_spanned! { inner_function.sig.output.span() => ::ez::Error },
230 );
231
232 outer_function.sig.inputs = Punctuated::new();
233 outer_function.sig.output =
234 parse_quote_spanned! { outer_function.sig.output.span() => -> Result<(), ::ez::Error> };
235 outer_function.sig.asyncness = None;
236
237 if inner_function.sig.asyncness.is_some() {
238 let block = inner_function.block.clone();
239 inner_function.block = parse_quote_spanned! { inner_function.block.span() => {
240 ::ez::__::tokio::runtime::Builder::new_current_thread()
241 .enable_all()
242 .build()?
243 .block_on(async #block)
244 } };
245
246 inner_function.sig.asyncness = None;
247 }
248
249 inner_function.vis = Visibility::Inherited;
250 let ident = inner_function.sig.ident.clone();
251
252 outer_function.block = parse_quote_spanned! { outer_function.block.span() => {
253 #inner_function
254 ::ez::__::entry_point(env!("CARGO_CRATE_NAME"), #ident)
255 } };
256
257 Ok(outer_function.to_token_stream())
258}