bsql_macros/lib.rs
1#![forbid(unsafe_code)]
2
3//! Proc macros for bsql.
4//!
5//! This crate is an implementation detail. Use [`bsql`] instead.
6
7extern crate proc_macro;
8
9mod codegen;
10#[cfg(feature = "sqlite")]
11mod codegen_sqlite;
12mod connection;
13mod dynamic;
14mod offline;
15mod parse;
16mod pg_enum;
17mod sort_enum;
18mod sql_norm;
19mod stmt_name;
20mod suggest;
21pub(crate) mod types;
22#[cfg(feature = "sqlite")]
23mod types_sqlite;
24mod validate;
25#[cfg(feature = "sqlite")]
26mod validate_sqlite;
27
28use proc_macro::TokenStream;
29
30/// Validate a SQL query against PostgreSQL at compile time and generate
31/// typed Rust code for executing it.
32///
33/// # Syntax
34///
35/// ```text
36/// bsql::query! {
37/// SELECT column1, column2
38/// FROM table
39/// WHERE column1 = $param_name: RustType
40/// }
41/// ```
42///
43/// Parameters are declared inline as `$name: Type`. The macro replaces them
44/// with positional `$1`, `$2`, ... and verifies type compatibility against
45/// the database schema.
46///
47/// # Execution methods
48///
49/// The macro returns an executor with these methods:
50/// - `.fetch_one(executor)` — returns exactly one row (errors on 0 or 2+)
51/// - `.fetch_all(executor)` — returns all rows as `Vec<T>`
52/// - `.fetch_optional(executor)` — returns `Option<T>` (errors on 2+)
53/// - `.execute(executor)` — returns affected row count (`u64`)
54///
55/// # Compile-time guarantees
56///
57/// - Table and column names are verified against the live database
58/// - Parameter types are checked against PostgreSQL's expected types
59/// - Nullable columns are automatically mapped to `Option<T>`
60/// - Invalid SQL produces a compile error, not a runtime error
61#[proc_macro]
62pub fn query(input: TokenStream) -> TokenStream {
63 let input2: proc_macro2::TokenStream = input.into();
64 match query_impl(input2) {
65 Ok(output) => output.into(),
66 Err(err) => err.to_compile_error().into(),
67 }
68}
69
70fn query_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
71 // Extract the SQL string from the input.
72 // Accepts either a string literal: query!("SELECT ...")
73 // or raw tokens: query! { SELECT ... } converted to string.
74 let sql = extract_sql(input)?;
75
76 // 1. Parse: extract params, query kind, normalize SQL, optional clauses, sort placeholder
77 let parsed = parse::parse_query(&sql)
78 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
79
80 // Detect backend from database URL (if not offline)
81 #[cfg(feature = "sqlite")]
82 {
83 let backend = connection::detect_backend()
84 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
85 if backend == Some(connection::Backend::Sqlite) {
86 return query_impl_sqlite(parsed);
87 }
88 }
89
90 // PostgreSQL path (default)
91 query_impl_postgres(parsed)
92}
93
94/// PostgreSQL query implementation (the original path).
95fn query_impl_postgres(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
96 // 2. Sort query path — $[sort: EnumType] present
97 if parsed.sort_placeholder.is_some() {
98 return query_impl_sort(parsed);
99 }
100
101 // 3. Expand dynamic query variants (if any optional clauses)
102 let variants = dynamic::expand_variants(&parsed)
103 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
104
105 if parsed.optional_clauses.is_empty() {
106 // Static query path — no optional clauses
107 let validation = if offline::is_offline() {
108 // OFFLINE: read cached validation result
109 offline::lookup_cached_validation(&parsed)
110 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
111 } else {
112 // ONLINE: validate against PostgreSQL via PREPARE with suggestions
113 let result = connection::with_connection(|conn| {
114 validate::validate_query_with_suggestions(&parsed, conn)
115 })?;
116
117 // Write to offline cache for future use
118 offline::write_cache(&parsed, &result);
119
120 result
121 };
122
123 // Check parameter type compatibility
124 validate::check_param_types(&parsed, &validation)
125 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
126
127 // Generate Rust code
128 Ok(codegen::generate_query_code(&parsed, &validation))
129 } else {
130 // Dynamic query path — has optional clauses
131 let validation = if offline::is_offline() {
132 // OFFLINE: read cached validation result for the base variant.
133 //
134 // The cache stores variant 0's param_pg_oids, which only covers
135 // the base params (not optional clause params). Param type
136 // checking is skipped here because:
137 // 1. The online build already validated ALL variants' param types.
138 // 2. The cached columns are identical across all variants (the
139 // SELECT list never changes, only WHERE clauses differ).
140 // 3. Codegen only needs the column info, not per-variant param OIDs.
141 offline::lookup_cached_validation(&parsed)
142 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
143 } else {
144 // ONLINE: validate ALL variants against PostgreSQL and check param types
145 let result = connection::with_connection(|conn| {
146 validate::validate_variants(&variants, &parsed, conn)
147 })?;
148
149 // Write to offline cache for future use
150 offline::write_cache(&parsed, &result);
151
152 result
153 };
154
155 // Generate dynamic Rust code with match dispatcher
156 Ok(codegen::generate_dynamic_query_code(
157 &parsed,
158 &validation,
159 &variants,
160 ))
161 }
162}
163
164/// SQLite query implementation.
165///
166/// Validates against a live SQLite database at compile time, then generates
167/// code that executes via `bsql_core::SqlitePool`.
168#[cfg(feature = "sqlite")]
169fn query_impl_sqlite(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
170 // Sort queries: $[sort: EnumType] present
171 if parsed.sort_placeholder.is_some() {
172 return query_impl_sqlite_sort(parsed);
173 }
174
175 // Expand dynamic query variants (if any optional clauses)
176 let variants = dynamic::expand_variants(&parsed)
177 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
178
179 if parsed.optional_clauses.is_empty() {
180 // Static query path — no optional clauses
181 let validation = if offline::is_offline() {
182 offline::lookup_cached_validation(&parsed)
183 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
184 } else {
185 let result = connection::with_sqlite_connection(|conn| {
186 validate_sqlite::validate_query_sqlite(&parsed, conn)
187 })?;
188
189 // Write to offline cache for future use
190 offline::write_cache(&parsed, &result);
191
192 result
193 };
194
195 // SQLite doesn't type parameters at prepare time, so we skip
196 // the PG-style param type check. Parameter types are verified
197 // at runtime by the SqliteEncode trait.
198
199 Ok(codegen_sqlite::generate_sqlite_query_code(
200 &parsed,
201 &validation,
202 ))
203 } else {
204 // Dynamic query path — has optional clauses
205 let validation = if offline::is_offline() {
206 offline::lookup_cached_validation(&parsed)
207 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
208 } else {
209 // Validate ALL variants against SQLite
210 let result = connection::with_sqlite_connection(|conn| {
211 validate_sqlite::validate_variants_sqlite(&variants, &parsed, conn)
212 })?;
213
214 offline::write_cache(&parsed, &result);
215
216 result
217 };
218
219 Ok(codegen_sqlite::generate_dynamic_sqlite_query_code(
220 &parsed,
221 &validation,
222 &variants,
223 ))
224 }
225}
226
227/// SQLite sort query implementation.
228#[cfg(feature = "sqlite")]
229fn query_impl_sqlite_sort(
230 parsed: parse::ParsedQuery,
231) -> Result<proc_macro2::TokenStream, syn::Error> {
232 let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
233 let sort_enum_name = &sort_placeholder.enum_name;
234
235 // Replace {SORT} with "1" to validate the query shape
236 let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
237
238 let dummy_parsed = parse::ParsedQuery {
239 normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
240 positional_sql: dummy_sql,
241 params: parsed.params.clone(),
242 kind: parsed.kind,
243 statement_name: parsed.statement_name.clone(),
244 optional_clauses: parsed.optional_clauses.clone(),
245 sort_placeholder: None,
246 };
247
248 let validation = if offline::is_offline() {
249 offline::lookup_cached_validation(&parsed)
250 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
251 } else {
252 let result = connection::with_sqlite_connection(|conn| {
253 validate_sqlite::validate_query_sqlite(&dummy_parsed, conn)
254 })?;
255
256 offline::write_cache(&parsed, &result);
257 result
258 };
259
260 Ok(codegen_sqlite::generate_sort_sqlite_query_code(
261 &parsed,
262 &validation,
263 sort_enum_name,
264 ))
265}
266
267/// Handle sort queries — queries with `$[sort: EnumType]`.
268///
269/// The sort enum is NOT resolved at macro expansion time (we don't have access
270/// to the enum definition from within the proc macro). Instead, we generate code
271/// that takes the sort enum as a parameter and uses `match` to select the SQL.
272///
273/// Validation: we validate each sort variant's expanded SQL at compile time
274/// by reading sort variant info. However, since the sort enum is defined via
275/// `#[bsql::sort]` in user code, we cannot read its variants from within
276/// the `query!` macro. Instead, the generated code uses the enum's `sql()`
277/// method at runtime. Validation of individual sort fragments happens when
278/// the user compiles — the sort enum's SQL fragments are checked by the user
279/// running their tests or by a separate validation step.
280///
281/// For now: generate code that takes a `sort` parameter with a `sql() -> &str`
282/// method, and splices the SQL at runtime via string replacement + pre-hashed
283/// dispatch.
284fn query_impl_sort(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
285 let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
286 let sort_enum_name = &sort_placeholder.enum_name;
287
288 // We can't validate sort variants at proc-macro time because we don't have
289 // the enum definition. Instead, generate code that does runtime SQL dispatch.
290 // The `{SORT}` in positional_sql will be a sentinel that codegen handles.
291
292 // For validation, we need at least the base query structure. Use a dummy
293 // ORDER BY to validate the query shape (columns, params) — replace {SORT}
294 // with "1" (which is always valid in ORDER BY).
295 let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
296
297 // Create a temporary ParsedQuery with the dummy SQL for validation
298 let dummy_parsed = parse::ParsedQuery {
299 normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
300 positional_sql: dummy_sql,
301 params: parsed.params.clone(),
302 kind: parsed.kind,
303 statement_name: parsed.statement_name.clone(),
304 optional_clauses: parsed.optional_clauses.clone(),
305 sort_placeholder: None,
306 };
307
308 let validation = if offline::is_offline() {
309 offline::lookup_cached_validation(&parsed)
310 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
311 } else {
312 let result = connection::with_connection(|conn| {
313 validate::validate_query_with_suggestions(&dummy_parsed, conn)
314 })?;
315
316 offline::write_cache(&parsed, &result);
317 result
318 };
319
320 validate::check_param_types(&parsed, &validation)
321 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
322
323 // Generate sort-aware code
324 Ok(codegen::generate_sort_query_code(
325 &parsed,
326 &validation,
327 sort_enum_name,
328 ))
329}
330
331/// Extract the SQL text from the macro input.
332///
333/// Accepts a string literal: `query!("SELECT ...")`
334fn extract_sql(input: proc_macro2::TokenStream) -> Result<String, syn::Error> {
335 let lit: syn::LitStr = syn::parse2(input)?;
336 Ok(lit.value())
337}
338
339/// Derive PostgreSQL enum <-> Rust enum mapping with `FromSql` and `ToSql`.
340///
341/// # Usage
342///
343/// ```rust,ignore
344/// #[bsql::pg_enum]
345/// pub enum TicketStatus {
346/// #[sql("new")]
347/// New,
348/// #[sql("in_progress")]
349/// InProgress,
350/// #[sql("resolved")]
351/// Resolved,
352/// #[sql("closed")]
353/// Closed,
354/// }
355/// ```
356///
357/// Each variant must have a `#[sql("label")]` attribute mapping it to the
358/// exact PostgreSQL enum label. The macro generates:
359/// - `FromSql` — deserializes from PostgreSQL text representation
360/// - `ToSql` — serializes to PostgreSQL text representation
361/// - `Display` — formats as the SQL label
362/// - Derives: `Debug, Clone, Copy, PartialEq, Eq, Hash`
363///
364/// If PostgreSQL sends a variant not present in the Rust enum, `FromSql`
365/// returns an error describing the schema mismatch.
366#[proc_macro_attribute]
367pub fn pg_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
368 let attr2: proc_macro2::TokenStream = attr.into();
369 let item2: proc_macro2::TokenStream = item.into();
370 match pg_enum::expand_pg_enum(attr2, item2) {
371 Ok(output) => output.into(),
372 Err(err) => err.to_compile_error().into(),
373 }
374}
375
376/// Define a sort enum for compile-time verified dynamic `ORDER BY` clauses.
377///
378/// # Usage
379///
380/// ```rust,ignore
381/// #[bsql::sort]
382/// pub enum TicketSort {
383/// #[sql("t.updated_at DESC, t.id DESC")]
384/// UpdatedAt,
385/// #[sql("t.deadline ASC NULLS LAST, t.id ASC")]
386/// Deadline,
387/// #[sql("t.id DESC")]
388/// Id,
389/// }
390/// ```
391///
392/// Use with the `$[sort: EnumType]` placeholder in `bsql::query!`:
393///
394/// ```rust,ignore
395/// let tickets = bsql::query!(
396/// "SELECT id, title FROM tickets ORDER BY $[sort: TicketSort] LIMIT $limit: i64"
397/// ).fetch_all(&pool)?;
398/// ```
399///
400/// Each variant must have a `#[sql("...")]` attribute mapping it to the
401/// SQL `ORDER BY` fragment. The macro generates:
402/// - The enum with `Debug, Clone, Copy, PartialEq, Eq, Hash`
403/// - A `sql(&self) -> &'static str` method returning the SQL fragment
404/// - `Display` — formats as the SQL fragment
405///
406/// Unlike `#[bsql::pg_enum]`, sort enums are NOT parameterized values.
407/// The SQL fragment is spliced directly into the query string.
408#[proc_macro_attribute]
409pub fn sort(attr: TokenStream, item: TokenStream) -> TokenStream {
410 let attr2: proc_macro2::TokenStream = attr.into();
411 let item2: proc_macro2::TokenStream = item.into();
412 match sort_enum::expand_sort_enum(attr2, item2) {
413 Ok(output) => output.into(),
414 Err(err) => err.to_compile_error().into(),
415 }
416}