1use proc_macro::TokenStream;
19use quote::{format_ident, quote};
20use syn::parse::{Parse, ParseStream};
21use syn::{parse_macro_input, Expr, LitStr, Token};
22
23mod cache;
24
25struct QueryInput {
28 sql: LitStr,
29 params: Vec<Expr>,
30 named: Vec<(syn::Ident, Expr)>,
32}
33
34impl Parse for QueryInput {
35 fn parse(input: ParseStream) -> syn::Result<Self> {
36 let sql: LitStr = input.parse()?;
37 let mut params = Vec::new();
38 let mut named = Vec::new();
39 let mut mode: Option<bool> = None; while input.peek(Token![,]) {
42 input.parse::<Token![,]>()?;
43 if input.is_empty() {
44 break;
45 }
46
47 let is_named_param = {
49 let fork = input.fork();
50 fork.parse::<syn::Ident>().is_ok()
51 && fork.parse::<Token![=]>().is_ok()
52 && !fork.peek(Token![=])
53 };
54
55 if is_named_param && mode != Some(false) {
56 let name: syn::Ident = input.parse()?;
57 input.parse::<Token![=]>()?;
58 let expr: Expr = input.parse()?;
59 named.push((name, expr));
60 mode = Some(true);
61 } else if mode != Some(true) {
62 params.push(input.parse()?);
63 mode = Some(false);
64 } else {
65 return Err(input.error("cannot mix positional and named parameters"));
66 }
67 }
68
69 Ok(QueryInput { sql, params, named })
70 }
71}
72
73fn resolve_named(
76 sql_str: String,
77 params: Vec<Expr>,
78 named: &[(syn::Ident, Expr)],
79 sql_span: &LitStr,
80) -> Result<(String, Vec<Expr>), TokenStream> {
81 if named.is_empty() {
82 return Ok((sql_str, params));
83 }
84 let (rewritten, names) = rewrite_named_params(&sql_str);
85 let mut ordered = Vec::with_capacity(names.len());
86 for name in &names {
87 match named.iter().find(|(n, _)| n == name) {
88 Some((_, expr)) => ordered.push(expr.clone()),
89 None => {
90 let msg = format!("named parameter `:{name}` in SQL has no binding");
91 return Err(syn::Error::new_spanned(sql_span, msg)
92 .to_compile_error()
93 .into());
94 }
95 }
96 }
97 for (n, _) in named {
98 if !names.iter().any(|name| name == &n.to_string()) {
99 let msg = format!("binding `{}` does not match any `:{}` in SQL", n, n);
100 return Err(syn::Error::new_spanned(n, msg).to_compile_error().into());
101 }
102 }
103 Ok((rewritten, ordered))
104}
105
106fn resolve_metadata(sql: &str) -> Result<(Vec<u32>, Vec<cache::CachedColumn>), String> {
108 let sql_hash = hash_sql(sql);
109 let offline = std::env::var("RESOLUTE_OFFLINE")
110 .map(|v| v == "true" || v == "1")
111 .unwrap_or(false);
112
113 if let Some(cached) = cache::read_cache(sql_hash) {
115 return Ok((cached.param_oids, cached.columns));
116 }
117
118 if offline {
120 return Err(format!(
121 "RESOLUTE_OFFLINE=true but no cached metadata for query (hash {sql_hash:x}). \
122 Run `resolute-cli prepare` to populate the cache."
123 ));
124 }
125
126 let (param_oids, columns) = describe_live(sql)?;
128
129 let entry = cache::CacheEntry {
131 sql: sql.to_string(),
132 hash: sql_hash,
133 param_oids: param_oids.clone(),
134 columns: columns.clone(),
135 };
136 if let Err(e) = cache::write_cache(&entry) {
137 eprintln!("resolute: warning: failed to write cache: {e}");
139 }
140
141 Ok((param_oids, columns))
142}
143
144fn describe_live(sql: &str) -> Result<(Vec<u32>, Vec<cache::CachedColumn>), String> {
146 let db_url = std::env::var("DATABASE_URL").map_err(|_| {
147 "DATABASE_URL not set and no cached metadata found. \
148 Set DATABASE_URL or run `resolute-cli prepare`."
149 .to_string()
150 })?;
151
152 let (user, password, host, port, database) = parse_pg_uri(&db_url)
153 .ok_or_else(|| "Invalid DATABASE_URL (could not parse as postgres:// URI)".to_string())?;
154 let addr = format!("{host}:{port}");
155
156 let rt = tokio::runtime::Builder::new_current_thread()
157 .enable_all()
158 .build()
159 .map_err(|e| format!("Failed to create tokio runtime: {e}"))?;
160
161 rt.block_on(async {
162 let mut conn = pg_wired::WireConn::connect(&addr, &user, &password, &database)
163 .await
164 .map_err(|e| format!("Failed to connect to database: {e}"))?;
165
166 let (param_oids, fields) = conn
167 .describe_statement(sql)
168 .await
169 .map_err(|e| format!("SQL error: {e}"))?;
170
171 let mut columns: Vec<cache::CachedColumn> = fields
174 .iter()
175 .map(|f| cache::CachedColumn {
176 name: f.name.clone(),
177 type_oid: f.type_oid,
178 nullable: true, })
180 .collect();
181
182 let table_cols: Vec<(usize, u32, i16)> = fields
184 .iter()
185 .enumerate()
186 .filter(|(_, f)| f.table_oid != 0 && f.column_id > 0)
187 .map(|(i, f)| (i, f.table_oid, f.column_id))
188 .collect();
189
190 if !table_cols.is_empty() {
191 let conditions: Vec<String> = table_cols
193 .iter()
194 .map(|(_, oid, col)| format!("(attrelid={oid} AND attnum={col})"))
195 .collect();
196 let null_sql = format!(
197 "SELECT attrelid, attnum, attnotnull FROM pg_attribute WHERE {}",
198 conditions.join(" OR ")
199 );
200
201 let mut buf = bytes::BytesMut::new();
203 pg_wired::protocol::frontend::encode_message(
204 &pg_wired::protocol::types::FrontendMsg::Query(null_sql.as_bytes()),
205 &mut buf,
206 );
207 if conn.send_raw(&buf).await.is_ok() {
208 if let Ok((rows, _)) = conn.collect_rows().await {
209 for row in &rows {
210 let oid: u32 = row
211 .cell(0)
212 .and_then(|b| std::str::from_utf8(b).ok())
213 .and_then(|s| s.parse().ok())
214 .unwrap_or(0);
215 let col: i16 = row
216 .cell(1)
217 .and_then(|b| std::str::from_utf8(b).ok())
218 .and_then(|s| s.parse().ok())
219 .unwrap_or(0);
220 let notnull: bool =
221 row.cell(2).map(|b| b == b"t".as_ref()).unwrap_or(false);
222
223 for &(idx, t_oid, t_col) in &table_cols {
225 if t_oid == oid && t_col == col && notnull {
226 columns[idx].nullable = false;
227 }
228 }
229 }
230 }
231 }
232 }
233
234 Ok((param_oids, columns))
235 })
236}
237
238fn oid_to_rust_type(oid: u32) -> Result<proc_macro2::TokenStream, String> {
244 let ty = match oid {
245 16 => quote! { bool },
247 18 | 19 | 25 | 1042 | 1043 => quote! { String },
248 20 => quote! { i64 },
249 21 => quote! { i16 },
250 23 | 26 => quote! { i32 },
251 700 => quote! { f32 },
252 701 => quote! { f64 },
253 17 => quote! { Vec<u8> },
254 869 => quote! { resolute::PgInet },
255 1700 => quote! { resolute::PgNumeric },
256 1000 => quote! { Vec<bool> },
258 1005 => quote! { Vec<i16> },
259 1007 => quote! { Vec<i32> },
260 1009 | 1015 => quote! { Vec<String> },
261 1016 => quote! { Vec<i64> },
262 1021 => quote! { Vec<f32> },
263 1022 => quote! { Vec<f64> },
264 1041 => quote! { Vec<resolute::PgInet> },
265 1231 => quote! { Vec<resolute::PgNumeric> },
266 3904 => quote! { resolute::PgRange<i32> },
268 3926 => quote! { resolute::PgRange<i64> },
269 3906 => quote! { resolute::PgRange<resolute::PgNumeric> },
270 #[cfg(feature = "json")]
272 114 | 3802 => quote! { serde_json::Value },
273 #[cfg(feature = "json")]
274 3807 => quote! { Vec<serde_json::Value> },
275 #[cfg(not(feature = "json"))]
276 114 | 3802 | 3807 => {
277 return Err(format!(
278 "column type `{}` requires the `json` feature, which is disabled. \
279 Enable `resolute/json` in your Cargo.toml to use JSON/JSONB columns.",
280 oid_to_type_name(oid)
281 ));
282 }
283 #[cfg(feature = "chrono")]
285 1082 => quote! { chrono::NaiveDate },
286 #[cfg(feature = "chrono")]
287 1083 => quote! { chrono::NaiveTime },
288 #[cfg(feature = "chrono")]
289 1114 => quote! { chrono::NaiveDateTime },
290 #[cfg(feature = "chrono")]
291 1184 => quote! { chrono::DateTime<chrono::Utc> },
292 #[cfg(feature = "chrono")]
293 1115 => quote! { Vec<chrono::NaiveDateTime> },
294 #[cfg(feature = "chrono")]
295 1182 => quote! { Vec<chrono::NaiveDate> },
296 #[cfg(feature = "chrono")]
297 1183 => quote! { Vec<chrono::NaiveTime> },
298 #[cfg(feature = "chrono")]
299 1185 => quote! { Vec<chrono::DateTime<chrono::Utc>> },
300 #[cfg(feature = "chrono")]
301 3912 => quote! { resolute::PgRange<chrono::NaiveDate> },
302 #[cfg(feature = "chrono")]
303 3908 => quote! { resolute::PgRange<chrono::NaiveDateTime> },
304 #[cfg(feature = "chrono")]
305 3910 => quote! { resolute::PgRange<chrono::DateTime<chrono::Utc>> },
306 #[cfg(not(feature = "chrono"))]
307 1082 | 1083 | 1114 | 1184 | 1115 | 1182 | 1183 | 1185 | 3912 | 3908 | 3910 => {
308 return Err(format!(
309 "column type `{}` requires the `chrono` feature, which is disabled. \
310 Enable `resolute/chrono` in your Cargo.toml to use date/time columns.",
311 oid_to_type_name(oid)
312 ));
313 }
314 #[cfg(feature = "uuid")]
316 2950 => quote! { uuid::Uuid },
317 #[cfg(feature = "uuid")]
318 2951 => quote! { Vec<uuid::Uuid> },
319 #[cfg(not(feature = "uuid"))]
320 2950 | 2951 => {
321 return Err(format!(
322 "column type `{}` requires the `uuid` feature, which is disabled. \
323 Enable `resolute/uuid` in your Cargo.toml to use UUID columns.",
324 oid_to_type_name(oid)
325 ));
326 }
327 _ => quote! { Vec<u8> },
328 };
329 Ok(ty)
330}
331
332#[proc_macro]
334pub fn query(input: TokenStream) -> TokenStream {
335 let parsed = parse_macro_input!(input as QueryInput);
336 query_impl(parsed)
337}
338
339fn query_impl(input: QueryInput) -> TokenStream {
340 let QueryInput { sql, params, named } = input;
341 let sql_str = sql.value();
342
343 let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
344 Ok(v) => v,
345 Err(ts) => return ts,
346 };
347
348 let (param_oids, column_infos) = match resolve_metadata(&sql_str) {
349 Ok(result) => result,
350 Err(e) => {
351 return syn::Error::new_spanned(&sql, e).to_compile_error().into();
352 }
353 };
354
355 if params.len() != param_oids.len() {
356 let msg = format!(
357 "expected {} parameter(s), got {}",
358 param_oids.len(),
359 params.len()
360 );
361 return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
362 }
363
364 let param_type_checks: Vec<_> = params
368 .iter()
369 .map(|param| {
370 quote! {
371 {
372 fn __resolute_check_param<T: resolute::Encode + Sync>(_: &T) {}
373 __resolute_check_param(&#param);
374 let _ = &#param as &dyn resolute::SqlParam;
375 }
376 }
377 })
378 .collect();
379
380 let overrides: Vec<_> = column_infos
382 .iter()
383 .map(|c| parse_type_override(&c.name))
384 .collect();
385
386 let field_names: Vec<_> = overrides
387 .iter()
388 .map(|(name, _)| format_ident!("{}", sanitize_ident(name)))
389 .collect();
390 let field_types: Vec<_> = match column_infos
391 .iter()
392 .zip(overrides.iter())
393 .map(
394 |(c, (_, type_override))| -> Result<proc_macro2::TokenStream, String> {
395 let base = if let Some(ref custom) = type_override {
396 custom.clone()
397 } else {
398 oid_to_rust_type(c.type_oid)?
399 };
400 Ok(if c.nullable {
401 quote! { Option<#base> }
402 } else {
403 base
404 })
405 },
406 )
407 .collect::<Result<Vec<_>, String>>()
408 {
409 Ok(v) => v,
410 Err(e) => return syn::Error::new_spanned(&sql, e).to_compile_error().into(),
411 };
412 let _field_indices: Vec<_> = (0..column_infos.len()).collect::<Vec<_>>();
413 let field_getters: Vec<_> = column_infos
414 .iter()
415 .enumerate()
416 .map(|(i, c)| {
417 if c.nullable {
418 quote! { row.get_opt(#i)? }
419 } else {
420 quote! { row.get(#i)? }
421 }
422 })
423 .collect();
424
425 let struct_name = format_ident!("__QueryResult_{}", hash_sql(&sql_str));
426
427 let param_refs: Vec<_> = params
428 .iter()
429 .map(|p| quote! { &#p as &dyn resolute::SqlParam })
430 .collect();
431
432 let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
434
435 let expanded = quote! {
436 {
437 #(#param_type_checks)*
439
440 #[allow(non_camel_case_types)]
441 #[derive(Debug)]
442 struct #struct_name {
443 #(pub #field_names: #field_types,)*
444 }
445
446 resolute::CheckedQuery::<#struct_name> {
447 sql: #sql_lit_rewritten,
448 params: vec![#(#param_refs),*],
449 _marker: std::marker::PhantomData,
450 mapper: |row: &resolute::Row| -> Result<#struct_name, resolute::TypedError> {
451 Ok(#struct_name {
452 #(#field_names: #field_getters,)*
453 })
454 },
455 }
456 }
457 };
458
459 TokenStream::from(expanded)
460}
461
462#[proc_macro]
465pub fn query_as(input: TokenStream) -> TokenStream {
466 let parsed = parse_macro_input!(input as QueryAsInput);
467 query_as_impl(parsed)
468}
469
470fn query_as_impl(input: QueryAsInput) -> TokenStream {
471 let QueryAsInput {
472 target_type,
473 sql,
474 params,
475 named,
476 } = input;
477 let sql_str = sql.value();
478
479 let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
480 Ok(v) => v,
481 Err(ts) => return ts,
482 };
483
484 let (param_oids, _column_infos) = match resolve_metadata(&sql_str) {
485 Ok(result) => result,
486 Err(e) => {
487 return syn::Error::new_spanned(&sql, e).to_compile_error().into();
488 }
489 };
490
491 if params.len() != param_oids.len() {
492 let msg = format!(
493 "expected {} parameter(s), got {}",
494 param_oids.len(),
495 params.len()
496 );
497 return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
498 }
499
500 let param_refs: Vec<_> = params
501 .iter()
502 .map(|p| quote! { &#p as &dyn resolute::SqlParam })
503 .collect();
504 let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
505
506 let expanded = quote! {
507 {
508 resolute::CheckedQuery::<#target_type> {
509 sql: #sql_lit_rewritten,
510 params: vec![#(#param_refs),*],
511 _marker: std::marker::PhantomData,
512 mapper: |row: &resolute::Row| -> Result<#target_type, resolute::TypedError> {
513 <#target_type as resolute::FromRow>::from_row(row)
514 },
515 }
516 }
517 };
518
519 TokenStream::from(expanded)
520}
521
522#[proc_macro]
524pub fn query_scalar(input: TokenStream) -> TokenStream {
525 let parsed = parse_macro_input!(input as QueryInput);
526 query_scalar_impl(parsed)
527}
528
529fn query_scalar_impl(input: QueryInput) -> TokenStream {
530 let QueryInput { sql, params, named } = input;
531 let sql_str = sql.value();
532
533 let (sql_str, params) = match resolve_named(sql_str, params, &named, &sql) {
534 Ok(v) => v,
535 Err(ts) => return ts,
536 };
537
538 let (param_oids, column_infos) = match resolve_metadata(&sql_str) {
539 Ok(result) => result,
540 Err(e) => {
541 return syn::Error::new_spanned(&sql, e).to_compile_error().into();
542 }
543 };
544
545 if params.len() != param_oids.len() {
546 let msg = format!(
547 "expected {} parameter(s), got {}",
548 param_oids.len(),
549 params.len()
550 );
551 return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
552 }
553
554 if column_infos.len() != 1 {
555 let msg = format!(
556 "query_scalar! requires exactly 1 column, got {}",
557 column_infos.len()
558 );
559 return syn::Error::new_spanned(&sql, msg).to_compile_error().into();
560 }
561
562 let scalar_type = {
563 let (_, type_override) = parse_type_override(&column_infos[0].name);
564 match type_override {
565 Some(ty) => ty,
566 None => match oid_to_rust_type(column_infos[0].type_oid) {
567 Ok(ty) => ty,
568 Err(e) => return syn::Error::new_spanned(&sql, e).to_compile_error().into(),
569 },
570 }
571 };
572 let param_refs: Vec<_> = params
573 .iter()
574 .map(|p| quote! { &#p as &dyn resolute::SqlParam })
575 .collect();
576 let sql_lit_rewritten = LitStr::new(&sql_str, sql.span());
577
578 let expanded = quote! {
579 {
580 resolute::CheckedQuery::<#scalar_type> {
581 sql: #sql_lit_rewritten,
582 params: vec![#(#param_refs),*],
583 _marker: std::marker::PhantomData,
584 mapper: |row: &resolute::Row| -> Result<#scalar_type, resolute::TypedError> {
585 row.get(0)
586 },
587 }
588 }
589 };
590
591 TokenStream::from(expanded)
592}
593
594struct QueryAsInput {
596 target_type: syn::Type,
597 sql: LitStr,
598 params: Vec<Expr>,
599 named: Vec<(syn::Ident, Expr)>,
600}
601
602impl Parse for QueryAsInput {
603 fn parse(input: ParseStream) -> syn::Result<Self> {
604 let target_type: syn::Type = input.parse()?;
605 input.parse::<Token![,]>()?;
606 let sql: LitStr = input.parse()?;
607 let mut params = Vec::new();
608 let mut named = Vec::new();
609 let mut mode: Option<bool> = None;
610 while input.peek(Token![,]) {
611 input.parse::<Token![,]>()?;
612 if input.is_empty() {
613 break;
614 }
615 let is_named_param = {
616 let fork = input.fork();
617 fork.parse::<syn::Ident>().is_ok()
618 && fork.parse::<Token![=]>().is_ok()
619 && !fork.peek(Token![=])
620 };
621 if is_named_param && mode != Some(false) {
622 let name: syn::Ident = input.parse()?;
623 input.parse::<Token![=]>()?;
624 let expr: Expr = input.parse()?;
625 named.push((name, expr));
626 mode = Some(true);
627 } else if mode != Some(true) {
628 params.push(input.parse()?);
629 mode = Some(false);
630 } else {
631 return Err(input.error("cannot mix positional and named parameters"));
632 }
633 }
634 Ok(QueryAsInput {
635 target_type,
636 sql,
637 params,
638 named,
639 })
640 }
641}
642
643#[proc_macro]
645pub fn query_file(input: TokenStream) -> TokenStream {
646 let QueryInput {
647 sql: path_lit,
648 params,
649 named: _,
650 } = parse_macro_input!(input as QueryInput);
651 let file_path = path_lit.value();
652
653 let sql_str = match read_sql_file(&file_path) {
654 Ok(s) => s,
655 Err(e) => {
656 return syn::Error::new_spanned(&path_lit, e)
657 .to_compile_error()
658 .into();
659 }
660 };
661
662 let sql_lit = LitStr::new(&sql_str, path_lit.span());
664 let inner = QueryInput {
665 sql: sql_lit,
666 params,
667 named: Vec::new(),
668 };
669 query_impl(inner)
670}
671
672#[proc_macro]
674pub fn query_file_as(input: TokenStream) -> TokenStream {
675 let QueryAsInput {
676 target_type,
677 sql: path_lit,
678 params,
679 named: _,
680 } = parse_macro_input!(input as QueryAsInput);
681 let file_path = path_lit.value();
682
683 let sql_str = match read_sql_file(&file_path) {
684 Ok(s) => s,
685 Err(e) => {
686 return syn::Error::new_spanned(&path_lit, e)
687 .to_compile_error()
688 .into();
689 }
690 };
691
692 let sql_lit = LitStr::new(&sql_str, path_lit.span());
693 let inner = QueryAsInput {
694 target_type,
695 sql: sql_lit,
696 params,
697 named: Vec::new(),
698 };
699 query_as_impl(inner)
700}
701
702#[proc_macro]
704pub fn query_file_scalar(input: TokenStream) -> TokenStream {
705 let QueryInput {
706 sql: path_lit,
707 params,
708 named: _,
709 } = parse_macro_input!(input as QueryInput);
710 let file_path = path_lit.value();
711
712 let sql_str = match read_sql_file(&file_path) {
713 Ok(s) => s,
714 Err(e) => {
715 return syn::Error::new_spanned(&path_lit, e)
716 .to_compile_error()
717 .into();
718 }
719 };
720
721 let sql_lit = LitStr::new(&sql_str, path_lit.span());
722 let inner = QueryInput {
723 sql: sql_lit,
724 params,
725 named: Vec::new(),
726 };
727 query_scalar_impl(inner)
728}
729
730#[proc_macro]
734pub fn query_unchecked(input: TokenStream) -> TokenStream {
735 let QueryInput {
736 sql,
737 params,
738 named: _,
739 } = parse_macro_input!(input as QueryInput);
740
741 let param_refs: Vec<_> = params
742 .iter()
743 .map(|p| quote! { &#p as &dyn resolute::SqlParam })
744 .collect();
745 let sql_literal = &sql;
746
747 let expanded = quote! {
748 {
749 resolute::UncheckedQuery {
750 sql: #sql_literal,
751 params: vec![#(#param_refs),*],
752 }
753 }
754 };
755
756 TokenStream::from(expanded)
757}
758
759fn read_sql_file(path: &str) -> Result<String, String> {
761 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into());
762 let full_path = std::path::Path::new(&manifest_dir).join(path);
763 std::fs::read_to_string(&full_path)
764 .map_err(|e| format!("Failed to read SQL file {}: {e}", full_path.display()))
765 .map(|s| s.trim().to_string())
766}
767
768#[allow(dead_code)]
770fn oid_to_type_name(oid: u32) -> &'static str {
771 match oid {
772 16 => "bool",
773 18 | 19 | 25 | 1042 | 1043 => "text",
774 20 => "int8",
775 21 => "int2",
776 23 => "int4",
777 26 => "oid",
778 700 => "float4",
779 701 => "float8",
780 17 => "bytea",
781 114 => "json",
782 869 => "inet",
783 1700 => "numeric",
784 3802 => "jsonb",
785 1082 => "date",
786 1083 => "time",
787 1114 => "timestamp",
788 1184 => "timestamptz",
789 2950 => "uuid",
790 1000 => "bool[]",
792 1005 => "int2[]",
793 1007 => "int4[]",
794 1009 | 1015 => "text[]",
795 1016 => "int8[]",
796 1021 => "float4[]",
797 1022 => "float8[]",
798 1041 => "inet[]",
799 1115 => "timestamp[]",
800 1182 => "date[]",
801 1183 => "time[]",
802 1185 => "timestamptz[]",
803 1231 => "numeric[]",
804 2951 => "uuid[]",
805 3807 => "jsonb[]",
806 3904 => "int4range",
808 3926 => "int8range",
809 3906 => "numrange",
810 3912 => "daterange",
811 3908 => "tsrange",
812 3910 => "tstzrange",
813 _ => "unknown",
814 }
815}
816
817fn sanitize_ident(name: &str) -> String {
818 let s: String = name
819 .chars()
820 .map(|c| {
821 if c.is_alphanumeric() || c == '_' {
822 c
823 } else {
824 '_'
825 }
826 })
827 .collect();
828 if s.starts_with(|c: char| c.is_ascii_digit()) {
829 format!("_{s}")
830 } else if s.is_empty() {
831 "column".to_string()
832 } else {
833 s
834 }
835}
836
837fn parse_type_override(column_name: &str) -> (String, Option<proc_macro2::TokenStream>) {
843 let bytes = column_name.as_bytes();
844 for (i, &b) in bytes.iter().enumerate() {
845 if b == b':' {
846 let prev_colon = i > 0 && bytes[i - 1] == b':';
848 let next_colon = i + 1 < bytes.len() && bytes[i + 1] == b':';
849 if prev_colon || next_colon {
850 continue;
851 }
852 let name = column_name[..i].trim();
853 let type_str = column_name[i + 1..].trim();
854 if !type_str.is_empty() {
855 if let Ok(ty) = syn::parse_str::<syn::Type>(type_str) {
856 return (name.to_string(), Some(quote! { #ty }));
857 }
858 }
859 }
860 }
861 (column_name.to_string(), None)
862}
863
864pub(crate) fn hash_sql(sql: &str) -> u64 {
866 let mut h: u64 = 0xcbf29ce484222325;
867 for b in sql.bytes() {
868 h ^= b as u64;
869 h = h.wrapping_mul(0x100000001b3);
870 }
871 h
872}
873
874fn rewrite_named_params(sql: &str) -> (String, Vec<String>) {
881 let mut result = String::with_capacity(sql.len());
882 let mut names: Vec<String> = Vec::new();
883 let mut positions: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
884 let chars: Vec<char> = sql.chars().collect();
885 let len = chars.len();
886 let mut i = 0;
887
888 while i < len {
889 if i + 1 < len && chars[i] == '-' && chars[i + 1] == '-' {
892 while i < len && chars[i] != '\n' {
893 result.push(chars[i]);
894 i += 1;
895 }
896 continue;
897 }
898
899 if i + 1 < len && chars[i] == '/' && chars[i + 1] == '*' {
901 result.push('/');
902 result.push('*');
903 i += 2;
904 while i + 1 < len && !(chars[i] == '*' && chars[i + 1] == '/') {
905 result.push(chars[i]);
906 i += 1;
907 }
908 if i + 1 < len {
909 result.push('*');
910 result.push('/');
911 i += 2;
912 }
913 continue;
914 }
915
916 if chars[i] == '\'' {
918 result.push('\'');
919 i += 1;
920 while i < len {
921 result.push(chars[i]);
922 if chars[i] == '\'' {
923 if i + 1 < len && chars[i + 1] == '\'' {
924 result.push('\'');
925 i += 2;
926 } else {
927 i += 1;
928 break;
929 }
930 } else {
931 i += 1;
932 }
933 }
934 continue;
935 }
936
937 if chars[i] == '"' {
939 result.push('"');
940 i += 1;
941 while i < len {
942 result.push(chars[i]);
943 if chars[i] == '"' {
944 i += 1;
945 break;
946 }
947 i += 1;
948 }
949 continue;
950 }
951
952 if chars[i] == '$' {
954 let tag_start = i;
955 i += 1;
956 while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
957 i += 1;
958 }
959 if i < len && chars[i] == '$' {
960 let tag: String = chars[tag_start..=i].iter().collect();
961 for c in tag.chars() {
962 result.push(c);
963 }
964 i += 1;
965 let tag_chars: Vec<char> = tag.chars().collect();
966 let tag_len = tag_chars.len();
967 loop {
968 if i >= len {
969 break;
970 }
971 if chars[i] == '$' && i + tag_len <= len {
972 let matches = chars[i..i + tag_len]
973 .iter()
974 .zip(tag_chars.iter())
975 .all(|(a, b)| a == b);
976 if matches {
977 for c in &tag_chars {
978 result.push(*c);
979 }
980 i += tag_len;
981 break;
982 }
983 }
984 result.push(chars[i]);
985 i += 1;
986 }
987 continue;
988 } else {
989 i = tag_start;
991 result.push(chars[i]);
992 i += 1;
993 continue;
994 }
995 }
996
997 if chars[i] == ':' && i + 1 < len && chars[i + 1] == ':' {
999 result.push(':');
1000 result.push(':');
1001 i += 2;
1002 continue;
1003 }
1004
1005 if chars[i] == ':' && i + 1 < len && (chars[i + 1].is_alphabetic() || chars[i + 1] == '_') {
1007 i += 1;
1008 let start = i;
1009 while i < len && (chars[i].is_alphanumeric() || chars[i] == '_') {
1010 i += 1;
1011 }
1012 let name: String = chars[start..i].iter().collect();
1013 let pos = if let Some(&existing) = positions.get(&name) {
1014 existing
1015 } else {
1016 names.push(name.clone());
1017 let pos = names.len();
1018 positions.insert(name, pos);
1019 pos
1020 };
1021 result.push('$');
1022 result.push_str(&pos.to_string());
1023 continue;
1024 }
1025
1026 result.push(chars[i]);
1027 i += 1;
1028 }
1029
1030 (result, names)
1031}
1032
1033fn parse_pg_uri(uri: &str) -> Option<(String, String, String, u16, String)> {
1034 let rest = uri
1035 .strip_prefix("postgres://")
1036 .or_else(|| uri.strip_prefix("postgresql://"))?;
1037 let (auth, hostdb) = rest.split_once('@').unwrap_or(("postgres:postgres", rest));
1038 let (user, password) = auth.split_once(':').unwrap_or((auth, ""));
1039 let (hostport, database) = hostdb.split_once('/').unwrap_or((hostdb, "postgres"));
1040 let (host, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
1041 let port: u16 = port_str.parse().unwrap_or(5432);
1042 Some((
1043 user.to_string(),
1044 password.to_string(),
1045 host.to_string(),
1046 port,
1047 database.to_string(),
1048 ))
1049}
1050
1051#[cfg(test)]
1056mod tests {
1057 use super::*;
1058
1059 #[test]
1062 fn test_type_override_basic() {
1063 let (name, ty) = parse_type_override("id: UserId");
1064 assert_eq!(name, "id");
1065 assert!(ty.is_some());
1066 }
1067
1068 #[test]
1069 fn test_type_override_with_module_path() {
1070 let (name, ty) = parse_type_override("id: crate::types::UserId");
1071 assert_eq!(name, "id");
1072 assert!(ty.is_some());
1073 }
1074
1075 #[test]
1076 fn test_type_override_no_override() {
1077 let (name, ty) = parse_type_override("user_name");
1078 assert_eq!(name, "user_name");
1079 assert!(ty.is_none());
1080 }
1081
1082 #[test]
1083 fn test_type_override_skips_double_colon_cast() {
1084 let (name, ty) = parse_type_override("created_at::text");
1085 assert_eq!(name, "created_at::text");
1086 assert!(ty.is_none(), ":: should not trigger type override");
1087 }
1088
1089 #[test]
1090 fn test_type_override_invalid_type_string() {
1091 let (name, ty) = parse_type_override("col: 123invalid");
1092 assert_eq!(name, "col: 123invalid");
1093 assert!(ty.is_none(), "invalid Rust type should fall back");
1094 }
1095
1096 #[test]
1097 fn test_type_override_empty_after_colon() {
1098 let (name, ty) = parse_type_override("col:");
1099 assert_eq!(name, "col:");
1100 assert!(ty.is_none());
1101 }
1102
1103 #[test]
1104 fn test_type_override_with_spaces() {
1105 let (name, ty) = parse_type_override(" id : UserId ");
1106 assert_eq!(name, "id");
1107 assert!(ty.is_some());
1108 }
1109
1110 #[test]
1111 fn test_type_override_option_type() {
1112 let (name, ty) = parse_type_override("email: Option<String>");
1113 assert_eq!(name, "email");
1114 assert!(ty.is_some());
1115 }
1116
1117 #[test]
1118 fn test_type_override_vec_type() {
1119 let (name, ty) = parse_type_override("tags: Vec<String>");
1120 assert_eq!(name, "tags");
1121 assert!(ty.is_some());
1122 }
1123
1124 #[test]
1127 fn test_named_params_basic() {
1128 let (sql, names) = rewrite_named_params("SELECT :id, :name");
1129 assert_eq!(sql, "SELECT $1, $2");
1130 assert_eq!(names, vec!["id", "name"]);
1131 }
1132
1133 #[test]
1134 fn test_named_params_duplicate() {
1135 let (sql, names) = rewrite_named_params("SELECT :id WHERE :id > 0");
1136 assert_eq!(sql, "SELECT $1 WHERE $1 > 0");
1137 assert_eq!(names, vec!["id"]);
1138 }
1139
1140 #[test]
1141 fn test_named_params_with_cast() {
1142 let (sql, names) = rewrite_named_params("SELECT :val::int4");
1143 assert_eq!(sql, "SELECT $1::int4");
1144 assert_eq!(names, vec!["val"]);
1145 }
1146
1147 #[test]
1148 fn test_named_params_in_string_literal() {
1149 let (sql, names) = rewrite_named_params("SELECT ':not_a_param'");
1150 assert_eq!(sql, "SELECT ':not_a_param'");
1151 assert!(names.is_empty());
1152 }
1153
1154 #[test]
1155 fn test_named_params_empty() {
1156 let (sql, names) = rewrite_named_params("SELECT 1");
1157 assert_eq!(sql, "SELECT 1");
1158 assert!(names.is_empty());
1159 }
1160
1161 #[test]
1162 fn test_named_params_underscore_prefix() {
1163 let (sql, names) = rewrite_named_params("SELECT :_private");
1164 assert_eq!(sql, "SELECT $1");
1165 assert_eq!(names, vec!["_private"]);
1166 }
1167
1168 #[test]
1171 fn test_parse_uri_full() {
1172 let (u, p, h, port, db) = parse_pg_uri("postgres://user:pass@host:1234/mydb").unwrap();
1173 assert_eq!(u, "user");
1174 assert_eq!(p, "pass");
1175 assert_eq!(h, "host");
1176 assert_eq!(port, 1234);
1177 assert_eq!(db, "mydb");
1178 }
1179
1180 #[test]
1181 fn test_parse_uri_defaults() {
1182 let (u, p, h, port, db) = parse_pg_uri("postgres://user:pass@localhost/mydb").unwrap();
1183 assert_eq!(h, "localhost");
1184 assert_eq!(port, 5432);
1185 assert_eq!(u, "user");
1186 assert_eq!(p, "pass");
1187 assert_eq!(db, "mydb");
1188 }
1189
1190 #[test]
1191 fn test_parse_uri_invalid() {
1192 assert!(parse_pg_uri("mysql://user:pass@host/db").is_none());
1193 }
1194
1195 #[test]
1196 fn test_parse_uri_postgresql_scheme() {
1197 let parsed = parse_pg_uri("postgresql://user:pass@host:5433/mydb").unwrap();
1198 assert_eq!(parsed.0, "user");
1199 assert_eq!(parsed.3, 5433);
1200 assert_eq!(parsed.4, "mydb");
1201 }
1202
1203 #[test]
1204 fn test_parse_uri_empty_password() {
1205 let parsed = parse_pg_uri("postgres://user@host/db").unwrap();
1206 assert_eq!(parsed.0, "user");
1207 assert_eq!(parsed.1, "");
1208 assert_eq!(parsed.2, "host");
1209 }
1210
1211 #[test]
1212 fn test_parse_uri_unset_database_defaults_to_postgres() {
1213 let parsed = parse_pg_uri("postgres://user:pass@host").unwrap();
1214 assert_eq!(parsed.4, "postgres");
1215 }
1216
1217 #[test]
1220 fn test_named_params_line_comment_skipped() {
1221 let (sql, names) = rewrite_named_params("SELECT :id -- :bogus\nFROM t");
1222 assert_eq!(sql, "SELECT $1 -- :bogus\nFROM t");
1223 assert_eq!(names, vec!["id"]);
1224 }
1225
1226 #[test]
1227 fn test_named_params_block_comment_skipped() {
1228 let (sql, names) = rewrite_named_params("SELECT :id /* :bogus */ FROM t");
1229 assert_eq!(sql, "SELECT $1 /* :bogus */ FROM t");
1230 assert_eq!(names, vec!["id"]);
1231 }
1232
1233 #[test]
1234 fn test_named_params_dollar_quoted_body_skipped() {
1235 let (sql, names) = rewrite_named_params("SELECT $$ :ignored $$ WHERE id = :id");
1236 assert_eq!(sql, "SELECT $$ :ignored $$ WHERE id = $1");
1237 assert_eq!(names, vec!["id"]);
1238 }
1239
1240 #[test]
1241 fn test_named_params_tagged_dollar_quote_skipped() {
1242 let (sql, names) = rewrite_named_params("SELECT $tag$ :ignored $tag$ WHERE id = :id");
1243 assert_eq!(sql, "SELECT $tag$ :ignored $tag$ WHERE id = $1");
1244 assert_eq!(names, vec!["id"]);
1245 }
1246
1247 #[test]
1248 fn test_named_params_quoted_identifier_skipped() {
1249 let (sql, names) = rewrite_named_params(r#"SELECT ":col" FROM t WHERE id = :id"#);
1250 assert_eq!(sql, r#"SELECT ":col" FROM t WHERE id = $1"#);
1251 assert_eq!(names, vec!["id"]);
1252 }
1253
1254 #[test]
1255 fn test_named_params_positional_dollar_param_passthrough() {
1256 let (sql, names) = rewrite_named_params("SELECT $1, :id FROM t");
1257 assert_eq!(sql, "SELECT $1, $1 FROM t");
1258 assert_eq!(names, vec!["id"]);
1259 }
1260
1261 #[test]
1262 fn test_named_params_escaped_single_quote_inside_literal() {
1263 let (sql, names) = rewrite_named_params("SELECT 'it''s :nothing' , :real");
1264 assert_eq!(sql, "SELECT 'it''s :nothing' , $1");
1265 assert_eq!(names, vec!["real"]);
1266 }
1267
1268 #[test]
1271 fn test_hash_sql_stable() {
1272 let sql = "SELECT id FROM t WHERE x = $1";
1273 assert_eq!(hash_sql(sql), hash_sql(sql));
1274 }
1275
1276 #[test]
1277 fn test_hash_sql_differs_by_content() {
1278 assert_ne!(hash_sql("SELECT 1"), hash_sql("SELECT 2"));
1279 }
1280
1281 #[test]
1282 fn test_hash_sql_empty() {
1283 assert_eq!(hash_sql(""), 0xcbf29ce484222325);
1285 }
1286
1287 #[test]
1290 fn test_cache_roundtrip() {
1291 let tmp = std::env::temp_dir().join(format!(
1292 "resolute-macros-cache-{}",
1293 std::time::SystemTime::now()
1294 .duration_since(std::time::UNIX_EPOCH)
1295 .unwrap()
1296 .as_nanos()
1297 ));
1298 std::fs::create_dir_all(&tmp).unwrap();
1299 let path = tmp.join("query.json");
1300
1301 let entry = cache::CacheEntry {
1302 sql: "SELECT 1::int4 AS n".into(),
1303 hash: 0xdeadbeef_cafebabe,
1304 param_oids: vec![23, 25],
1305 columns: vec![cache::CachedColumn {
1306 name: "n".into(),
1307 type_oid: 23,
1308 nullable: true,
1309 }],
1310 };
1311
1312 let json = serde_json::to_string_pretty(&entry).unwrap();
1313 std::fs::write(&path, &json).unwrap();
1314 let raw = std::fs::read_to_string(&path).unwrap();
1315 let decoded: cache::CacheEntry = serde_json::from_str(&raw).unwrap();
1316
1317 assert_eq!(decoded.sql, entry.sql);
1318 assert_eq!(decoded.hash, entry.hash);
1319 assert_eq!(decoded.param_oids, entry.param_oids);
1320 assert_eq!(decoded.columns.len(), 1);
1321 assert_eq!(decoded.columns[0].name, "n");
1322 assert_eq!(decoded.columns[0].type_oid, 23);
1323 assert!(decoded.columns[0].nullable);
1324
1325 std::fs::remove_dir_all(&tmp).ok();
1326 }
1327
1328 #[test]
1329 fn test_cache_entry_missing_nullable_defaults_to_false() {
1330 let legacy = r#"{
1333 "sql": "SELECT 1",
1334 "hash": 1,
1335 "param_oids": [],
1336 "columns": [{"name": "n", "type_oid": 23}]
1337 }"#;
1338 let entry: cache::CacheEntry = serde_json::from_str(legacy).unwrap();
1339 assert_eq!(entry.columns.len(), 1);
1340 assert!(!entry.columns[0].nullable);
1341 }
1342}