actix_error_derive/
lib.rs

1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3
4use proc_macro::TokenStream;
5/// 
6#[derive(FromVariant, Default)]
7#[darling(default, attributes(error))]
8struct Opts {
9    code: Option<u16>,
10    status: Option<String>,
11    kind: Option<String>,
12    msg: Option<String>,
13    ignore: bool,
14    group: bool,
15}
16
17
18/// This derive macro is used to convert an enum into an ApiError.  
19/// You can use it by adding the ```#[derive(AsApiError)]``` attribute to your enum.  
20/// By default, the kind is ```snake case```.  
21/// ```#[error(kind = "your_message_id")]``` attribute to the variant.  
22/// You can also add a custom code to the error by adding the ```#[error(code = 400)]``` attribute to the variant.  
23/// The following status are available and return the corresponding status code: 
24/// ``` rust
25/// fn get_status_code(error_kind: &str) -> u16 {
26///     match error_kind {
27///         "BadRequest" => 400,
28///         "Unauthorized" => 401,
29///         "Forbidden" => 403,
30///         "NotFound" => 404,
31///         "MethodNotAllowed" => 405,
32///         "Conflict" => 409,
33///         "Gone" => 410,
34///         "PayloadTooLarge" => 413,
35///         "UnsupportedMediaType" => 415,
36///         "UnprocessableEntity" => 422,
37///         "TooManyRequests" => 429,
38///         "InternalServerError" => 500,
39///         "NotImplemented" => 501,
40///         "BadGateway" => 502,
41///         "ServiceUnavailable" => 503,
42///         "GatewayTimeout" => 504,
43///         _ => 0, // Or some other default/error handling
44///     }
45/// }
46///
47/// // Example usage:
48/// let code = get_status_code("NotFound");
49/// assert_eq!(code, 404);
50/// let default_code = get_status_code("SomeOtherError");
51/// assert_eq!(default_code, 0);
52/// ```
53#[proc_macro_derive(AsApiError, attributes(error))]
54pub fn derive(input: TokenStream) -> TokenStream {
55    use convert_case::{Case, Casing};
56
57    // Parse the input tokens into a syntax tree
58    let ast = parse_macro_input!(input as DeriveInput); 
59    let ident_name = ast.ident;
60
61    // Get the variants
62    let enum_data = match ast.data {
63        syn::Data::Enum(data) => data,
64        _ => panic!("ApiError can only be derived for enums"),
65    };
66    let variants = enum_data.variants;
67
68    // Generate the variant's code 
69    let variants = variants.iter().map(|v| {
70        let ident = &v.ident;
71        let matching_wrapped = if let syn::Fields::Unnamed(u) = &v.fields {
72            let mut fields = String::new();
73            for (i, _) in u.unnamed.iter().enumerate() {
74                fields.push_str(&format!("a{}", i));
75                if i < u.unnamed.len() - 1 {
76                    fields.push_str(", ");
77                }
78            }
79            format!("({})", fields)
80        } else if let syn::Fields::Named(u) = &v.fields {
81            let mut fields = String::new();
82            for (i, field) in u.named.iter().enumerate() {
83                fields.push_str(field.ident.as_ref().unwrap().to_string().as_str());
84                if i < u.named.len() - 1 {
85                    fields.push_str(", ");
86                }
87            }
88            format!("{{ {} }}", fields)
89        } else {
90            String::new()
91        };
92
93        
94        // Get the tuple if it exists
95        let tuple = match &v.fields {
96            syn::Fields::Unnamed(u) => Some(u),
97            _ => None,
98        };
99        let struc = if let syn::Fields::Named(n) = &v.fields {
100            Some(n)
101        } else {
102            None
103        };
104            
105        let opts = Opts::from_variant(&v).expect("Couldn't get the options for the variant");
106        let code = if let Some(code) = opts.code {
107            code
108        } else {
109            if let Some(ref error_kind) = opts.status {
110                match error_kind.as_str() {
111                    "BadRequest" => 400,
112                    "Unauthorized" => 401,
113                    "Forbidden" => 403,
114                    "NotFound" => 404,
115                    "MethodNotAllowed" => 405,
116                    "Conflict" => 409,
117                    "Gone" => 410,
118                    "PayloadTooLarge" => 413,
119                    "UnsupportedMediaType" => 415,
120                    "UnprocessableEntity" => 422,
121                    "TooManyRequests" => 429,
122                    "InternalServerError" => 500,
123                    "NotImplemented" => 501,
124                    "BadGateway" => 502,
125                    "ServiceUnavailable" => 503,
126                    "GatewayTimeout" => 504,
127                    _ => panic!("Invalid kind for variant {}: {}", ident, error_kind),
128                }
129            } else {
130                500
131            }
132        };
133
134        
135        use actix_web::http::StatusCode;
136        if let Err(e) = StatusCode::from_u16(code) {
137            panic!("Invalid status code for variant {}: {}", ident, e);
138        }
139        let kind = opts.kind.unwrap_or_else(|| ident.to_string().to_case(Case::Snake));
140
141        // Get the messages for the variant
142        let mut message = "String::new()".to_owned();
143        if let Some(msg) = opts.msg {
144            message = if let Some(tuple) = tuple  {
145                // genrate a string like "format!(\"message\", self.0, self.1)"
146                // Where message is the msg attribute of the variant
147                // and self.0, self.1 are the tuple fields
148                let mut fields = String::new();
149                for (i, _) in tuple.unnamed.iter().enumerate() {
150                    fields.push_str(&format!("a{}", i));
151                    if i < tuple.unnamed.len() - 1 {
152                        fields.push_str(", ");
153                    }
154                }
155                format!("format!(\"{}\", {})", msg, fields)
156            } else if let Some(_) = struc {
157                format!("format!(\"{}\")", msg)
158            } else {
159                format!("\"{}\".to_owned()", msg)
160            };
161
162            if opts.ignore {
163                message = format!("\"{}\".to_owned()", msg);
164            }
165        }
166
167        // The list_vars variable and its associated logic can be removed because list_vars is never populated.
168
169        let api_error = if opts.group {
170            String::from("a0.as_api_error()")
171        } else {
172            format!("ApiError::new({code}, \"{kind}\", {message})", code = code, kind = kind, message = message)
173        };
174        
175        format!("
176            {ident_name}::{ident} {matching_wrapped} => {{ 
177                {api_error}
178            }},
179        ", )
180    });
181
182    // Implement the ApiError trait
183    let mut code = String::new();
184    code.push_str(&format!("impl AsApiErrorTrait for {ident_name} {{\n"));
185    code.push_str(" fn as_api_error(&self) -> ApiError {\n");
186    code.push_str("     match &self {\n");
187    for v in variants {
188        code.push_str(&v.to_string());
189    }
190    code.push_str("\n    }\n");
191    code.push_str("   }\n");
192    code.push_str("}\n");
193
194    code.push_str(&format!(r#"
195        impl std::fmt::Debug for {ident_name} {{
196            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
197                let api_error = self.as_api_error();
198                write!(f, "{{:?}}", api_error)
199            }}
200        }}
201    
202        impl std::fmt::Display for {ident_name} {{
203            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
204                let api_error = self.as_api_error();
205                write!(f, "{{}}", api_error)
206            }}
207        }}
208    
209        impl actix_web::ResponseError for {ident_name} {{
210            fn status_code(&self) -> actix_web::http::StatusCode {{
211                let api_error = self.as_api_error();
212                actix_web::http::StatusCode::from_u16(api_error.code).unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
213            }}
214        
215            fn error_response(&self) -> actix_web::HttpResponse {{
216                let api_error = self.as_api_error();
217                let status = actix_web::http::StatusCode::from_u16(api_error.code).unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
218                actix_web::HttpResponse::build(status).json(api_error)
219            }}
220        }}
221    "#));
222    code.parse().expect("Couldn't parse the code")
223}