Skip to main content

hyperdb_api_derive/
lib.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Procedural macros for `hyperdb-api`.
5//!
6//! Currently exposes `#[derive(FromRow)]`, which generates an
7//! [`hyperdb_api::FromRow`] impl for a struct by mapping each field
8//! to a column with the matching name.
9//!
10//! Re-exported by `hyperdb-api` so callers don't need to add this
11//! crate as a direct dependency. Use it as `use hyperdb_api::FromRow;`
12//! and `#[derive(FromRow)]` on a struct.
13//!
14//! # Example
15//!
16//! ```ignore
17//! use hyperdb_api::FromRow;
18//!
19//! #[derive(FromRow)]
20//! struct User {
21//!     id: i32,
22//!     name: String,
23//!     // Map to a different column name with `rename`:
24//!     #[hyperdb(rename = "email_address")]
25//!     email: Option<String>,
26//! }
27//! ```
28//!
29//! # Attributes
30//!
31//! - `#[hyperdb(rename = "...")]` on a field uses the given column
32//!   name instead of the field name.
33//! - `#[hyperdb(index = N)]` on a field uses positional access
34//!   ([`RowAccessor::position`] / [`RowAccessor::position_opt`]) at
35//!   column index `N` instead of name-based lookup. Mutually exclusive
36//!   with `rename`.
37//! - Field types of `Option<T>` use [`RowAccessor::get_opt`] /
38//!   [`RowAccessor::position_opt`] (NULL → `None`); other field types
39//!   use [`RowAccessor::get`] / [`RowAccessor::position`] (NULL →
40//!   error).
41//!
42//! [`hyperdb_api::FromRow`]: https://docs.rs/hyperdb-api
43//! [`RowAccessor::get_opt`]: https://docs.rs/hyperdb-api
44//! [`RowAccessor::get`]: https://docs.rs/hyperdb-api
45//! [`RowAccessor::position`]: https://docs.rs/hyperdb-api
46//! [`RowAccessor::position_opt`]: https://docs.rs/hyperdb-api
47
48mod table_derive;
49
50use proc_macro::TokenStream;
51use proc_macro2::TokenStream as TokenStream2;
52use quote::quote;
53use syn::{
54    parse_macro_input, spanned::Spanned, Data, DataStruct, DeriveInput, Field, Fields,
55    GenericArgument, LitInt, LitStr, PathArguments, Type, TypePath,
56};
57
58/// How a field maps to a column. Either by name (the default or
59/// `#[hyperdb(rename = "...")]`) or by ordinal position
60/// (`#[hyperdb(index = N)]`).
61enum FieldSource {
62    Name(String),
63    Index(usize),
64}
65
66/// Derives `hyperdb_api::Table` for a struct.
67///
68/// Generates `impl Table` with `NAME` and `CREATE_SQL` consts. When the
69/// `compile-time` cargo feature is enabled and `#[hyperdb(register)]` is
70/// present, also registers the table with the compile-time validator.
71///
72/// # Attributes (struct level)
73///
74/// - `#[hyperdb(table = "name")]` — override the SQL table name (default:
75///   lower_snake_case of the struct ident).
76/// - `#[hyperdb(register)]` — register for compile-time `query_as!` validation.
77///
78/// # Attributes (field level)
79///
80/// - `#[hyperdb(primary_key)]` — marks the column as NOT NULL (always true
81///   for non-`Option` fields, but documents intent).
82/// - `#[hyperdb(rename = "col")]` — use a different SQL column name.
83#[proc_macro_derive(Table, attributes(hyperdb))]
84pub fn table_derive(input: TokenStream) -> TokenStream {
85    let input = parse_macro_input!(input as DeriveInput);
86    match table_derive::expand(&input) {
87        Ok(ts) => ts.into(),
88        Err(e) => e.to_compile_error().into(),
89    }
90}
91
92/// Compile-time validated typed query macro.
93///
94/// Syntax: `query_as!(Type, "SQL")` or `query_as!(Type, "SQL", arg1, arg2, …)`
95///
96/// Returns a [`hyperdb_api::QueryAs<Type>`] builder. `Type` must implement
97/// [`hyperdb_api::FromRow`] and must be registered via
98/// `#[derive(Table)] #[hyperdb(register)]`.
99///
100/// With the `compile-time` cargo feature enabled, validates at build time that
101/// the SQL is syntactically valid, all referenced tables are registered, and
102/// all struct fields appear in the projected columns.
103///
104/// # Module ordering constraint (`compile-time` feature)
105///
106/// Registration happens at proc-macro expansion time in the proc-macro host
107/// process. Rust expands macros in the order modules are declared in `mod`
108/// statements (top-to-bottom in `lib.rs`/`main.rs`). If `derive(Table)` and
109/// `query_as!` are in different modules, the module containing `derive(Table)`
110/// structs **must be declared (via `mod`) before** the module containing
111/// `query_as!` calls, otherwise a false `StructNotRegistered` compile error
112/// is emitted.
113///
114/// Within a single file, struct-level derives always expand before
115/// function-body macros, so ordering within a file is not a concern.
116#[proc_macro]
117pub fn query_as(input: TokenStream) -> TokenStream {
118    match expand_query_as(&input.into()) {
119        Ok(ts) => ts.into(),
120        Err(e) => e.to_compile_error().into(),
121    }
122}
123
124fn expand_query_as(input: &TokenStream2) -> syn::Result<TokenStream2> {
125    use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
126
127    // Parse: Type, "sql_literal" [, expr, expr, ...]
128    let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
129    let args = parser.parse2(input.clone())?;
130    let mut iter = args.iter();
131
132    let ty_expr = iter.next().ok_or_else(|| {
133        syn::Error::new_spanned(
134            input,
135            "query_as! expects at least two arguments: query_as!(Type, \"SQL\")",
136        )
137    })?;
138
139    // Re-parse the first token as a type (not an expression).
140    let ty: Type = syn::parse2(quote!(#ty_expr))?;
141
142    let sql_expr = iter.next().ok_or_else(|| {
143        syn::Error::new_spanned(
144            ty_expr,
145            "query_as! expects a SQL string literal as the second argument",
146        )
147    })?;
148
149    // Remaining args are the bind parameters.
150    let rest: Vec<&Expr> = iter.collect();
151
152    // Compile-time validation: runs inside the proc-macro host at expansion time.
153    // The `compile-time` feature gates this — without it the macro is a
154    // pure pass-through with zero overhead. The variables are extracted inside
155    // the cfg block to avoid unused-variable warnings in the feature-off build.
156    #[cfg(feature = "compile-time")]
157    {
158        let struct_name = last_type_ident(&ty).map(ToString::to_string);
159        let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
160        if let (Some(struct_name), Some(sql_lit)) = (struct_name, sql_lit) {
161            let sql_str = sql_lit.value();
162            if let Err(e) = hyperdb_compile_check::validate_query_as(&struct_name, &sql_str) {
163                let msg = e.to_diagnostic();
164                return Ok(quote! {
165                    ::std::compile_error!(#msg)
166                });
167            }
168        }
169    }
170
171    Ok(quote! {
172        ::hyperdb_api::QueryAs::<#ty>::new(#sql_expr, &[#(&#rest),*])
173    })
174}
175
176/// Extract the last path segment ident from a type path (e.g. `User` from `crate::User`).
177/// Only needed when `compile-time` feature is enabled (used for registry lookup).
178#[cfg(feature = "compile-time")]
179fn last_type_ident(ty: &Type) -> Option<&syn::Ident> {
180    let Type::Path(syn::TypePath { path, qself: None }) = ty else {
181        return None;
182    };
183    path.segments.last().map(|s| &s.ident)
184}
185
186/// Validated single-column query macro.
187///
188/// Syntax: `query_scalar!(Type, "SQL")` or `query_scalar!(Type, "SQL", arg1, …)`
189///
190/// Returns a [`hyperdb_api::QueryScalar<Type>`] builder. `Type` must implement
191/// [`hyperdb_api::RowValue`]. No `derive(Table)` is required — scalars project
192/// a single column and don't map to a struct.
193///
194/// With the `compile-time` feature enabled, validates at build time that the
195/// SQL returns exactly one column.
196#[proc_macro]
197pub fn query_scalar(input: TokenStream) -> TokenStream {
198    match expand_query_scalar(&input.into()) {
199        Ok(ts) => ts.into(),
200        Err(e) => e.to_compile_error().into(),
201    }
202}
203
204fn expand_query_scalar(input: &TokenStream2) -> syn::Result<TokenStream2> {
205    use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
206
207    let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
208    let args = parser.parse2(input.clone())?;
209    let mut iter = args.iter();
210
211    let ty_expr = iter.next().ok_or_else(|| {
212        syn::Error::new_spanned(
213            input,
214            "query_scalar! expects at least two arguments: query_scalar!(Type, \"SQL\")",
215        )
216    })?;
217
218    let ty: Type = syn::parse2(quote!(#ty_expr))?;
219
220    let sql_expr = iter.next().ok_or_else(|| {
221        syn::Error::new_spanned(
222            ty_expr,
223            "query_scalar! expects a SQL string literal as the second argument",
224        )
225    })?;
226
227    let rest: Vec<&Expr> = iter.collect();
228
229    // Compile-time validation: verify the SQL returns exactly one column.
230    #[cfg(feature = "compile-time")]
231    {
232        let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
233        if let Some(sql_lit) = sql_lit {
234            let sql_str = sql_lit.value();
235            // Validate SQL structure (syntax + table existence) using a dummy
236            // struct name that won't be in the registry — we only care about
237            // one-column check, not struct-field matching.
238            match hyperdb_compile_check::validate_scalar_sql(&sql_str) {
239                Ok(()) => {}
240                Err(e) => {
241                    let msg = e.to_diagnostic();
242                    return Ok(quote! { ::std::compile_error!(#msg) });
243                }
244            }
245        }
246    }
247
248    Ok(quote! {
249        ::hyperdb_api::QueryScalar::<#ty>::new(#sql_expr, &[#(&#rest),*])
250    })
251}
252
253/// Derives `hyperdb_api::FromRow` for a struct.
254///
255/// See the crate-level documentation for the full feature list.
256#[proc_macro_derive(FromRow, attributes(hyperdb))]
257pub fn from_row_derive(input: TokenStream) -> TokenStream {
258    let input = parse_macro_input!(input as DeriveInput);
259    match expand(&input) {
260        Ok(ts) => ts.into(),
261        Err(e) => e.to_compile_error().into(),
262    }
263}
264
265fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
266    let name = &input.ident;
267    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
268
269    let fields = match &input.data {
270        Data::Struct(DataStruct {
271            fields: Fields::Named(named),
272            ..
273        }) => &named.named,
274        Data::Struct(_) => {
275            return Err(syn::Error::new_spanned(
276                &input.ident,
277                "FromRow can only be derived on structs with named fields",
278            ));
279        }
280        Data::Enum(_) => {
281            return Err(syn::Error::new_spanned(
282                &input.ident,
283                "FromRow cannot be derived on enums",
284            ));
285        }
286        Data::Union(_) => {
287            return Err(syn::Error::new_spanned(
288                &input.ident,
289                "FromRow cannot be derived on unions",
290            ));
291        }
292    };
293
294    let assignments = fields
295        .iter()
296        .map(field_assignment)
297        .collect::<syn::Result<Vec<_>>>()?;
298
299    Ok(quote! {
300        #[automatically_derived]
301        impl #impl_generics ::hyperdb_api::FromRow for #name #ty_generics #where_clause {
302            fn from_row(
303                row: ::hyperdb_api::RowAccessor<'_>,
304            ) -> ::hyperdb_api::Result<Self> {
305                Ok(Self {
306                    #(#assignments),*
307                })
308            }
309        }
310    })
311}
312
313/// Generates `field_name: row.get("col")?` (or `get_opt`/`position`/`position_opt`
314/// for `Option<T>` fields and/or `#[hyperdb(index = N)]`).
315fn field_assignment(field: &Field) -> syn::Result<TokenStream2> {
316    let ident = field
317        .ident
318        .as_ref()
319        .ok_or_else(|| syn::Error::new_spanned(field, "tuple-struct fields are not supported"))?;
320    let source = field_source_for(field, ident)?;
321    let is_opt = is_option_type(&field.ty);
322
323    let getter = match (source, is_opt) {
324        (FieldSource::Name(name), true) => {
325            let lit = LitStr::new(&name, ident.span());
326            quote!(row.get_opt(#lit)?)
327        }
328        (FieldSource::Name(name), false) => {
329            let lit = LitStr::new(&name, ident.span());
330            quote!(row.get(#lit)?)
331        }
332        (FieldSource::Index(idx), true) => quote!(row.position_opt(#idx)?),
333        (FieldSource::Index(idx), false) => quote!(row.position(#idx)?),
334    };
335
336    Ok(quote! { #ident: #getter })
337}
338
339/// Reads `#[hyperdb(rename = "...")]` or `#[hyperdb(index = N)]` from a field's
340/// attributes. Falls back to a name-based source using the field's identifier.
341/// `rename` and `index` are mutually exclusive.
342fn field_source_for(field: &Field, default: &syn::Ident) -> syn::Result<FieldSource> {
343    let mut rename: Option<(String, proc_macro2::Span)> = None;
344    let mut index: Option<(usize, proc_macro2::Span)> = None;
345
346    for attr in &field.attrs {
347        if !attr.path().is_ident("hyperdb") {
348            continue;
349        }
350        attr.parse_nested_meta(|meta| {
351            if meta.path.is_ident("rename") {
352                let s: LitStr = meta.value()?.parse()?;
353                rename = Some((s.value(), meta.path.span()));
354                Ok(())
355            } else if meta.path.is_ident("index") {
356                let n: LitInt = meta.value()?.parse()?;
357                let parsed: usize = n.base10_parse()?;
358                index = Some((parsed, meta.path.span()));
359                Ok(())
360            } else if meta.path.is_ident("primary_key") {
361                // Table-derive attribute; silently ignored by FromRow.
362                Ok(())
363            } else {
364                Err(meta.error(format!(
365                    "unrecognized hyperdb attribute `{}`; supported attributes: rename, index",
366                    meta.path
367                        .get_ident()
368                        .map_or_else(|| "?".to_string(), ToString::to_string)
369                )))
370            }
371        })?;
372    }
373
374    match (rename, index) {
375        (Some(_), Some((_, idx_span))) => Err(syn::Error::new(
376            idx_span,
377            "`#[hyperdb(rename = ...)]` and `#[hyperdb(index = N)]` are mutually exclusive",
378        )),
379        (Some((name, _)), None) => Ok(FieldSource::Name(name)),
380        (None, Some((idx, _))) => Ok(FieldSource::Index(idx)),
381        (None, None) => Ok(FieldSource::Name(default.to_string())),
382    }
383}
384
385/// Detects `Option<T>` (any path ending in `Option<T>`).
386fn is_option_type(ty: &Type) -> bool {
387    let Type::Path(TypePath { path, qself: None }) = ty else {
388        return false;
389    };
390    let Some(last) = path.segments.last() else {
391        return false;
392    };
393    if last.ident != "Option" {
394        return false;
395    }
396    matches!(
397        last.arguments,
398        PathArguments::AngleBracketed(ref args)
399            if matches!(args.args.first(), Some(GenericArgument::Type(_)))
400    )
401}