Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitInt,
51        LitStr,
52        Token,
53        parse::{
54            Parse,
55            ParseStream,
56        },
57        parse_macro_input,
58    },
59};
60
61mod convert;
62
63struct ParamType {
64    arr: bool,
65    opt: bool,
66    base: String,
67}
68
69struct GoodQueryInput {
70    db_mod: Ident,
71    version: Option<usize>,
72    db_name: String,
73    sql: String,
74    param_types: Vec<(Ident, ParamType)>,
75    conn: syn::Expr,
76    params: Vec<syn::Expr>,
77}
78
79impl Parse for GoodQueryInput {
80    fn parse(input: ParseStream) -> syn::Result<Self> {
81        let db_mod: Ident = input.parse()?;
82        input.parse::<Token![,]>()?;
83        let (db_name, version, sql) = {
84            let first: LitStr = input.parse()?;
85            if input.peek(Token![;]) {
86                input.parse::<Token![;]>()?;
87                ("".to_string(), None, first.value())
88            } else {
89                input.parse::<Token![,]>()?;
90                let lookahead = input.lookahead1();
91                if lookahead.peek(LitInt) {
92                    let version_lit: LitInt = input.parse()?;
93                    let version = version_lit.base10_parse::<usize>()?;
94                    input.parse::<Token![,]>()?;
95                    let sql_lit: LitStr = input.parse()?;
96                    let sql = sql_lit.value();
97                    input.parse::<Token![;]>()?;
98                    (first.value(), Some(version), sql)
99                } else if lookahead.peek(LitStr) {
100                    let sql_lit: LitStr = input.parse()?;
101                    let sql = sql_lit.value();
102                    input.parse::<Token![;]>()?;
103                    (first.value(), None, sql)
104                } else {
105                    return Err(lookahead.error());
106                }
107            }
108        };
109        let conn: syn::Expr = input.parse()?;
110        let mut param_types = Vec::new();
111        let mut params = Vec::new();
112        while input.peek(Token![,]) {
113            input.parse::<Token![,]>()?;
114            if input.is_empty() {
115                break;
116            }
117            let name: Ident = input.parse()?;
118            input.parse::<Token![:]>()?;
119            let mut arr = false;
120            let mut opt = false;
121            let mut base = String::new();
122            while input.peek(Ident) {
123                let id: Ident = input.parse()?;
124                if id == "arr" {
125                    arr = true;
126                } else if id == "opt" {
127                    opt = true;
128                } else {
129                    base = id.to_string();
130                    break;
131                }
132            }
133            if base.is_empty() {
134                return Err(input.error("Expected parameter type"));
135            }
136            input.parse::<Token![=]>()?;
137            let val: syn::Expr = input.parse()?;
138            param_types.push((name, ParamType {
139                arr: arr,
140                opt: opt,
141                base: base,
142            }));
143            params.push(val);
144        }
145        let mut final_sql = String::new();
146        let mut last_end = 0;
147        let bytes = sql.as_bytes();
148        let mut i = 0;
149        while i < bytes.len() {
150            if bytes[i] == b'$' {
151                if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
152                    final_sql.push_str(&sql[last_end .. i]);
153                    i += 2;
154                    let content_start = i;
155                    while i < bytes.len() && bytes[i] != b'}' {
156                        i += 1;
157                    }
158                    if i >= bytes.len() {
159                        return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
160                    }
161                    let content = &sql[content_start .. i];
162                    i += 1;
163                    let mut split = None;
164                    for (idx, b) in content.as_bytes().iter().enumerate() {
165                        if *b == b'=' {
166                            split = Some((&content[..idx], &content[idx + 1..]));
167                            break;
168                        }
169                    }
170                    let (type_str, val_str) = split.ok_or_else(|| {
171                        syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
172                    })?;
173                    let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
174                    params.push(val);
175                    param_types.push((name, pt));
176                    final_sql.push_str(&format!("${}", param_idx));
177                    last_end = i;
178                    continue;
179                }
180            }
181            i += 1;
182        }
183        final_sql.push_str(&sql[last_end..]);
184        Ok(GoodQueryInput {
185            db_mod: db_mod,
186            version: version,
187            db_name: db_name,
188            sql: final_sql,
189            param_types: param_types,
190            conn: conn,
191            params: params,
192        })
193    }
194}
195
196fn parse_inline_param(
197    input: ParseStream,
198    type_str: &str,
199    val_str: &str,
200    current_params_len: usize,
201) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
202    let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
203        syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
204    })?;
205    let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
206        syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
207    })?;
208
209    use syn::parse::Parser;
210
211    let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
212        let mut arr = false;
213        let mut opt = false;
214        let mut base = String::new();
215        while type_input.peek(Ident) {
216            let id: Ident = type_input.parse()?;
217            if id == "arr" {
218                arr = true;
219            } else if id == "opt" {
220                opt = true;
221            } else {
222                base = id.to_string();
223                break;
224            }
225        }
226        Ok((arr, opt, base))
227    }).parse2(type_tokens).map_err(|e| {
228        syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
229    })?;
230    if base_p.is_empty() {
231        return Err(input.error("Expected base type in inline parameter"));
232    }
233    let param_idx = current_params_len + 1;
234    let name = format_ident!("p{}", param_idx);
235    Ok((param_idx, name, ParamType {
236        arr: arr_p,
237        opt: opt_p,
238        base: base_p,
239    }, val))
240}
241
242fn get_db_info(_engine: &str, provided_db_name: String) -> String {
243    provided_db_name
244}
245
246fn parse_and_generate_pg(
247    input: GoodQueryInput,
248    res_count: good_ormning_core::QueryResCount,
249) -> proc_macro2::TokenStream {
250    let db_name = get_db_info("pg", input.db_name.clone());
251    let dialect = sqlparser::dialect::PostgreSqlDialect {};
252    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
253        Ok(ast) => ast,
254        Err(e) => {
255            let e = e.to_string();
256            return quote!(compile_error!(#e));
257        },
258    };
259    if ast.is_empty() {
260        return quote!(compile_error!("Empty SQL statement"));
261    }
262    let statement = &ast[0];
263    let mut errs = Errs::new();
264    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
265    let path =
266        std::path::Path::new(&out_dir)
267            .join("good_ormning")
268            .join(good_ormning_core::utils::json_file_name(&db_name));
269    if !path.exists() {
270        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
271        return quote!(compile_error!(#e));
272    }
273    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
274        Ok(m) => m,
275        Err(e) => {
276            let e = e.to_string();
277            return quote!(compile_error!(#e));
278        },
279    };
280    let mut field_lookup = HashMap::new();
281    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
282    let version = match versions_map.get(&version_i) {
283        Some(v) => v,
284        None => {
285            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
286            return quote!(compile_error!(#e));
287        },
288    };
289    let custom_types = version.custom_types.clone();
290    for (table_id, table) in &version.tables {
291        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
292        for (field_id, field) in &table.fields {
293            fields.insert(PgFieldRef {
294                table_id: table_id.clone(),
295                field_id: field_id.clone(),
296            }, PgFieldInfo {
297                sql_name: field.id.clone(),
298                type_: field.type_.type_.clone(),
299            });
300        }
301        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
302            sql_name: table.id.clone(),
303            fields: fields,
304        });
305    }
306    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
307    query.res_count = res_count;
308    let mut hasher = DefaultHasher::new();
309    input.sql.hash(&mut hasher);
310    let query_hash = hasher.finish();
311    let query_name = format_ident!("good_query_{}", query_hash);
312    query.name = query_name.to_string();
313    let pascal_db_name: String = db_name.to_case(Case::Pascal);
314    let db_mod = &input.db_mod;
315    let db_type = if let Some(v) = input.version {
316        let name = format_ident!("Db{}{}", pascal_db_name, v);
317        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
318    } else {
319        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
320        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
321    };
322    let generated =
323        good_ormning_core::pg::query::generate::generate_query_functions(
324            &mut errs,
325            field_lookup,
326            vec![query],
327            "inline",
328            db_type,
329        );
330    let conn = &input.conn;
331    let args = &input.params;
332    let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
333    let db_mod_str = input.db_mod.to_string();
334    let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
335    quote!{
336        {
337            const _:() = {
338                if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
339                    #[allow(unconditional_panic)]
340                    let _ = ["Database name mismatch"][1];
341                }
342            };
343            use ::good_ormning::runtime::GoodError;
344            use ::good_ormning::runtime::ToGoodError;
345            use ::good_ormning::runtime::pg::PgConnection;
346            #(#generated) * #query_name(#conn, #(#args,) *)
347        }
348    }
349}
350
351fn parse_and_generate_sqlite(
352    input: GoodQueryInput,
353    res_count: good_ormning_core::QueryResCount,
354) -> proc_macro2::TokenStream {
355    let db_name = get_db_info("sqlite", input.db_name.clone());
356    let dialect = sqlparser::dialect::SQLiteDialect {};
357    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
358        Ok(ast) => ast,
359        Err(e) => {
360            let e = e.to_string();
361            return quote!(compile_error!(#e));
362        },
363    };
364    if ast.is_empty() {
365        return quote!(compile_error!("Empty SQL statement"));
366    }
367    let statement = &ast[0];
368    let mut errs = Errs::new();
369    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
370    let path =
371        std::path::Path::new(&out_dir)
372            .join("good_ormning")
373            .join(good_ormning_core::utils::json_file_name(&db_name));
374    if !path.exists() {
375        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
376        return quote!(compile_error!(#e));
377    }
378    let versions_map: HashMap<usize, SqliteVersion> =
379        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
380            Ok(m) => m,
381            Err(e) => {
382                let e = e.to_string();
383                return quote!(compile_error!(#e));
384            },
385        };
386    let mut field_lookup = HashMap::new();
387    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
388    let version = match versions_map.get(&version_i) {
389        Some(v) => v,
390        None => {
391            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
392            return quote!(compile_error!(#e));
393        },
394    };
395    let custom_types = version.custom_types.clone();
396    for (table_id, table) in &version.tables {
397        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
398        for (field_id, field) in &table.fields {
399            fields.insert(SqliteFieldRef {
400                table_id: table_id.clone(),
401                field_id: field_id.clone(),
402            }, SqliteFieldInfo {
403                sql_name: field.id.clone(),
404                type_: field.type_.type_.clone(),
405            });
406        }
407        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
408            sql_name: table.id.clone(),
409            fields: fields,
410        });
411    }
412    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
413    query.res_count = res_count;
414    let mut hasher = DefaultHasher::new();
415    input.sql.hash(&mut hasher);
416    let query_hash = hasher.finish();
417    let query_name = format_ident!("good_query_{}", query_hash);
418    query.name = query_name.to_string();
419    let pascal_db_name: String = db_name.to_case(Case::Pascal);
420    let db_mod = &input.db_mod;
421    let db_type = if let Some(v) = input.version {
422        let name = format_ident!("Db{}{}", pascal_db_name, v);
423        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
424    } else {
425        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
426        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
427    };
428    let generated =
429        good_ormning_core::sqlite::query::generate::generate_query_functions(
430            &mut errs,
431            field_lookup,
432            vec![query],
433            "inline",
434            db_type,
435        );
436    let conn = &input.conn;
437    let args = &input.params;
438    let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
439    let db_mod_str = input.db_mod.to_string();
440    let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
441    quote!{
442        {
443            const _:() = {
444                if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
445                    #[allow(unconditional_panic)]
446                    let _ = ["Database name mismatch"][1];
447                }
448            };
449            use ::good_ormning::runtime::GoodError;
450            use ::good_ormning::runtime::ToGoodError;
451            use ::good_ormning::runtime::sqlite::SqliteConnection;
452            #(#generated) * #query_name(#conn, #(#args,) *)
453        }
454    }
455}
456
457/// See the `good_query` macro help in the readme.
458#[proc_macro]
459pub fn good_query_pg(input: TokenStream) -> TokenStream {
460    let input = parse_macro_input!(input as GoodQueryInput);
461    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
462}
463
464/// See the `good_query` macro help in the readme.
465#[proc_macro]
466pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
467    let input = parse_macro_input!(input as GoodQueryInput);
468    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
469}
470
471/// See the `good_query` macro help in the readme.
472#[proc_macro]
473pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
474    let input = parse_macro_input!(input as GoodQueryInput);
475    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
476}
477
478/// See the `good_query` macro help in the readme.
479#[proc_macro]
480pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
481    let input = parse_macro_input!(input as GoodQueryInput);
482    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
483}
484
485/// See the `good_query` macro help in the readme.
486#[proc_macro]
487pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
488    let input = parse_macro_input!(input as GoodQueryInput);
489    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
490}
491
492/// See the `good_query` macro help in the readme.
493#[proc_macro]
494pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
495    let input = parse_macro_input!(input as GoodQueryInput);
496    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
497}
498
499/// See the `good_query` macro help in the readme.
500#[proc_macro]
501pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
502    let input = parse_macro_input!(input as GoodQueryInput);
503    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
504}
505
506/// See the `good_query` macro help in the readme.
507#[proc_macro]
508pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
509    let input = parse_macro_input!(input as GoodQueryInput);
510    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
511}