Skip to main content

pg_query_mapper_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, GenericArgument, PathArguments, Type};
4
5#[proc_macro_derive(PgQueryMapper, attributes(pg_mapper))]
6pub fn pg_query_mapper_derive(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let struct_ident = input.ident;
9    let mapper_ident = format_ident!("{}Mapper", struct_ident);
10
11    let mut alias_prefix = String::new();
12
13    // Parse struct attributes
14    for attr in &input.attrs {
15        if attr.path().is_ident("pg_mapper") {
16            let res = attr.parse_nested_meta(|meta| {
17                if meta.path.is_ident("alias_prefix") {
18                    let value = meta.value()?;
19                    let s: syn::LitStr = value.parse()?;
20                    alias_prefix = s.value();
21                    Ok(())
22                } else {
23                    Err(meta.error("unsupported struct attribute for pg_mapper"))
24                }
25            });
26            if let Err(err) = res {
27                return TokenStream::from(err.into_compile_error());
28            }
29        }
30    }
31
32    let fields = match input.data {
33        Data::Struct(ref data) => match data.fields {
34            syn::Fields::Named(ref fields) => &fields.named,
35            _ => panic!("PgQueryMapper only supports structs with named fields"),
36        },
37        _ => panic!("PgQueryMapper only supports structs"),
38    };
39
40    let mut mapper_fields = Vec::new();
41    let mut new_vars = Vec::new();
42    let mut new_match_arms = Vec::new();
43    let mut new_struct_fields = Vec::new();
44    let mut map_fields = Vec::new();
45
46    let mut first_required_idx_ident = None;
47    let mut first_required_ty = None;
48
49    for field in fields {
50        let field_ident = field.ident.as_ref().unwrap();
51        let ty = &field.ty;
52        let idx_ident = format_ident!("{}_idx", field_ident);
53
54        let mut rename = None;
55        let mut is_json = false;
56        let mut is_flatten = false;
57        let mut is_skip = false;
58
59        for attr in &field.attrs {
60            if attr.path().is_ident("pg_mapper") {
61                let res = attr.parse_nested_meta(|meta| {
62                    if meta.path.is_ident("rename") {
63                        let value = meta.value()?;
64                        let s: syn::LitStr = value.parse()?;
65                        rename = Some(s.value());
66                        Ok(())
67                    } else if meta.path.is_ident("json") {
68                        is_json = true;
69                        Ok(())
70                    } else if meta.path.is_ident("flatten") {
71                        is_flatten = true;
72                        Ok(())
73                    } else if meta.path.is_ident("skip") {
74                        is_skip = true;
75                        Ok(())
76                    } else {
77                        Err(meta.error("unsupported field attribute for pg_mapper"))
78                    }
79                });
80                if let Err(err) = res {
81                    return TokenStream::from(err.into_compile_error());
82                }
83            }
84        }
85
86        if is_skip {
87            map_fields.push(quote! {
88                #field_ident: Default::default()
89            });
90            continue;
91        }
92
93        let is_opt = is_field_wrapper(ty);
94
95        if is_flatten {
96            let inner_ty = if is_opt {
97                extract_inner_type(ty).unwrap_or(ty)
98            } else {
99                ty
100            };
101
102            let nested_mapper_ident = format_ident!("{}Mapper", get_type_ident(inner_ty).unwrap());
103            let mapper_field_ident = format_ident!("{}_mapper", field_ident);
104
105            mapper_fields.push(quote! {
106                #mapper_field_ident: #nested_mapper_ident
107            });
108
109            new_struct_fields.push(quote! {
110                #mapper_field_ident: #nested_mapper_ident::new(columns)?
111            });
112
113            if is_opt {
114                map_fields.push(quote! {
115                    #field_ident: self.#mapper_field_ident.map_optional(row)?
116                });
117            } else {
118                map_fields.push(quote! {
119                    #field_ident: self.#mapper_field_ident.map(row)?
120                });
121            }
122            continue;
123        }
124
125        let col_name = rename
126            .clone()
127            .unwrap_or_else(|| format!("{}{}", alias_prefix, field_ident));
128
129        if !is_opt && !is_json {
130            // Required field
131            mapper_fields.push(quote! { #idx_ident: usize });
132            new_vars.push(quote! { let mut #idx_ident = None; });
133            new_match_arms.push(quote! { #col_name => #idx_ident = Some(idx), });
134            new_struct_fields.push(quote! {
135                #idx_ident: #idx_ident.ok_or_else(|| pg_query_mapper::MapperError::MissingColumn(#col_name.into()))?
136            });
137
138            map_fields.push(quote! {
139                #field_ident: row.try_get(self.#idx_ident)?
140            });
141
142            if first_required_idx_ident.is_none() {
143                first_required_idx_ident = Some(idx_ident.clone());
144                first_required_ty = Some(ty.clone());
145            }
146        } else {
147            // Optional or JSON field
148            mapper_fields.push(quote! { #idx_ident: Option<usize> });
149            new_vars.push(quote! { let mut #idx_ident = None; });
150            new_match_arms.push(quote! { #col_name => #idx_ident = Some(idx), });
151            new_struct_fields.push(quote! { #idx_ident });
152
153            if is_json {
154                // If it's a Field<T>, we need to wrap the deserialized object in Field::Present
155                if is_opt {
156                    map_fields.push(quote! {
157                        #field_ident: match self.#idx_ident {
158                            Some(idx) => {
159                                let raw: Option<serde_json::Value> = row.try_get(idx)?;
160                                match raw {
161                                    Some(val) => optional_field::Field::Present(Some(
162                                        serde_json::from_value(val).unwrap_or_else(|e| panic!("JSON parse error: {}", e))
163                                    )),
164                                    None => optional_field::Field::Present(None),
165                                }
166                            },
167                            None => optional_field::Field::Missing,
168                        }
169                    });
170                } else {
171                    map_fields.push(quote! {
172                        #field_ident: match self.#idx_ident {
173                            Some(idx) => {
174                                let raw: Option<serde_json::Value> = row.try_get(idx)?;
175                                match raw {
176                                    Some(val) => serde_json::from_value(val).unwrap_or_else(|e| panic!("JSON parse error: {}", e)),
177                                    None => panic!("Missing required JSON column data for {}", #col_name),
178                                }
179                            },
180                            None => panic!("Missing required JSON column index for {}", #col_name),
181                        }
182                    });
183                }
184            } else {
185                map_fields.push(quote! {
186                    #field_ident: match self.#idx_ident {
187                        Some(idx) => optional_field::Field::Present(row.try_get(idx)?),
188                        None => optional_field::Field::Missing,
189                    }
190                });
191            }
192        }
193    }
194
195    let map_optional_method = if let (Some(first_idx), Some(first_ty)) =
196        (first_required_idx_ident, first_required_ty)
197    {
198        quote! {
199            pub fn map_optional(&self, row: &tokio_postgres::Row) -> Result<optional_field::Field<#struct_ident>, tokio_postgres::Error> {
200                let first_val: Option<#first_ty> = row.try_get(self.#first_idx)?;
201                if first_val.is_none() {
202                    return Ok(optional_field::Field::Present(None));
203                }
204
205                let mapped = self.map(row)?;
206                Ok(optional_field::Field::Present(Some(mapped)))
207            }
208        }
209    } else {
210        quote! {
211            pub fn map_optional(&self, row: &tokio_postgres::Row) -> Result<optional_field::Field<#struct_ident>, tokio_postgres::Error> {
212                let mapped = self.map(row)?;
213                Ok(optional_field::Field::Present(Some(mapped)))
214            }
215        }
216    };
217
218    let expanded = quote! {
219        pub struct #mapper_ident {
220            #(#mapper_fields),*
221        }
222
223        impl #mapper_ident {
224            pub fn new(columns: &[tokio_postgres::Column]) -> Result<Self, pg_query_mapper::MapperError> {
225                #(#new_vars)*
226
227                for (idx, column) in columns.iter().enumerate() {
228                    match column.name() {
229                        #(#new_match_arms)*
230                        _ => {}
231                    }
232                }
233
234                Ok(Self {
235                    #(#new_struct_fields),*
236                })
237            }
238
239            pub fn map(&self, row: &tokio_postgres::Row) -> Result<#struct_ident, tokio_postgres::Error> {
240                Ok(#struct_ident {
241                    #(#map_fields),*
242                })
243            }
244
245            #map_optional_method
246        }
247    };
248
249    TokenStream::from(expanded)
250}
251
252fn is_field_wrapper(ty: &Type) -> bool {
253    if let Type::Path(type_path) = ty {
254        if let Some(segment) = type_path.path.segments.last() {
255            return segment.ident == "Field";
256        }
257    }
258    false
259}
260
261fn extract_inner_type(ty: &Type) -> Option<&Type> {
262    if let Type::Path(type_path) = ty {
263        if let Some(segment) = type_path.path.segments.last() {
264            if segment.ident == "Field" {
265                if let PathArguments::AngleBracketed(args) = &segment.arguments {
266                    if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
267                        return Some(inner_ty);
268                    }
269                }
270            }
271        }
272    }
273    None
274}
275
276fn get_type_ident(ty: &Type) -> Option<&syn::Ident> {
277    if let Type::Path(type_path) = ty {
278        if let Some(segment) = type_path.path.segments.last() {
279            return Some(&segment.ident);
280        }
281    }
282    None
283}