Skip to main content

ironstone_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, Variant};
4
5#[proc_macro_derive(HasId)]
6pub fn has_id_derive(input: TokenStream) -> TokenStream {
7    // Parse the input tokens into a syntax tree
8    let ast = parse_macro_input!(input as DeriveInput);
9
10    // Get the name of the struct we're deriving for (e.g., "AccountRow")
11    let name = &ast.ident;
12
13    // Find the type of the `id` field
14    let id_type = match &ast.data {
15        Data::Struct(s) => match &s.fields {
16            Fields::Named(fields) => {
17                // Look for a field with the identifier "id"
18                let id_field = fields
19                    .named
20                    .iter()
21                    .find(|f| f.ident.as_ref().unwrap() == "id");
22                match id_field {
23                    Some(field) => &field.ty, // Get the type of the field
24                    None => panic!("Struct must have a field named `id` to derive `HasId`"),
25                }
26            }
27            _ => panic!("`HasId` can only be derived for structs with named fields"),
28        },
29        _ => panic!("`HasId` can only be derived for structs"),
30    };
31
32    // Generate the implementation of the HasId trait
33    let code = quote! {
34        impl HasId for #name {
35            type Id = #id_type;
36        }
37    };
38
39    // Return the generated code as a TokenStream
40    code.into()
41}
42
43#[proc_macro_derive(EnumTextType)]
44pub fn enum_text_type_derive(input: TokenStream) -> TokenStream {
45    // Parse the input tokens into a syntax tree
46    let ast = parse_macro_input!(input as DeriveInput);
47    // Get the name of the struct we're deriving for (e.g., "AccountRow")
48    let name = &ast.ident;
49
50    // Extract the enum variants from the AST
51    let variants = if let syn::Data::Enum(data) = ast.data {
52        data.variants
53    } else {
54        // This macro only works on enums, so we'll panic if it's not an enum.
55        unimplemented!("EnumTextType can only be used on enums");
56    };
57
58    // --- Logic to generate match arms for Display and FromStr ---
59
60    // A helper function to find the `#[serde(rename = "...")]` string
61    fn get_string_repr(variant: &Variant) -> String {
62        for attr in &variant.attrs {
63            if attr.path().is_ident("serde") {
64                if let Meta::List(meta_list) = &attr.meta {
65                    if let Ok(expr) = meta_list.parse_args::<syn::MetaNameValue>() {
66                        if expr.path.is_ident("rename") {
67                            if let syn::Expr::Lit(expr_lit) = expr.value {
68                                if let Lit::Str(lit_str) = expr_lit.lit {
69                                    return lit_str.value();
70                                }
71                            }
72                        }
73                    }
74                }
75            }
76        }
77
78        // Default to the lowercase version of the variant name
79        variant.ident.to_string().to_lowercase()
80    }
81
82    // Create the match arms for the `Display` implementation
83    let display_arms = variants.iter().map(|variant| {
84        let variant_ident = &variant.ident;
85        let string_repr = get_string_repr(variant);
86        quote! { Self::#variant_ident => write!(f, #string_repr) }
87    });
88
89    // Create the match arms for the `FromStr` implementation
90    let from_str_arms = variants.iter().map(|variant| {
91        let variant_ident = &variant.ident;
92        let string_repr = get_string_repr(variant);
93        quote! { #string_repr => Ok(Self::#variant_ident) }
94    });
95
96    // Generate the implementation of the HasId trait
97    let code = quote! {
98      // --- Display Implementation ---
99        impl std::fmt::Display for #name {
100            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101                match self {
102                    #(#display_arms),*
103                }
104            }
105        }
106
107        // This now adds the necessary trailing comma
108        impl std::str::FromStr for #name {
109            // Use a generic boxed error to avoid defining a new struct.
110            type Err = Box<dyn std::error::Error + Send + Sync + 'static>;
111            fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
112               match s {
113                    #(#from_str_arms,)*
114
115                    // For the wildcard case, create the error from a formatted string.
116                    // The `.into()` at the end converts the String into the Box<dyn Error>.
117                    _ => Err(format!("Invalid variant `{}` for enum `{}`", s, stringify!(#name)).into()),
118                }
119            }
120        }
121
122            // --- DECODE (From database TEXT to Rust Enum) ---
123        impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #name {
124            fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
125                let value_str = <&str as sqlx::decode::Decode<sqlx::Postgres>>::decode(value)?;
126                Ok(#name::from_str(value_str)?)
127            }
128        }
129
130        // --- ENCODE (From Rust Enum to database TEXT) ---
131        impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #name {
132            fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> std::result::Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
133                let s = self.to_string();
134                <&str as sqlx::encode::Encode<sqlx::Postgres>>::encode(&s, buf)
135            }
136        }
137
138        impl sqlx::Type<sqlx::Postgres> for #name {
139            fn type_info() -> sqlx::postgres::PgTypeInfo {
140                // This tells sqlx that our `$ty` enum corresponds to the `TEXT` type in PostgreSQL.
141                sqlx::postgres::PgTypeInfo::with_name("TEXT")
142            }
143        }
144    };
145
146    // Return the generated code as a TokenStream
147    code.into()
148}
149
150#[proc_macro_derive(HasActiveFilter)]
151pub fn has_active_filter_derive(input: TokenStream) -> TokenStream {
152    // 1. Parse the input tokens into a syntax tree
153    let input = parse_macro_input!(input as DeriveInput);
154    let name = input.ident;
155
156    // 2. Extract fields (assuming it's a struct with named fields)
157    let fields = match input.data {
158        Data::Struct(data) => match data.fields {
159            Fields::Named(fields) => fields.named,
160            _ => panic!("HasActiveFilter can only be derived for structs with named fields."),
161        },
162        _ => panic!("HasActiveFilter can only be derived for structs."),
163    };
164
165    // 3. Generate the OR'd checks for each field.
166    // The macro iterates over the fields and creates tokens like: self.field1.is_some() || self.field2.is_some()
167    let checks = fields
168        .iter()
169        .map(|f| {
170            let field_name = &f.ident;
171            quote! {
172                self.#field_name.is_some()
173            }
174        })
175        .collect::<Vec<_>>();
176
177    // 4. Combine into the final impl block
178    let expanded = quote! {
179        impl HasActiveFilter for #name {
180            fn has_active_filter(&self) -> bool {
181                // Combine all checks with '||'
182                #(#checks)||*
183            }
184        }
185    };
186
187    TokenStream::from(expanded)
188}