1use proc_macro::TokenStream;
23use quote::quote;
24use syn::{parse_macro_input, Attribute, Data, DeriveInput};
25
26#[proc_macro_derive(AlienErrorData, attributes(error))]
27pub fn derive_alien_error(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29 let name = input.ident;
30
31 let (
32 code_match_arms,
33 retryable_match_arms,
34 internal_match_arms,
35 http_status_code_match_arms,
36 context_match_arms,
37 message_match_arms,
38 retryable_inherit_match_arms,
39 internal_inherit_match_arms,
40 http_status_code_inherit_match_arms,
41 ) = match input.data {
42 Data::Enum(ref data_enum) => {
43 let mut code_arms = Vec::new();
44 let mut retryable_arms = Vec::new();
45 let mut internal_arms = Vec::new();
46 let mut http_status_code_arms = Vec::new();
47 let mut context_arms = Vec::new();
48 let mut message_arms = Vec::new();
49 let mut retryable_inherit_arms = Vec::new();
50 let mut internal_inherit_arms = Vec::new();
51 let mut http_status_code_inherit_arms = Vec::new();
52
53 for variant in &data_enum.variants {
54 let ident = &variant.ident;
55
56 let (
57 code_val,
58 retryable_val,
59 internal_val,
60 http_status_code_val,
61 message_val,
62 retryable_inherit,
63 internal_inherit,
64 http_status_code_inherit,
65 ) = parse_error_attrs(&variant.attrs, ident.to_string());
66
67 let matcher = if variant.fields.is_empty() {
68 quote! { #name::#ident }
69 } else {
70 quote! { #name::#ident { .. } }
71 };
72
73 let code_lit = code_val;
74 let retry_bool = retryable_val;
75 let internal_bool = internal_val;
76 let http_status_code_u16 = http_status_code_val;
77
78 code_arms.push(quote! { #matcher => #code_lit });
79 retryable_arms.push(quote! { #matcher => #retry_bool });
80 internal_arms.push(quote! { #matcher => #internal_bool });
81 http_status_code_arms.push(quote! { #matcher => #http_status_code_u16 });
82 retryable_inherit_arms.push(quote! { #matcher => #retryable_inherit });
83 internal_inherit_arms.push(quote! { #matcher => #internal_inherit });
84 http_status_code_inherit_arms
85 .push(quote! { #matcher => #http_status_code_inherit });
86
87 match &variant.fields {
89 syn::Fields::Named(fields_named) if !fields_named.named.is_empty() => {
90 let field_idents: Vec<_> = fields_named
93 .named
94 .iter()
95 .map(|f| f.ident.as_ref().unwrap())
96 .collect();
97 let matcher = quote! { #name::#ident { #( ref #field_idents ),* } };
98
99 let interpolated_message =
101 generate_message_interpolation(&message_val, &field_idents);
102 message_arms.push(quote! { #matcher => #interpolated_message });
103
104 context_arms.push(quote! { #matcher => {
105 let mut map = serde_json::Map::new();
106 #( map.insert(
107 stringify!(#field_idents).to_string(),
108 serde_json::to_value(#field_idents)
109 .expect(&format!("Failed to serialize field '{}' to JSON. This field must implement Serialize correctly.", stringify!(#field_idents)))
110 ); )*
111 Some(serde_json::Value::Object(map))
112 } });
113 }
114 _ => {
115 let matcher = if variant.fields.is_empty() {
116 quote! { #name::#ident }
117 } else {
118 quote! { #name::#ident { .. } }
119 };
120 message_arms.push(quote! { #matcher => #message_val.to_string() });
121 context_arms.push(quote! { #matcher => None });
122 }
123 }
124 }
125 (
126 code_arms,
127 retryable_arms,
128 internal_arms,
129 http_status_code_arms,
130 context_arms,
131 message_arms,
132 retryable_inherit_arms,
133 internal_inherit_arms,
134 http_status_code_inherit_arms,
135 )
136 }
137 _ => {
138 return quote! { compile_error!("AlienErrorData can only be derived for enums"); }
139 .into();
140 }
141 };
142
143 let expanded = quote! {
144 impl alien_error::AlienErrorData for #name {
145 fn code(&self) -> &'static str {
146 match self {
147 #(#code_match_arms),*
148 }
149 }
150 fn retryable(&self) -> bool {
151 match self {
152 #(#retryable_match_arms),*
153 }
154 }
155 fn internal(&self) -> bool {
156 match self {
157 #(#internal_match_arms),*
158 }
159 }
160 fn http_status_code(&self) -> u16 {
161 match self {
162 #(#http_status_code_match_arms),*
163 }
164 }
165 fn message(&self) -> String {
166 match self {
167 #(#message_match_arms),*
168 }
169 }
170 fn context(&self) -> Option<serde_json::Value> {
171 match self {
172 #(#context_match_arms),*
173 }
174 }
175 fn retryable_inherit(&self) -> Option<bool> {
176 match self {
177 #(#retryable_inherit_match_arms),*
178 }
179 }
180 fn internal_inherit(&self) -> Option<bool> {
181 match self {
182 #(#internal_inherit_match_arms),*
183 }
184 }
185 fn http_status_code_inherit(&self) -> Option<u16> {
186 match self {
187 #(#http_status_code_inherit_match_arms),*
188 }
189 }
190 }
191
192 impl std::fmt::Display for #name {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "{}", self.message())
195 }
196 }
197 };
198
199 TokenStream::from(expanded)
200}
201
202fn parse_error_attrs(
203 attrs: &[Attribute],
204 default_code: String,
205) -> (
206 proc_macro2::TokenStream,
207 proc_macro2::TokenStream,
208 proc_macro2::TokenStream,
209 proc_macro2::TokenStream,
210 String,
211 proc_macro2::TokenStream,
212 proc_macro2::TokenStream,
213 proc_macro2::TokenStream,
214) {
215 let mut code = default_code;
216 let mut retryable: Option<String> = None;
217 let mut internal: Option<String> = None;
218 let mut http_status_code: Option<String> = None;
219 let mut message: Option<String> = None;
220
221 for attr in attrs {
222 if !attr.path().is_ident("error") {
223 continue;
224 }
225 if let Err(e) = attr.parse_nested_meta(|meta| {
226 if meta.path.is_ident("code") {
227 let lit: syn::LitStr = meta.value()?.parse()?;
228 code = lit.value();
229 Ok(())
230 } else if meta.path.is_ident("retryable") {
231 let lit: syn::LitStr = meta.value()?.parse()?;
232 retryable = Some(lit.value());
233 Ok(())
234 } else if meta.path.is_ident("internal") {
235 let lit: syn::LitStr = meta.value()?.parse()?;
236 internal = Some(lit.value());
237 Ok(())
238 } else if meta.path.is_ident("http_status_code") {
239 let value = meta.value()?;
241
242 let lit: syn::Lit = value.parse()?;
244
245 match lit {
246 syn::Lit::Str(lit_str) => {
247 http_status_code = Some(lit_str.value());
249 }
250 syn::Lit::Int(lit_int) => {
251 let parsed_value = lit_int.base10_parse::<u16>()?;
253 http_status_code = Some(parsed_value.to_string());
254 }
255 _ => {
256 return Err(
257 meta.error("http_status_code must be a string or integer literal")
258 );
259 }
260 }
261 Ok(())
262 } else if meta.path.is_ident("message") {
263 let lit: syn::LitStr = meta.value()?.parse()?;
264 message = Some(lit.value());
265 Ok(())
266 } else {
267 Err(meta.error("unsupported error attribute key"))
268 }
269 }) {
270 return (
272 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
273 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
274 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
275 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
276 String::new(),
277 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
278 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
279 syn::Error::new(e.span(), e.to_string()).to_compile_error(),
280 );
281 }
282 }
283
284 macro_rules! parse_flag {
286 ($val:expr,$name:expr) => {
287 match $val {
288 Some(ref s) if s == "true" => quote! { true },
289 Some(ref s) if s == "false" => quote! { false },
290 Some(ref s) if s == "inherit" => quote! { false }, Some(ref _other) => syn::Error::new(proc_macro2::Span::call_site(), format!("{} must be \"true\", \"false\" or \"inherit\"", $name)).to_compile_error(),
292 None => syn::Error::new(proc_macro2::Span::call_site(), format!("{}=\"...\" is required in #[error(...)]", $name)).to_compile_error(),
293 }
294 };
295 }
296
297 macro_rules! parse_inherit_flag {
299 ($val:expr) => {
300 match $val {
301 Some(ref s) if s == "inherit" => quote! { None },
302 Some(ref s) if s == "true" => quote! { Some(true) },
303 Some(ref s) if s == "false" => quote! { Some(false) },
304 Some(_) => quote! { Some(false) }, None => syn::Error::new(proc_macro2::Span::call_site(), "flag is required")
306 .to_compile_error(),
307 }
308 };
309 }
310
311 let retry_ts = parse_flag!(retryable.clone(), "retryable");
312 let internal_ts = parse_flag!(internal.clone(), "internal");
313 let retryable_inherit_ts = parse_inherit_flag!(retryable);
314 let internal_inherit_ts = parse_inherit_flag!(internal);
315
316 let code_ts = {
317 let lit = syn::LitStr::new(&code, proc_macro2::Span::call_site());
318 quote! { #lit }
319 };
320
321 let (http_status_code_ts, http_status_code_inherit_ts) = match http_status_code {
323 Some(ref s) if s == "inherit" => {
324 (quote! { 500 }, quote! { None })
326 }
327 Some(ref s) => {
328 match s.parse::<u16>() {
330 Ok(status_code) => (quote! { #status_code }, quote! { Some(#status_code) }),
331 Err(_) => (
332 syn::Error::new(
333 proc_macro2::Span::call_site(),
334 "http_status_code must be a number or \"inherit\"",
335 )
336 .to_compile_error(),
337 syn::Error::new(
338 proc_macro2::Span::call_site(),
339 "http_status_code must be a number or \"inherit\"",
340 )
341 .to_compile_error(),
342 ),
343 }
344 }
345 None => {
346 (quote! { 500 }, quote! { Some(500) })
348 }
349 };
350
351 let message_str = message.unwrap_or_else(|| code.clone());
352
353 (
354 code_ts,
355 retry_ts,
356 internal_ts,
357 http_status_code_ts,
358 message_str,
359 retryable_inherit_ts,
360 internal_inherit_ts,
361 http_status_code_inherit_ts,
362 )
363}
364
365fn generate_message_interpolation(
366 message_template: &str,
367 field_idents: &[&syn::Ident],
368) -> proc_macro2::TokenStream {
369 if field_idents.is_empty() {
373 quote! { #message_template.to_string() }
374 } else {
375 let used_fields: Vec<&syn::Ident> = field_idents
377 .iter()
378 .filter(|field| {
379 let field_name = field.to_string();
380 message_template.contains(&format!("{{{}", field_name))
381 })
382 .cloned()
383 .collect();
384
385 if used_fields.is_empty() {
386 quote! { #message_template.to_string() }
388 } else {
389 quote! { format!(#message_template, #(#used_fields = #used_fields),*) }
391 }
392 }
393}