Skip to main content

bsql_macros/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Proc macros for bsql.
4//!
5//! This crate is an implementation detail. Use [`bsql`] instead.
6
7extern crate proc_macro;
8
9mod codegen;
10#[cfg(feature = "sqlite")]
11mod codegen_sqlite;
12mod connection;
13mod dynamic;
14mod offline;
15mod parse;
16mod pg_enum;
17mod sort_enum;
18mod sql_norm;
19mod stmt_name;
20mod suggest;
21pub(crate) mod types;
22#[cfg(feature = "sqlite")]
23mod types_sqlite;
24mod validate;
25#[cfg(feature = "sqlite")]
26mod validate_sqlite;
27
28use proc_macro::TokenStream;
29
30/// Validate a SQL query against PostgreSQL at compile time and generate
31/// typed Rust code for executing it.
32///
33/// # Syntax
34///
35/// ```text
36/// bsql::query! {
37///     SELECT column1, column2
38///     FROM table
39///     WHERE column1 = $param_name: RustType
40/// }
41/// ```
42///
43/// Parameters are declared inline as `$name: Type`. The macro replaces them
44/// with positional `$1`, `$2`, ... and verifies type compatibility against
45/// the database schema.
46///
47/// # Execution methods
48///
49/// The macro returns an executor with these methods:
50/// - `.fetch_one(executor)` — returns exactly one row (errors on 0 or 2+)
51/// - `.fetch_all(executor)` — returns all rows as `Vec<T>`
52/// - `.fetch_optional(executor)` — returns `Option<T>` (errors on 2+)
53/// - `.execute(executor)` — returns affected row count (`u64`)
54///
55/// # Compile-time guarantees
56///
57/// - Table and column names are verified against the live database
58/// - Parameter types are checked against PostgreSQL's expected types
59/// - Nullable columns are automatically mapped to `Option<T>`
60/// - Invalid SQL produces a compile error, not a runtime error
61#[proc_macro]
62pub fn query(input: TokenStream) -> TokenStream {
63    let input2: proc_macro2::TokenStream = input.into();
64    match query_impl(input2) {
65        Ok(output) => output.into(),
66        Err(err) => err.to_compile_error().into(),
67    }
68}
69
70fn query_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
71    // Extract the SQL string from the input.
72    // Accepts either a string literal: query!("SELECT ...")
73    // or raw tokens: query! { SELECT ... } converted to string.
74    let sql = extract_sql(input)?;
75
76    // 1. Parse: extract params, query kind, normalize SQL, optional clauses, sort placeholder
77    let parsed = parse::parse_query(&sql)
78        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
79
80    // Detect backend from database URL (if not offline)
81    #[cfg(feature = "sqlite")]
82    {
83        let backend = connection::detect_backend()
84            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
85        if backend == Some(connection::Backend::Sqlite) {
86            return query_impl_sqlite(parsed);
87        }
88    }
89
90    // PostgreSQL path (default)
91    query_impl_postgres(parsed)
92}
93
94/// PostgreSQL query implementation (the original path).
95fn query_impl_postgres(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
96    // 2. Sort query path — $[sort: EnumType] present
97    if parsed.sort_placeholder.is_some() {
98        return query_impl_sort(parsed);
99    }
100
101    // 3. Expand dynamic query variants (if any optional clauses)
102    let variants = dynamic::expand_variants(&parsed)
103        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
104
105    if parsed.optional_clauses.is_empty() {
106        // Static query path — no optional clauses
107        let validation = if offline::is_offline() {
108            // OFFLINE: read cached validation result
109            offline::lookup_cached_validation(&parsed)
110                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
111        } else {
112            // ONLINE: validate against PostgreSQL via PREPARE with suggestions
113            let result = connection::with_connection(|conn| {
114                validate::validate_query_with_suggestions(&parsed, conn)
115            })?;
116
117            // Write to offline cache for future use
118            offline::write_cache(&parsed, &result);
119
120            result
121        };
122
123        // Check parameter type compatibility
124        validate::check_param_types(&parsed, &validation)
125            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
126
127        // Generate Rust code
128        Ok(codegen::generate_query_code(&parsed, &validation))
129    } else {
130        // Dynamic query path — has optional clauses
131        let validation = if offline::is_offline() {
132            // OFFLINE: read cached validation result for the base variant.
133            //
134            // The cache stores variant 0's param_pg_oids, which only covers
135            // the base params (not optional clause params). Param type
136            // checking is skipped here because:
137            //  1. The online build already validated ALL variants' param types.
138            //  2. The cached columns are identical across all variants (the
139            //     SELECT list never changes, only WHERE clauses differ).
140            //  3. Codegen only needs the column info, not per-variant param OIDs.
141            offline::lookup_cached_validation(&parsed)
142                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
143        } else {
144            // ONLINE: validate ALL variants against PostgreSQL and check param types
145            let result = connection::with_connection(|conn| {
146                validate::validate_variants(&variants, &parsed, conn)
147            })?;
148
149            // Write to offline cache for future use
150            offline::write_cache(&parsed, &result);
151
152            result
153        };
154
155        // Generate dynamic Rust code with match dispatcher
156        Ok(codegen::generate_dynamic_query_code(
157            &parsed,
158            &validation,
159            &variants,
160        ))
161    }
162}
163
164/// SQLite query implementation.
165///
166/// Validates against a live SQLite database at compile time, then generates
167/// code that executes via `bsql_core::SqlitePool`.
168#[cfg(feature = "sqlite")]
169fn query_impl_sqlite(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
170    // Sort queries: $[sort: EnumType] present
171    if parsed.sort_placeholder.is_some() {
172        return query_impl_sqlite_sort(parsed);
173    }
174
175    // Expand dynamic query variants (if any optional clauses)
176    let variants = dynamic::expand_variants(&parsed)
177        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
178
179    if parsed.optional_clauses.is_empty() {
180        // Static query path — no optional clauses
181        let validation = if offline::is_offline() {
182            offline::lookup_cached_validation(&parsed)
183                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
184        } else {
185            let result = connection::with_sqlite_connection(|conn| {
186                validate_sqlite::validate_query_sqlite(&parsed, conn)
187            })?;
188
189            // Write to offline cache for future use
190            offline::write_cache(&parsed, &result);
191
192            result
193        };
194
195        // SQLite doesn't type parameters at prepare time, so we skip
196        // the PG-style param type check. Parameter types are verified
197        // at runtime by the SqliteEncode trait.
198
199        Ok(codegen_sqlite::generate_sqlite_query_code(
200            &parsed,
201            &validation,
202        ))
203    } else {
204        // Dynamic query path — has optional clauses
205        let validation = if offline::is_offline() {
206            offline::lookup_cached_validation(&parsed)
207                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
208        } else {
209            // Validate ALL variants against SQLite
210            let result = connection::with_sqlite_connection(|conn| {
211                validate_sqlite::validate_variants_sqlite(&variants, &parsed, conn)
212            })?;
213
214            offline::write_cache(&parsed, &result);
215
216            result
217        };
218
219        Ok(codegen_sqlite::generate_dynamic_sqlite_query_code(
220            &parsed,
221            &validation,
222            &variants,
223        ))
224    }
225}
226
227/// SQLite sort query implementation.
228#[cfg(feature = "sqlite")]
229fn query_impl_sqlite_sort(
230    parsed: parse::ParsedQuery,
231) -> Result<proc_macro2::TokenStream, syn::Error> {
232    let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
233    let sort_enum_name = &sort_placeholder.enum_name;
234
235    // Replace {SORT} with "1" to validate the query shape
236    let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
237
238    let dummy_parsed = parse::ParsedQuery {
239        normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
240        positional_sql: dummy_sql,
241        params: parsed.params.clone(),
242        kind: parsed.kind,
243        statement_name: parsed.statement_name.clone(),
244        optional_clauses: parsed.optional_clauses.clone(),
245        sort_placeholder: None,
246    };
247
248    let validation = if offline::is_offline() {
249        offline::lookup_cached_validation(&parsed)
250            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
251    } else {
252        let result = connection::with_sqlite_connection(|conn| {
253            validate_sqlite::validate_query_sqlite(&dummy_parsed, conn)
254        })?;
255
256        offline::write_cache(&parsed, &result);
257        result
258    };
259
260    Ok(codegen_sqlite::generate_sort_sqlite_query_code(
261        &parsed,
262        &validation,
263        sort_enum_name,
264    ))
265}
266
267/// Handle sort queries — queries with `$[sort: EnumType]`.
268///
269/// The sort enum is NOT resolved at macro expansion time (we don't have access
270/// to the enum definition from within the proc macro). Instead, we generate code
271/// that takes the sort enum as a parameter and uses `match` to select the SQL.
272///
273/// Validation: we validate each sort variant's expanded SQL at compile time
274/// by reading sort variant info. However, since the sort enum is defined via
275/// `#[bsql::sort]` in user code, we cannot read its variants from within
276/// the `query!` macro. Instead, the generated code uses the enum's `sql()`
277/// method at runtime. Validation of individual sort fragments happens when
278/// the user compiles — the sort enum's SQL fragments are checked by the user
279/// running their tests or by a separate validation step.
280///
281/// For now: generate code that takes a `sort` parameter with a `sql() -> &str`
282/// method, and splices the SQL at runtime via string replacement + pre-hashed
283/// dispatch.
284fn query_impl_sort(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
285    let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
286    let sort_enum_name = &sort_placeholder.enum_name;
287
288    // We can't validate sort variants at proc-macro time because we don't have
289    // the enum definition. Instead, generate code that does runtime SQL dispatch.
290    // The `{SORT}` in positional_sql will be a sentinel that codegen handles.
291
292    // For validation, we need at least the base query structure. Use a dummy
293    // ORDER BY to validate the query shape (columns, params) — replace {SORT}
294    // with "1" (which is always valid in ORDER BY).
295    let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
296
297    // Create a temporary ParsedQuery with the dummy SQL for validation
298    let dummy_parsed = parse::ParsedQuery {
299        normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
300        positional_sql: dummy_sql,
301        params: parsed.params.clone(),
302        kind: parsed.kind,
303        statement_name: parsed.statement_name.clone(),
304        optional_clauses: parsed.optional_clauses.clone(),
305        sort_placeholder: None,
306    };
307
308    let validation = if offline::is_offline() {
309        offline::lookup_cached_validation(&parsed)
310            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
311    } else {
312        let result = connection::with_connection(|conn| {
313            validate::validate_query_with_suggestions(&dummy_parsed, conn)
314        })?;
315
316        offline::write_cache(&parsed, &result);
317        result
318    };
319
320    validate::check_param_types(&parsed, &validation)
321        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
322
323    // Generate sort-aware code
324    Ok(codegen::generate_sort_query_code(
325        &parsed,
326        &validation,
327        sort_enum_name,
328    ))
329}
330
331/// Extract the SQL text from the macro input.
332///
333/// Accepts a string literal: `query!("SELECT ...")`
334fn extract_sql(input: proc_macro2::TokenStream) -> Result<String, syn::Error> {
335    let lit: syn::LitStr = syn::parse2(input)?;
336    Ok(lit.value())
337}
338
339/// Derive PostgreSQL enum <-> Rust enum mapping with `FromSql` and `ToSql`.
340///
341/// # Usage
342///
343/// ```rust,ignore
344/// #[bsql::pg_enum]
345/// pub enum TicketStatus {
346///     #[sql("new")]
347///     New,
348///     #[sql("in_progress")]
349///     InProgress,
350///     #[sql("resolved")]
351///     Resolved,
352///     #[sql("closed")]
353///     Closed,
354/// }
355/// ```
356///
357/// Each variant must have a `#[sql("label")]` attribute mapping it to the
358/// exact PostgreSQL enum label. The macro generates:
359/// - `FromSql` — deserializes from PostgreSQL text representation
360/// - `ToSql` — serializes to PostgreSQL text representation
361/// - `Display` — formats as the SQL label
362/// - Derives: `Debug, Clone, Copy, PartialEq, Eq, Hash`
363///
364/// If PostgreSQL sends a variant not present in the Rust enum, `FromSql`
365/// returns an error describing the schema mismatch.
366#[proc_macro_attribute]
367pub fn pg_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
368    let attr2: proc_macro2::TokenStream = attr.into();
369    let item2: proc_macro2::TokenStream = item.into();
370    match pg_enum::expand_pg_enum(attr2, item2) {
371        Ok(output) => output.into(),
372        Err(err) => err.to_compile_error().into(),
373    }
374}
375
376/// Define a sort enum for compile-time verified dynamic `ORDER BY` clauses.
377///
378/// # Usage
379///
380/// ```rust,ignore
381/// #[bsql::sort]
382/// pub enum TicketSort {
383///     #[sql("t.updated_at DESC, t.id DESC")]
384///     UpdatedAt,
385///     #[sql("t.deadline ASC NULLS LAST, t.id ASC")]
386///     Deadline,
387///     #[sql("t.id DESC")]
388///     Id,
389/// }
390/// ```
391///
392/// Use with the `$[sort: EnumType]` placeholder in `bsql::query!`:
393///
394/// ```rust,ignore
395/// let tickets = bsql::query!(
396///     "SELECT id, title FROM tickets ORDER BY $[sort: TicketSort] LIMIT $limit: i64"
397/// ).fetch_all(&pool)?;
398/// ```
399///
400/// Each variant must have a `#[sql("...")]` attribute mapping it to the
401/// SQL `ORDER BY` fragment. The macro generates:
402/// - The enum with `Debug, Clone, Copy, PartialEq, Eq, Hash`
403/// - A `sql(&self) -> &'static str` method returning the SQL fragment
404/// - `Display` — formats as the SQL fragment
405///
406/// Unlike `#[bsql::pg_enum]`, sort enums are NOT parameterized values.
407/// The SQL fragment is spliced directly into the query string.
408#[proc_macro_attribute]
409pub fn sort(attr: TokenStream, item: TokenStream) -> TokenStream {
410    let attr2: proc_macro2::TokenStream = attr.into();
411    let item2: proc_macro2::TokenStream = item.into();
412    match sort_enum::expand_sort_enum(attr2, item2) {
413        Ok(output) => output.into(),
414        Err(err) => err.to_compile_error().into(),
415    }
416}