Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitInt,
51        LitStr,
52        Token,
53        parse::{
54            Parse,
55            ParseStream,
56        },
57        parse_macro_input,
58    },
59};
60
61mod convert;
62
63struct ParamType {
64    arr: bool,
65    opt: bool,
66    base: String,
67}
68
69struct GoodQueryInput {
70    version: Option<usize>,
71    db_name: String,
72    sql: String,
73    param_types: Vec<(Ident, ParamType)>,
74    conn: syn::Expr,
75    params: Vec<syn::Expr>,
76}
77
78impl Parse for GoodQueryInput {
79    fn parse(input: ParseStream) -> syn::Result<Self> {
80        let (db_name, version, sql) = {
81            let first: LitStr = input.parse()?;
82
83            if input.peek(Token![;]) {
84                input.parse::<Token![;]>()?;
85                ("".to_string(), None, first.value())
86            } else {
87                input.parse::<Token![,]>()?;
88
89                let lookahead = input.lookahead1();
90                if lookahead.peek(LitInt) {
91                    let version_lit: LitInt = input.parse()?;
92                    let version = version_lit.base10_parse::<usize>()?;
93                    input.parse::<Token![,]>()?;
94                    let sql_lit: LitStr = input.parse()?;
95                    let sql = sql_lit.value();
96                    input.parse::<Token![;]>()?;
97                    (first.value(), Some(version), sql)
98                } else if lookahead.peek(LitStr) {
99                    let sql_lit: LitStr = input.parse()?;
100                    let sql = sql_lit.value();
101                    input.parse::<Token![;]>()?;
102                    (first.value(), None, sql)
103                } else {
104                    return Err(lookahead.error());
105                }
106            }
107        };
108        let conn: syn::Expr = input.parse()?;
109        let mut param_types = Vec::new();
110        let mut params = Vec::new();
111        while input.peek(Token![,]) {
112            input.parse::<Token![,]>()?;
113            if input.is_empty() {
114                break;
115            }
116            let name: Ident = input.parse()?;
117            input.parse::<Token![:]>()?;
118            let mut arr = false;
119            let mut opt = false;
120            let mut base = String::new();
121            while input.peek(Ident) {
122                let id: Ident = input.parse()?;
123                if id == "arr" {
124                    arr = true;
125                } else if id == "opt" {
126                    opt = true;
127                } else {
128                    base = id.to_string();
129                    break;
130                }
131            }
132            if base.is_empty() {
133                return Err(input.error("Expected parameter type"));
134            }
135            input.parse::<Token![=]>()?;
136            let val: syn::Expr = input.parse()?;
137            param_types.push((name, ParamType {
138                arr: arr,
139                opt: opt,
140                base: base,
141            }));
142            params.push(val);
143        }
144        let mut final_sql = String::new();
145        let mut last_end = 0;
146        let bytes = sql.as_bytes();
147        let mut i = 0;
148        while i < bytes.len() {
149            if bytes[i] == b'$' {
150                if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
151                    final_sql.push_str(&sql[last_end .. i]);
152                    i += 2;
153                    let content_start = i;
154                    while i < bytes.len() && bytes[i] != b'}' {
155                        i += 1;
156                    }
157                    if i >= bytes.len() {
158                        return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
159                    }
160                    let content = &sql[content_start .. i];
161                    i += 1;
162                    let mut split = None;
163                    for (idx, b) in content.as_bytes().iter().enumerate() {
164                        if *b == b'=' {
165                            split = Some((&content[..idx], &content[idx + 1..]));
166                            break;
167                        }
168                    }
169                    let (type_str, val_str) = split.ok_or_else(|| {
170                        syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
171                    })?;
172                    let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
173                    params.push(val);
174                    param_types.push((name, pt));
175                    final_sql.push_str(&format!("${}", param_idx));
176                    last_end = i;
177                    continue;
178                }
179            }
180            i += 1;
181        }
182        final_sql.push_str(&sql[last_end..]);
183        Ok(GoodQueryInput {
184            version: version,
185            db_name: db_name,
186            sql: final_sql,
187            param_types: param_types,
188            conn: conn,
189            params: params,
190        })
191    }
192}
193
194fn parse_inline_param(
195    input: ParseStream,
196    type_str: &str,
197    val_str: &str,
198    current_params_len: usize,
199) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
200    let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
201        syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
202    })?;
203    let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
204        syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
205    })?;
206
207    use syn::parse::Parser;
208
209    let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
210        let mut arr = false;
211        let mut opt = false;
212        let mut base = String::new();
213        while type_input.peek(Ident) {
214            let id: Ident = type_input.parse()?;
215            if id == "arr" {
216                arr = true;
217            } else if id == "opt" {
218                opt = true;
219            } else {
220                base = id.to_string();
221                break;
222            }
223        }
224        Ok((arr, opt, base))
225    }).parse2(type_tokens).map_err(|e| {
226        syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
227    })?;
228    if base_p.is_empty() {
229        return Err(input.error("Expected base type in inline parameter"));
230    }
231    let param_idx = current_params_len + 1;
232    let name = format_ident!("p{}", param_idx);
233    Ok((param_idx, name, ParamType {
234        arr: arr_p,
235        opt: opt_p,
236        base: base_p,
237    }, val))
238}
239
240fn get_db_info(_engine: &str, provided_db_name: String) -> String {
241    provided_db_name
242}
243
244fn parse_and_generate_pg(
245    input: GoodQueryInput,
246    res_count: good_ormning_core::QueryResCount,
247) -> proc_macro2::TokenStream {
248    let db_name = get_db_info("pg", input.db_name.clone());
249    let dialect = sqlparser::dialect::PostgreSqlDialect {};
250    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
251        Ok(ast) => ast,
252        Err(e) => {
253            let e = e.to_string();
254            return quote!(compile_error!(#e));
255        },
256    };
257    if ast.is_empty() {
258        return quote!(compile_error!("Empty SQL statement"));
259    }
260    let statement = &ast[0];
261    let mut errs = Errs::new();
262    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
263    let path =
264        std::path::Path::new(&out_dir)
265            .join("good_ormning")
266            .join(good_ormning_core::utils::json_file_name(&db_name));
267    if !path.exists() {
268        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
269        return quote!(compile_error!(#e));
270    }
271    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
272        Ok(m) => m,
273        Err(e) => {
274            let e = e.to_string();
275            return quote!(compile_error!(#e));
276        },
277    };
278    let mut field_lookup = HashMap::new();
279    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
280    let version = match versions_map.get(&version_i) {
281        Some(v) => v,
282        None => {
283            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
284            return quote!(compile_error!(#e));
285        },
286    };
287    let custom_types = version.custom_types.clone();
288    for (table_id, table) in &version.tables {
289        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
290        for (field_id, field) in &table.fields {
291            fields.insert(PgFieldRef {
292                table_id: table_id.clone(),
293                field_id: field_id.clone(),
294            }, PgFieldInfo {
295                sql_name: field.id.clone(),
296                type_: field.type_.type_.clone(),
297            });
298        }
299        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
300            sql_name: table.id.clone(),
301            fields: fields,
302        });
303    }
304    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
305    query.res_count = res_count;
306    let mut hasher = DefaultHasher::new();
307    input.sql.hash(&mut hasher);
308    let query_hash = hasher.finish();
309    let query_name = format_ident!("good_query_{}", query_hash);
310    query.name = query_name.to_string();
311    let pascal_db_name: String = db_name.to_case(Case::Pascal);
312    let db_type = if let Some(v) = input.version {
313        let name = format_ident!("Db{}{}", pascal_db_name, v);
314        quote!(dbm::#name <impl ::good_ormning::runtime::pg::PgConnection>)
315    } else {
316        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
317        quote!(dbm::#name <impl ::good_ormning::runtime::pg::PgConnection>)
318    };
319    let generated =
320        good_ormning_core::pg::query::generate::generate_query_functions(
321            &mut errs,
322            field_lookup,
323            vec![query],
324            "inline",
325            db_type,
326        );
327    let conn = &input.conn;
328    let args = &input.params;
329    quote!{
330        {
331            use ::good_ormning::runtime::GoodError;
332            use ::good_ormning::runtime::ToGoodError;
333            use ::good_ormning::runtime::pg::PgConnection;
334            #(#generated) * #query_name(#conn, #(#args,) *)
335        }
336    }
337}
338
339fn parse_and_generate_sqlite(
340    input: GoodQueryInput,
341    res_count: good_ormning_core::QueryResCount,
342) -> proc_macro2::TokenStream {
343    let db_name = get_db_info("sqlite", input.db_name.clone());
344    let dialect = sqlparser::dialect::SQLiteDialect {};
345    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
346        Ok(ast) => ast,
347        Err(e) => {
348            let e = e.to_string();
349            return quote!(compile_error!(#e));
350        },
351    };
352    if ast.is_empty() {
353        return quote!(compile_error!("Empty SQL statement"));
354    }
355    let statement = &ast[0];
356    let mut errs = Errs::new();
357    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
358    let path =
359        std::path::Path::new(&out_dir)
360            .join("good_ormning")
361            .join(good_ormning_core::utils::json_file_name(&db_name));
362    if !path.exists() {
363        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
364        return quote!(compile_error!(#e));
365    }
366    let versions_map: HashMap<usize, SqliteVersion> =
367        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
368            Ok(m) => m,
369            Err(e) => {
370                let e = e.to_string();
371                return quote!(compile_error!(#e));
372            },
373        };
374    let mut field_lookup = HashMap::new();
375    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
376    let version = match versions_map.get(&version_i) {
377        Some(v) => v,
378        None => {
379            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
380            return quote!(compile_error!(#e));
381        },
382    };
383    let custom_types = version.custom_types.clone();
384    for (table_id, table) in &version.tables {
385        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
386        for (field_id, field) in &table.fields {
387            fields.insert(SqliteFieldRef {
388                table_id: table_id.clone(),
389                field_id: field_id.clone(),
390            }, SqliteFieldInfo {
391                sql_name: field.id.clone(),
392                type_: field.type_.type_.clone(),
393            });
394        }
395        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
396            sql_name: table.id.clone(),
397            fields: fields,
398        });
399    }
400    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
401    query.res_count = res_count;
402    let mut hasher = DefaultHasher::new();
403    input.sql.hash(&mut hasher);
404    let query_hash = hasher.finish();
405    let query_name = format_ident!("good_query_{}", query_hash);
406    query.name = query_name.to_string();
407    let pascal_db_name: String = db_name.to_case(Case::Pascal);
408    let db_type = if let Some(v) = input.version {
409        let name = format_ident!("Db{}{}", pascal_db_name, v);
410        quote!(dbm::#name <impl ::good_ormning::runtime::sqlite::SqliteConnection>)
411    } else {
412        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
413        quote!(dbm::#name <impl ::good_ormning::runtime::sqlite::SqliteConnection>)
414    };
415    let generated =
416        good_ormning_core::sqlite::query::generate::generate_query_functions(
417            &mut errs,
418            field_lookup,
419            vec![query],
420            "inline",
421            db_type,
422        );
423    let conn = &input.conn;
424    let args = &input.params;
425    quote!{
426        {
427            use ::good_ormning::runtime::GoodError;
428            use ::good_ormning::runtime::ToGoodError;
429            use ::good_ormning::runtime::sqlite::SqliteConnection;
430            #(#generated) * #query_name(#conn, #(#args,) *)
431        }
432    }
433}
434
435/// See the `good_query` macro help in the readme.
436#[proc_macro]
437pub fn good_query_pg(input: TokenStream) -> TokenStream {
438    let input = parse_macro_input!(input as GoodQueryInput);
439    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
440}
441
442/// See the `good_query` macro help in the readme.
443#[proc_macro]
444pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
445    let input = parse_macro_input!(input as GoodQueryInput);
446    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
447}
448
449/// See the `good_query` macro help in the readme.
450#[proc_macro]
451pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
452    let input = parse_macro_input!(input as GoodQueryInput);
453    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
454}
455
456/// See the `good_query` macro help in the readme.
457#[proc_macro]
458pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
459    let input = parse_macro_input!(input as GoodQueryInput);
460    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
461}
462
463/// See the `good_query` macro help in the readme.
464#[proc_macro]
465pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
466    let input = parse_macro_input!(input as GoodQueryInput);
467    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
468}
469
470/// See the `good_query` macro help in the readme.
471#[proc_macro]
472pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
473    let input = parse_macro_input!(input as GoodQueryInput);
474    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
475}
476
477/// See the `good_query` macro help in the readme.
478#[proc_macro]
479pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
480    let input = parse_macro_input!(input as GoodQueryInput);
481    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
482}
483
484/// See the `good_query` macro help in the readme.
485#[proc_macro]
486pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
487    let input = parse_macro_input!(input as GoodQueryInput);
488    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
489}