1#![deny(missing_docs, rustdoc::all)]
2#![doc = include_str!("../README.md")]
3
4use proc_macro::TokenStream;
5use syn::__private::quote::quote;
6use syn::__private::{ToTokens, TokenStream2};
7use syn::parse::discouraged::Speculative;
8use syn::parse::{Parse, ParseStream};
9use syn::punctuated::Punctuated;
10use syn::token::{self, Colon, Comma};
11use syn::{bracketed, Attribute, Generics, Ident, LitStr, Result, Type, Visibility};
12
13#[proc_macro]
135pub fn error(tokens: TokenStream) -> TokenStream {
136 match error_impl(tokens.into()) {
137 Ok(toks) => toks.into(),
138 Err(err) => err.to_compile_error().into(),
139 }
140}
141
142fn error_impl(tokens: TokenStream2) -> Result<TokenStream2> {
143 let Error {
144 attrs,
145 vis,
146 name,
147 generics,
148 msg,
149 contents,
150 } = syn::parse2(tokens)?;
151
152 let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
153
154 let item_cfgs: Vec<&Attribute> = attrs
155 .iter()
156 .filter(|attr| attr.meta.path().is_ident("cfg"))
157 .collect();
158 let item_cfgs = quote! { #(#item_cfgs)* };
159
160 Ok(match contents {
161 ErrorContents::Unit => quote! {
162 #(#attrs)*
163 #[derive(Debug)]
164 #[non_exhaustive]
165 #vis struct #name #generics;
166
167 #item_cfgs
168 impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
169 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
170 f.write_str(#msg)
171 }
172 }
173
174 #item_cfgs
175 impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
176 },
177 ErrorContents::Struct { fields } => {
178 let cfgs: Vec<Vec<&Attribute>> = fields
179 .iter()
180 .map(|field| {
181 field
182 .attrs
183 .iter()
184 .filter(|attr| attr.meta.path().is_ident("cfg"))
185 .collect()
186 })
187 .collect();
188 let field_names: Vec<&Ident> = fields.iter().map(|field| &field.name).collect();
189 quote! {
190 #(#attrs)*
191 #[derive(Debug)]
192 #[non_exhaustive]
193 #vis struct #name #generics {
194 #fields
195 }
196
197 #item_cfgs
198 impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
199 #[allow(unused_variables)]
200 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
201 let Self {
202 #(
203 #(#cfgs)*
204 #field_names,
205 )*
206 } = self;
207 f.write_fmt(format_args!(#msg))
208 }
209 }
210
211 #item_cfgs
212 impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
213 }
214 }
215 ErrorContents::Enum { sources } => {
216 let source_attrs: Vec<&Vec<Attribute>> =
217 sources.iter().map(|source| &source.attrs).collect();
218 let cfgs: Vec<Vec<Attribute>> = source_attrs
219 .iter()
220 .map(|&attrs| {
221 let mut attrs = attrs.clone();
222 attrs.retain(|attr| attr.meta.path().is_ident("cfg"));
223 attrs
224 })
225 .collect();
226 let source_idents: Vec<&Ident> = sources.iter().map(|source| &source.ident).collect();
227 let write_msg = match &msg {
228 Some(msg) => quote! {
229 f.write_str(#msg)
230 },
231 None => {
232 quote! {
233 match self {
234 #(
235 #(#cfgs)*
236 Self::#source_idents(err) => ::std::fmt::Display::fmt(err, f),
237 )*
238 _ => unreachable!(),
239 }
240 }
241 }
242 };
243 quote! {
244 #(#attrs)*
245 #[derive(Debug)]
246 #[non_exhaustive]
247 #vis enum #name #generics {
248 #(
249 #(#source_attrs)*
250 #source_idents(#source_idents),
251 )*
252 }
253
254 #item_cfgs
255 impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
256 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
257 #write_msg
258 }
259 }
260
261 #item_cfgs
262 impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {
263 fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
264 Some(match self {
265 #(
266 #(#cfgs)*
267 #name::#source_idents(err) => err,
268 )*
269 _ => unreachable!(),
270 })
271 }
272 }
273
274 #(
275 #item_cfgs
276 #(#cfgs)*
277 impl #impl_gen ::std::convert::From<#source_idents> for #name #ty_gen #where_gen {
278 fn from(source: #source_idents) -> Self {
279 Self::#source_idents(source)
280 }
281 }
282 )*
283 }
284 }
285 ErrorContents::Array {
286 inner_attrs, inner, ..
287 } => quote! {
288 #(#attrs)*
289 #[derive(Debug)]
290 #[non_exhaustive]
291 #vis struct #name #generics (#(#inner_attrs)* pub Vec<#inner>);
292
293 #item_cfgs
294 impl #impl_gen ::std::fmt::Display for #name #ty_gen #where_gen {
295 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
296 f.write_str(#msg)?;
297 f.write_str(":")?;
298 for err in &self.0 {
299 f.write_str("\n")?;
300 f.write_fmt(format_args!("{}", err))?;
301 }
302 Ok(())
303 }
304 }
305
306 #item_cfgs
307 impl #impl_gen ::std::error::Error for #name #ty_gen #where_gen {}
308 },
309 })
310}
311
312struct Field {
313 attrs: Vec<Attribute>,
314 vis: Visibility,
315 name: Ident,
316 colon: Colon,
317 ty: Type,
318}
319
320impl Parse for Field {
321 fn parse(input: ParseStream) -> Result<Self> {
322 Ok(Self {
323 attrs: input.call(Attribute::parse_outer)?,
324 vis: input.parse()?,
325 name: input.parse()?,
326 colon: input.parse()?,
327 ty: input.parse()?,
328 })
329 }
330}
331
332impl ToTokens for Field {
333 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
334 for attr in &self.attrs {
335 attr.to_tokens(tokens);
336 }
337 self.vis.to_tokens(tokens);
338 self.name.to_tokens(tokens);
339 self.colon.to_tokens(tokens);
340 self.ty.to_tokens(tokens);
341 }
342}
343
344struct ErrorVariant {
345 attrs: Vec<Attribute>,
346 ident: Ident,
347}
348
349impl Parse for ErrorVariant {
350 fn parse(input: ParseStream) -> Result<Self> {
351 Ok(Self {
352 attrs: input.call(Attribute::parse_outer)?,
353 ident: input.parse()?,
354 })
355 }
356}
357
358enum ErrorContents {
359 Unit,
360 Struct {
361 fields: Punctuated<Field, Comma>,
362 },
363 Enum {
364 sources: Punctuated<ErrorVariant, Comma>,
365 },
366 Array {
367 inner_attrs: Vec<Attribute>,
368 inner: Type,
369 },
370}
371
372impl Parse for ErrorContents {
373 fn parse(input: ParseStream) -> Result<Self> {
374 if input.is_empty() {
375 return Ok(Self::Unit);
376 }
377
378 let fork = input.fork();
379 if let Ok(fields) = fork.call(Punctuated::parse_terminated) {
380 input.advance_to(&fork);
381 return Ok(Self::Struct { fields });
382 }
383
384 let fork = input.fork();
385 if let Ok(sources) = fork.call(Punctuated::parse_terminated) {
386 input.advance_to(&fork);
387 return Ok(Self::Enum { sources });
388 }
389
390 if input.peek(token::Bracket) {
391 let content;
392 let _ = bracketed!(content in input);
393 let attrs = content.call(Attribute::parse_outer)?;
394 let inner = content.parse::<Type>()?;
395 return Ok(Self::Array {
396 inner_attrs: attrs,
397 inner,
398 });
399 }
400
401 Err(input.error("invalid error contents"))
402 }
403}
404
405struct Error {
406 attrs: Vec<Attribute>,
407 vis: Visibility,
408 name: Ident,
409 generics: Generics,
410 msg: Option<LitStr>,
411 contents: ErrorContents,
412}
413
414impl Parse for Error {
415 fn parse(input: ParseStream) -> Result<Self> {
416 let attrs = input.call(Attribute::parse_outer)?;
417 let vis = input.parse::<Visibility>()?;
418 let name = input.parse::<Ident>()?;
419 let generics = input.parse::<Generics>()?;
420 let msg = input.parse::<LitStr>().ok();
421 let contents = input.parse::<ErrorContents>()?;
422
423 if msg.is_none() && !matches!(contents, ErrorContents::Enum { .. }) {
424 return Err(input.error("any non-enum error must have a display message"));
425 }
426
427 Ok(Self {
428 attrs,
429 vis,
430 name,
431 generics,
432 msg,
433 contents,
434 })
435 }
436}