hugsqlx_core/
lib.rs

1extern crate proc_macro;
2
3mod condblock;
4mod parser;
5
6use parser::{Kind, Method, Query};
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::quote;
9use std::{
10    collections::BTreeSet,
11    env, fs,
12    path::{Path, PathBuf},
13};
14use syn::{parse_str, Lit, Meta, MetaNameValue, Type};
15
16pub struct Context(Type, Type, Type, Type);
17pub enum ContextType {
18    Postgres,
19    Sqlite,
20    Mysql,
21    Default,
22}
23impl Context {
24    pub fn new(context_type: ContextType) -> Self {
25        match context_type {
26            ContextType::Postgres => Context(
27                parse_str::<Type>("sqlx::postgres::Postgres").unwrap(),
28                parse_str::<Type>("sqlx::postgres::PgArguments").unwrap(),
29                parse_str::<Type>("sqlx::postgres::PgRow").unwrap(),
30                parse_str::<Type>("sqlx::postgres::PgQueryResult").unwrap(),
31            ),
32            ContextType::Sqlite => Context(
33                parse_str::<Type>("sqlx::sqlite::Sqlite").unwrap(),
34                parse_str::<Type>("sqlx::sqlite::SqliteArguments<'q>").unwrap(),
35                parse_str::<Type>("sqlx::sqlite::SqliteRow").unwrap(),
36                parse_str::<Type>("sqlx::sqlite::SqliteQueryResult").unwrap(),
37            ),
38            ContextType::Mysql => Context(
39                parse_str::<Type>("sqlx::mysql::MySql").unwrap(),
40                parse_str::<Type>("sqlx::mysql::MySqlArguments").unwrap(),
41                parse_str::<Type>("sqlx::mysql::MySqlRow").unwrap(),
42                parse_str::<Type>("sqlx::mysql::MySqlQueryResult").unwrap(),
43            ),
44            _ => panic!("None of [postgres, sqlite, mysql] feature enabled"),
45        }
46    }
47}
48
49/// Find all pairs of the `name = "value"` attribute from the derive input
50fn find_attribute_values(ast: &syn::DeriveInput, attr_name: &str) -> Vec<String> {
51    ast.attrs
52        .iter()
53        .filter(|value| value.path.is_ident(attr_name))
54        .filter_map(|attr| attr.parse_meta().ok())
55        .filter_map(|meta| match meta {
56            Meta::NameValue(MetaNameValue {
57                lit: Lit::Str(val), ..
58            }) => Some(val.value()),
59            _ => None,
60        })
61        .collect()
62}
63
64fn workspace_dir() -> PathBuf {
65    let output = std::process::Command::new(env!("CARGO"))
66        .arg("locate-project")
67        .arg("--workspace")
68        .arg("--message-format=plain")
69        .output()
70        .unwrap()
71        .stdout;
72    let cargo_path = Path::new(std::str::from_utf8(&output).unwrap().trim());
73    cargo_path
74        .parent()
75        .unwrap()
76        .to_path_buf()
77        .canonicalize()
78        .unwrap_or_else(|err| {
79            panic!(
80                "workspace dir path must resolve to an absolute path: {}",
81                err
82            )
83        })
84}
85
86fn snake_to_pascal(snake: &str) -> String {
87    let mut result = String::with_capacity(snake.len());
88    let mut capitalize_next = true;
89
90    for ch in snake.chars() {
91        if ch == '_' {
92            capitalize_next = true;
93        } else if capitalize_next {
94            result.extend(ch.to_uppercase());
95            capitalize_next = false;
96        } else {
97            result.push(ch);
98        }
99    }
100    result
101}
102
103/// Find a suitable candidate queries path by both the local crate's CARGO_MANIFEST_DIR
104/// as well as the workspace root.
105pub fn find_queries_path(queries_path: String) -> PathBuf {
106    // The directory of the crate's cargo dir. This may be different from the workspace root's directory.
107    let cargo_dir = env::var("CARGO_MANIFEST_DIR").expect("Could not locate Cargo.toml");
108    let cargo_dir_canonical_path = Path::new(&cargo_dir)
109        .canonicalize()
110        .unwrap_or_else(|err| panic!("cargo dir path must resolve to an absolute path: {}", err));
111
112    let mut seen = BTreeSet::new();
113    let candidate_path = cargo_dir_canonical_path.join(&queries_path);
114    if candidate_path.exists() {
115        return candidate_path;
116    }
117    seen.insert(cargo_dir_canonical_path);
118
119    let workspace_root = workspace_dir();
120    let candidate_path = workspace_root.join(&queries_path);
121
122    if candidate_path.exists() {
123        return candidate_path;
124    }
125
126    seen.insert(workspace_root);
127    panic!("Queries path must be relative to the crate's Cargo.toml location or the workspace root. Tried the following folders: {seen:?}");
128}
129
130pub fn impl_hug_sqlx(ast: &syn::DeriveInput, ctx: Context) -> TokenStream2 {
131    let mut queries_paths = find_attribute_values(ast, "queries");
132    if queries_paths.len() != 1 {
133        panic!(
134            "#[derive(HugSql)] must contain one attribute like this #[queries = \"db/queries/\"]"
135        );
136    }
137    let canonical_path = find_queries_path(queries_paths.remove(0));
138
139    let files = if canonical_path.is_dir() {
140        walkdir::WalkDir::new(canonical_path)
141            .follow_links(true)
142            .sort_by_file_name()
143            .into_iter()
144            .filter_map(|e| e.ok())
145            .filter(|e| e.file_type().is_file())
146            .map(move |e| std::fs::canonicalize(e.path()).expect("Could not get canonical path"))
147            .collect()
148    } else {
149        vec![canonical_path]
150    };
151
152    let name = &ast.ident;
153    let mut output_ts = TokenStream2::new();
154    let mut functions = TokenStream2::new();
155    let mut enums = TokenStream2::new();
156
157    for f in files {
158        if let Ok(input) = fs::read_to_string(f) {
159            match parser::parse_queries(input) {
160                Ok(ast) => {
161                    generate_impl_fns(ast, &ctx, &mut functions, &mut enums);
162                }
163                Err(parse_errs) => parse_errs
164                    .into_iter()
165                    .for_each(|e| eprintln!("Parse error: {}", e)),
166            }
167        }
168    }
169
170    output_ts.extend(quote! {
171        #enums
172
173        pub trait HugSql {
174            #functions
175        }
176        impl HugSql for #name {
177        }
178    });
179    output_ts
180}
181
182fn generate_cond_block_resolver_fn(query: &Query) -> (TokenStream2, TokenStream2, TokenStream2) {
183    let sql_blocks = &query.sql;
184    let cond_blocks = sql_blocks
185        .iter()
186        .filter(|b| matches!(b, condblock::SqlBlock::Conditional(_, _)))
187        .count();
188
189    if cond_blocks > 0 {
190        let enumeration = Ident::new(&snake_to_pascal(&query.name), Span::call_site());
191        let mut variants = Vec::with_capacity(cond_blocks);
192
193        // Generate compile-time code that builds the SQL string at runtime
194        let block_processing: Vec<_> = sql_blocks
195            .iter()
196            .map(|block| match block {
197                condblock::SqlBlock::Conditional(id, sql) => {
198                    let variant = Ident::new(&snake_to_pascal(id), Span::call_site());
199                    let quot = quote! {
200                        if block_resolver(#enumeration::#variant) {
201                            result.push_str(#sql);
202                            result.push('\n');
203                        }
204                    };
205                    variants.push(variant);
206                    quot
207                }
208                condblock::SqlBlock::Literal(sql) => {
209                    quote! {
210                        result.push_str(#sql);
211                        result.push('\n');
212                    }
213                }
214            })
215            .collect();
216
217        // Generate Enums that will be passed to block resolving function
218        let variant_tokens = variants.into_iter().map(|variant| {
219            quote! {
220                #variant,
221            }
222        });
223
224        return (
225            quote! { block_resolver: impl Fn(#enumeration) -> bool + Send, },
226            quote! {
227                &{
228                    let mut result = String::new();
229                    #(#block_processing)*
230                    result
231                }
232            },
233            quote! {
234                pub enum #enumeration {
235                    #(#variant_tokens)*
236                }
237            },
238        );
239    }
240    let sql = match sql_blocks.first() {
241        Some(condblock::SqlBlock::Literal(sql))
242        | Some(condblock::SqlBlock::Conditional(_, sql)) => sql,
243        None => "",
244    };
245    (TokenStream2::new(), quote! { #sql }, TokenStream2::new())
246}
247
248fn generate_impl_fns(
249    queries: Vec<Query>,
250    ctx: &Context,
251    functions_ts: &mut TokenStream2,
252    enums_ts: &mut TokenStream2,
253) {
254    for q in queries {
255        if let Some(doc) = &q.doc {
256            functions_ts.extend(quote! { #[doc = #doc] });
257        }
258        match q.kind {
259            Kind::Typed => generate_typed_fn(q, ctx, functions_ts, enums_ts),
260            Kind::Untyped => generate_untyped_fn(q, ctx, functions_ts, enums_ts),
261            Kind::Mapped => generate_mapped_fn(q, ctx, functions_ts, enums_ts),
262        }
263    }
264}
265
266fn generate_typed_fn(
267    q: Query,
268    Context(db, args, row, result): &Context,
269    functions_ts: &mut TokenStream2,
270    enums_ts: &mut TokenStream2,
271) {
272    let name = Ident::new(&q.name, Span::call_site());
273    let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
274
275    enums_ts.extend(enums);
276
277    functions_ts.extend(match q.method {
278        Method::FetchMany => {
279            quote! {
280                async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> futures_core::stream::BoxStream<'e, Result<T, sqlx::Error>>
281                where
282                      'q: 'e,
283                      'c: 'e,
284                      E: sqlx::Executor<'c, Database = #db> + 'e,
285                      T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
286                    sqlx::query_as_with(#sql, params).fetch(executor)
287                }
288            }
289        },
290        Method::FetchOne => {
291            quote! {
292                async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<T, sqlx::Error>
293                where
294                      'q: 'e,
295                      'c: 'e,
296                      E: sqlx::Executor<'c, Database = #db> + 'e,
297                      T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
298                    sqlx::query_as_with(#sql, params).fetch_one(executor).await
299                }
300            }
301        },
302        Method::FetchOptional => {
303            quote! {
304                async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<Option<T>, sqlx::Error>
305                where
306                      'q: 'e,
307                      'c: 'e,
308                      E: sqlx::Executor<'c, Database = #db> + 'e,
309                      T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
310                    sqlx::query_as_with(#sql, params).fetch_optional(executor).await
311                }
312            }
313        },
314        Method::FetchAll => {
315            quote! {
316                async fn #name<'q, 'e, 'c, E, T> (executor: E, #block_resolver params: #args) -> Result<Vec<T>, sqlx::Error>
317                where
318                     'q: 'e,
319                     'c: 'e,
320                      E: sqlx::Executor<'c, Database = #db> + 'e,
321                      T: Send + Unpin + for<'r> sqlx::FromRow<'r, #row> + 'e {
322                    sqlx::query_as_with(#sql, params).fetch_all(executor).await
323                }
324            }
325        },
326        Method::Execute => {
327            quote! {
328                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
329                where
330                 'q: 'e,
331                 'c: 'e,
332                 E: sqlx::Executor<'c, Database = #db> + 'e {
333                    sqlx::query_with(#sql, params).execute(executor).await
334                }
335            }
336        },
337    });
338}
339
340fn generate_untyped_fn(
341    q: Query,
342    Context(db, args, row, result): &Context,
343    functions_ts: &mut TokenStream2,
344    enums_ts: &mut TokenStream2,
345) {
346    let name = Ident::new(&q.name, Span::call_site());
347    let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
348
349    enums_ts.extend(enums);
350
351    functions_ts.extend(match q.method {
352        Method::FetchMany => {
353            quote! {
354                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> futures_core::stream::BoxStream<'e, Result<#row, sqlx::Error>>
355                where
356                 'q: 'e,
357                 'c: 'e,
358                 E: sqlx::Executor<'c, Database = #db> + 'e {
359                    sqlx::query_with(#sql, params).fetch(executor)
360                }
361            }
362        },
363        Method::FetchOne => {
364            quote! {
365                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#row, sqlx::Error>
366                where
367                 'q: 'e,
368                 'c: 'e,
369                 E: sqlx::Executor<'c, Database = #db> + 'e {
370                    sqlx::query_with(#sql, params).fetch_one(executor).await
371                }
372            }
373        },
374        Method::FetchOptional => {
375            quote! {
376                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<Option<#row>, sqlx::Error>
377                where
378                 'q: 'e,
379                 'c: 'e,
380                 E: sqlx::Executor<'c, Database = #db> + 'e {
381                    sqlx::query_with(#sql, params).fetch_optional(executor).await
382                }
383            }
384        },
385        Method::FetchAll => {
386            quote! {
387                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<Vec<#row>, sqlx::Error>
388                where
389                 'q: 'e,
390                 'c: 'e,
391                 E: sqlx::Executor<'c, Database = #db> + 'e {
392                    sqlx::query_with(#sql, params).fetch_all(executor).await
393                }
394            }
395        },
396        Method::Execute => {
397            quote! {
398                async fn #name<'q, 'e, 'c, E> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
399                where
400                 'q: 'e,
401                 'c: 'e,
402                 E: sqlx::Executor<'c, Database = #db> + 'e {
403                    sqlx::query_with(#sql, params).execute(executor).await
404                }
405            }
406        },
407    });
408}
409
410fn generate_mapped_fn(
411    q: Query,
412    Context(db, args, row, result): &Context,
413    functions_ts: &mut TokenStream2,
414    enums_ts: &mut TokenStream2,
415) {
416    let name = Ident::new(&q.name, Span::call_site());
417    let (block_resolver, sql, enums) = generate_cond_block_resolver_fn(&q);
418
419    enums_ts.extend(enums);
420
421    functions_ts.extend(match q.method {
422        Method::FetchMany => {
423            quote! {
424                async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> futures_core::stream::BoxStream<'e, Result<T, sqlx::Error>>
425                where
426                      'q: 'e,
427                      'c: 'e,
428                      E: sqlx::Executor<'c, Database = #db> + 'e,
429                      F: FnMut(#row) -> T + Send + 'e,
430                      T: Send + Unpin + 'e {
431                    sqlx::query_with(#sql, params)
432                        .map(mapper)
433                        .fetch(executor)
434                }
435            }
436        },
437        Method::FetchOne => {
438            quote! {
439                async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<T, sqlx::Error>
440                where
441                      'q: 'e,
442                      'c: 'e,
443                      E: sqlx::Executor<'c, Database = #db> + 'e,
444                      F: FnMut(#row) -> T + Send + 'e,
445                      T: Send + Unpin + 'e {
446                    sqlx::query_with(#sql, params)
447                        .map(mapper)
448                        .fetch_one(executor)
449                        .await
450                }
451            }
452        },
453        Method::FetchOptional => {
454            quote! {
455                async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<Option<T>, sqlx::Error>
456                where
457                      'q: 'e,
458                      'c: 'e,
459                      E: sqlx::Executor<'c, Database = #db> + 'e,
460                      F: FnMut(#row) -> T + Send + 'e,
461                      T: Send + Unpin + 'e {
462                    sqlx::query_with(#sql, params)
463                        .map(mapper)
464                        .fetch_optional(executor)
465                        .await
466                }
467            }
468        },
469        Method::FetchAll => {
470            quote! {
471                async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args, mapper: F) -> Result<Vec<T>, sqlx::Error>
472                where
473                      'q: 'e,
474                      'c: 'e,
475                      E: sqlx::Executor<'c, Database = #db> + 'e,
476                      F: FnMut(#row) -> T + Send + 'e,
477                      T: Send + Unpin + 'e {
478                    sqlx::query_with(#sql, params)
479                        .map(mapper)
480                        .fetch_all(executor)
481                        .await
482                }
483            }
484        },
485        Method::Execute => {
486            quote! {
487                async fn #name<'q, 'e, 'c, E, F, T> (executor: E, #block_resolver params: #args) -> Result<#result, sqlx::Error>
488                where
489                      'q: 'e,
490                      'c: 'e,
491                      E: sqlx::Executor<'c, Database = #db> + 'e {
492                    sqlx::query_with(#sql, params).execute(executor).await
493                }
494            }
495        },
496    });
497}
498
499cfg_if::cfg_if! {
500    if #[cfg(feature = "postgres")] {
501        #[macro_export]
502        macro_rules! params {
503            ($($arg:expr),*) => {
504                {
505                    use sqlx::Arguments;
506                    let mut args = sqlx::postgres::PgArguments::default();
507                    $( args.add($arg).unwrap(); )*
508                    args
509                }
510            };
511        }
512    } else if #[cfg(feature = "mysql")] {
513        #[macro_export]
514        macro_rules! params {
515            ($($arg:expr),*) => {
516                {
517                    use sqlx::Arguments;
518                    let mut args = sqlx::mysql::MySqlArguments::default();
519                    $( args.add($arg).unwrap(); )*
520                    args
521                }
522            };
523        }
524    } else {
525        #[macro_export]
526        macro_rules! params {
527            ($($arg:expr),*) => {
528                {
529                    use sqlx::Arguments;
530                    let mut args = sqlx::sqlite::SqliteArguments::default();
531                    $( args.add($arg).unwrap(); )*
532                    args
533                }
534            };
535        }
536    }
537}
538
539#[cfg(test)]
540mod test {
541    use crate::parser::{query_parser, Kind, Method};
542    use chumsky::Parser;
543
544    #[test]
545    fn parsing_defaults() {
546        let input = r#"
547-- :name fetch_users
548-- :doc Returns all the users from DB
549SELECT user_id, email, name, picture FROM users
550"#;
551
552        let queries = query_parser().parse(input).unwrap();
553        assert_eq!(queries.len(), 1);
554        assert_eq!(queries[0].name, "fetch_users");
555        assert_eq!(
556            queries[0].doc,
557            Some("Returns all the users from DB".to_string())
558        );
559        assert_eq!(queries[0].kind, Kind::Untyped);
560        assert_eq!(queries[0].method, Method::Execute);
561    }
562
563    #[test]
564    fn parsing_default_type() {
565        let input = r#"
566-- :name fetch_users :^
567SELECT user_id, email, name, picture FROM users
568"#;
569
570        let queries = query_parser().parse(input).unwrap();
571        assert_eq!(queries.len(), 1);
572        assert_eq!(queries[0].name, "fetch_users");
573        assert_eq!(queries[0].doc, None);
574        assert_eq!(queries[0].kind, Kind::Untyped);
575        assert_eq!(queries[0].method, Method::FetchMany);
576    }
577
578    #[test]
579    fn parsing_type_aliases() {
580        let input = r#"
581-- :name fetch_users :<> :^
582SELECT user_id, email, name, picture FROM users
583"#;
584
585        let queries = query_parser().parse(input).unwrap();
586        assert_eq!(queries.len(), 1);
587        assert_eq!(queries[0].name, "fetch_users");
588        assert_eq!(queries[0].doc, None);
589        assert_eq!(queries[0].kind, Kind::Typed);
590        assert_eq!(queries[0].method, Method::FetchMany);
591    }
592
593    #[test]
594    fn parsing_default_call_method() {
595        let input = r#"
596-- :name fetch_users :mapped
597SELECT user_id, email, name, picture FROM users
598"#;
599
600        let queries = query_parser().parse(input).unwrap();
601        assert_eq!(queries.len(), 1);
602        assert_eq!(queries[0].name, "fetch_users");
603        assert_eq!(queries[0].doc, None);
604        assert_eq!(queries[0].kind, Kind::Mapped);
605        assert_eq!(queries[0].method, Method::Execute);
606    }
607
608    #[test]
609    fn parsing_multiple() {
610        let input = r#"
611-- :name fetch_users
612-- :doc Returns all the users from DB
613SELECT user_id, email, name, picture FROM users
614
615-- :name fetch_user_by_id :untyped :1
616-- :doc Fetches user by its identifier
617SELECT user_id, email, name, picture
618  FROM users
619 WHERE user_id = $1
620
621-- :name set_picture :typed :1
622-- :doc Sets user's picture.
623-- Picture is expected to be a valid URL.
624UPDATE users
625   -- expected URL to the picture
626   SET picture = ?
627 WHERE user_id = ?
628
629-- :name delete_user :typed :1
630DELETE FROM users
631 WHERE user_id = ?
632"#;
633
634        let queries = query_parser().parse(input).unwrap();
635        assert_eq!(queries.len(), 4);
636
637        assert_eq!(queries[0].name, "fetch_users".to_string());
638        assert_eq!(
639            queries[0].doc,
640            Some("Returns all the users from DB".to_string())
641        );
642        assert_eq!(queries[0].kind, Kind::Untyped);
643        assert_eq!(queries[0].method, Method::Execute);
644
645        assert_eq!(queries[1].name, "fetch_user_by_id".to_string());
646        assert_eq!(
647            queries[1].doc,
648            Some("Fetches user by its identifier".to_string())
649        );
650        assert_eq!(queries[1].kind, Kind::Untyped);
651        assert_eq!(queries[1].method, Method::FetchOne);
652
653        assert_eq!(queries[2].name, "set_picture".to_string());
654        assert_eq!(
655            queries[2].doc,
656            Some("Sets user's picture.\nPicture is expected to be a valid URL.".to_string())
657        );
658        assert_eq!(queries[2].kind, Kind::Typed);
659        assert_eq!(queries[2].method, Method::FetchOne);
660
661        assert_eq!(queries[3].name, "delete_user".to_string());
662        assert_eq!(queries[3].doc, None);
663        assert_eq!(queries[3].kind, Kind::Typed);
664        assert_eq!(queries[3].method, Method::FetchOne);
665    }
666}