Skip to main content

resolute_macros/
lib.rs

1//! Compile-time checked SQL query macros with offline cache support.
2//!
3//! Connects to PostgreSQL at compile time via pg-wired to validate SQL
4//! and generate typed result structs.
5//!
6//! # Modes
7//!
8//! - **Online (default):** Connects to DB via `DATABASE_URL`, caches results to `.resolute/`
9//! - **Offline:** Set `RESOLUTE_OFFLINE=true` to use cached metadata only (no DB needed)
10//! - **Prepare:** Run `resolute-cli prepare` to populate the cache from source files
11//!
12//! ```ignore
13//! let rows = resolute::query!("SELECT id, name FROM users WHERE id = $1", user_id)
14//!     .fetch_all(&client)
15//!     .await?;
16//! ```
17
18use proc_macro::TokenStream;
19use quote::{format_ident, quote};
20use syn::parse::{Parse, ParseStream};
21use syn::{parse_macro_input, Expr, LitStr, Token};
22
23mod cache;
24
25/// Input to the query! macro: SQL literal + optional comma-separated params.
26/// Supports both positional (`query!("... $1", val)`) and named (`query!("... :name", name = val)`) params.
27struct QueryInput {
28    sql: LitStr,
29    params: Vec<Expr>,
30    /// If named params were used, this holds (name, expr) pairs.
31    named: Vec<(syn::Ident, Expr)>,
32}
33
34impl Parse for QueryInput {
35    fn parse(input: ParseStream) -> syn::Result<Self> {
36        let sql: LitStr = input.parse()?;
37        let mut params = Vec::new();
38        let mut named = Vec::new();
39        let mut mode: Option<bool> = None; // None=unknown, Some(true)=named, Some(false)=positional
40
41        while input.peek(Token![,]) {
42            input.parse::<Token![,]>()?;
43            if input.is_empty() {
44                break;
45            }
46
47            // Try to detect named param: `ident = expr` (but not `ident == expr`)
48            let is_named_param = {
49                let fork = input.fork();
50                fork.parse::<syn::Ident>().is_ok()
51                    && fork.parse::<Token![=]>().is_ok()
52                    && !fork.peek(Token![=])
53            };
54
55            if is_named_param && mode != Some(false) {
56                let name: syn::Ident = input.parse()?;
57                input.parse::<Token![=]>()?;
58                let expr: Expr = input.parse()?;
59                named.push((name, expr));
60                mode = Some(true);
61            } else if mode != Some(true) {
62                params.push(input.parse()?);
63                mode = Some(false);
64            } else {
65                return Err(input.error("cannot mix positional and named parameters"));
66            }
67        }
68
69        Ok(QueryInput { sql, params, named })
70    }
71}
72
73/// If named params are present, rewrite SQL and reorder params.
74/// Otherwise pass through unchanged.
75fn resolve_named(
76    sql_str: String,
77    params: Vec<Expr>,
78    named: &[(syn::Ident, Expr)],
79    sql_span: &LitStr,
80) -> Result<(String, Vec<Expr>), TokenStream> {
81    if named.is_empty() {
82        return Ok((sql_str, params));
83    }
84    let (rewritten, names) = rewrite_named_params(&sql_str);
85    let mut ordered = Vec::with_capacity(names.len());
86    for name in &names {
87        match named.iter().find(|(n, _)| n == name) {
88            Some((_, expr)) => ordered.push(expr.clone()),
89            None => {
90                let msg = format!("named parameter `:{name}` in SQL has no binding");
91                return Err(syn::Error::new_spanned(sql_span, msg)
92                    .to_compile_error()
93                    .into());
94            }
95        }
96    }
97    for (n, _) in named {
98        if !names.iter().any(|name| name == &n.to_string()) {
99            let msg = format!("binding `{}` does not match any `:{}` in SQL", n, n);
100            return Err(syn::Error::new_spanned(n, msg).to_compile_error().into());
101        }
102    }
103    Ok((rewritten, ordered))
104}
105
106/// Resolve query metadata: try cache first, then live DB, then update cache.
107fn resolve_metadata(sql: &str) -> Result<(Vec<u32>, Vec<cache::CachedColumn>), String> {
108    let sql_hash = hash_sql(sql);
109    let offline = std::env::var("RESOLUTE_OFFLINE")
110        .map(|v| v == "true" || v == "1")
111        .unwrap_or(false);
112
113    // 1. Try the cache first.
114    if let Some(cached) = cache::read_cache(sql_hash) {
115        return Ok((cached.param_oids, cached.columns));
116    }
117
118    // 2. If offline mode, fail — cache is required.
119    if offline {
120        return Err(format!(
121            "RESOLUTE_OFFLINE=true but no cached metadata for query (hash {sql_hash:x}). \
122             Run `resolute-cli prepare` to populate the cache."
123        ));
124    }
125
126    // 3. Connect to PG and describe.
127    let (param_oids, columns) = describe_live(sql)?;
128
129    // 4. Write to cache for future offline builds.
130    let entry = cache::CacheEntry {
131        sql: sql.to_string(),
132        hash: sql_hash,
133        param_oids: param_oids.clone(),
134        columns: columns.clone(),
135    };
136    if let Err(e) = cache::write_cache(&entry) {
137        // Cache write failure is non-fatal — just warn.
138        eprintln!("resolute: warning: failed to write cache: {e}");
139    }
140
141    Ok((param_oids, columns))
142}
143
144/// Connect to PG via pg-wired and describe the statement.
145fn describe_live(sql: &str) -> Result<(Vec<u32>, Vec<cache::CachedColumn>), String> {
146    let db_url = std::env::var("DATABASE_URL").map_err(|_| {
147        "DATABASE_URL not set and no cached metadata found. \
148         Set DATABASE_URL or run `resolute-cli prepare`."
149            .to_string()
150    })?;
151
152    let (user, password, host, port, database) = parse_pg_uri(&db_url)
153        .ok_or_else(|| "Invalid DATABASE_URL (could not parse as postgres:// URI)".to_string())?;
154    let addr = format!("{host}:{port}");
155
156    let rt = tokio::runtime::Builder::new_current_thread()
157        .enable_all()
158        .build()
159        .map_err(|e| format!("Failed to create tokio runtime: {e}"))?;
160
161    rt.block_on(async {
162        let mut conn = pg_wired::WireConn::connect(&addr, &user, &password, &database)
163            .await
164            .map_err(|e| format!("Failed to connect to database: {e}"))?;
165
166        let (param_oids, fields) = conn
167            .describe_statement(sql)
168            .await
169            .map_err(|e| format!("SQL error: {e}"))?;
170
171        // Detect nullable columns by querying pg_attribute for real table columns.
172        // Batch all table_oid/column_id pairs into one query.
173        let mut columns: Vec<cache::CachedColumn> = fields
174            .iter()
175            .map(|f| cache::CachedColumn {
176                name: f.name.clone(),
177                type_oid: f.type_oid,
178                nullable: true, // Default: assume nullable.
179            })
180            .collect();
181
182        // Collect non-null info for columns that come from real tables.
183        let table_cols: Vec<(usize, u32, i16)> = fields
184            .iter()
185            .enumerate()
186            .filter(|(_, f)| f.table_oid != 0 && f.column_id > 0)
187            .map(|(i, f)| (i, f.table_oid, f.column_id))
188            .collect();
189
190        if !table_cols.is_empty() {
191            // Build a single query to check all columns at once.
192            let conditions: Vec<String> = table_cols
193                .iter()
194                .map(|(_, oid, col)| format!("(attrelid={oid} AND attnum={col})"))
195                .collect();
196            let null_sql = format!(
197                "SELECT attrelid, attnum, attnotnull FROM pg_attribute WHERE {}",
198                conditions.join(" OR ")
199            );
200
201            // Send as simple query and collect rows.
202            let mut buf = bytes::BytesMut::new();
203            pg_wired::protocol::frontend::encode_message(
204                &pg_wired::protocol::types::FrontendMsg::Query(null_sql.as_bytes()),
205                &mut buf,
206            );
207            if conn.send_raw(&buf).await.is_ok() {
208                if let Ok((rows, _)) = conn.collect_rows().await {
209                    for row in &rows {
210                        let oid: u32 = row
211                            .cell(0)
212                            .and_then(|b| std::str::from_utf8(b).ok())
213                            .and_then(|s| s.parse().ok())
214                            .unwrap_or(0);
215                        let col: i16 = row
216                            .cell(1)
217                            .and_then(|b| std::str::from_utf8(b).ok())
218                            .and_then(|s| s.parse().ok())
219                            .unwrap_or(0);
220                        let notnull: bool =
221                            row.cell(2).map(|b| b == b"t".as_ref()).unwrap_or(false);
222
223                        // Find the matching column and mark it non-nullable.
224                        for &(idx, t_oid, t_col) in &table_cols {
225                            if t_oid == oid && t_col == col && notnull {
226                                columns[idx].nullable = false;
227                            }
228                        }
229                    }
230                }
231            }
232        }
233
234        Ok((param_oids, columns))
235    })
236}
237
238/// Map a PostgreSQL type OID to a Rust type token.
239///
240/// Returns `Err` if the OID maps to a Rust type behind a disabled feature
241/// (`chrono`, `json`, or `uuid`). Callers surface the error as a
242/// `syn::Error` pointing at the SQL literal.
243fn oid_to_rust_type(oid: u32) -> Result<proc_macro2::TokenStream, String> {
244    let ty = match oid {
245        // Scalar types
246        16 => quote! { bool },
247        18 | 19 | 25 | 1042 | 1043 => quote! { String },
248        20 => quote! { i64 },
249        21 => quote! { i16 },
250        23 | 26 => quote! { i32 },
251        700 => quote! { f32 },
252        701 => quote! { f64 },
253        17 => quote! { Vec<u8> },
254        869 => quote! { resolute::PgInet },
255        1700 => quote! { resolute::PgNumeric },
256        // Array types
257        1000 => quote! { Vec<bool> },
258        1005 => quote! { Vec<i16> },
259        1007 => quote! { Vec<i32> },
260        1009 | 1015 => quote! { Vec<String> },
261        1016 => quote! { Vec<i64> },
262        1021 => quote! { Vec<f32> },
263        1022 => quote! { Vec<f64> },
264        1041 => quote! { Vec<resolute::PgInet> },
265        1231 => quote! { Vec<resolute::PgNumeric> },
266        // Range types
267        3904 => quote! { resolute::PgRange<i32> },
268        3926 => quote! { resolute::PgRange<i64> },
269        3906 => quote! { resolute::PgRange<resolute::PgNumeric> },
270        // JSON (feature = "json")
271        #[cfg(feature = "json")]
272        114 | 3802 => quote! { serde_json::Value },
273        #[cfg(feature = "json")]
274        3807 => quote! { Vec<serde_json::Value> },
275        #[cfg(not(feature = "json"))]
276        114 | 3802 | 3807 => {
277            return Err(format!(
278                "column type `{}` requires the `json` feature, which is disabled. \
279                 Enable `resolute/json` in your Cargo.toml to use JSON/JSONB columns.",
280                oid_to_type_name(oid)
281            ));
282        }
283        // chrono (feature = "chrono")
284        #[cfg(feature = "chrono")]
285        1082 => quote! { chrono::NaiveDate },
286        #[cfg(feature = "chrono")]
287        1083 => quote! { chrono::NaiveTime },
288        #[cfg(feature = "chrono")]
289        1114 => quote! { chrono::NaiveDateTime },
290        #[cfg(feature = "chrono")]
291        1184 => quote! { chrono::DateTime<chrono::Utc> },
292        #[cfg(feature = "chrono")]
293        1115 => quote! { Vec<chrono::NaiveDateTime> },
294        #[cfg(feature = "chrono")]
295        1182 => quote! { Vec<chrono::NaiveDate> },
296        #[cfg(feature = "chrono")]
297        1183 => quote! { Vec<chrono::NaiveTime> },
298        #[cfg(feature = "chrono")]
299        1185 => quote! { Vec<chrono::DateTime<chrono::Utc>> },
300        #[cfg(feature = "chrono")]
301        3912 => quote! { resolute::PgRange<chrono::NaiveDate> },
302        #[cfg(feature = "chrono")]
303        3908 => quote! { resolute::PgRange<chrono::NaiveDateTime> },
304        #[cfg(feature = "chrono")]
305        3910 => quote! { resolute::PgRange<chrono::DateTime<chrono::Utc>> },
306        #[cfg(not(feature = "chrono"))]
307        1082 | 1083 | 1114 | 1184 | 1115 | 1182 | 1183 | 1185 | 3912 | 3908 | 3910 => {
308            return Err(format!(
309                "column type `{}` requires the `chrono` feature, which is disabled. \
310                 Enable `resolute/chrono` in your Cargo.toml to use date/time columns.",
311                oid_to_type_name(oid)
312            ));
313        }
314        // uuid (feature = "uuid")
315        #[cfg(feature = "uuid")]
316        2950 => quote! { uuid::Uuid },
317        #[cfg(feature = "uuid")]
318        2951 => quote! { Vec<uuid::Uuid> },
319        #[cfg(not(feature = "uuid"))]
320        2950 | 2951 => {
321            return Err(format!(
322                "column type `{}` requires the `uuid` feature, which is disabled. \
323                 Enable `resolute/uuid` in your Cargo.toml to use UUID columns.",
324                oid_to_type_name(oid)
325            ));
326        }
327        _ => quote! { Vec<u8> },
328    };
329    Ok(ty)
330}
331
332/// `query!("SQL", param1, param2, ...)` — compile-time checked SQL query.
333#[proc_macro]
334pub fn query(input: TokenStream) -> TokenStream {
335    let parsed = parse_macro_input!(input as QueryInput);
336    query_impl(parsed)
337}
338
339fn query_impl(input: QueryInput) -> TokenStream {
340    let QueryInput { sql, params, named } = input;
341    let sql_str = sql.value();
342
343    let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
344        Ok(v) => v,
345        Err(ts) => return ts,
346    };
347
348    let (param_oids, column_infos) = match resolve_metadata(&sql_str) {
349        Ok(result) => result,
350        Err(e) => {
351            return syn::Error::new_spanned(&sql, e).to_compile_error().into();
352        }
353    };
354
355    if params.len() != param_oids.len() {
356        let msg = format!(
357            "expected {} parameter(s), got {}",
358            param_oids.len(),
359            params.len()
360        );
361        return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
362    }
363
364    // Generate compile-time param type checks. The check verifies the param
365    // implements `Encode`/`SqlParam`; downstream trait impls enforce that the
366    // Rust type matches the PostgreSQL OID expected by the server.
367    let param_type_checks: Vec<_> = params
368        .iter()
369        .map(|param| {
370            quote! {
371                {
372                    fn __resolute_check_param<T: resolute::Encode + Sync>(_: &T) {}
373                    __resolute_check_param(&#param);
374                    let _ = &#param as &dyn resolute::SqlParam;
375                }
376            }
377        })
378        .collect();
379
380    // Parse type overrides from column names (e.g., "id: UserId").
381    let overrides: Vec<_> = column_infos
382        .iter()
383        .map(|c| parse_type_override(&c.name))
384        .collect();
385
386    let field_names: Vec<_> = overrides
387        .iter()
388        .map(|(name, _)| format_ident!("{}", sanitize_ident(name)))
389        .collect();
390    let field_types: Vec<_> = match column_infos
391        .iter()
392        .zip(overrides.iter())
393        .map(
394            |(c, (_, type_override))| -> Result<proc_macro2::TokenStream, String> {
395                let base = if let Some(ref custom) = type_override {
396                    custom.clone()
397                } else {
398                    oid_to_rust_type(c.type_oid)?
399                };
400                Ok(if c.nullable {
401                    quote! { Option<#base> }
402                } else {
403                    base
404                })
405            },
406        )
407        .collect::<Result<Vec<_>, String>>()
408    {
409        Ok(v) => v,
410        Err(e) => return syn::Error::new_spanned(&sql, e).to_compile_error().into(),
411    };
412    let _field_indices: Vec<_> = (0..column_infos.len()).collect::<Vec<_>>();
413    let field_getters: Vec<_> = column_infos
414        .iter()
415        .enumerate()
416        .map(|(i, c)| {
417            if c.nullable {
418                quote! { row.get_opt(#i)? }
419            } else {
420                quote! { row.get(#i)? }
421            }
422        })
423        .collect();
424
425    let struct_name = format_ident!("__QueryResult_{}", hash_sql(&sql_str));
426
427    let param_refs: Vec<_> = params
428        .iter()
429        .map(|p| quote! { &#p as &dyn resolute::SqlParam })
430        .collect();
431
432    // Use rewritten SQL (with $1,$2) in generated code, not the original :name SQL.
433    let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
434
435    let expanded = quote! {
436        {
437            // Compile-time parameter type assertions.
438            #(#param_type_checks)*
439
440            #[allow(non_camel_case_types)]
441            #[derive(Debug)]
442            struct #struct_name {
443                #(pub #field_names: #field_types,)*
444            }
445
446            resolute::CheckedQuery::<#struct_name> {
447                sql: #sql_lit_rewritten,
448                params: vec![#(#param_refs),*],
449                _marker: std::marker::PhantomData,
450                mapper: |row: &resolute::Row| -> Result<#struct_name, resolute::TypedError> {
451                    Ok(#struct_name {
452                        #(#field_names: #field_getters,)*
453                    })
454                },
455            }
456        }
457    };
458
459    TokenStream::from(expanded)
460}
461
462/// `query_as!(Type, "SQL", param1, param2, ...)` — compile-time checked query
463/// mapped to an existing struct via FromRow.
464#[proc_macro]
465pub fn query_as(input: TokenStream) -> TokenStream {
466    let parsed = parse_macro_input!(input as QueryAsInput);
467    query_as_impl(parsed)
468}
469
470fn query_as_impl(input: QueryAsInput) -> TokenStream {
471    let QueryAsInput {
472        target_type,
473        sql,
474        params,
475        named,
476    } = input;
477    let sql_str = sql.value();
478
479    let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
480        Ok(v) => v,
481        Err(ts) => return ts,
482    };
483
484    let (param_oids, _column_infos) = match resolve_metadata(&sql_str) {
485        Ok(result) => result,
486        Err(e) => {
487            return syn::Error::new_spanned(&sql, e).to_compile_error().into();
488        }
489    };
490
491    if params.len() != param_oids.len() {
492        let msg = format!(
493            "expected {} parameter(s), got {}",
494            param_oids.len(),
495            params.len()
496        );
497        return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
498    }
499
500    let param_refs: Vec<_> = params
501        .iter()
502        .map(|p| quote! { &#p as &dyn resolute::SqlParam })
503        .collect();
504    let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
505
506    let expanded = quote! {
507        {
508            resolute::CheckedQuery::<#target_type> {
509                sql: #sql_lit_rewritten,
510                params: vec![#(#param_refs),*],
511                _marker: std::marker::PhantomData,
512                mapper: |row: &resolute::Row| -> Result<#target_type, resolute::TypedError> {
513                    <#target_type as resolute::FromRow>::from_row(row)
514                },
515            }
516        }
517    };
518
519    TokenStream::from(expanded)
520}
521
522/// `query_scalar!("SQL", param1, ...)` — compile-time checked single-value query.
523#[proc_macro]
524pub fn query_scalar(input: TokenStream) -> TokenStream {
525    let parsed = parse_macro_input!(input as QueryInput);
526    query_scalar_impl(parsed)
527}
528
529fn query_scalar_impl(input: QueryInput) -> TokenStream {
530    let QueryInput { sql, params, named } = input;
531    let sql_str = sql.value();
532
533    let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
534        Ok(v) => v,
535        Err(ts) => return ts,
536    };
537
538    let (param_oids, column_infos) = match resolve_metadata(&sql_str) {
539        Ok(result) => result,
540        Err(e) => {
541            return syn::Error::new_spanned(&sql, e).to_compile_error().into();
542        }
543    };
544
545    if params.len() != param_oids.len() {
546        let msg = format!(
547            "expected {} parameter(s), got {}",
548            param_oids.len(),
549            params.len()
550        );
551        return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
552    }
553
554    if column_infos.len() != 1 {
555        let msg = format!(
556            "query_scalar! requires exactly 1 column, got {}",
557            column_infos.len()
558        );
559        return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
560    }
561
562    let scalar_type = {
563        let (_, type_override) = parse_type_override(&column_infos[0].name);
564        match type_override {
565            Some(ty) => ty,
566            None => match oid_to_rust_type(column_infos[0].type_oid) {
567                Ok(ty) => ty,
568                Err(e) => return syn::Error::new_spanned(&sql, e).to_compile_error().into(),
569            },
570        }
571    };
572    let param_refs: Vec<_> = params
573        .iter()
574        .map(|p| quote! { &#p as &dyn resolute::SqlParam })
575        .collect();
576    let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
577
578    let expanded = quote! {
579        {
580            resolute::CheckedQuery::<#scalar_type> {
581                sql: #sql_lit_rewritten,
582                params: vec![#(#param_refs),*],
583                _marker: std::marker::PhantomData,
584                mapper: |row: &resolute::Row| -> Result<#scalar_type, resolute::TypedError> {
585                    row.get(0)
586                },
587            }
588        }
589    };
590
591    TokenStream::from(expanded)
592}
593
594/// Input to query_as!: Type, "SQL", params...
595struct QueryAsInput {
596    target_type: syn::Type,
597    sql: LitStr,
598    params: Vec<Expr>,
599    named: Vec<(syn::Ident, Expr)>,
600}
601
602impl Parse for QueryAsInput {
603    fn parse(input: ParseStream) -> syn::Result<Self> {
604        let target_type: syn::Type = input.parse()?;
605        input.parse::<Token![,]>()?;
606        let sql: LitStr = input.parse()?;
607        let mut params = Vec::new();
608        let mut named = Vec::new();
609        let mut mode: Option<bool> = None;
610        while input.peek(Token![,]) {
611            input.parse::<Token![,]>()?;
612            if input.is_empty() {
613                break;
614            }
615            let is_named_param = {
616                let fork = input.fork();
617                fork.parse::<syn::Ident>().is_ok()
618                    && fork.parse::<Token![=]>().is_ok()
619                    && !fork.peek(Token![=])
620            };
621            if is_named_param && mode != Some(false) {
622                let name: syn::Ident = input.parse()?;
623                input.parse::<Token![=]>()?;
624                let expr: Expr = input.parse()?;
625                named.push((name, expr));
626                mode = Some(true);
627            } else if mode != Some(true) {
628                params.push(input.parse()?);
629                mode = Some(false);
630            } else {
631                return Err(input.error("cannot mix positional and named parameters"));
632            }
633        }
634        Ok(QueryAsInput {
635            target_type,
636            sql,
637            params,
638            named,
639        })
640    }
641}
642
643/// `query_file!("path/to/query.sql", param1, param2, ...)` — like query! but reads SQL from a file.
644#[proc_macro]
645pub fn query_file(input: TokenStream) -> TokenStream {
646    let QueryInput {
647        sql: path_lit,
648        params,
649        named: _,
650    } = parse_macro_input!(input as QueryInput);
651    let file_path = path_lit.value();
652
653    let sql_str = match read_sql_file(&file_path) {
654        Ok(s) => s,
655        Err(e) => {
656            return syn::Error::new_spanned(&path_lit, e)
657                .to_compile_error()
658                .into();
659        }
660    };
661
662    // Reuse the query! logic with the file contents.
663    let sql_lit = LitStr::new(&sql_str, path_lit.span());
664    let inner = QueryInput {
665        sql: sql_lit,
666        params,
667        named: Vec::new(),
668    };
669    query_impl(inner)
670}
671
672/// `query_file_as!(Type, "path/to/query.sql", param1, ...)` — like query_as! but reads SQL from a file.
673#[proc_macro]
674pub fn query_file_as(input: TokenStream) -> TokenStream {
675    let QueryAsInput {
676        target_type,
677        sql: path_lit,
678        params,
679        named: _,
680    } = parse_macro_input!(input as QueryAsInput);
681    let file_path = path_lit.value();
682
683    let sql_str = match read_sql_file(&file_path) {
684        Ok(s) => s,
685        Err(e) => {
686            return syn::Error::new_spanned(&path_lit, e)
687                .to_compile_error()
688                .into();
689        }
690    };
691
692    let sql_lit = LitStr::new(&sql_str, path_lit.span());
693    let inner = QueryAsInput {
694        target_type,
695        sql: sql_lit,
696        params,
697        named: Vec::new(),
698    };
699    query_as_impl(inner)
700}
701
702/// `query_file_scalar!("path/to/query.sql", param1, ...)` — file-based scalar query.
703#[proc_macro]
704pub fn query_file_scalar(input: TokenStream) -> TokenStream {
705    let QueryInput {
706        sql: path_lit,
707        params,
708        named: _,
709    } = parse_macro_input!(input as QueryInput);
710    let file_path = path_lit.value();
711
712    let sql_str = match read_sql_file(&file_path) {
713        Ok(s) => s,
714        Err(e) => {
715            return syn::Error::new_spanned(&path_lit, e)
716                .to_compile_error()
717                .into();
718        }
719    };
720
721    let sql_lit = LitStr::new(&sql_str, path_lit.span());
722    let inner = QueryInput {
723        sql: sql_lit,
724        params,
725        named: Vec::new(),
726    };
727    query_scalar_impl(inner)
728}
729
730/// `query_unchecked!("SQL", param1, ...)` — skip compile-time validation.
731/// Useful when DATABASE_URL is unavailable and no cache exists.
732/// Params are passed as-is; no type or count checking.
733#[proc_macro]
734pub fn query_unchecked(input: TokenStream) -> TokenStream {
735    let QueryInput {
736        sql,
737        params,
738        named: _,
739    } = parse_macro_input!(input as QueryInput);
740
741    let param_refs: Vec<_> = params
742        .iter()
743        .map(|p| quote! { &#p as &dyn resolute::SqlParam })
744        .collect();
745    let sql_literal = &sql;
746
747    let expanded = quote! {
748        {
749            resolute::UncheckedQuery {
750                sql: #sql_literal,
751                params: vec![#(#param_refs),*],
752            }
753        }
754    };
755
756    TokenStream::from(expanded)
757}
758
759/// Read a SQL file relative to CARGO_MANIFEST_DIR.
760fn read_sql_file(path: &str) -> Result<String, String> {
761    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into());
762    let full_path = std::path::Path::new(&manifest_dir).join(path);
763    std::fs::read_to_string(&full_path)
764        .map_err(|e| format!("Failed to read SQL file {}: {e}", full_path.display()))
765        .map(|s| s.trim().to_string())
766}
767
768/// Human-readable PG type name for error messages.
769#[allow(dead_code)]
770fn oid_to_type_name(oid: u32) -> &'static str {
771    match oid {
772        16 => "bool",
773        18 | 19 | 25 | 1042 | 1043 => "text",
774        20 => "int8",
775        21 => "int2",
776        23 => "int4",
777        26 => "oid",
778        700 => "float4",
779        701 => "float8",
780        17 => "bytea",
781        114 => "json",
782        869 => "inet",
783        1700 => "numeric",
784        3802 => "jsonb",
785        1082 => "date",
786        1083 => "time",
787        1114 => "timestamp",
788        1184 => "timestamptz",
789        2950 => "uuid",
790        // Array types
791        1000 => "bool[]",
792        1005 => "int2[]",
793        1007 => "int4[]",
794        1009 | 1015 => "text[]",
795        1016 => "int8[]",
796        1021 => "float4[]",
797        1022 => "float8[]",
798        1041 => "inet[]",
799        1115 => "timestamp[]",
800        1182 => "date[]",
801        1183 => "time[]",
802        1185 => "timestamptz[]",
803        1231 => "numeric[]",
804        2951 => "uuid[]",
805        3807 => "jsonb[]",
806        // Range types
807        3904 => "int4range",
808        3926 => "int8range",
809        3906 => "numrange",
810        3912 => "daterange",
811        3908 => "tsrange",
812        3910 => "tstzrange",
813        _ => "unknown",
814    }
815}
816
817fn sanitize_ident(name: &str) -> String {
818    let s: String = name
819        .chars()
820        .map(|c| {
821            if c.is_alphanumeric() || c == '_' {
822                c
823            } else {
824                '_'
825            }
826        })
827        .collect();
828    if s.starts_with(|c: char| c.is_ascii_digit()) {
829        format!("_{s}")
830    } else if s.is_empty() {
831        "column".to_string()
832    } else {
833        s
834    }
835}
836
837/// Parse a type override from a column name.
838///
839/// Supports `"col_name: RustType"` syntax (e.g., `"id: UserId"`).
840/// Skips `::` (PostgreSQL casts) to avoid false positives.
841/// Returns `(actual_name, Some(type))` if override found, or `(original_name, None)`.
842fn parse_type_override(column_name: &str) -> (String, Option<proc_macro2::TokenStream>) {
843    let bytes = column_name.as_bytes();
844    for (i, &b) in bytes.iter().enumerate() {
845        if b == b':' {
846            // Skip `::` (PostgreSQL cast syntax).
847            let prev_colon = i > 0 && bytes[i - 1] == b':';
848            let next_colon = i + 1 < bytes.len() && bytes[i + 1] == b':';
849            if prev_colon || next_colon {
850                continue;
851            }
852            let name = column_name[..i].trim();
853            let type_str = column_name[i + 1..].trim();
854            if !type_str.is_empty() {
855                if let Ok(ty) = syn::parse_str::<syn::Type>(type_str) {
856                    return (name.to_string(), Some(quote! { #ty }));
857                }
858            }
859        }
860    }
861    (column_name.to_string(), None)
862}
863
864/// FNV-1a hash.
865pub(crate) fn hash_sql(sql: &str) -> u64 {
866    let mut h: u64 = 0xcbf29ce484222325;
867    for b in sql.bytes() {
868        h ^= b as u64;
869        h = h.wrapping_mul(0x100000001b3);
870    }
871    h
872}
873
874/// Rewrite `:name` named params to `$N` positional params.
875///
876/// Honours PostgreSQL token boundaries so `:name` tokens inside comments,
877/// string literals, quoted identifiers, and dollar-quoted bodies are left
878/// alone. Duplicate names reuse the same positional index. Returns
879/// `(rewritten_sql, ordered_param_names)`.
880fn rewrite_named_params(sql: &str) -> (String, Vec<String>) {
881    let mut result = String::with_capacity(sql.len());
882    let mut names: Vec<String> = Vec::new();
883    let mut positions: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
884    let chars: Vec<char> = sql.chars().collect();
885    let len = chars.len();
886    let mut i = 0;
887
888    while i < len {
889        // -- line comment: passes through verbatim, but any `:name` inside
890        // is treated as text, not a placeholder.
891        if i + 1 < len && chars[i] == '-' && chars[i + 1] == '-' {
892            while i < len && chars[i] != '\n' {
893                result.push(chars[i]);
894                i += 1;
895            }
896            continue;
897        }
898
899        // /* block comment */: same deal, pass through unchanged.
900        if i + 1 < len && chars[i] == '/' && chars[i + 1] == '*' {
901            result.push('/');
902            result.push('*');
903            i += 2;
904            while i + 1 < len && !(chars[i] == '*' && chars[i + 1] == '/') {
905                result.push(chars[i]);
906                i += 1;
907            }
908            if i + 1 < len {
909                result.push('*');
910                result.push('/');
911                i += 2;
912            }
913            continue;
914        }
915
916        // 'string literal' with '' escaping.
917        if chars[i] == '\'' {
918            result.push('\'');
919            i += 1;
920            while i < len {
921                result.push(chars[i]);
922                if chars[i] == '\'' {
923                    if i + 1 < len && chars[i + 1] == '\'' {
924                        result.push('\'');
925                        i += 2;
926                    } else {
927                        i += 1;
928                        break;
929                    }
930                } else {
931                    i += 1;
932                }
933            }
934            continue;
935        }
936
937        // "quoted identifier": skip contents.
938        if chars[i] == '"' {
939            result.push('"');
940            i += 1;
941            while i < len {
942                result.push(chars[i]);
943                if chars[i] == '"' {
944                    i += 1;
945                    break;
946                }
947                i += 1;
948            }
949            continue;
950        }
951
952        // $tag$dollar-quoted body$tag$: skip contents (also handles $$...$$).
953        if chars[i] == '$' {
954            let tag_start = i;
955            i += 1;
956            while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
957                i += 1;
958            }
959            if i < len && chars[i] == '$' {
960                let tag: String = chars[tag_start..=i].iter().collect();
961                for c in tag.chars() {
962                    result.push(c);
963                }
964                i += 1;
965                let tag_chars: Vec<char> = tag.chars().collect();
966                let tag_len = tag_chars.len();
967                loop {
968                    if i >= len {
969                        break;
970                    }
971                    if chars[i] == '$' && i + tag_len <= len {
972                        let matches = chars[i..i + tag_len]
973                            .iter()
974                            .zip(tag_chars.iter())
975                            .all(|(a, b)| a == b);
976                        if matches {
977                            for c in &tag_chars {
978                                result.push(*c);
979                            }
980                            i += tag_len;
981                            break;
982                        }
983                    }
984                    result.push(chars[i]);
985                    i += 1;
986                }
987                continue;
988            } else {
989                // Bare `$` followed by non-tag (e.g. `$1` positional param).
990                i = tag_start;
991                result.push(chars[i]);
992                i += 1;
993                continue;
994            }
995        }
996
997        // :: cast: pass through.
998        if chars[i] == ':' && i + 1 < len && chars[i + 1] == ':' {
999            result.push(':');
1000            result.push(':');
1001            i += 2;
1002            continue;
1003        }
1004
1005        // :name — named parameter.
1006        if chars[i] == ':' && i + 1 < len && (chars[i + 1].is_alphabetic() || chars[i + 1] == '_') {
1007            i += 1;
1008            let start = i;
1009            while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
1010                i += 1;
1011            }
1012            let name: String = chars[start..i].iter().collect();
1013            let pos = if let Some(&existing) = positions.get(&name) {
1014                existing
1015            } else {
1016                names.push(name.clone());
1017                let pos = names.len();
1018                positions.insert(name, pos);
1019                pos
1020            };
1021            result.push('$');
1022            result.push_str(&pos.to_string());
1023            continue;
1024        }
1025
1026        result.push(chars[i]);
1027        i += 1;
1028    }
1029
1030    (result, names)
1031}
1032
1033fn parse_pg_uri(uri: &str) -> Option<(String, String, String, u16, String)> {
1034    let rest = uri
1035        .strip_prefix("postgres://")
1036        .or_else(|| uri.strip_prefix("postgresql://"))?;
1037    let (auth, hostdb) = rest.split_once('@').unwrap_or(("postgres:postgres", rest));
1038    let (user, password) = auth.split_once(':').unwrap_or((auth, ""));
1039    let (hostport, database) = hostdb.split_once('/').unwrap_or((hostdb, "postgres"));
1040    let (host, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
1041    let port: u16 = port_str.parse().unwrap_or(5432);
1042    Some((
1043        user.to_string(),
1044        password.to_string(),
1045        host.to_string(),
1046        port,
1047        database.to_string(),
1048    ))
1049}
1050
1051// ---------------------------------------------------------------------------
1052// Tests
1053// ---------------------------------------------------------------------------
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058
1059    // -- parse_type_override tests --
1060
1061    #[test]
1062    fn test_type_override_basic() {
1063        let (name, ty) = parse_type_override("id: UserId");
1064        assert_eq!(name, "id");
1065        assert!(ty.is_some());
1066    }
1067
1068    #[test]
1069    fn test_type_override_with_module_path() {
1070        let (name, ty) = parse_type_override("id: crate::types::UserId");
1071        assert_eq!(name, "id");
1072        assert!(ty.is_some());
1073    }
1074
1075    #[test]
1076    fn test_type_override_no_override() {
1077        let (name, ty) = parse_type_override("user_name");
1078        assert_eq!(name, "user_name");
1079        assert!(ty.is_none());
1080    }
1081
1082    #[test]
1083    fn test_type_override_skips_double_colon_cast() {
1084        let (name, ty) = parse_type_override("created_at::text");
1085        assert_eq!(name, "created_at::text");
1086        assert!(ty.is_none(), ":: should not trigger type override");
1087    }
1088
1089    #[test]
1090    fn test_type_override_invalid_type_string() {
1091        let (name, ty) = parse_type_override("col: 123invalid");
1092        assert_eq!(name, "col: 123invalid");
1093        assert!(ty.is_none(), "invalid Rust type should fall back");
1094    }
1095
1096    #[test]
1097    fn test_type_override_empty_after_colon() {
1098        let (name, ty) = parse_type_override("col:");
1099        assert_eq!(name, "col:");
1100        assert!(ty.is_none());
1101    }
1102
1103    #[test]
1104    fn test_type_override_with_spaces() {
1105        let (name, ty) = parse_type_override("  id  :  UserId  ");
1106        assert_eq!(name, "id");
1107        assert!(ty.is_some());
1108    }
1109
1110    #[test]
1111    fn test_type_override_option_type() {
1112        let (name, ty) = parse_type_override("email: Option<String>");
1113        assert_eq!(name, "email");
1114        assert!(ty.is_some());
1115    }
1116
1117    #[test]
1118    fn test_type_override_vec_type() {
1119        let (name, ty) = parse_type_override("tags: Vec<String>");
1120        assert_eq!(name, "tags");
1121        assert!(ty.is_some());
1122    }
1123
1124    // -- rewrite_named_params tests --
1125
1126    #[test]
1127    fn test_named_params_basic() {
1128        let (sql, names) = rewrite_named_params("SELECT :id, :name");
1129        assert_eq!(sql, "SELECT $1, $2");
1130        assert_eq!(names, vec!["id", "name"]);
1131    }
1132
1133    #[test]
1134    fn test_named_params_duplicate() {
1135        let (sql, names) = rewrite_named_params("SELECT :id WHERE :id > 0");
1136        assert_eq!(sql, "SELECT $1 WHERE $1 > 0");
1137        assert_eq!(names, vec!["id"]);
1138    }
1139
1140    #[test]
1141    fn test_named_params_with_cast() {
1142        let (sql, names) = rewrite_named_params("SELECT :val::int4");
1143        assert_eq!(sql, "SELECT $1::int4");
1144        assert_eq!(names, vec!["val"]);
1145    }
1146
1147    #[test]
1148    fn test_named_params_in_string_literal() {
1149        let (sql, names) = rewrite_named_params("SELECT ':not_a_param'");
1150        assert_eq!(sql, "SELECT ':not_a_param'");
1151        assert!(names.is_empty());
1152    }
1153
1154    #[test]
1155    fn test_named_params_empty() {
1156        let (sql, names) = rewrite_named_params("SELECT 1");
1157        assert_eq!(sql, "SELECT 1");
1158        assert!(names.is_empty());
1159    }
1160
1161    #[test]
1162    fn test_named_params_underscore_prefix() {
1163        let (sql, names) = rewrite_named_params("SELECT :_private");
1164        assert_eq!(sql, "SELECT $1");
1165        assert_eq!(names, vec!["_private"]);
1166    }
1167
1168    // -- parse_pg_uri tests --
1169
1170    #[test]
1171    fn test_parse_uri_full() {
1172        let (u, p, h, port, db) = parse_pg_uri("postgres://user:pass@host:1234/mydb").unwrap();
1173        assert_eq!(u, "user");
1174        assert_eq!(p, "pass");
1175        assert_eq!(h, "host");
1176        assert_eq!(port, 1234);
1177        assert_eq!(db, "mydb");
1178    }
1179
1180    #[test]
1181    fn test_parse_uri_defaults() {
1182        let (u, p, h, port, db) = parse_pg_uri("postgres://user:pass@localhost/mydb").unwrap();
1183        assert_eq!(h, "localhost");
1184        assert_eq!(port, 5432);
1185        assert_eq!(u, "user");
1186        assert_eq!(p, "pass");
1187        assert_eq!(db, "mydb");
1188    }
1189
1190    #[test]
1191    fn test_parse_uri_invalid() {
1192        assert!(parse_pg_uri("mysql://user:pass@host/db").is_none());
1193    }
1194
1195    #[test]
1196    fn test_parse_uri_postgresql_scheme() {
1197        let parsed = parse_pg_uri("postgresql://user:pass@host:5433/mydb").unwrap();
1198        assert_eq!(parsed.0, "user");
1199        assert_eq!(parsed.3, 5433);
1200        assert_eq!(parsed.4, "mydb");
1201    }
1202
1203    #[test]
1204    fn test_parse_uri_empty_password() {
1205        let parsed = parse_pg_uri("postgres://user@host/db").unwrap();
1206        assert_eq!(parsed.0, "user");
1207        assert_eq!(parsed.1, "");
1208        assert_eq!(parsed.2, "host");
1209    }
1210
1211    #[test]
1212    fn test_parse_uri_unset_database_defaults_to_postgres() {
1213        let parsed = parse_pg_uri("postgres://user:pass@host").unwrap();
1214        assert_eq!(parsed.4, "postgres");
1215    }
1216
1217    // -- rewrite_named_params: comments and dollar quotes --
1218
1219    #[test]
1220    fn test_named_params_line_comment_skipped() {
1221        let (sql, names) = rewrite_named_params("SELECT :id -- :bogus\nFROM t");
1222        assert_eq!(sql, "SELECT $1 -- :bogus\nFROM t");
1223        assert_eq!(names, vec!["id"]);
1224    }
1225
1226    #[test]
1227    fn test_named_params_block_comment_skipped() {
1228        let (sql, names) = rewrite_named_params("SELECT :id /* :bogus */ FROM t");
1229        assert_eq!(sql, "SELECT $1 /* :bogus */ FROM t");
1230        assert_eq!(names, vec!["id"]);
1231    }
1232
1233    #[test]
1234    fn test_named_params_dollar_quoted_body_skipped() {
1235        let (sql, names) = rewrite_named_params("SELECT $$ :ignored $$ WHERE id = :id");
1236        assert_eq!(sql, "SELECT $$ :ignored $$ WHERE id = $1");
1237        assert_eq!(names, vec!["id"]);
1238    }
1239
1240    #[test]
1241    fn test_named_params_tagged_dollar_quote_skipped() {
1242        let (sql, names) = rewrite_named_params("SELECT $tag$ :ignored $tag$ WHERE id = :id");
1243        assert_eq!(sql, "SELECT $tag$ :ignored $tag$ WHERE id = $1");
1244        assert_eq!(names, vec!["id"]);
1245    }
1246
1247    #[test]
1248    fn test_named_params_quoted_identifier_skipped() {
1249        let (sql, names) = rewrite_named_params(r#"SELECT ":col" FROM t WHERE id = :id"#);
1250        assert_eq!(sql, r#"SELECT ":col" FROM t WHERE id = $1"#);
1251        assert_eq!(names, vec!["id"]);
1252    }
1253
1254    #[test]
1255    fn test_named_params_positional_dollar_param_passthrough() {
1256        let (sql, names) = rewrite_named_params("SELECT $1, :id FROM t");
1257        assert_eq!(sql, "SELECT $1, $1 FROM t");
1258        assert_eq!(names, vec!["id"]);
1259    }
1260
1261    #[test]
1262    fn test_named_params_escaped_single_quote_inside_literal() {
1263        let (sql, names) = rewrite_named_params("SELECT 'it''s :nothing' , :real");
1264        assert_eq!(sql, "SELECT 'it''s :nothing' , $1");
1265        assert_eq!(names, vec!["real"]);
1266    }
1267
1268    // -- hash_sql --
1269
1270    #[test]
1271    fn test_hash_sql_stable() {
1272        let sql = "SELECT id FROM t WHERE x = $1";
1273        assert_eq!(hash_sql(sql), hash_sql(sql));
1274    }
1275
1276    #[test]
1277    fn test_hash_sql_differs_by_content() {
1278        assert_ne!(hash_sql("SELECT 1"), hash_sql("SELECT 2"));
1279    }
1280
1281    #[test]
1282    fn test_hash_sql_empty() {
1283        // FNV-1a offset basis for the empty input.
1284        assert_eq!(hash_sql(""), 0xcbf29ce484222325);
1285    }
1286
1287    // -- cache round-trip --
1288
1289    #[test]
1290    fn test_cache_roundtrip() {
1291        let tmp = std::env::temp_dir().join(format!(
1292            "resolute-macros-cache-{}",
1293            std::time::SystemTime::now()
1294                .duration_since(std::time::UNIX_EPOCH)
1295                .unwrap()
1296                .as_nanos()
1297        ));
1298        std::fs::create_dir_all(&tmp).unwrap();
1299        let path = tmp.join("query.json");
1300
1301        let entry = cache::CacheEntry {
1302            sql: "SELECT 1::int4 AS n".into(),
1303            hash: 0xdeadbeef_cafebabe,
1304            param_oids: vec![23, 25],
1305            columns: vec![cache::CachedColumn {
1306                name: "n".into(),
1307                type_oid: 23,
1308                nullable: true,
1309            }],
1310        };
1311
1312        let json = serde_json::to_string_pretty(&entry).unwrap();
1313        std::fs::write(&path, &json).unwrap();
1314        let raw = std::fs::read_to_string(&path).unwrap();
1315        let decoded: cache::CacheEntry = serde_json::from_str(&raw).unwrap();
1316
1317        assert_eq!(decoded.sql, entry.sql);
1318        assert_eq!(decoded.hash, entry.hash);
1319        assert_eq!(decoded.param_oids, entry.param_oids);
1320        assert_eq!(decoded.columns.len(), 1);
1321        assert_eq!(decoded.columns[0].name, "n");
1322        assert_eq!(decoded.columns[0].type_oid, 23);
1323        assert!(decoded.columns[0].nullable);
1324
1325        std::fs::remove_dir_all(&tmp).ok();
1326    }
1327
1328    #[test]
1329    fn test_cache_entry_missing_nullable_defaults_to_false() {
1330        // Old cache files written before the `nullable` field existed must
1331        // still deserialize — the field is `#[serde(default)]`.
1332        let legacy = r#"{
1333            "sql": "SELECT 1",
1334            "hash": 1,
1335            "param_oids": [],
1336            "columns": [{"name": "n", "type_oid": 23}]
1337        }"#;
1338        let entry: cache::CacheEntry = serde_json::from_str(legacy).unwrap();
1339        assert_eq!(entry.columns.len(), 1);
1340        assert!(!entry.columns[0].nullable);
1341    }
1342}