1extern crate proc_macro;
2
3mod condblock;
4mod parser;
5
6use parser::{Kind, Method, Query};
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::quote;
9use std::{
10 collections::BTreeSet,
11 env, fs,
12 path::{Path, PathBuf},
13};
14use syn::{parse_str, Lit, Meta, MetaNameValue, Type};
15
16pub struct Context(Type, Type, Type, Type);
17pub enum ContextType {
18 Postgres,
19 Sqlite,
20 Mysql,
21 Default,
22}
23impl Context {
24 pub fn new(context_type: ContextType) -> Self {
25 match context_type {
26 ContextType::Postgres => Context(
27 parse_str::<Type>("sqlx::postgres::Postgres").unwrap(),
28 parse_str::<Type>("sqlx::postgres::PgArguments").unwrap(),
29 parse_str::<Type>("sqlx::postgres::PgRow").unwrap(),
30 parse_str::<Type>("sqlx::postgres::PgQueryResult").unwrap(),
31 ),
32 ContextType::Sqlite => Context(
33 parse_str::<Type>("sqlx::sqlite::Sqlite").unwrap(),
34 parse_str::<Type>("sqlx::sqlite::SqliteArguments<'q>").unwrap(),
35 parse_str::<Type>("sqlx::sqlite::SqliteRow").unwrap(),
36 parse_str::<Type>("sqlx::sqlite::SqliteQueryResult").unwrap(),
37 ),
38 ContextType::Mysql => Context(
39 parse_str::<Type>("sqlx::mysql::MySql").unwrap(),
40 parse_str::<Type>("sqlx::mysql::MySqlArguments").unwrap(),
41 parse_str::<Type>("sqlx::mysql::MySqlRow").unwrap(),
42 parse_str::<Type>("sqlx::mysql::MySqlQueryResult").unwrap(),
43 ),
44 _ => panic!("None of [postgres, sqlite, mysql] feature enabled"),
45 }
46 }
47}
48
49fn find_attribute_values(ast: &syn::DeriveInput, attr_name: &str) -> Vec<String> {
51 ast.attrs
52 .iter()
53 .filter(|value| value.path.is_ident(attr_name))
54 .filter_map(|attr| attr.parse_meta().ok())
55 .filter_map(|meta| match meta {
56 Meta::NameValue(MetaNameValue {
57 lit: Lit::Str(val), ..
58 }) => Some(val.value()),
59 _ => None,
60 })
61 .collect()
62}
63
64fn workspace_dir() -> PathBuf {
65 let output = std::process::Command::new(env!("CARGO"))
66 .arg("locate-project")
67 .arg("--workspace")
68 .arg("--message-format=plain")
69 .output()
70 .unwrap()
71 .stdout;
72 let cargo_path = Path::new(std::str::from_utf8(&output).unwrap().trim());
73 cargo_path
74 .parent()
75 .unwrap()
76 .to_path_buf()
77 .canonicalize()
78 .unwrap_or_else(|err| {
79 panic!(
80 "workspace dir path must resolve to an absolute path: {}",
81 err
82 )
83 })
84}
85
86fn snake_to_pascal(snake: &str) -> String {
87 let mut result = String::with_capacity(snake.len());
88 let mut capitalize_next = true;
89
90 for ch in snake.chars() {
91 if ch == '_' {
92 capitalize_next = true;
93 } else if capitalize_next {
94 result.extend(ch.to_uppercase());
95 capitalize_next = false;
96 } else {
97 result.push(ch);
98 }
99 }
100 result
101}
102
103pub fn find_queries_path(queries_path: String) -> PathBuf {
106 let cargo_dir = env::var("CARGO_MANIFEST_DIR").expect("Could not locate Cargo.toml");
108 let cargo_dir_canonical_path = Path::new(&cargo_dir)
109 .canonicalize()
110 .unwrap_or_else(|err| panic!("cargo dir path must resolve to an absolute path: {}", err));
111
112 let mut seen = BTreeSet::new();
113 let candidate_path = cargo_dir_canonical_path.join(&queries_path);
114 if candidate_path.exists() {
115 return candidate_path;
116 }
117 seen.insert(cargo_dir_canonical_path);
118
119 let workspace_root = workspace_dir();
120 let candidate_path = workspace_root.join(&queries_path);
121
122 if candidate_path.exists() {
123 return candidate_path;
124 }
125
126 seen.insert(workspace_root);
127 panic!("Queries path must be relative to the crate's Cargo.toml location or the workspace root. Tried the following folders: {seen:?}");
128}
129
130pub fn impl_hug_sqlx(ast: &syn::DeriveInput, ctx: Context) -> TokenStream2 {
131 let mut queries_paths = find_attribute_values(ast, "queries");
132 if queries_paths.len() != 1 {
133 panic!(
134 "#[derive(HugSql)] must contain one attribute like this #[queries = \"db/queries/\"]"
135 );
136 }
137 let canonical_path = find_queries_path(queries_paths.remove(0));
138
139 let files = if canonical_path.is_dir() {
140 walkdir::WalkDir::new(canonical_path)
141 .follow_links(true)
142 .sort_by_file_name()
143 .into_iter()
144 .filter_map(|e| e.ok())
145 .filter(|e| e.file_type().is_file())
146 .map(move |e| std::fs::canonicalize(e.path()).expect("Could not get canonical path"))
147 .collect()
148 } else {
149 vec![canonical_path]
150 };
151
152 let name = &ast.ident;
153 let mut output_ts = TokenStream2::new();
154 let mut functions = TokenStream2::new();
155 let mut enums = TokenStream2::new();
156
157 for f in files {
158 if let Ok(input) = fs::read_to_string(f) {
159 match parser::parse_queries(input) {
160 Ok(ast) => {
161 generate_impl_fns(ast, &ctx, &mut functions, &mut enums);
162 }
163 Err(parse_errs) => parse_errs
164 .into_iter()
165 .for_each(|e| eprintln!("Parse error: {}", e)),
166 }
167 }
168 }
169
170 output_ts.extend(quote! {
171 #enums
172
173 pub trait HugSql {
174 #functions
175 }
176 impl HugSql for #name {
177 }
178 });
179 output_ts
180}
181
182fn generate_cond_block_resolver_fn(query: &Query) -> (TokenStream2, TokenStream2, TokenStream2) {
183 let sql_blocks = &query.sql;
184 let cond_blocks = sql_blocks
185 .iter()
186 .filter(|b| matches!(b, condblock::SqlBlock::Conditional(_, _)))
187 .count();
188
189 if cond_blocks > 0 {
190 let enumeration = Ident::new(&snake_to_pascal(&query.name), Span::call_site());
191 let mut variants = Vec::with_capacity(cond_blocks);
192
193 let block_processing: Vec<_> = sql_blocks
195 .iter()
196 .map(|block| match block {
197 condblock::SqlBlock::Conditional(id, sql) => {
198 let variant = Ident::new(&snake_to_pascal(id), Span::call_site());
199 let quot = quote! {
200 if block_resolver(#enumeration::#variant) {
201 result.push_str(#sql);
202 result.push('\n');
203 }
204 };
205 variants.push(variant);
206 quot
207 }
208 condblock::SqlBlock::Literal(sql) => {
209 quote! {
210 result.push_str(#sql);
211 result.push('\n');
212 }
213 }
214 })
215 .collect();
216
217 let variant_tokens = variants.into_iter().map(|variant| {
219 quote! {
220 #variant,
221 }
222 });
223
224 return (
225 quote! { block_resolver: impl Fn(#enumeration) -> bool + Send, },
226 quote! {
227 &{
228 let mut result = String::new();
229 #(#block_processing)*
230 result
231 }
232 },
233 quote! {
234 pub enum #enumeration {
235 #(#variant_tokens)*
236 }
237 },
238 );
239 }
240 let sql = match sql_blocks.first() {
241 Some(condblock::SqlBlock::Literal(sql))
242 | Some(condblock::SqlBlock::Conditional(_, sql)) => sql,
243 None => "",
244 };
245 (TokenStream2::new(), quote! { #sql }, TokenStream2::new())
246}
247
248fn generate_impl_fns(
249 queries: Vec<Query>,
250 ctx: &Context,
251 functions_ts: &mut TokenStream2,
252 enums_ts: &mut TokenStream2,
253) {
254 for q in queries {
255 if let Some(doc) = &q.doc {
256 functions_ts.extend(quote! { #[doc = #doc] });
257 }
258 match q.kind {
259 Kind::Typed => generate_typed_fn(q, ctx, functions_ts, enums_ts),
260 Kind::Untyped => generate_untyped_fn(q, ctx, functions_ts, enums_ts),
261 Kind::Mapped => generate_mapped_fn(q, ctx, functions_ts, enums_ts),
262 }
263 }
264}
265
266fn generate_typed_fn(
267 q: Query,
268 Context(db, args, row, result): &Context,
269 functions_ts: &mut TokenStream2,
270 enums_ts: &mut TokenStream2,
271) {
272 let name = Ident::new(&q.name, Span::call_site());
273 let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
274
275 enums_ts.extend(enums);
276
277 functions_ts.extend(match q.method {
278 Method::FetchMany => {
279 quote! {
280 async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> futures_core::stream::BoxStream<'e, Result<T, sqlx::Error>>
281 where
282 'q: 'e,
283 'c: 'e,
284 E: sqlx::Executor<'c, Database = #db> + 'e,
285 T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
286 sqlx::query_as_with(#sql, params).fetch(executor)
287 }
288 }
289 },
290 Method::FetchOne => {
291 quote! {
292 async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<T, sqlx::Error>
293 where
294 'q: 'e,
295 'c: 'e,
296 E: sqlx::Executor<'c, Database = #db> + 'e,
297 T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
298 sqlx::query_as_with(#sql, params).fetch_one(executor).await
299 }
300 }
301 },
302 Method::FetchOptional => {
303 quote! {
304 async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<Option<T>, sqlx::Error>
305 where
306 'q: 'e,
307 'c: 'e,
308 E: sqlx::Executor<'c, Database = #db> + 'e,
309 T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
310 sqlx::query_as_with(#sql, params).fetch_optional(executor).await
311 }
312 }
313 },
314 Method::FetchAll => {
315 quote! {
316 async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<Vec<T>, sqlx::Error>
317 where
318 'q: 'e,
319 'c: 'e,
320 E: sqlx::Executor<'c, Database = #db> + 'e,
321 T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
322 sqlx::query_as_with(#sql, params).fetch_all(executor).await
323 }
324 }
325 },
326 Method::Execute => {
327 quote! {
328 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
329 where
330 'q: 'e,
331 'c: 'e,
332 E: sqlx::Executor<'c, Database = #db> + 'e {
333 sqlx::query_with(#sql, params).execute(executor).await
334 }
335 }
336 },
337 });
338}
339
340fn generate_untyped_fn(
341 q: Query,
342 Context(db, args, row, result): &Context,
343 functions_ts: &mut TokenStream2,
344 enums_ts: &mut TokenStream2,
345) {
346 let name = Ident::new(&q.name, Span::call_site());
347 let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
348
349 enums_ts.extend(enums);
350
351 functions_ts.extend(match q.method {
352 Method::FetchMany => {
353 quote! {
354 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> futures_core::stream::BoxStream<'e, Result<#row, sqlx::Error>>
355 where
356 'q: 'e,
357 'c: 'e,
358 E: sqlx::Executor<'c, Database = #db> + 'e {
359 sqlx::query_with(#sql, params).fetch(executor)
360 }
361 }
362 },
363 Method::FetchOne => {
364 quote! {
365 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#row, sqlx::Error>
366 where
367 'q: 'e,
368 'c: 'e,
369 E: sqlx::Executor<'c, Database = #db> + 'e {
370 sqlx::query_with(#sql, params).fetch_one(executor).await
371 }
372 }
373 },
374 Method::FetchOptional => {
375 quote! {
376 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<Option<#row>, sqlx::Error>
377 where
378 'q: 'e,
379 'c: 'e,
380 E: sqlx::Executor<'c, Database = #db> + 'e {
381 sqlx::query_with(#sql, params).fetch_optional(executor).await
382 }
383 }
384 },
385 Method::FetchAll => {
386 quote! {
387 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<Vec<#row>, sqlx::Error>
388 where
389 'q: 'e,
390 'c: 'e,
391 E: sqlx::Executor<'c, Database = #db> + 'e {
392 sqlx::query_with(#sql, params).fetch_all(executor).await
393 }
394 }
395 },
396 Method::Execute => {
397 quote! {
398 async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
399 where
400 'q: 'e,
401 'c: 'e,
402 E: sqlx::Executor<'c, Database = #db> + 'e {
403 sqlx::query_with(#sql, params).execute(executor).await
404 }
405 }
406 },
407 });
408}
409
410fn generate_mapped_fn(
411 q: Query,
412 Context(db, args, row, result): &Context,
413 functions_ts: &mut TokenStream2,
414 enums_ts: &mut TokenStream2,
415) {
416 let name = Ident::new(&q.name, Span::call_site());
417 let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
418
419 enums_ts.extend(enums);
420
421 functions_ts.extend(match q.method {
422 Method::FetchMany => {
423 quote! {
424 async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> futures_core::stream::BoxStream<'e, Result<T, sqlx::Error>>
425 where
426 'q: 'e,
427 'c: 'e,
428 E: sqlx::Executor<'c, Database = #db> + 'e,
429 F: FnMut(#row) -> T + Send + 'e,
430 T: Send + Unpin + 'e {
431 sqlx::query_with(#sql, params)
432 .map(mapper)
433 .fetch(executor)
434 }
435 }
436 },
437 Method::FetchOne => {
438 quote! {
439 async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<T, sqlx::Error>
440 where
441 'q: 'e,
442 'c: 'e,
443 E: sqlx::Executor<'c, Database = #db> + 'e,
444 F: FnMut(#row) -> T + Send + 'e,
445 T: Send + Unpin + 'e {
446 sqlx::query_with(#sql, params)
447 .map(mapper)
448 .fetch_one(executor)
449 .await
450 }
451 }
452 },
453 Method::FetchOptional => {
454 quote! {
455 async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<Option<T>, sqlx::Error>
456 where
457 'q: 'e,
458 'c: 'e,
459 E: sqlx::Executor<'c, Database = #db> + 'e,
460 F: FnMut(#row) -> T + Send + 'e,
461 T: Send + Unpin + 'e {
462 sqlx::query_with(#sql, params)
463 .map(mapper)
464 .fetch_optional(executor)
465 .await
466 }
467 }
468 },
469 Method::FetchAll => {
470 quote! {
471 async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<Vec<T>, sqlx::Error>
472 where
473 'q: 'e,
474 'c: 'e,
475 E: sqlx::Executor<'c, Database = #db> + 'e,
476 F: FnMut(#row) -> T + Send + 'e,
477 T: Send + Unpin + 'e {
478 sqlx::query_with(#sql, params)
479 .map(mapper)
480 .fetch_all(executor)
481 .await
482 }
483 }
484 },
485 Method::Execute => {
486 quote! {
487 async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
488 where
489 'q: 'e,
490 'c: 'e,
491 E: sqlx::Executor<'c, Database = #db> + 'e {
492 sqlx::query_with(#sql, params).execute(executor).await
493 }
494 }
495 },
496 });
497}
498
499cfg_if::cfg_if! {
500 if #[cfg(feature = "postgres")] {
501 #[macro_export]
502 macro_rules! params {
503 ($($arg:expr),*) => {
504 {
505 use sqlx::Arguments;
506 let mut args = sqlx::postgres::PgArguments::default();
507 $( args.add($arg).unwrap(); )*
508 args
509 }
510 };
511 }
512 } else if #[cfg(feature = "mysql")] {
513 #[macro_export]
514 macro_rules! params {
515 ($($arg:expr),*) => {
516 {
517 use sqlx::Arguments;
518 let mut args = sqlx::mysql::MySqlArguments::default();
519 $( args.add($arg).unwrap(); )*
520 args
521 }
522 };
523 }
524 } else {
525 #[macro_export]
526 macro_rules! params {
527 ($($arg:expr),*) => {
528 {
529 use sqlx::Arguments;
530 let mut args = sqlx::sqlite::SqliteArguments::default();
531 $( args.add($arg).unwrap(); )*
532 args
533 }
534 };
535 }
536 }
537}
538
539#[cfg(test)]
540mod test {
541 use crate::parser::{query_parser, Kind, Method};
542 use chumsky::Parser;
543
544 #[test]
545 fn parsing_defaults() {
546 let input = r#"
547-- :name fetch_users
548-- :doc Returns all the users from DB
549SELECT user_id, email, name, picture FROM users
550"#;
551
552 let queries = query_parser().parse(input).unwrap();
553 assert_eq!(queries.len(), 1);
554 assert_eq!(queries[0].name, "fetch_users");
555 assert_eq!(
556 queries[0].doc,
557 Some("Returns all the users from DB".to_string())
558 );
559 assert_eq!(queries[0].kind, Kind::Untyped);
560 assert_eq!(queries[0].method, Method::Execute);
561 }
562
563 #[test]
564 fn parsing_default_type() {
565 let input = r#"
566-- :name fetch_users :^
567SELECT user_id, email, name, picture FROM users
568"#;
569
570 let queries = query_parser().parse(input).unwrap();
571 assert_eq!(queries.len(), 1);
572 assert_eq!(queries[0].name, "fetch_users");
573 assert_eq!(queries[0].doc, None);
574 assert_eq!(queries[0].kind, Kind::Untyped);
575 assert_eq!(queries[0].method, Method::FetchMany);
576 }
577
578 #[test]
579 fn parsing_type_aliases() {
580 let input = r#"
581-- :name fetch_users :<> :^
582SELECT user_id, email, name, picture FROM users
583"#;
584
585 let queries = query_parser().parse(input).unwrap();
586 assert_eq!(queries.len(), 1);
587 assert_eq!(queries[0].name, "fetch_users");
588 assert_eq!(queries[0].doc, None);
589 assert_eq!(queries[0].kind, Kind::Typed);
590 assert_eq!(queries[0].method, Method::FetchMany);
591 }
592
593 #[test]
594 fn parsing_default_call_method() {
595 let input = r#"
596-- :name fetch_users :mapped
597SELECT user_id, email, name, picture FROM users
598"#;
599
600 let queries = query_parser().parse(input).unwrap();
601 assert_eq!(queries.len(), 1);
602 assert_eq!(queries[0].name, "fetch_users");
603 assert_eq!(queries[0].doc, None);
604 assert_eq!(queries[0].kind, Kind::Mapped);
605 assert_eq!(queries[0].method, Method::Execute);
606 }
607
608 #[test]
609 fn parsing_multiple() {
610 let input = r#"
611-- :name fetch_users
612-- :doc Returns all the users from DB
613SELECT user_id, email, name, picture FROM users
614
615-- :name fetch_user_by_id :untyped :1
616-- :doc Fetches user by its identifier
617SELECT user_id, email, name, picture
618 FROM users
619 WHERE user_id = $1
620
621-- :name set_picture :typed :1
622-- :doc Sets user's picture.
623-- Picture is expected to be a valid URL.
624UPDATE users
625 -- expected URL to the picture
626 SET picture = ?
627 WHERE user_id = ?
628
629-- :name delete_user :typed :1
630DELETE FROM users
631 WHERE user_id = ?
632"#;
633
634 let queries = query_parser().parse(input).unwrap();
635 assert_eq!(queries.len(), 4);
636
637 assert_eq!(queries[0].name, "fetch_users".to_string());
638 assert_eq!(
639 queries[0].doc,
640 Some("Returns all the users from DB".to_string())
641 );
642 assert_eq!(queries[0].kind, Kind::Untyped);
643 assert_eq!(queries[0].method, Method::Execute);
644
645 assert_eq!(queries[1].name, "fetch_user_by_id".to_string());
646 assert_eq!(
647 queries[1].doc,
648 Some("Fetches user by its identifier".to_string())
649 );
650 assert_eq!(queries[1].kind, Kind::Untyped);
651 assert_eq!(queries[1].method, Method::FetchOne);
652
653 assert_eq!(queries[2].name, "set_picture".to_string());
654 assert_eq!(
655 queries[2].doc,
656 Some("Sets user's picture.\nPicture is expected to be a valid URL.".to_string())
657 );
658 assert_eq!(queries[2].kind, Kind::Typed);
659 assert_eq!(queries[2].method, Method::FetchOne);
660
661 assert_eq!(queries[3].name, "delete_user".to_string());
662 assert_eq!(queries[3].doc, None);
663 assert_eq!(queries[3].kind, Kind::Typed);
664 assert_eq!(queries[3].method, Method::FetchOne);
665 }
666}