Skip to main content

anchor_syn/parser/
error.rs

1use {
2    crate::{Error, ErrorArgs, ErrorCode},
3    syn::{
4        parse::{Parse, Result as ParseResult},
5        spanned::Spanned,
6        Expr,
7    },
8};
9
10// Removes any internal #[msg] attributes, as they are inert.
11pub fn parse(error_enum: &mut syn::ItemEnum, args: Option<ErrorArgs>) -> Result<Error, syn::Error> {
12    let ident = error_enum.ident.clone();
13    let mut last_discriminant = 0;
14    let codes: Vec<ErrorCode> = error_enum
15        .variants
16        .iter_mut()
17        .map(|variant: &mut syn::Variant| {
18            let msg = parse_error_attribute(variant)?;
19            let ident = variant.ident.clone();
20            let id = match &variant.discriminant {
21                None => last_discriminant,
22                Some((_, disc)) => match disc {
23                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
24                        syn::Lit::Int(int) => int.base10_parse::<u32>().map_err(|_| {
25                            syn::Error::new(int.span(), "error discriminant must be a valid u32")
26                        })?,
27                        _ => {
28                            return Err(syn::Error::new(
29                                expr_lit.lit.span(),
30                                "error discriminant must be an integer literal",
31                            ))
32                        }
33                    },
34                    _ => {
35                        return Err(syn::Error::new(
36                            disc.span(),
37                            "error discriminant must be an integer literal",
38                        ))
39                    }
40                },
41            };
42            last_discriminant = id + 1;
43
44            // Remove any non-doc attributes on the error variant.
45            variant.attrs.retain(|attr| attr.path().is_ident("doc"));
46
47            Ok(ErrorCode { id, ident, msg })
48        })
49        .collect::<Result<Vec<_>, syn::Error>>()?;
50    Ok(Error {
51        name: error_enum.ident.to_string(),
52        raw_enum: error_enum.clone(),
53        ident,
54        codes,
55        args,
56    })
57}
58
59fn parse_error_attribute(variant: &syn::Variant) -> Result<Option<String>, syn::Error> {
60    let attrs = variant
61        .attrs
62        .iter()
63        .filter(|attr| !attr.path().is_ident("doc"))
64        .collect::<Vec<_>>();
65    match attrs.len() {
66        0 => Ok(None),
67        1 => {
68            #[allow(
69                clippy::indexing_slicing,
70                reason = "inside match arm where attrs.len() == 1"
71            )]
72            let attr = &attrs[0];
73            if !attr.path().is_ident("msg") {
74                return Err(syn::Error::new(
75                    attr.span(),
76                    "use `#[msg(\"...\")]` to specify error strings",
77                ));
78            }
79
80            let g_stream = match &attr.meta {
81                syn::Meta::List(list) => list.tokens.clone(),
82                _ => {
83                    return Err(syn::Error::new(
84                        attr.span(),
85                        "expected `#[msg(\"message\")]`",
86                    ))
87                }
88            };
89
90            let msg = match g_stream.into_iter().next() {
91                None => {
92                    return Err(syn::Error::new(
93                        attr.span(),
94                        "`#[msg]` requires a message string",
95                    ))
96                }
97                Some(msg) => msg.to_string().replace('\"', ""),
98            };
99
100            Ok(Some(msg))
101        }
102        _ => Err(syn::Error::new(
103            variant.span(),
104            "too many attributes; use `#[msg(\"...\")]` to specify error strings",
105        )),
106    }
107}
108
109pub struct ErrorInput {
110    pub error_code: Expr,
111}
112
113impl Parse for ErrorInput {
114    fn parse(stream: syn::parse::ParseStream) -> ParseResult<Self> {
115        let error_code = stream.call(Expr::parse)?;
116        Ok(Self { error_code })
117    }
118}