Skip to main content

resolute_derive/
lib.rs

1//! Derive macros for resolute: `FromRow`, `PgEnum`, `PgComposite`, `PgDomain`.
2
3mod pg_composite;
4mod pg_domain;
5mod pg_enum;
6
7/// Consume the value (if any) of a meta entry whose path the caller does not
8/// recognize. Allows a per-helper `parse_nested_meta` pass to ignore keys that
9/// belong to other helpers without leaving an unparsed `= "..."` in the stream.
10pub(crate) fn consume_unknown_meta_value(meta: &syn::meta::ParseNestedMeta) -> syn::Result<()> {
11    if meta.input.peek(syn::Token![=]) {
12        let _ = meta.value()?.parse::<syn::Expr>()?;
13    }
14    Ok(())
15}
16
17use proc_macro::TokenStream;
18use quote::quote;
19use syn::{parse_macro_input, Data, DeriveInput, Fields, LitStr};
20
21/// Derive `FromRow` for structs with named fields.
22///
23/// # Field attributes
24///
25/// - `#[from_row(rename = "col")]` — use a different column name
26/// - `#[from_row(skip)]` — skip the field, use `Default::default()`
27/// - `#[from_row(default)]` — use `Default::default()` if column is NULL or missing
28/// - `#[from_row(json)]` — deserialize a JSON/JSONB column via serde
29/// - `#[from_row(try_from = "SourceType")]` — decode as SourceType, then `TryFrom` convert
30/// - `#[from_row(flatten)]` — call `FromRow::from_row` on a nested struct
31#[proc_macro_derive(FromRow, attributes(from_row))]
32pub fn derive_from_row(input: TokenStream) -> TokenStream {
33    let input = parse_macro_input!(input as DeriveInput);
34    match derive_from_row_inner(input) {
35        Ok(tokens) => tokens.into(),
36        Err(err) => err.to_compile_error().into(),
37    }
38}
39
40fn derive_from_row_inner(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
41    let name = &input.ident;
42    let generics = &input.generics;
43    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
44
45    let fields = match &input.data {
46        Data::Struct(data) => match &data.fields {
47            Fields::Named(fields) => &fields.named,
48            _ => {
49                return Err(syn::Error::new_spanned(
50                    &input,
51                    "FromRow only supports structs with named fields",
52                ));
53            }
54        },
55        _ => {
56            return Err(syn::Error::new_spanned(
57                &input,
58                "FromRow only supports structs",
59            ));
60        }
61    };
62
63    let field_extractions = fields
64        .iter()
65        .map(|field| {
66            let field_name = field.ident.as_ref().unwrap();
67            let field_type = &field.ty;
68            let attrs = FromRowFieldAttrs::parse(field)?;
69
70            let col_name = attrs.rename.unwrap_or_else(|| field_name.to_string());
71
72            if attrs.skip {
73                return Ok(quote! { #field_name: Default::default() });
74            }
75
76            if attrs.flatten {
77                return Ok(quote! {
78                    #field_name: <#field_type as resolute::FromRow>::from_row(row)?
79                });
80            }
81
82            if let Some(ref source_type) = attrs.try_from {
83                if is_option_type(field_type) {
84                    return Ok(quote! {
85                        #field_name: {
86                            let __opt: Option<#source_type> = row.get_opt_by_name(#col_name)?;
87                            match __opt {
88                                Some(__src) => Some(
89                                    <_ as std::convert::TryFrom<#source_type>>::try_from(__src)
90                                        .map_err(|e| resolute::TypedError::Decode {
91                                            column: 0,
92                                            message: format!("try_from({}): {}", #col_name, e),
93                                        })?
94                                ),
95                                None => None,
96                            }
97                        }
98                    });
99                } else {
100                    return Ok(quote! {
101                        #field_name: {
102                            let __src: #source_type = row.get_by_name(#col_name)?;
103                            <#field_type as std::convert::TryFrom<#source_type>>::try_from(__src)
104                                .map_err(|e| resolute::TypedError::Decode {
105                                    column: 0,
106                                    message: format!("try_from({}): {}", #col_name, e),
107                                })?
108                        }
109                    });
110                }
111            }
112
113            if attrs.json {
114                if is_option_type(field_type) {
115                    return Ok(quote! {
116                        #field_name: {
117                            let __opt: Option<serde_json::Value> = row.get_opt_by_name(#col_name)?;
118                            match __opt {
119                                Some(__v) => Some(
120                                    serde_json::from_value(__v).map_err(|e| resolute::TypedError::Decode {
121                                        column: 0,
122                                        message: format!("json({}): {}", #col_name, e),
123                                    })?
124                                ),
125                                None => None,
126                            }
127                        }
128                    });
129                } else {
130                    return Ok(quote! {
131                        #field_name: {
132                            let __v: serde_json::Value = row.get_by_name(#col_name)?;
133                            serde_json::from_value(__v).map_err(|e| resolute::TypedError::Decode {
134                                column: 0,
135                                message: format!("json({}): {}", #col_name, e),
136                            })?
137                        }
138                    });
139                }
140            }
141
142            if attrs.default {
143                if is_option_type(field_type) {
144                    return Ok(quote! {
145                        #field_name: if row.has_column(#col_name) {
146                            row.get_opt_by_name(#col_name)?
147                        } else {
148                            None
149                        }
150                    });
151                } else {
152                    return Ok(quote! {
153                        #field_name: if row.has_column(#col_name) {
154                            match row.get_by_name(#col_name) {
155                                Ok(v) => v,
156                                Err(resolute::TypedError::UnexpectedNull(_)) => Default::default(),
157                                Err(e) => return Err(e),
158                            }
159                        } else {
160                            Default::default()
161                        }
162                    });
163                }
164            }
165
166            // Normal field, no special attributes.
167            if is_option_type(field_type) {
168                Ok(quote! { #field_name: row.get_opt_by_name(#col_name)? })
169            } else {
170                Ok(quote! { #field_name: row.get_by_name(#col_name)? })
171            }
172        })
173        .collect::<syn::Result<Vec<_>>>()?;
174
175    Ok(quote! {
176        impl #impl_generics resolute::FromRow for #name #ty_generics #where_clause {
177            fn from_row(row: &resolute::Row) -> Result<Self, resolute::TypedError> {
178                Ok(Self {
179                    #(#field_extractions,)*
180                })
181            }
182        }
183    })
184}
185
186// ---------------------------------------------------------------------------
187// FromRow attribute parsing
188// ---------------------------------------------------------------------------
189
190/// Parsed attributes for a single field in `#[derive(FromRow)]`.
191struct FromRowFieldAttrs {
192    rename: Option<String>,
193    skip: bool,
194    default: bool,
195    json: bool,
196    try_from: Option<syn::Type>,
197    flatten: bool,
198}
199
200impl FromRowFieldAttrs {
201    fn parse(field: &syn::Field) -> syn::Result<Self> {
202        let mut attrs = Self {
203            rename: None,
204            skip: false,
205            default: false,
206            json: false,
207            try_from: None,
208            flatten: false,
209        };
210
211        for attr in &field.attrs {
212            if !attr.path().is_ident("from_row") {
213                continue;
214            }
215            attr.parse_nested_meta(|meta| {
216                if meta.path.is_ident("rename") {
217                    let value = meta.value()?;
218                    let s: LitStr = value.parse()?;
219                    attrs.rename = Some(s.value());
220                } else if meta.path.is_ident("skip") {
221                    attrs.skip = true;
222                } else if meta.path.is_ident("default") {
223                    attrs.default = true;
224                } else if meta.path.is_ident("json") {
225                    attrs.json = true;
226                } else if meta.path.is_ident("try_from") {
227                    let value = meta.value()?;
228                    let s: LitStr = value.parse()?;
229                    let ty: syn::Type = syn::parse_str(&s.value()).map_err(|e| {
230                        syn::Error::new(
231                            s.span(),
232                            format!("from_row(try_from = \"...\") must be a valid Rust type: {e}"),
233                        )
234                    })?;
235                    attrs.try_from = Some(ty);
236                } else if meta.path.is_ident("flatten") {
237                    attrs.flatten = true;
238                } else {
239                    return Err(meta.error("unknown from_row attribute"));
240                }
241                Ok(())
242            })?;
243        }
244
245        // Validate incompatible combinations.
246        if attrs.skip
247            && (attrs.rename.is_some()
248                || attrs.default
249                || attrs.json
250                || attrs.try_from.is_some()
251                || attrs.flatten)
252        {
253            return Err(syn::Error::new_spanned(
254                field,
255                "from_row(skip) cannot be combined with other attributes",
256            ));
257        }
258        if attrs.flatten && (attrs.rename.is_some() || attrs.json || attrs.try_from.is_some()) {
259            return Err(syn::Error::new_spanned(
260                field,
261                "from_row(flatten) cannot be combined with rename, json, or try_from",
262            ));
263        }
264        if attrs.json && attrs.try_from.is_some() {
265            return Err(syn::Error::new_spanned(
266                field,
267                "from_row(json) cannot be combined with try_from",
268            ));
269        }
270
271        Ok(attrs)
272    }
273}
274
275/// Check if a type is `Option<T>`.
276fn is_option_type(ty: &syn::Type) -> bool {
277    if let syn::Type::Path(type_path) = ty {
278        if let Some(seg) = type_path.path.segments.last() {
279            return seg.ident == "Option";
280        }
281    }
282    false
283}
284
285/// Derive `Encode`, `Decode`, `DecodeText`, and `PgType` for a Rust enum
286/// representing a PostgreSQL enum type.
287#[proc_macro_derive(PgEnum, attributes(pg_type))]
288pub fn derive_pg_enum(input: TokenStream) -> TokenStream {
289    let input = parse_macro_input!(input as DeriveInput);
290    pg_enum::derive(input)
291}
292
293/// Derive `Encode`, `Decode`, `DecodeText`, and `PgType` for a Rust struct
294/// representing a PostgreSQL composite type.
295#[proc_macro_derive(PgComposite, attributes(pg_type))]
296pub fn derive_pg_composite(input: TokenStream) -> TokenStream {
297    let input = parse_macro_input!(input as DeriveInput);
298    pg_composite::derive(input)
299}
300
301/// Derive `Encode`, `Decode`, `DecodeText`, and `PgType` for a newtype struct
302/// representing a PostgreSQL domain type.
303#[proc_macro_derive(PgDomain, attributes(pg_type))]
304pub fn derive_pg_domain(input: TokenStream) -> TokenStream {
305    let input = parse_macro_input!(input as DeriveInput);
306    pg_domain::derive(input)
307}
308
309/// Attribute macro for database-backed tests.
310///
311/// Creates a temporary database, optionally runs migrations, provides a
312/// `Client` argument, and drops the database after the test completes.
313///
314/// ```ignore
315/// #[resolute::test]
316/// async fn my_test(client: resolute::Client) {
317///     client.simple_query("CREATE TABLE t (id int)").await.unwrap();
318///     client.execute("INSERT INTO t VALUES ($1)", &[&1i32]).await.unwrap();
319/// }
320///
321/// #[resolute::test(migrations = "migrations")]
322/// async fn with_migrations(client: resolute::Client) {
323///     // migrations have already been applied
324/// }
325/// ```
326#[proc_macro_attribute]
327pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
328    let input_fn = parse_macro_input!(item as syn::ItemFn);
329
330    let mut migrations: Option<String> = None;
331    let attr_parser = syn::meta::parser(|meta| {
332        if meta.path.is_ident("migrations") {
333            let value = meta.value()?;
334            let s: LitStr = value.parse()?;
335            migrations = Some(s.value());
336            Ok(())
337        } else {
338            Err(meta.error("unknown resolute::test attribute"))
339        }
340    });
341    parse_macro_input!(attr with attr_parser);
342
343    let fn_name = &input_fn.sig.ident;
344    let fn_block = &input_fn.block;
345    let fn_vis = &input_fn.vis;
346    let fn_attrs = &input_fn.attrs;
347
348    let create_db = if let Some(mig_path) = &migrations {
349        quote! {
350            let __test_db = resolute::test_db::TestDb::create_with_migrations(
351                &__addr, &__user, &__pass, #mig_path,
352            ).await.expect("failed to create test database");
353        }
354    } else {
355        quote! {
356            let __test_db = resolute::test_db::TestDb::create(
357                &__addr, &__user, &__pass,
358            ).await.expect("failed to create test database");
359        }
360    };
361
362    let expanded = quote! {
363        #(#fn_attrs)*
364        #[tokio::test]
365        #fn_vis async fn #fn_name() {
366            // Read RESOLUTE_TEST_{ADDR,USER,PASSWORD} via the test-db helper
367            // so the macro matches `TestDb::create` and the documented env
368            // var names. Defaults: 127.0.0.1:54322 / postgres / postgres.
369            let __addr = resolute::test_db::test_addr().to_string();
370            let __user = resolute::test_db::test_user().to_string();
371            let __pass = resolute::test_db::test_password().to_string();
372
373            #create_db
374
375            let client = __test_db.client().await.expect("failed to connect to test database");
376
377            // Run the user's test body.
378            let __result = async { #fn_block }.await;
379
380            // Cleanup: drop the test database.
381            drop(client);
382            let _ = __test_db.drop_db().await;
383        }
384    };
385
386    TokenStream::from(expanded)
387}