1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit, Meta, Variant};
4
5#[proc_macro_derive(HasId)]
6pub fn has_id_derive(input: TokenStream) -> TokenStream {
7 let ast = parse_macro_input!(input as DeriveInput);
9
10 let name = &ast.ident;
12
13 let id_type = match &ast.data {
15 Data::Struct(s) => match &s.fields {
16 Fields::Named(fields) => {
17 let id_field = fields
19 .named
20 .iter()
21 .find(|f| f.ident.as_ref().unwrap() == "id");
22 match id_field {
23 Some(field) => &field.ty, None => panic!("Struct must have a field named `id` to derive `HasId`"),
25 }
26 }
27 _ => panic!("`HasId` can only be derived for structs with named fields"),
28 },
29 _ => panic!("`HasId` can only be derived for structs"),
30 };
31
32 let code = quote! {
34 impl HasId for #name {
35 type Id = #id_type;
36 }
37 };
38
39 code.into()
41}
42
43#[proc_macro_derive(EnumTextType)]
44pub fn enum_text_type_derive(input: TokenStream) -> TokenStream {
45 let ast = parse_macro_input!(input as DeriveInput);
47 let name = &ast.ident;
49
50 let variants = if let syn::Data::Enum(data) = ast.data {
52 data.variants
53 } else {
54 unimplemented!("EnumTextType can only be used on enums");
56 };
57
58 fn get_string_repr(variant: &Variant) -> String {
62 for attr in &variant.attrs {
63 if attr.path().is_ident("serde") {
64 if let Meta::List(meta_list) = &attr.meta {
65 if let Ok(expr) = meta_list.parse_args::<syn::MetaNameValue>() {
66 if expr.path.is_ident("rename") {
67 if let syn::Expr::Lit(expr_lit) = expr.value {
68 if let Lit::Str(lit_str) = expr_lit.lit {
69 return lit_str.value();
70 }
71 }
72 }
73 }
74 }
75 }
76 }
77
78 variant.ident.to_string().to_lowercase()
80 }
81
82 let display_arms = variants.iter().map(|variant| {
84 let variant_ident = &variant.ident;
85 let string_repr = get_string_repr(variant);
86 quote! { Self::#variant_ident => write!(f, #string_repr) }
87 });
88
89 let from_str_arms = variants.iter().map(|variant| {
91 let variant_ident = &variant.ident;
92 let string_repr = get_string_repr(variant);
93 quote! { #string_repr => Ok(Self::#variant_ident) }
94 });
95
96 let code = quote! {
98 impl std::fmt::Display for #name {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 #(#display_arms),*
103 }
104 }
105 }
106
107 impl std::str::FromStr for #name {
109 type Err = Box<dyn std::error::Error + Send + Sync + 'static>;
111 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
112 match s {
113 #(#from_str_arms,)*
114
115 _ => Err(format!("Invalid variant `{}` for enum `{}`", s, stringify!(#name)).into()),
118 }
119 }
120 }
121
122 impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #name {
124 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
125 let value_str = <&str as sqlx::decode::Decode<sqlx::Postgres>>::decode(value)?;
126 Ok(#name::from_str(value_str)?)
127 }
128 }
129
130 impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #name {
132 fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> std::result::Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
133 let s = self.to_string();
134 <&str as sqlx::encode::Encode<sqlx::Postgres>>::encode(&s, buf)
135 }
136 }
137
138 impl sqlx::Type<sqlx::Postgres> for #name {
139 fn type_info() -> sqlx::postgres::PgTypeInfo {
140 sqlx::postgres::PgTypeInfo::with_name("TEXT")
142 }
143 }
144 };
145
146 code.into()
148}
149
150#[proc_macro_derive(HasActiveFilter)]
151pub fn has_active_filter_derive(input: TokenStream) -> TokenStream {
152 let input = parse_macro_input!(input as DeriveInput);
154 let name = input.ident;
155
156 let fields = match input.data {
158 Data::Struct(data) => match data.fields {
159 Fields::Named(fields) => fields.named,
160 _ => panic!("HasActiveFilter can only be derived for structs with named fields."),
161 },
162 _ => panic!("HasActiveFilter can only be derived for structs."),
163 };
164
165 let checks = fields
168 .iter()
169 .map(|f| {
170 let field_name = &f.ident;
171 quote! {
172 self.#field_name.is_some()
173 }
174 })
175 .collect::<Vec<_>>();
176
177 let expanded = quote! {
179 impl HasActiveFilter for #name {
180 fn has_active_filter(&self) -> bool {
181 #(#checks)||*
183 }
184 }
185 };
186
187 TokenStream::from(expanded)
188}