1#![forbid(unsafe_code)]
2
3extern crate proc_macro;
8
9mod codegen;
10#[cfg(feature = "sqlite")]
11mod codegen_sqlite;
12mod connection;
13mod dynamic;
14#[cfg(feature = "explain")]
15mod explain;
16mod offline;
17mod parse;
18mod pg_enum;
19mod sort_enum;
20mod sql_norm;
21mod stmt_name;
22mod suggest;
23mod test_macro;
24pub(crate) mod types;
25#[cfg(feature = "sqlite")]
26mod types_sqlite;
27mod validate;
28#[cfg(feature = "sqlite")]
29mod validate_sqlite;
30
31use proc_macro::TokenStream;
32
33#[proc_macro]
65pub fn query(input: TokenStream) -> TokenStream {
66 let input2: proc_macro2::TokenStream = input.into();
67 match query_impl(input2) {
68 Ok(output) => output.into(),
69 Err(err) => err.to_compile_error().into(),
70 }
71}
72
73fn query_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
74 let sql = extract_sql(input)?;
78
79 let parsed = parse::parse_query(&sql)
81 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
82
83 #[cfg(feature = "sqlite")]
85 {
86 let backend = connection::detect_backend()
87 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
88 if backend == Some(connection::Backend::Sqlite) {
89 return query_impl_sqlite(parsed);
90 }
91 }
92
93 query_impl_postgres(parsed)
95}
96
97fn query_impl_postgres(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
99 if parsed.sort_placeholder.is_some() {
101 return query_impl_sort(parsed);
102 }
103
104 if parsed.optional_clauses.is_empty() {
105 let validation = if offline::is_offline() {
107 offline::lookup_cached_validation(&parsed)
109 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
110 } else {
111 let result = connection::with_connection(|conn| {
113 validate::validate_query_with_suggestions(&parsed, conn)
114 })?;
115
116 offline::write_cache(&parsed, &result);
118
119 result
120 };
121
122 validate::check_param_types(&parsed, &validation)
124 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
125
126 Ok(codegen::generate_query_code(&parsed, &validation))
128 } else {
129 let validation = if offline::is_offline() {
134 offline::lookup_cached_validation(&parsed)
142 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
143 } else {
144 let result = connection::with_connection(|conn| {
147 let variants = dynamic::expand_variants(&parsed)?;
148 validate::validate_variants(&variants, &parsed, conn)
149 })?;
150
151 offline::write_cache(&parsed, &result);
153
154 result
155 };
156
157 Ok(codegen::generate_dynamic_query_code(&parsed, &validation))
159 }
160}
161
162#[cfg(feature = "sqlite")]
167fn query_impl_sqlite(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
168 if parsed.sort_placeholder.is_some() {
170 return query_impl_sqlite_sort(parsed);
171 }
172
173 if parsed.optional_clauses.is_empty() {
174 let validation = if offline::is_offline() {
176 offline::lookup_cached_validation(&parsed)
177 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
178 } else {
179 let result = connection::with_sqlite_connection(|conn| {
180 validate_sqlite::validate_query_sqlite(&parsed, conn)
181 })?;
182
183 offline::write_cache(&parsed, &result);
185
186 result
187 };
188
189 Ok(codegen_sqlite::generate_sqlite_query_code(
194 &parsed,
195 &validation,
196 ))
197 } else {
198 let validation = if offline::is_offline() {
202 offline::lookup_cached_validation(&parsed)
203 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
204 } else {
205 let result = connection::with_sqlite_connection(|conn| {
207 let variants = dynamic::expand_variants(&parsed)?;
208 validate_sqlite::validate_variants_sqlite(&variants, &parsed, conn)
209 })?;
210
211 offline::write_cache(&parsed, &result);
212
213 result
214 };
215
216 Ok(codegen_sqlite::generate_dynamic_sqlite_query_code(
217 &parsed,
218 &validation,
219 ))
220 }
221}
222
223#[cfg(feature = "sqlite")]
225fn query_impl_sqlite_sort(
226 parsed: parse::ParsedQuery,
227) -> Result<proc_macro2::TokenStream, syn::Error> {
228 let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
229 let sort_enum_name = &sort_placeholder.enum_name;
230
231 let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
233
234 let dummy_parsed = parse::ParsedQuery {
235 normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
236 positional_sql: dummy_sql,
237 params: parsed.params.clone(),
238 kind: parsed.kind,
239 statement_name: parsed.statement_name.clone(),
240 optional_clauses: parsed.optional_clauses.clone(),
241 sort_placeholder: None,
242 };
243
244 let validation = if offline::is_offline() {
245 offline::lookup_cached_validation(&parsed)
246 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
247 } else {
248 let result = connection::with_sqlite_connection(|conn| {
249 validate_sqlite::validate_query_sqlite(&dummy_parsed, conn)
250 })?;
251
252 offline::write_cache(&parsed, &result);
253 result
254 };
255
256 Ok(codegen_sqlite::generate_sort_sqlite_query_code(
257 &parsed,
258 &validation,
259 sort_enum_name,
260 ))
261}
262
263fn query_impl_sort(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
281 let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
282 let sort_enum_name = &sort_placeholder.enum_name;
283
284 let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
286
287 let dummy_parsed = parse::ParsedQuery {
289 normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
290 positional_sql: dummy_sql,
291 params: parsed.params.clone(),
292 kind: parsed.kind,
293 statement_name: parsed.statement_name.clone(),
294 optional_clauses: parsed.optional_clauses.clone(),
295 sort_placeholder: None,
296 };
297
298 let validation = if offline::is_offline() {
299 offline::lookup_cached_validation(&parsed)
300 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
301 } else {
302 let result = connection::with_connection(|conn| {
303 validate::validate_query_with_suggestions(&dummy_parsed, conn)
304 })?;
305
306 let sorts_dir = std::env::var("CARGO_MANIFEST_DIR")
309 .map(|d| std::path::PathBuf::from(d).join(".bsql").join("sorts"))
310 .ok();
311 if let Some(sorts_dir) = sorts_dir {
312 let path = sorts_dir.join(format!("{}.txt", sort_enum_name));
313 if let Ok(content) = std::fs::read_to_string(&path) {
314 connection::with_connection(|conn| {
315 for fragment in content.lines().filter(|l| !l.is_empty()) {
316 let test_sql = parsed.positional_sql.replace("{SORT}", fragment);
317 let prepare = format!("PREPARE __bsql_sort_check AS {}", test_sql);
318 if let Err(e) = conn.simple_query(&prepare) {
319 return Err(format!("sort fragment '{}' is invalid: {}", fragment, e));
320 }
321 let _ = conn.simple_query("DEALLOCATE __bsql_sort_check");
322 }
323 Ok(())
324 })?;
325 }
326 }
327
328 offline::write_cache(&parsed, &result);
329 result
330 };
331
332 validate::check_param_types(&parsed, &validation)
333 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
334
335 Ok(codegen::generate_sort_query_code(
337 &parsed,
338 &validation,
339 sort_enum_name,
340 ))
341}
342
343#[proc_macro]
363pub fn query_as(input: TokenStream) -> TokenStream {
364 let input2: proc_macro2::TokenStream = input.into();
365 match query_as_impl(input2) {
366 Ok(output) => output.into(),
367 Err(err) => err.to_compile_error().into(),
368 }
369}
370
371struct QueryAsArgs {
373 target_type: syn::Path,
374 _comma: syn::Token![,],
375 sql: syn::LitStr,
376}
377
378impl syn::parse::Parse for QueryAsArgs {
379 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
380 Ok(QueryAsArgs {
381 target_type: input.parse()?,
382 _comma: input.parse()?,
383 sql: input.parse()?,
384 })
385 }
386}
387
388fn extract_type_and_sql(
389 input: proc_macro2::TokenStream,
390) -> Result<(syn::Path, String), syn::Error> {
391 let args: QueryAsArgs = syn::parse2(input)?;
392 Ok((args.target_type, args.sql.value()))
393}
394
395fn query_as_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
396 let (target_type, sql) = extract_type_and_sql(input)?;
397
398 let parsed = parse::parse_query(&sql)
399 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
400
401 if parsed.sort_placeholder.is_some() {
403 return Err(syn::Error::new(
404 proc_macro2::Span::call_site(),
405 "query_as! does not support $[sort: ...] placeholders; use query! instead",
406 ));
407 }
408
409 if !parsed.optional_clauses.is_empty() {
411 return Err(syn::Error::new(
412 proc_macro2::Span::call_site(),
413 "query_as! does not support optional clauses; use query! instead",
414 ));
415 }
416
417 #[cfg(feature = "sqlite")]
419 {
420 let backend = connection::detect_backend()
421 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
422 if backend == Some(connection::Backend::Sqlite) {
423 return query_as_impl_sqlite(parsed, target_type);
424 }
425 }
426
427 query_as_impl_postgres(parsed, target_type)
429}
430
431fn query_as_impl_postgres(
432 parsed: parse::ParsedQuery,
433 target_type: syn::Path,
434) -> Result<proc_macro2::TokenStream, syn::Error> {
435 let validation = if offline::is_offline() {
436 offline::lookup_cached_validation(&parsed)
437 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
438 } else {
439 let result = connection::with_connection(|conn| {
440 validate::validate_query_with_suggestions(&parsed, conn)
441 })?;
442
443 offline::write_cache(&parsed, &result);
444 result
445 };
446
447 validate::check_param_types(&parsed, &validation)
448 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
449
450 Ok(codegen::generate_query_as_code(
451 &parsed,
452 &validation,
453 &target_type,
454 ))
455}
456
457#[cfg(feature = "sqlite")]
458fn query_as_impl_sqlite(
459 parsed: parse::ParsedQuery,
460 target_type: syn::Path,
461) -> Result<proc_macro2::TokenStream, syn::Error> {
462 let validation = if offline::is_offline() {
463 offline::lookup_cached_validation(&parsed)
464 .map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
465 } else {
466 let result = connection::with_sqlite_connection(|conn| {
467 validate_sqlite::validate_query_sqlite(&parsed, conn)
468 })?;
469
470 offline::write_cache(&parsed, &result);
471 result
472 };
473
474 Ok(codegen_sqlite::generate_sqlite_query_as_code(
475 &parsed,
476 &validation,
477 &target_type,
478 ))
479}
480
481fn extract_sql(input: proc_macro2::TokenStream) -> Result<String, syn::Error> {
485 let lit: syn::LitStr = syn::parse2(input)?;
486 Ok(lit.value())
487}
488
489#[proc_macro_attribute]
517pub fn pg_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
518 let attr2: proc_macro2::TokenStream = attr.into();
519 let item2: proc_macro2::TokenStream = item.into();
520 match pg_enum::expand_pg_enum(attr2, item2) {
521 Ok(output) => output.into(),
522 Err(err) => err.to_compile_error().into(),
523 }
524}
525
526#[proc_macro_attribute]
559pub fn sort(attr: TokenStream, item: TokenStream) -> TokenStream {
560 let attr2: proc_macro2::TokenStream = attr.into();
561 let item2: proc_macro2::TokenStream = item.into();
562 match sort_enum::expand_sort_enum(attr2, item2) {
563 Ok(output) => output.into(),
564 Err(err) => err.to_compile_error().into(),
565 }
566}
567
568#[proc_macro_attribute]
602pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
603 let attr2: proc_macro2::TokenStream = attr.into();
604 let item2: proc_macro2::TokenStream = item.into();
605 match test_macro::expand_test(attr2, item2) {
606 Ok(output) => output.into(),
607 Err(err) => err.to_compile_error().into(),
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::{extract_type_and_sql, QueryAsArgs};
614
615 #[test]
616 fn parse_query_as_args() {
617 let tokens: proc_macro2::TokenStream = "User, \"SELECT id FROM users\"".parse().unwrap();
618 let args: QueryAsArgs = syn::parse2(tokens).unwrap();
619 assert_eq!(args.sql.value(), "SELECT id FROM users");
620 let last_segment = args.target_type.segments.last().unwrap().ident.to_string();
622 assert_eq!(last_segment, "User");
623 }
624
625 #[test]
626 fn parse_query_as_args_module_path() {
627 let tokens: proc_macro2::TokenStream = "crate::models::User, \"SELECT id FROM users\""
628 .parse()
629 .unwrap();
630 let args: QueryAsArgs = syn::parse2(tokens).unwrap();
631 assert_eq!(args.sql.value(), "SELECT id FROM users");
632 let segments: Vec<String> = args
633 .target_type
634 .segments
635 .iter()
636 .map(|s| s.ident.to_string())
637 .collect();
638 assert_eq!(segments, vec!["crate", "models", "User"]);
639 }
640
641 #[test]
642 fn extract_type_and_sql_basic() {
643 let tokens: proc_macro2::TokenStream = "Row, \"SELECT name FROM t WHERE id = $id: i32\""
644 .parse()
645 .unwrap();
646 let (path, sql) = extract_type_and_sql(tokens).unwrap();
647 assert_eq!(sql, "SELECT name FROM t WHERE id = $id: i32");
648 assert_eq!(path.segments.last().unwrap().ident.to_string(), "Row");
649 }
650
651 #[test]
652 fn extract_type_and_sql_missing_comma_fails() {
653 let tokens: proc_macro2::TokenStream = "User \"SELECT id FROM t\"".parse().unwrap();
654 assert!(extract_type_and_sql(tokens).is_err());
655 }
656
657 #[test]
658 fn extract_type_and_sql_missing_sql_fails() {
659 let tokens: proc_macro2::TokenStream = "User,".parse().unwrap();
660 assert!(extract_type_and_sql(tokens).is_err());
661 }
662}