bsql-macros 0.20.0

Proc macros for bsql — compile-time safe SQL for Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
#![forbid(unsafe_code)]

//! Proc macros for bsql.
//!
//! This crate is an implementation detail. Use [`bsql`] instead.

extern crate proc_macro;

mod codegen;
#[cfg(feature = "sqlite")]
mod codegen_sqlite;
mod connection;
mod dynamic;
mod offline;
mod parse;
mod pg_enum;
mod sort_enum;
mod sql_norm;
mod stmt_name;
mod suggest;
pub(crate) mod types;
#[cfg(feature = "sqlite")]
mod types_sqlite;
mod validate;
#[cfg(feature = "sqlite")]
mod validate_sqlite;

use proc_macro::TokenStream;

/// Validate a SQL query against PostgreSQL at compile time and generate
/// typed Rust code for executing it.
///
/// # Syntax
///
/// ```text
/// bsql::query! {
///     SELECT column1, column2
///     FROM table
///     WHERE column1 = $param_name: RustType
/// }
/// ```
///
/// Parameters are declared inline as `$name: Type`. The macro replaces them
/// with positional `$1`, `$2`, ... and verifies type compatibility against
/// the database schema.
///
/// # Execution methods
///
/// The macro returns an executor with these methods:
/// - `.fetch_one(executor)` — returns exactly one row (errors on 0 or 2+)
/// - `.fetch_all(executor)` — returns all rows as `Vec<T>`
/// - `.fetch_optional(executor)` — returns `Option<T>` (errors on 2+)
/// - `.execute(executor)` — returns affected row count (`u64`)
///
/// # Compile-time guarantees
///
/// - Table and column names are verified against the live database
/// - Parameter types are checked against PostgreSQL's expected types
/// - Nullable columns are automatically mapped to `Option<T>`
/// - Invalid SQL produces a compile error, not a runtime error
#[proc_macro]
pub fn query(input: TokenStream) -> TokenStream {
    let input2: proc_macro2::TokenStream = input.into();
    match query_impl(input2) {
        Ok(output) => output.into(),
        Err(err) => err.to_compile_error().into(),
    }
}

fn query_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
    // Extract the SQL string from the input.
    // Accepts either a string literal: query!("SELECT ...")
    // or raw tokens: query! { SELECT ... } converted to string.
    let sql = extract_sql(input)?;

    // 1. Parse: extract params, query kind, normalize SQL, optional clauses, sort placeholder
    let parsed = parse::parse_query(&sql)
        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;

    // Detect backend from database URL (if not offline)
    #[cfg(feature = "sqlite")]
    {
        let backend = connection::detect_backend()
            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
        if backend == Some(connection::Backend::Sqlite) {
            return query_impl_sqlite(parsed);
        }
    }

    // PostgreSQL path (default)
    query_impl_postgres(parsed)
}

/// PostgreSQL query implementation (the original path).
fn query_impl_postgres(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
    // 2. Sort query path — $[sort: EnumType] present
    if parsed.sort_placeholder.is_some() {
        return query_impl_sort(parsed);
    }

    if parsed.optional_clauses.is_empty() {
        // Static query path — no optional clauses
        let validation = if offline::is_offline() {
            // OFFLINE: read cached validation result
            offline::lookup_cached_validation(&parsed)
                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
        } else {
            // ONLINE: validate against PostgreSQL via PREPARE with suggestions
            let result = connection::with_connection(|conn| {
                validate::validate_query_with_suggestions(&parsed, conn)
            })?;

            // Write to offline cache for future use
            offline::write_cache(&parsed, &result);

            result
        };

        // Check parameter type compatibility
        validate::check_param_types(&parsed, &validation)
            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;

        // Generate Rust code
        Ok(codegen::generate_query_code(&parsed, &validation))
    } else {
        // Dynamic query path — has optional clauses.
        //
        // Validation: O(N+1) PREPAREs — base query + one per clause.
        // Codegen: O(N) runtime SQL builder (no 2^N match arms).
        let validation = if offline::is_offline() {
            // OFFLINE: read cached validation result for the base query.
            //
            // The cache stores the base query's param_pg_oids (not optional
            // clause params). Param type checking is skipped here because:
            //  1. The online build already validated all clauses' param types.
            //  2. The cached columns are identical (SELECT list never changes).
            //  3. Codegen only needs the column info, not per-clause param OIDs.
            offline::lookup_cached_validation(&parsed)
                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
        } else {
            // ONLINE: full 2^N validation — every combination checked.
            // "If it compiles, the SQL is correct" — no exceptions.
            let result = connection::with_connection(|conn| {
                let variants = dynamic::expand_variants(&parsed)?;
                validate::validate_variants(&variants, &parsed, conn)
            })?;

            // Write to offline cache for future use
            offline::write_cache(&parsed, &result);

            result
        };

        // Generate dynamic Rust code with runtime SQL dispatcher
        Ok(codegen::generate_dynamic_query_code(&parsed, &validation))
    }
}

/// SQLite query implementation.
///
/// Validates against a live SQLite database at compile time, then generates
/// code that executes via `bsql_core::SqlitePool`.
#[cfg(feature = "sqlite")]
fn query_impl_sqlite(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
    // Sort queries: $[sort: EnumType] present
    if parsed.sort_placeholder.is_some() {
        return query_impl_sqlite_sort(parsed);
    }

    if parsed.optional_clauses.is_empty() {
        // Static query path — no optional clauses
        let validation = if offline::is_offline() {
            offline::lookup_cached_validation(&parsed)
                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
        } else {
            let result = connection::with_sqlite_connection(|conn| {
                validate_sqlite::validate_query_sqlite(&parsed, conn)
            })?;

            // Write to offline cache for future use
            offline::write_cache(&parsed, &result);

            result
        };

        // SQLite doesn't type parameters at prepare time, so we skip
        // the PG-style param type check. Parameter types are verified
        // at runtime by the SqliteEncode trait.

        Ok(codegen_sqlite::generate_sqlite_query_code(
            &parsed,
            &validation,
        ))
    } else {
        // Dynamic query path — has optional clauses.
        // Validation: O(N+1) — base + each clause individually.
        // Codegen: O(N) runtime SQL builder.
        let validation = if offline::is_offline() {
            offline::lookup_cached_validation(&parsed)
                .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
        } else {
            // Full 2^N validation — every combination checked.
            let result = connection::with_sqlite_connection(|conn| {
                let variants = dynamic::expand_variants(&parsed)?;
                validate_sqlite::validate_variants_sqlite(&variants, &parsed, conn)
            })?;

            offline::write_cache(&parsed, &result);

            result
        };

        Ok(codegen_sqlite::generate_dynamic_sqlite_query_code(
            &parsed,
            &validation,
        ))
    }
}

/// SQLite sort query implementation.
#[cfg(feature = "sqlite")]
fn query_impl_sqlite_sort(
    parsed: parse::ParsedQuery,
) -> Result<proc_macro2::TokenStream, syn::Error> {
    let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
    let sort_enum_name = &sort_placeholder.enum_name;

    // Replace {SORT} with "1" to validate the query shape
    let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");

    let dummy_parsed = parse::ParsedQuery {
        normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
        positional_sql: dummy_sql,
        params: parsed.params.clone(),
        kind: parsed.kind,
        statement_name: parsed.statement_name.clone(),
        optional_clauses: parsed.optional_clauses.clone(),
        sort_placeholder: None,
    };

    let validation = if offline::is_offline() {
        offline::lookup_cached_validation(&parsed)
            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
    } else {
        let result = connection::with_sqlite_connection(|conn| {
            validate_sqlite::validate_query_sqlite(&dummy_parsed, conn)
        })?;

        offline::write_cache(&parsed, &result);
        result
    };

    Ok(codegen_sqlite::generate_sort_sqlite_query_code(
        &parsed,
        &validation,
        sort_enum_name,
    ))
}

/// Handle sort queries — queries with `$[sort: EnumType]`.
///
/// The sort enum is NOT resolved at macro expansion time (we don't have access
/// to the enum definition from within the proc macro). Instead, we generate code
/// that takes the sort enum as a parameter and uses `match` to select the SQL.
///
/// Validation: we validate each sort variant's expanded SQL at compile time
/// by reading sort variant info. However, since the sort enum is defined via
/// `#[bsql::sort]` in user code, we cannot read its variants from within
/// the `query!` macro. Instead, the generated code uses the enum's `sql()`
/// method at runtime. Validation of individual sort fragments happens when
/// the user compiles — the sort enum's SQL fragments are checked by the user
/// running their tests or by a separate validation step.
///
/// For now: generate code that takes a `sort` parameter with a `sql() -> &str`
/// method, and splices the SQL at runtime via string replacement + pre-hashed
/// dispatch.
fn query_impl_sort(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
    let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
    let sort_enum_name = &sort_placeholder.enum_name;

    // We can't validate sort variants at proc-macro time because we don't have
    // the enum definition. Instead, generate code that does runtime SQL dispatch.
    // The `{SORT}` in positional_sql will be a sentinel that codegen handles.

    // For validation, we need at least the base query structure. Use a dummy
    // ORDER BY to validate the query shape (columns, params) — replace {SORT}
    // with "1" (which is always valid in ORDER BY).
    let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");

    // Create a temporary ParsedQuery with the dummy SQL for validation
    let dummy_parsed = parse::ParsedQuery {
        normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
        positional_sql: dummy_sql,
        params: parsed.params.clone(),
        kind: parsed.kind,
        statement_name: parsed.statement_name.clone(),
        optional_clauses: parsed.optional_clauses.clone(),
        sort_placeholder: None,
    };

    let validation = if offline::is_offline() {
        offline::lookup_cached_validation(&parsed)
            .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
    } else {
        let result = connection::with_connection(|conn| {
            validate::validate_query_with_suggestions(&dummy_parsed, conn)
        })?;

        offline::write_cache(&parsed, &result);
        result
    };

    validate::check_param_types(&parsed, &validation)
        .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;

    // Generate sort-aware code
    Ok(codegen::generate_sort_query_code(
        &parsed,
        &validation,
        sort_enum_name,
    ))
}

/// Extract the SQL text from the macro input.
///
/// Accepts a string literal: `query!("SELECT ...")`
fn extract_sql(input: proc_macro2::TokenStream) -> Result<String, syn::Error> {
    let lit: syn::LitStr = syn::parse2(input)?;
    Ok(lit.value())
}

/// Derive PostgreSQL enum <-> Rust enum mapping with `FromSql` and `ToSql`.
///
/// # Usage
///
/// ```rust,ignore
/// #[bsql::pg_enum]
/// pub enum TicketStatus {
///     #[sql("new")]
///     New,
///     #[sql("in_progress")]
///     InProgress,
///     #[sql("resolved")]
///     Resolved,
///     #[sql("closed")]
///     Closed,
/// }
/// ```
///
/// Each variant must have a `#[sql("label")]` attribute mapping it to the
/// exact PostgreSQL enum label. The macro generates:
/// - `FromSql` — deserializes from PostgreSQL text representation
/// - `ToSql` — serializes to PostgreSQL text representation
/// - `Display` — formats as the SQL label
/// - Derives: `Debug, Clone, Copy, PartialEq, Eq, Hash`
///
/// If PostgreSQL sends a variant not present in the Rust enum, `FromSql`
/// returns an error describing the schema mismatch.
#[proc_macro_attribute]
pub fn pg_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr2: proc_macro2::TokenStream = attr.into();
    let item2: proc_macro2::TokenStream = item.into();
    match pg_enum::expand_pg_enum(attr2, item2) {
        Ok(output) => output.into(),
        Err(err) => err.to_compile_error().into(),
    }
}

/// Define a sort enum for compile-time verified dynamic `ORDER BY` clauses.
///
/// # Usage
///
/// ```rust,ignore
/// #[bsql::sort]
/// pub enum TicketSort {
///     #[sql("t.updated_at DESC, t.id DESC")]
///     UpdatedAt,
///     #[sql("t.deadline ASC NULLS LAST, t.id ASC")]
///     Deadline,
///     #[sql("t.id DESC")]
///     Id,
/// }
/// ```
///
/// Use with the `$[sort: EnumType]` placeholder in `bsql::query!`:
///
/// ```rust,ignore
/// let tickets = bsql::query!(
///     "SELECT id, title FROM tickets ORDER BY $[sort: TicketSort] LIMIT $limit: i64"
/// ).fetch_all(&pool)?;
/// ```
///
/// Each variant must have a `#[sql("...")]` attribute mapping it to the
/// SQL `ORDER BY` fragment. The macro generates:
/// - The enum with `Debug, Clone, Copy, PartialEq, Eq, Hash`
/// - A `sql(&self) -> &'static str` method returning the SQL fragment
/// - `Display` — formats as the SQL fragment
///
/// Unlike `#[bsql::pg_enum]`, sort enums are NOT parameterized values.
/// The SQL fragment is spliced directly into the query string.
#[proc_macro_attribute]
pub fn sort(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr2: proc_macro2::TokenStream = attr.into();
    let item2: proc_macro2::TokenStream = item.into();
    match sort_enum::expand_sort_enum(attr2, item2) {
        Ok(output) => output.into(),
        Err(err) => err.to_compile_error().into(),
    }
}