miden_thiserror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::ParseStream;
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitInt, LitStr, Meta,
8    Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<&'a Attribute>,
14    pub backtrace: Option<&'a Attribute>,
15    pub from: Option<&'a Attribute>,
16    pub transparent: Option<Transparent<'a>>,
17}
18
19#[derive(Clone)]
20pub struct Display<'a> {
21    pub original: &'a Attribute,
22    pub fmt: LitStr,
23    pub args: TokenStream,
24    pub requires_fmt_machinery: bool,
25    pub has_bonus_display: bool,
26    pub implied_bounds: Set<(usize, Trait)>,
27}
28
29#[derive(Copy, Clone)]
30pub struct Transparent<'a> {
31    pub original: &'a Attribute,
32    pub span: Span,
33}
34
35#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
36pub enum Trait {
37    Debug,
38    Display,
39    Octal,
40    LowerHex,
41    UpperHex,
42    Pointer,
43    Binary,
44    LowerExp,
45    UpperExp,
46}
47
48pub fn get(input: &[Attribute]) -> Result<Attrs> {
49    let mut attrs = Attrs {
50        display: None,
51        source: None,
52        backtrace: None,
53        from: None,
54        transparent: None,
55    };
56
57    for attr in input {
58        if attr.path().is_ident("error") {
59            parse_error_attribute(&mut attrs, attr)?;
60        } else if attr.path().is_ident("source") {
61            attr.meta.require_path_only()?;
62            if attrs.source.is_some() {
63                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
64            }
65            attrs.source = Some(attr);
66        } else if attr.path().is_ident("backtrace") {
67            attr.meta.require_path_only()?;
68            if attrs.backtrace.is_some() {
69                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
70            }
71            attrs.backtrace = Some(attr);
72        } else if attr.path().is_ident("from") {
73            match attr.meta {
74                Meta::Path(_) => {}
75                Meta::List(_) | Meta::NameValue(_) => {
76                    // Assume this is meant for derive_more crate or something.
77                    continue;
78                }
79            }
80            if attrs.from.is_some() {
81                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
82            }
83            attrs.from = Some(attr);
84        }
85    }
86
87    Ok(attrs)
88}
89
90fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
91    syn::custom_keyword!(transparent);
92
93    attr.parse_args_with(|input: ParseStream| {
94        if let Some(kw) = input.parse::<Option<transparent>>()? {
95            if attrs.transparent.is_some() {
96                return Err(Error::new_spanned(
97                    attr,
98                    "duplicate #[error(transparent)] attribute",
99                ));
100            }
101            attrs.transparent = Some(Transparent {
102                original: attr,
103                span: kw.span,
104            });
105            return Ok(());
106        }
107
108        let fmt: LitStr = input.parse()?;
109
110        let ahead = input.fork();
111        ahead.parse::<Option<Token![,]>>()?;
112        let args = if ahead.is_empty() {
113            input.advance_to(&ahead);
114            TokenStream::new()
115        } else {
116            parse_token_expr(input, false)?
117        };
118
119        let requires_fmt_machinery = !args.is_empty();
120
121        let display = Display {
122            original: attr,
123            fmt,
124            args,
125            requires_fmt_machinery,
126            has_bonus_display: false,
127            implied_bounds: Set::new(),
128        };
129        if attrs.display.is_some() {
130            return Err(Error::new_spanned(
131                attr,
132                "only one #[error(...)] attribute is allowed",
133            ));
134        }
135        attrs.display = Some(display);
136        Ok(())
137    })
138}
139
140fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
141    let mut tokens = Vec::new();
142    while !input.is_empty() {
143        if begin_expr && input.peek(Token![.]) {
144            if input.peek2(Ident) {
145                input.parse::<Token![.]>()?;
146                begin_expr = false;
147                continue;
148            }
149            if input.peek2(LitInt) {
150                input.parse::<Token![.]>()?;
151                let int: Index = input.parse()?;
152                let ident = format_ident!("_{}", int.index, span = int.span);
153                tokens.push(TokenTree::Ident(ident));
154                begin_expr = false;
155                continue;
156            }
157        }
158
159        begin_expr = input.peek(Token![break])
160            || input.peek(Token![continue])
161            || input.peek(Token![if])
162            || input.peek(Token![in])
163            || input.peek(Token![match])
164            || input.peek(Token![mut])
165            || input.peek(Token![return])
166            || input.peek(Token![while])
167            || input.peek(Token![+])
168            || input.peek(Token![&])
169            || input.peek(Token![!])
170            || input.peek(Token![^])
171            || input.peek(Token![,])
172            || input.peek(Token![/])
173            || input.peek(Token![=])
174            || input.peek(Token![>])
175            || input.peek(Token![<])
176            || input.peek(Token![|])
177            || input.peek(Token![%])
178            || input.peek(Token![;])
179            || input.peek(Token![*])
180            || input.peek(Token![-]);
181
182        let token: TokenTree = if input.peek(token::Paren) {
183            let content;
184            let delimiter = parenthesized!(content in input);
185            let nested = parse_token_expr(&content, true)?;
186            let mut group = Group::new(Delimiter::Parenthesis, nested);
187            group.set_span(delimiter.span.join());
188            TokenTree::Group(group)
189        } else if input.peek(token::Brace) {
190            let content;
191            let delimiter = braced!(content in input);
192            let nested = parse_token_expr(&content, true)?;
193            let mut group = Group::new(Delimiter::Brace, nested);
194            group.set_span(delimiter.span.join());
195            TokenTree::Group(group)
196        } else if input.peek(token::Bracket) {
197            let content;
198            let delimiter = bracketed!(content in input);
199            let nested = parse_token_expr(&content, true)?;
200            let mut group = Group::new(Delimiter::Bracket, nested);
201            group.set_span(delimiter.span.join());
202            TokenTree::Group(group)
203        } else {
204            input.parse()?
205        };
206        tokens.push(token);
207    }
208    Ok(TokenStream::from_iter(tokens))
209}
210
211impl ToTokens for Display<'_> {
212    fn to_tokens(&self, tokens: &mut TokenStream) {
213        let fmt = &self.fmt;
214        let args = &self.args;
215
216        // Currently `write!(f, "text")` produces less efficient code than
217        // `f.write_str("text")`. We recognize the case when the format string
218        // has no braces and no interpolated values, and generate simpler code.
219        tokens.extend(if self.requires_fmt_machinery {
220            quote! {
221                ::core::write!(__formatter, #fmt #args)
222            }
223        } else {
224            quote! {
225                __formatter.write_str(#fmt)
226            }
227        });
228    }
229}
230
231impl ToTokens for Trait {
232    fn to_tokens(&self, tokens: &mut TokenStream) {
233        let trait_name = match self {
234            Trait::Debug => "Debug",
235            Trait::Display => "Display",
236            Trait::Octal => "Octal",
237            Trait::LowerHex => "LowerHex",
238            Trait::UpperHex => "UpperHex",
239            Trait::Pointer => "Pointer",
240            Trait::Binary => "Binary",
241            Trait::LowerExp => "LowerExp",
242            Trait::UpperExp => "UpperExp",
243        };
244        let ident = Ident::new(trait_name, Span::call_site());
245        tokens.extend(quote!(::core::fmt::#ident));
246    }
247}