Skip to main content

anchor_syn/codegen/
error.rs

1use {crate::Error, quote::quote};
2
3pub fn generate(error: Error) -> proc_macro2::TokenStream {
4    let error_enum = &error.raw_enum;
5    let enum_name = &error.ident;
6    // Each arm of the `match` statement for implementing `std::fmt::Display`
7    // on the user defined error code.
8    let display_variant_dispatch: Vec<proc_macro2::TokenStream> = error
9        .raw_enum
10        .variants
11        .iter()
12        .enumerate()
13        .map(|(idx, variant)| {
14            let ident = &variant.ident;
15            let error_code = &error.codes[idx];
16            let display_msg = match &error_code.msg {
17                None => {
18                    quote! {
19                        <Self as std::fmt::Debug>::fmt(self, fmt)
20                    }
21                }
22                Some(msg) => {
23                    quote! {
24                        write!(fmt, #msg)
25                    }
26                }
27            };
28            quote! {
29                #enum_name::#ident => #display_msg
30            }
31        })
32        .collect();
33
34    // Each arm of the `match` statement for implementing the `name` function
35    // on the user defined error code.
36    let name_variant_dispatch: Vec<proc_macro2::TokenStream> = error
37        .raw_enum
38        .variants
39        .iter()
40        .map(|variant| {
41            let ident = &variant.ident;
42            let ident_name = ident.to_string();
43            quote! {
44                #enum_name::#ident => #ident_name.to_string()
45            }
46        })
47        .collect();
48
49    let offset = match &error.args {
50        None => quote! { anchor_lang::error::ERROR_CODE_OFFSET},
51        Some(args) => {
52            let offset = &args.offset;
53            quote! { #offset }
54        }
55    };
56
57    let ret = quote! {
58        #[derive(std::fmt::Debug, Clone, Copy)]
59        #[repr(u32)]
60        #error_enum
61
62        impl #enum_name {
63            /// Gets the name of this [#enum_name].
64            pub fn name(&self) -> String {
65                match self {
66                    #(#name_variant_dispatch),*
67                }
68            }
69        }
70
71        impl From<#enum_name> for u32 {
72            fn from(e: #enum_name) -> u32 {
73                e as u32 + #offset
74            }
75        }
76
77        impl From<#enum_name> for anchor_lang::error::Error {
78            fn from(error_code: #enum_name) -> anchor_lang::error::Error {
79                anchor_lang::error::Error::from(
80                    anchor_lang::error::AnchorError {
81                        error_name: error_code.name(),
82                        error_code_number: error_code.into(),
83                        error_msg: error_code.to_string(),
84                        error_origin: None,
85                        compared_values: None
86                    }
87                )
88            }
89        }
90
91        impl std::fmt::Display for #enum_name {
92            fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
93                match self {
94                    #(#display_variant_dispatch),*
95                }
96            }
97        }
98    };
99
100    #[cfg(feature = "idl-build")]
101    {
102        let idl_print = crate::idl::gen_idl_print_fn_error(&error);
103        return quote! {
104            #ret
105            #idl_print
106        };
107    };
108
109    #[allow(unreachable_code)]
110    ret
111}