Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitInt,
51        LitStr,
52        Token,
53        parse::{
54            Parse,
55            ParseStream,
56        },
57        parse_macro_input,
58    },
59};
60
61mod convert;
62
63struct ParamType {
64    arr: bool,
65    opt: bool,
66    base: String,
67}
68
69struct GoodQueryInput {
70    db_mod: Ident,
71    version: Option<usize>,
72    db_name: String,
73    sql: String,
74    param_types: Vec<(Ident, ParamType)>,
75    conn: syn::Expr,
76    params: Vec<syn::Expr>,
77}
78
79impl Parse for GoodQueryInput {
80    fn parse(input: ParseStream) -> syn::Result<Self> {
81        let db_mod: Ident = input.parse()?;
82        input.parse::<Token![,]>()?;
83        let (db_name, version, sql) = {
84            if input.peek(LitInt) {
85                let version_lit: LitInt = input.parse()?;
86                let version = version_lit.base10_parse::<usize>()?;
87                input.parse::<Token![,]>()?;
88                let sql_lit: LitStr = input.parse()?;
89                let sql = sql_lit.value();
90                input.parse::<Token![;]>()?;
91                ("".to_string(), Some(version), sql)
92            } else {
93                let first: LitStr = input.parse()?;
94                if input.peek(Token![;]) {
95                    input.parse::<Token![;]>()?;
96                    ("".to_string(), None, first.value())
97                } else {
98                    input.parse::<Token![,]>()?;
99                    let lookahead = input.lookahead1();
100                    if lookahead.peek(LitInt) {
101                        let version_lit: LitInt = input.parse()?;
102                        let version = version_lit.base10_parse::<usize>()?;
103                        input.parse::<Token![,]>()?;
104                        let sql_lit: LitStr = input.parse()?;
105                        let sql = sql_lit.value();
106                        input.parse::<Token![;]>()?;
107                        (first.value(), Some(version), sql)
108                    } else if lookahead.peek(LitStr) {
109                        let sql_lit: LitStr = input.parse()?;
110                        let sql = sql_lit.value();
111                        input.parse::<Token![;]>()?;
112                        (first.value(), None, sql)
113                    } else {
114                        return Err(lookahead.error());
115                    }
116                }
117            }
118        };
119        let conn: syn::Expr = input.parse()?;
120        let mut param_types = Vec::new();
121        let mut params = Vec::new();
122        while input.peek(Token![,]) {
123            input.parse::<Token![,]>()?;
124            if input.is_empty() {
125                break;
126            }
127            let name: Ident = input.parse()?;
128            input.parse::<Token![:]>()?;
129            let mut arr = false;
130            let mut opt = false;
131            let mut base = String::new();
132            while input.peek(Ident) {
133                let id: Ident = input.parse()?;
134                if id == "arr" {
135                    arr = true;
136                } else if id == "opt" {
137                    opt = true;
138                } else {
139                    base = id.to_string();
140                    break;
141                }
142            }
143            if base.is_empty() {
144                return Err(input.error("Expected parameter type"));
145            }
146            input.parse::<Token![=]>()?;
147            let val: syn::Expr = input.parse()?;
148            param_types.push((name, ParamType {
149                arr: arr,
150                opt: opt,
151                base: base,
152            }));
153            params.push(val);
154        }
155        let mut final_sql = String::new();
156        let mut last_end = 0;
157        let bytes = sql.as_bytes();
158        let mut i = 0;
159        while i < bytes.len() {
160            if bytes[i] == b'$' {
161                if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
162                    final_sql.push_str(&sql[last_end .. i]);
163                    i += 2;
164                    let content_start = i;
165                    while i < bytes.len() && bytes[i] != b'}' {
166                        i += 1;
167                    }
168                    if i >= bytes.len() {
169                        return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
170                    }
171                    let content = &sql[content_start .. i];
172                    i += 1;
173                    let mut split = None;
174                    for (idx, b) in content.as_bytes().iter().enumerate() {
175                        if *b == b'=' {
176                            split = Some((&content[..idx], &content[idx + 1..]));
177                            break;
178                        }
179                    }
180                    let (type_str, val_str) = split.ok_or_else(|| {
181                        syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
182                    })?;
183                    let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
184                    params.push(val);
185                    param_types.push((name, pt));
186                    final_sql.push_str(&format!("${}", param_idx));
187                    last_end = i;
188                    continue;
189                }
190            }
191            i += 1;
192        }
193        final_sql.push_str(&sql[last_end..]);
194        Ok(GoodQueryInput {
195            db_mod: db_mod,
196            version: version,
197            db_name: db_name,
198            sql: final_sql,
199            param_types: param_types,
200            conn: conn,
201            params: params,
202        })
203    }
204}
205
206fn parse_inline_param(
207    input: ParseStream,
208    type_str: &str,
209    val_str: &str,
210    current_params_len: usize,
211) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
212    let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
213        syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
214    })?;
215    let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
216        syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
217    })?;
218
219    use syn::parse::Parser;
220
221    let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
222        let mut arr = false;
223        let mut opt = false;
224        let mut base = String::new();
225        while type_input.peek(Ident) {
226            let id: Ident = type_input.parse()?;
227            if id == "arr" {
228                arr = true;
229            } else if id == "opt" {
230                opt = true;
231            } else {
232                base = id.to_string();
233                break;
234            }
235        }
236        Ok((arr, opt, base))
237    }).parse2(type_tokens).map_err(|e| {
238        syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
239    })?;
240    if base_p.is_empty() {
241        return Err(input.error("Expected base type in inline parameter"));
242    }
243    let param_idx = current_params_len + 1;
244    let name = format_ident!("p{}", param_idx);
245    Ok((param_idx, name, ParamType {
246        arr: arr_p,
247        opt: opt_p,
248        base: base_p,
249    }, val))
250}
251
252fn get_db_info(_engine: &str, provided_db_name: String) -> String {
253    provided_db_name
254}
255
256fn parse_and_generate_pg(
257    input: GoodQueryInput,
258    res_count: good_ormning_core::QueryResCount,
259) -> proc_macro2::TokenStream {
260    let db_name = get_db_info("pg", input.db_name.clone());
261    let dialect = sqlparser::dialect::PostgreSqlDialect {};
262    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
263        Ok(ast) => ast,
264        Err(e) => {
265            let e = e.to_string();
266            return quote!(compile_error!(#e));
267        },
268    };
269    if ast.is_empty() {
270        return quote!(compile_error!("Empty SQL statement"));
271    }
272    let statement = &ast[0];
273    let mut errs = Errs::new();
274    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
275    let path =
276        std::path::Path::new(&out_dir)
277            .join("good_ormning")
278            .join(good_ormning_core::utils::json_file_name(&db_name));
279    if !path.exists() {
280        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
281        return quote!(compile_error!(#e));
282    }
283    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
284        Ok(m) => m,
285        Err(e) => {
286            let e = e.to_string();
287            return quote!(compile_error!(#e));
288        },
289    };
290    let mut field_lookup = HashMap::new();
291    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
292    let version = match versions_map.get(&version_i) {
293        Some(v) => v,
294        None => {
295            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
296            return quote!(compile_error!(#e));
297        },
298    };
299    let custom_types = version.custom_types.clone();
300    for (table_id, table) in &version.tables {
301        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
302        for (field_id, field) in &table.fields {
303            fields.insert(PgFieldRef {
304                table_id: table_id.clone(),
305                field_id: field_id.clone(),
306            }, PgFieldInfo {
307                sql_name: field.id.clone(),
308                type_: field.type_.type_.clone(),
309            });
310        }
311        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
312            sql_name: table.id.clone(),
313            fields: fields,
314        });
315    }
316    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
317    query.res_count = res_count;
318    let mut hasher = DefaultHasher::new();
319    input.sql.hash(&mut hasher);
320    let query_hash = hasher.finish();
321    let query_name = format_ident!("good_query_{}", query_hash);
322    query.name = query_name.to_string();
323    let pascal_db_name: String = db_name.to_case(Case::Pascal);
324    let db_mod = &input.db_mod;
325    let db_type = if let Some(v) = input.version {
326        let name = format_ident!("Db{}{}", pascal_db_name, v);
327        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
328    } else {
329        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
330        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
331    };
332    let generated =
333        good_ormning_core::pg::query::generate::generate_query_functions(
334            &mut errs,
335            field_lookup,
336            vec![query],
337            "inline",
338            db_type,
339        );
340    let conn = &input.conn;
341    let args = &input.params;
342    let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
343    let db_mod_str = input.db_mod.to_string();
344    let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
345    quote!{
346        {
347            const _:() = {
348                if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
349                    #[allow(unconditional_panic)]
350                    let _ = ["Database name mismatch"][1];
351                }
352            };
353            use ::good_ormning::runtime::GoodError;
354            use ::good_ormning::runtime::ToGoodError;
355            use ::good_ormning::runtime::pg::PgConnection;
356            #(#generated) * #query_name(#conn, #(#args,) *)
357        }
358    }
359}
360
361fn parse_and_generate_sqlite(
362    input: GoodQueryInput,
363    res_count: good_ormning_core::QueryResCount,
364) -> proc_macro2::TokenStream {
365    let db_name = get_db_info("sqlite", input.db_name.clone());
366    let dialect = sqlparser::dialect::SQLiteDialect {};
367    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
368        Ok(ast) => ast,
369        Err(e) => {
370            let e = e.to_string();
371            return quote!(compile_error!(#e));
372        },
373    };
374    if ast.is_empty() {
375        return quote!(compile_error!("Empty SQL statement"));
376    }
377    let statement = &ast[0];
378    let mut errs = Errs::new();
379    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
380    let path =
381        std::path::Path::new(&out_dir)
382            .join("good_ormning")
383            .join(good_ormning_core::utils::json_file_name(&db_name));
384    if !path.exists() {
385        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
386        return quote!(compile_error!(#e));
387    }
388    let versions_map: HashMap<usize, SqliteVersion> =
389        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
390            Ok(m) => m,
391            Err(e) => {
392                let e = e.to_string();
393                return quote!(compile_error!(#e));
394            },
395        };
396    let mut field_lookup = HashMap::new();
397    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
398    let version = match versions_map.get(&version_i) {
399        Some(v) => v,
400        None => {
401            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
402            return quote!(compile_error!(#e));
403        },
404    };
405    let custom_types = version.custom_types.clone();
406    for (table_id, table) in &version.tables {
407        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
408        for (field_id, field) in &table.fields {
409            fields.insert(SqliteFieldRef {
410                table_id: table_id.clone(),
411                field_id: field_id.clone(),
412            }, SqliteFieldInfo {
413                sql_name: field.id.clone(),
414                type_: field.type_.type_.clone(),
415            });
416        }
417        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
418            sql_name: table.id.clone(),
419            fields: fields,
420        });
421    }
422    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
423    query.res_count = res_count;
424    let mut hasher = DefaultHasher::new();
425    input.sql.hash(&mut hasher);
426    let query_hash = hasher.finish();
427    let query_name = format_ident!("good_query_{}", query_hash);
428    query.name = query_name.to_string();
429    let pascal_db_name: String = db_name.to_case(Case::Pascal);
430    let db_mod = &input.db_mod;
431    let db_type = if let Some(v) = input.version {
432        let name = format_ident!("Db{}{}", pascal_db_name, v);
433        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
434    } else {
435        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
436        quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
437    };
438    let generated =
439        good_ormning_core::sqlite::query::generate::generate_query_functions(
440            &mut errs,
441            field_lookup,
442            vec![query],
443            "inline",
444            db_type,
445        );
446    let conn = &input.conn;
447    let args = &input.params;
448    let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
449    let db_mod_str = input.db_mod.to_string();
450    let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
451    quote!{
452        {
453            const _:() = {
454                if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
455                    #[allow(unconditional_panic)]
456                    let _ = ["Database name mismatch"][1];
457                }
458            };
459            use ::good_ormning::runtime::GoodError;
460            use ::good_ormning::runtime::ToGoodError;
461            use ::good_ormning::runtime::sqlite::SqliteConnection;
462            #(#generated) * #query_name(#conn, #(#args,) *)
463        }
464    }
465}
466
467/// See the `good_query` macro help in the readme.
468#[proc_macro]
469pub fn good_query_pg(input: TokenStream) -> TokenStream {
470    let input = parse_macro_input!(input as GoodQueryInput);
471    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
472}
473
474/// See the `good_query` macro help in the readme.
475#[proc_macro]
476pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
477    let input = parse_macro_input!(input as GoodQueryInput);
478    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
479}
480
481/// See the `good_query` macro help in the readme.
482#[proc_macro]
483pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
484    let input = parse_macro_input!(input as GoodQueryInput);
485    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
486}
487
488/// See the `good_query` macro help in the readme.
489#[proc_macro]
490pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
491    let input = parse_macro_input!(input as GoodQueryInput);
492    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
493}
494
495/// See the `good_query` macro help in the readme.
496#[proc_macro]
497pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
498    let input = parse_macro_input!(input as GoodQueryInput);
499    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
500}
501
502/// See the `good_query` macro help in the readme.
503#[proc_macro]
504pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
505    let input = parse_macro_input!(input as GoodQueryInput);
506    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
507}
508
509/// See the `good_query` macro help in the readme.
510#[proc_macro]
511pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
512    let input = parse_macro_input!(input as GoodQueryInput);
513    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
514}
515
516/// See the `good_query` macro help in the readme.
517#[proc_macro]
518pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
519    let input = parse_macro_input!(input as GoodQueryInput);
520    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
521}