1#![recursion_limit = "512"]
2#![doc = include_str!("../README.md")]
3
4extern crate proc_macro;
5
6use core::panic;
7
8use convert_case::{Case, Casing};
9use proc_macro::TokenStream;
10use proc_macro2::Span;
11use quote::quote;
12use std::collections::HashSet;
13use syn::{
14 parse::{Parse, ParseStream, Result},
15 parse_macro_input,
16 punctuated::Punctuated,
17 token::PathSep,
18 Ident, ImplItemFn, ItemFn, ItemImpl, PathArguments, PathSegment, ReturnType, Token, Type, TypePath, Visibility,
19};
20
21struct ErrorsetArgs {
22 visibility: Visibility,
23 module: Option<Ident>,
24}
25
26impl Parse for ErrorsetArgs {
27 fn parse(input: ParseStream) -> Result<Self> {
28 let mut module = None;
29
30 let visibility: Visibility = input.parse()?;
32 let lookahead = input.lookahead1();
34 if lookahead.peek(Token![mod]) {
35 input.parse::<Token![mod]>()?;
36 let mod_name: Ident = input.parse()?;
37 module = Some(mod_name);
38 }
39
40 Ok(ErrorsetArgs { visibility, module })
41 }
42}
43
44#[proc_macro_attribute]
45pub fn errorset(attr: TokenStream, item: TokenStream) -> TokenStream {
46 let args = parse_macro_input!(attr as ErrorsetArgs);
47 let input = parse_macro_input!(item as syn::Item);
48
49 match input {
50 syn::Item::Fn(item_fn) => handle_function(&args, item_fn),
51 syn::Item::Impl(item_impl) => handle_impl_block(&args, item_impl),
52 _ => panic!("errorset can only be applied to functions or impl blocks"),
53 }
54}
55
56struct Output {
57 enum_def: proc_macro2::TokenStream,
58 fn_def: proc_macro2::TokenStream,
59}
60
61fn process_fn(args: &ErrorsetArgs, item_fn: &ItemFn) -> Result<Option<Output>> {
62 let fn_name = &item_fn.sig.ident;
64 let enum_name = Ident::new(
65 &format!("{}Errors", fn_name.to_string().to_case(Case::Pascal)),
66 Span::call_site(),
67 );
68
69 let output_type = match &item_fn.sig.output {
71 ReturnType::Type(_, ty) => ty,
72 _ => {
73 return Err(syn::Error::new_spanned(
74 &item_fn.sig.output,
75 "Function must have a valid return type",
76 ))
77 }
78 };
79
80 let (new_return_type, err_types) = if let Type::Path(TypePath { path, .. }) = &**output_type {
81 if let Some(last_segment) = path.segments.last() {
82 if let PathArguments::AngleBracketed(ref params) = last_segment.arguments {
83 if params.args.len() != 2 {
84 return Err(syn::Error::new_spanned(
85 ¶ms.args,
86 "Expected exactly 2 generic arguments",
87 ));
88 }
89
90 match params.args.iter().nth(1).unwrap() {
91 syn::GenericArgument::Type(Type::Tuple(tuple)) => {
92 let mut punctuated = Punctuated::<PathSegment, PathSep>::new();
93 for seg in path.segments.iter() {
94 punctuated.push_value(seg.ident.clone().into());
95 if punctuated.len() < path.segments.len() {
97 punctuated.push_punct(PathSep::default());
98 }
99 }
100 let new_path = syn::Path {
101 leading_colon: path.leading_colon.clone(),
102 segments: punctuated,
103 };
104
105 let first_generic_arg = params.args.iter().next().unwrap();
108 let new_return_type = if let Some(module) = &args.module {
109 quote! {
110 #new_path<#first_generic_arg, #module::#enum_name>
111 }
112 } else {
113 quote! {
114 #new_path<#first_generic_arg, #enum_name>
115 }
116 };
117 let err_types = tuple.elems.clone();
118 (new_return_type, err_types)
119 }
120 syn::GenericArgument::Type(Type::Paren(_)) | syn::GenericArgument::Type(Type::Path(_)) => {
121 return Ok(None);
124 }
125 other => {
126 return Err(syn::Error::new_spanned(
127 other,
128 "Expected the second generic argument to be a tuple",
129 ));
130 }
131 }
132 } else {
133 return Err(syn::Error::new_spanned(
134 last_segment,
135 "Expected angle-bracketed generic arguments",
136 ));
137 }
138 } else {
139 return Err(syn::Error::new_spanned(
140 path,
141 "Expected a valid type path for the generic type",
142 ));
143 }
144 } else {
145 return Err(syn::Error::new_spanned(
146 output_type,
147 "Function must return a generic type with 2 parameters",
148 ));
149 };
150
151 let mut seen = HashSet::new();
153 let enum_variants = err_types
154 .iter()
155 .filter(|ty| match ty {
156 Type::Path(TypePath { path, .. }) => seen.insert(path.segments.last().unwrap().ident.to_string()),
157 _ => true,
158 })
159 .map(|ty| {
160 let ty_name = match ty {
161 Type::Path(TypePath { path, .. }) => path.segments.last().unwrap().ident.clone(),
162 _ => return quote! {}, };
164 quote! {
165 #[error(transparent)]
166 #ty_name(#[from] #ty),
167 }
168 });
169
170 let enum_vis = if args.module.is_some() {
172 syn::Visibility::Public(Default::default())
174 } else {
175 item_fn.vis.clone()
176 };
177 let enum_def = quote! {
178 #[derive(::thiserror::Error, Debug)]
179 #enum_vis enum #enum_name {
180 #(#enum_variants)*
181 }
182 };
183
184 let fn_sig = &item_fn.sig;
185 let fn_attrs = &item_fn.attrs;
186 let fn_vis = &item_fn.vis;
187 let fn_body = &item_fn.block;
188
189 let mut new_sig = fn_sig.clone();
190 new_sig.output = syn::parse2(quote! { -> #new_return_type }).unwrap();
191
192 let new_fn = quote! {
194 #(#fn_attrs)*
195 #fn_vis #new_sig
196 #fn_body
197 };
198
199 Ok(Some(Output { enum_def, fn_def: new_fn }))
200}
201
202fn handle_function(args: &ErrorsetArgs, item_fn: ItemFn) -> TokenStream {
203 match process_fn(args, &item_fn) {
204 Ok(Some(Output { enum_def, fn_def })) => {
205 if let Some(module) = &args.module {
206 let vis = &args.visibility;
207 quote! {
208 #vis mod #module {
209 use super::*;
210 #enum_def
211 }
212 #fn_def
213 }
214 } else {
215 quote! {
216 #enum_def
217 #fn_def
218 }
219 }
220 }
221 Ok(None) => quote! { #item_fn },
222 Err(e) => e.to_compile_error(),
223 }
224 .into()
225}
226
227fn handle_impl_block(args: &ErrorsetArgs, item_impl: ItemImpl) -> TokenStream {
228 let mut new_items = Vec::new();
229 let mut new_enums = Vec::new();
230
231 for item in item_impl.items {
232 if let syn::ImplItem::Fn(method) = &item {
233 let mut new_attrs = Vec::new();
234 let mut marked = false;
235
236 for attr in &method.attrs {
237 if attr.path().is_ident("errorset") {
238 if attr.meta.require_path_only().is_err() {
239 return syn::Error::new_spanned(
240 attr,
241 "errorset attribute must not have arguments inside impl blocks",
242 )
243 .to_compile_error()
244 .into();
245 }
246 marked = true;
247 } else {
248 new_attrs.push(attr.clone());
249 }
250 }
251
252 if !marked {
253 new_items.push(item);
254 continue;
255 }
256
257 let item_fn = ItemFn {
258 attrs: new_attrs,
259 vis: method.vis.clone(),
260 sig: method.sig.clone(),
261 block: Box::new(method.block.clone()),
262 };
263
264 match process_fn(args, &item_fn) {
265 Ok(Some(Output { enum_def, fn_def })) => {
266 let impl_item = syn::parse2::<ImplItemFn>(fn_def).expect("Invalid method reparse");
267 new_items.push(impl_item.into());
268 new_enums.push(enum_def);
269 }
270 Ok(None) => new_items.push(item),
271 Err(e) => return e.to_compile_error().into(),
272 }
273 } else {
274 new_items.push(item);
275 }
276 }
277
278 let new_impl_block = ItemImpl { items: new_items, ..item_impl };
279
280 if let Some(module) = &args.module {
281 if new_enums.is_empty() {
284 quote! {
285 #new_impl_block
286 }
287 } else {
288 let vis = &args.visibility;
289 quote! {
290 #vis mod #module {
291 use super::*;
292 #(#new_enums)*
293 }
294 #new_impl_block
295 }
296 }
297 } else {
298 quote! {
299 #(#new_enums)*
300 #new_impl_block
301 }
302 }
303 .into()
304}