actix_error_derive/
lib.rs1use darling::FromVariant;
2use syn::{parse_macro_input, DeriveInput};
3
4use proc_macro::TokenStream;
5#[derive(FromVariant, Default)]
7#[darling(default, attributes(error))]
8struct Opts {
9 code: Option<u16>,
10 status: Option<String>,
11 kind: Option<String>,
12 msg: Option<String>,
13 ignore: bool,
14 group: bool,
15}
16
17
18#[proc_macro_derive(AsApiError, attributes(error))]
54pub fn derive(input: TokenStream) -> TokenStream {
55 use convert_case::{Case, Casing};
56
57 let ast = parse_macro_input!(input as DeriveInput);
59 let ident_name = ast.ident;
60
61 let enum_data = match ast.data {
63 syn::Data::Enum(data) => data,
64 _ => panic!("ApiError can only be derived for enums"),
65 };
66 let variants = enum_data.variants;
67
68 let variants = variants.iter().map(|v| {
70 let ident = &v.ident;
71 let matching_wrapped = if let syn::Fields::Unnamed(u) = &v.fields {
72 let mut fields = String::new();
73 for (i, _) in u.unnamed.iter().enumerate() {
74 fields.push_str(&format!("a{}", i));
75 if i < u.unnamed.len() - 1 {
76 fields.push_str(", ");
77 }
78 }
79 format!("({})", fields)
80 } else if let syn::Fields::Named(u) = &v.fields {
81 let mut fields = String::new();
82 for (i, field) in u.named.iter().enumerate() {
83 fields.push_str(field.ident.as_ref().unwrap().to_string().as_str());
84 if i < u.named.len() - 1 {
85 fields.push_str(", ");
86 }
87 }
88 format!("{{ {} }}", fields)
89 } else {
90 String::new()
91 };
92
93
94 let tuple = match &v.fields {
96 syn::Fields::Unnamed(u) => Some(u),
97 _ => None,
98 };
99 let struc = if let syn::Fields::Named(n) = &v.fields {
100 Some(n)
101 } else {
102 None
103 };
104
105 let opts = Opts::from_variant(&v).expect("Couldn't get the options for the variant");
106 let code = if let Some(code) = opts.code {
107 code
108 } else {
109 if let Some(ref error_kind) = opts.status {
110 match error_kind.as_str() {
111 "BadRequest" => 400,
112 "Unauthorized" => 401,
113 "Forbidden" => 403,
114 "NotFound" => 404,
115 "MethodNotAllowed" => 405,
116 "Conflict" => 409,
117 "Gone" => 410,
118 "PayloadTooLarge" => 413,
119 "UnsupportedMediaType" => 415,
120 "UnprocessableEntity" => 422,
121 "TooManyRequests" => 429,
122 "InternalServerError" => 500,
123 "NotImplemented" => 501,
124 "BadGateway" => 502,
125 "ServiceUnavailable" => 503,
126 "GatewayTimeout" => 504,
127 _ => panic!("Invalid kind for variant {}: {}", ident, error_kind),
128 }
129 } else {
130 500
131 }
132 };
133
134
135 use actix_web::http::StatusCode;
136 if let Err(e) = StatusCode::from_u16(code) {
137 panic!("Invalid status code for variant {}: {}", ident, e);
138 }
139 let kind = opts.kind.unwrap_or_else(|| ident.to_string().to_case(Case::Snake));
140
141 let mut message = "String::new()".to_owned();
143 if let Some(msg) = opts.msg {
144 message = if let Some(tuple) = tuple {
145 let mut fields = String::new();
149 for (i, _) in tuple.unnamed.iter().enumerate() {
150 fields.push_str(&format!("a{}", i));
151 if i < tuple.unnamed.len() - 1 {
152 fields.push_str(", ");
153 }
154 }
155 format!("format!(\"{}\", {})", msg, fields)
156 } else if let Some(_) = struc {
157 format!("format!(\"{}\")", msg)
158 } else {
159 format!("\"{}\".to_owned()", msg)
160 };
161
162 if opts.ignore {
163 message = format!("\"{}\".to_owned()", msg);
164 }
165 }
166
167 let api_error = if opts.group {
170 String::from("a0.as_api_error()")
171 } else {
172 format!("ApiError::new({code}, \"{kind}\", {message})", code = code, kind = kind, message = message)
173 };
174
175 format!("
176 {ident_name}::{ident} {matching_wrapped} => {{
177 {api_error}
178 }},
179 ", )
180 });
181
182 let mut code = String::new();
184 code.push_str(&format!("impl AsApiErrorTrait for {ident_name} {{\n"));
185 code.push_str(" fn as_api_error(&self) -> ApiError {\n");
186 code.push_str(" match &self {\n");
187 for v in variants {
188 code.push_str(&v.to_string());
189 }
190 code.push_str("\n }\n");
191 code.push_str(" }\n");
192 code.push_str("}\n");
193
194 code.push_str(&format!(r#"
195 impl std::fmt::Debug for {ident_name} {{
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
197 let api_error = self.as_api_error();
198 write!(f, "{{:?}}", api_error)
199 }}
200 }}
201
202 impl std::fmt::Display for {ident_name} {{
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
204 let api_error = self.as_api_error();
205 write!(f, "{{}}", api_error)
206 }}
207 }}
208
209 impl actix_web::ResponseError for {ident_name} {{
210 fn status_code(&self) -> actix_web::http::StatusCode {{
211 let api_error = self.as_api_error();
212 actix_web::http::StatusCode::from_u16(api_error.code).unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR)
213 }}
214
215 fn error_response(&self) -> actix_web::HttpResponse {{
216 let api_error = self.as_api_error();
217 let status = actix_web::http::StatusCode::from_u16(api_error.code).unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
218 actix_web::HttpResponse::build(status).json(api_error)
219 }}
220 }}
221 "#));
222 code.parse().expect("Couldn't parse the code")
223}