Skip to main content

aip_193_derive/
lib.rs

1use darling::{FromDeriveInput, FromField, FromVariant, ast};
2use proc_macro::TokenStream;
3use proc_macro_crate::{FoundCrate, crate_name};
4use quote::quote;
5use syn::{DeriveInput, Expr, Ident, parse_macro_input};
6
7#[derive(Debug)]
8enum DomainValue {
9    String(String),
10    Function(Expr),
11}
12
13impl darling::FromMeta for DomainValue {
14    fn from_string(value: &str) -> darling::Result<Self> {
15        Ok(DomainValue::String(value.to_string()))
16    }
17
18    fn from_expr(expr: &Expr) -> darling::Result<Self> {
19        // Check if this is a string literal expression
20        if let Expr::Lit(expr_lit) = expr {
21            if let syn::Lit::Str(lit_str) = &expr_lit.lit {
22                return Ok(DomainValue::String(lit_str.value()));
23            }
24        }
25
26        // Otherwise treat it as a function path
27        Ok(DomainValue::Function(expr.clone()))
28    }
29}
30
31#[derive(Debug, FromDeriveInput)]
32#[darling(attributes(status), supports(enum_any))]
33struct StatusInput {
34    ident: Ident,
35    data: ast::Data<StatusVariant, ()>,
36    domain: DomainValue,
37    #[darling(default)]
38    into_response: bool,
39    #[darling(default = "default_true")]
40    use_display: bool,
41}
42
43fn default_true() -> bool {
44    true
45}
46
47#[derive(Debug, FromVariant)]
48#[darling(attributes(status))]
49struct StatusVariant {
50    ident: Ident,
51    fields: ast::Fields<StatusField>,
52    code: Ident,
53    #[darling(default)]
54    message: Option<String>,
55    #[darling(default)]
56    use_display: Option<bool>,
57}
58
59#[derive(Debug, FromField)]
60#[darling(attributes(status))]
61struct StatusField {
62    ident: Option<Ident>,
63    #[darling(default)]
64    metadata: bool,
65    #[darling(default)]
66    metadata_key: Option<String>,
67}
68
69fn get_crate_path() -> proc_macro2::TokenStream {
70    if let Ok(found) = crate_name("aip") {
71        return match found {
72            FoundCrate::Itself => quote!(crate::__private::errors),
73            FoundCrate::Name(name) => {
74                let ident = Ident::new(&name, proc_macro2::Span::call_site());
75                quote!(::#ident::__private::errors)
76            }
77        };
78    }
79
80    if let Ok(found) = crate_name("aip-193") {
81        return match found {
82            FoundCrate::Itself => quote!(crate),
83            FoundCrate::Name(name) => {
84                let ident = Ident::new(&name, proc_macro2::Span::call_site());
85                quote!(::#ident)
86            }
87        };
88    }
89
90    quote!(::aip_193)
91}
92
93#[proc_macro_derive(IntoStatus, attributes(status))]
94pub fn derive_into_status(input: TokenStream) -> TokenStream {
95    let input = parse_macro_input!(input as DeriveInput);
96
97    let parsed = match StatusInput::from_derive_input(&input) {
98        Ok(v) => v,
99        Err(e) => return e.write_errors().into(),
100    };
101
102    let expanded = generate_impl(&parsed);
103    TokenStream::from(expanded)
104}
105
106fn generate_impl(input: &StatusInput) -> proc_macro2::TokenStream {
107    let name = &input.ident;
108    let krate = get_crate_path();
109
110    let variants = match &input.data {
111        ast::Data::Enum(variants) => variants,
112        _ => panic!("IntoStatus only supports enums"),
113    };
114
115    let code_arms = generate_code_arms(name, variants, &krate);
116    let message_arms = generate_message_arms(name, variants, input.use_display);
117    let metadata_arms = generate_metadata_arms(name, variants);
118
119    let domain_impl = match &input.domain {
120        DomainValue::String(s) => quote! { #s },
121        DomainValue::Function(expr) => quote! { (#expr)() },
122    };
123
124    let into_status_impl = quote! {
125        impl #krate::__private::IntoStatus for #name {
126            fn code(&self) -> #krate::Code {
127                match self {
128                    #(#code_arms),*
129                }
130            }
131
132            fn message(&self) -> ::std::string::String {
133                match self {
134                    #(#message_arms),*
135                }
136            }
137
138            fn reason(&self) -> &str {
139                self.as_ref()
140            }
141
142            fn domain(&self) -> &str {
143                #domain_impl
144            }
145
146            fn metadata(&self) -> #krate::__private::HashMap<::std::string::String, ::std::string::String> {
147                match self {
148                    #(#metadata_arms),*
149                }
150            }
151        }
152    };
153
154    let into_response_impl = if input.into_response {
155        generate_into_response_impl(name, &krate)
156    } else {
157        quote! {}
158    };
159
160    quote! {
161        #into_status_impl
162        #into_response_impl
163    }
164}
165
166fn get_axum_path() -> proc_macro2::TokenStream {
167    if let Ok(found) = crate_name("axum") {
168        return match found {
169            FoundCrate::Itself => quote!(crate),
170            FoundCrate::Name(name) => {
171                let ident = Ident::new(&name, proc_macro2::Span::call_site());
172                quote!(::#ident)
173            }
174        };
175    }
176
177    if let Ok(found) = crate_name("axum-core") {
178        return match found {
179            FoundCrate::Itself => quote!(crate),
180            FoundCrate::Name(name) => {
181                let ident = Ident::new(&name, proc_macro2::Span::call_site());
182                quote!(::#ident)
183            }
184        };
185    }
186
187    quote!(::axum)
188}
189
190fn generate_into_response_impl(
191    name: &Ident,
192    krate: &proc_macro2::TokenStream,
193) -> proc_macro2::TokenStream {
194    let axum = get_axum_path();
195
196    quote! {
197        impl #axum::response::IntoResponse for #name {
198            fn into_response(self) -> #axum::response::Response {
199                use #krate::__private::IntoStatus as _;
200                let status = Status::from(self);
201                <#krate::__private::Status as #axum::response::IntoResponse>::into_response(status)
202            }
203        }
204    }
205}
206
207fn generate_code_arms(
208    enum_name: &Ident,
209    variants: &[StatusVariant],
210    krate: &proc_macro2::TokenStream,
211) -> Vec<proc_macro2::TokenStream> {
212    variants
213        .iter()
214        .map(|v| {
215            let code = &v.code;
216            let pattern = generate_pattern_ignore_fields(enum_name, &v.ident, &v.fields);
217            quote! {
218                #pattern => #krate::Code::#code
219            }
220        })
221        .collect()
222}
223
224fn generate_message_arms(
225    enum_name: &Ident,
226    variants: &[StatusVariant],
227    use_display: bool,
228) -> Vec<proc_macro2::TokenStream> {
229    variants
230        .iter()
231        .map(|v| {
232            let message_expr = if let Some(template) = &v.message {
233                let pattern = generate_pattern(enum_name, v);
234                let message = parse_message_template(template, &v.fields);
235                quote! { #pattern => #message }
236            } else {
237                let should_use_display = v.use_display.unwrap_or(use_display);
238                let pattern = generate_pattern_ignore_fields(enum_name, &v.ident, &v.fields);
239                if should_use_display {
240                    quote! { #pattern => ::std::string::ToString::to_string(self) }
241                } else {
242                    let default_msg = format!("{}", v.ident);
243                    quote! { #pattern => #default_msg.to_string() }
244                }
245            };
246
247            message_expr
248        })
249        .collect()
250}
251
252fn parse_message_template(
253    template: &str,
254    fields: &ast::Fields<StatusField>,
255) -> proc_macro2::TokenStream {
256    let field_names: Vec<String> = fields
257        .iter()
258        .filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
259        .collect();
260
261    let mut format_str = String::new();
262    let mut args: Vec<proc_macro2::TokenStream> = Vec::new();
263
264    let mut chars = template.chars().peekable();
265    while let Some(c) = chars.next() {
266        if c == '{' {
267            let mut field_name = String::new();
268            while let Some(&next) = chars.peek() {
269                if next == '}' {
270                    chars.next();
271                    break;
272                }
273                field_name.push(chars.next().unwrap());
274            }
275
276            if field_names.contains(&field_name) {
277                format_str.push_str("{}");
278                let field_ident = Ident::new(&field_name, proc_macro2::Span::call_site());
279                args.push(quote! { #field_ident });
280            } else {
281                format_str.push('{');
282                format_str.push_str(&field_name);
283                format_str.push('}');
284            }
285        } else {
286            format_str.push(c);
287        }
288    }
289
290    if args.is_empty() {
291        quote! { #template.to_string() }
292    } else {
293        quote! { format!(#format_str, #(#args),*) }
294    }
295}
296
297fn generate_metadata_arms(
298    enum_name: &Ident,
299    variants: &[StatusVariant],
300) -> Vec<proc_macro2::TokenStream> {
301    variants
302        .iter()
303        .map(|v| {
304            let pattern = generate_pattern_for_metadata(enum_name, v);
305
306            let metadata_fields: Vec<_> = v
307                .fields
308                .iter()
309                .filter(|f| f.metadata)
310                .filter_map(|f| {
311                    let field_name = f.ident.as_ref()?;
312                    let key = f
313                        .metadata_key
314                        .clone()
315                        .unwrap_or_else(|| field_name.to_string());
316                    Some(quote! {
317                        map.insert(#key.to_string(), #field_name.to_string());
318                    })
319                })
320                .collect();
321
322            quote! {
323                #pattern => {
324                    #[allow(unused_mut)]
325                    let mut map = ::std::collections::HashMap::new();
326                    #(#metadata_fields)*
327                    map
328                }
329            }
330        })
331        .collect()
332}
333
334fn generate_pattern(enum_name: &Ident, variant: &StatusVariant) -> proc_macro2::TokenStream {
335    let variant_name = &variant.ident;
336
337    match &variant.fields.style {
338        ast::Style::Unit => {
339            quote! { #enum_name::#variant_name }
340        }
341        ast::Style::Struct => {
342            let field_names: Vec<_> = variant
343                .fields
344                .iter()
345                .filter_map(|f| f.ident.as_ref())
346                .collect();
347            quote! { #enum_name::#variant_name { #(#field_names),* } }
348        }
349        ast::Style::Tuple => {
350            let bindings: Vec<_> = (0..variant.fields.len())
351                .map(|i| {
352                    let ident = Ident::new(&format!("_{}", i), proc_macro2::Span::call_site());
353                    quote! { #ident }
354                })
355                .collect();
356            quote! { #enum_name::#variant_name(#(#bindings),*) }
357        }
358    }
359}
360
361fn generate_pattern_ignore_fields(
362    enum_name: &Ident,
363    variant_name: &Ident,
364    fields: &ast::Fields<StatusField>,
365) -> proc_macro2::TokenStream {
366    match fields.style {
367        ast::Style::Unit => {
368            quote! { #enum_name::#variant_name }
369        }
370        ast::Style::Struct => {
371            quote! { #enum_name::#variant_name { .. } }
372        }
373        ast::Style::Tuple => {
374            quote! { #enum_name::#variant_name(..) }
375        }
376    }
377}
378
379fn generate_pattern_for_metadata(
380    enum_name: &Ident,
381    variant: &StatusVariant,
382) -> proc_macro2::TokenStream {
383    let variant_name = &variant.ident;
384
385    match &variant.fields.style {
386        ast::Style::Unit => {
387            quote! { #enum_name::#variant_name }
388        }
389        ast::Style::Struct => {
390            let metadata_fields: Vec<_> = variant
391                .fields
392                .iter()
393                .filter(|f| f.metadata)
394                .filter_map(|f| f.ident.as_ref())
395                .collect();
396
397            if metadata_fields.is_empty() {
398                quote! { #enum_name::#variant_name { .. } }
399            } else {
400                quote! { #enum_name::#variant_name { #(#metadata_fields),*, .. } }
401            }
402        }
403        ast::Style::Tuple => {
404            let has_metadata = variant.fields.iter().any(|f| f.metadata);
405
406            if !has_metadata {
407                quote! { #enum_name::#variant_name(..) }
408            } else {
409                let bindings: Vec<_> = variant
410                    .fields
411                    .iter()
412                    .enumerate()
413                    .map(|(i, f)| {
414                        let binding =
415                            Ident::new(&format!("_{}", i), proc_macro2::Span::call_site());
416                        if f.metadata {
417                            quote! { #binding }
418                        } else {
419                            quote! { _ }
420                        }
421                    })
422                    .collect();
423                quote! { #enum_name::#variant_name(#(#bindings),*) }
424            }
425        }
426    }
427}