Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitInt,
51        LitStr,
52        Token,
53        parse::{
54            Parse,
55            ParseStream,
56        },
57        parse_macro_input,
58    },
59};
60
61mod convert;
62
63struct ParamType {
64    arr: bool,
65    opt: bool,
66    base: String,
67}
68
69struct GoodQueryInput {
70    version: Option<usize>,
71    db_name: String,
72    sql: String,
73    param_types: Vec<(Ident, ParamType)>,
74    conn: syn::Expr,
75    params: Vec<syn::Expr>,
76}
77
78impl Parse for GoodQueryInput {
79    fn parse(input: ParseStream) -> syn::Result<Self> {
80        let (db_name, version, sql) = {
81            let first: LitStr = input.parse()?;
82            if input.peek(Token![;]) {
83                input.parse::<Token![;]>()?;
84                ("".to_string(), None, first.value())
85            } else {
86                input.parse::<Token![,]>()?;
87                let lookahead = input.lookahead1();
88                if lookahead.peek(LitInt) {
89                    let version_lit: LitInt = input.parse()?;
90                    let version = version_lit.base10_parse::<usize>()?;
91                    input.parse::<Token![,]>()?;
92                    let sql_lit: LitStr = input.parse()?;
93                    let sql = sql_lit.value();
94                    input.parse::<Token![;]>()?;
95                    (first.value(), Some(version), sql)
96                } else if lookahead.peek(LitStr) {
97                    let sql_lit: LitStr = input.parse()?;
98                    let sql = sql_lit.value();
99                    input.parse::<Token![;]>()?;
100                    (first.value(), None, sql)
101                } else {
102                    return Err(lookahead.error());
103                }
104            }
105        };
106        let conn: syn::Expr = input.parse()?;
107        let mut param_types = Vec::new();
108        let mut params = Vec::new();
109        while input.peek(Token![,]) {
110            input.parse::<Token![,]>()?;
111            if input.is_empty() {
112                break;
113            }
114            let name: Ident = input.parse()?;
115            input.parse::<Token![:]>()?;
116            let mut arr = false;
117            let mut opt = false;
118            let mut base = String::new();
119            while input.peek(Ident) {
120                let id: Ident = input.parse()?;
121                if id == "arr" {
122                    arr = true;
123                } else if id == "opt" {
124                    opt = true;
125                } else {
126                    base = id.to_string();
127                    break;
128                }
129            }
130            if base.is_empty() {
131                return Err(input.error("Expected parameter type"));
132            }
133            input.parse::<Token![=]>()?;
134            let val: syn::Expr = input.parse()?;
135            param_types.push((name, ParamType {
136                arr: arr,
137                opt: opt,
138                base: base,
139            }));
140            params.push(val);
141        }
142        let mut final_sql = String::new();
143        let mut last_end = 0;
144        let bytes = sql.as_bytes();
145        let mut i = 0;
146        while i < bytes.len() {
147            if bytes[i] == b'$' {
148                if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
149                    final_sql.push_str(&sql[last_end .. i]);
150                    i += 2;
151                    let content_start = i;
152                    while i < bytes.len() && bytes[i] != b'}' {
153                        i += 1;
154                    }
155                    if i >= bytes.len() {
156                        return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
157                    }
158                    let content = &sql[content_start .. i];
159                    i += 1;
160                    let mut split = None;
161                    for (idx, b) in content.as_bytes().iter().enumerate() {
162                        if *b == b'=' {
163                            split = Some((&content[..idx], &content[idx + 1..]));
164                            break;
165                        }
166                    }
167                    let (type_str, val_str) = split.ok_or_else(|| {
168                        syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
169                    })?;
170                    let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
171                    params.push(val);
172                    param_types.push((name, pt));
173                    final_sql.push_str(&format!("${}", param_idx));
174                    last_end = i;
175                    continue;
176                }
177            }
178            i += 1;
179        }
180        final_sql.push_str(&sql[last_end..]);
181        Ok(GoodQueryInput {
182            version: version,
183            db_name: db_name,
184            sql: final_sql,
185            param_types: param_types,
186            conn: conn,
187            params: params,
188        })
189    }
190}
191
192fn parse_inline_param(
193    input: ParseStream,
194    type_str: &str,
195    val_str: &str,
196    current_params_len: usize,
197) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
198    let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
199        syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
200    })?;
201    let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
202        syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
203    })?;
204
205    use syn::parse::Parser;
206
207    let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
208        let mut arr = false;
209        let mut opt = false;
210        let mut base = String::new();
211        while type_input.peek(Ident) {
212            let id: Ident = type_input.parse()?;
213            if id == "arr" {
214                arr = true;
215            } else if id == "opt" {
216                opt = true;
217            } else {
218                base = id.to_string();
219                break;
220            }
221        }
222        Ok((arr, opt, base))
223    }).parse2(type_tokens).map_err(|e| {
224        syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
225    })?;
226    if base_p.is_empty() {
227        return Err(input.error("Expected base type in inline parameter"));
228    }
229    let param_idx = current_params_len + 1;
230    let name = format_ident!("p{}", param_idx);
231    Ok((param_idx, name, ParamType {
232        arr: arr_p,
233        opt: opt_p,
234        base: base_p,
235    }, val))
236}
237
238fn get_db_info(_engine: &str, provided_db_name: String) -> String {
239    provided_db_name
240}
241
242fn parse_and_generate_pg(
243    input: GoodQueryInput,
244    res_count: good_ormning_core::QueryResCount,
245) -> proc_macro2::TokenStream {
246    let db_name = get_db_info("pg", input.db_name.clone());
247    let dialect = sqlparser::dialect::PostgreSqlDialect {};
248    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
249        Ok(ast) => ast,
250        Err(e) => {
251            let e = e.to_string();
252            return quote!(compile_error!(#e));
253        },
254    };
255    if ast.is_empty() {
256        return quote!(compile_error!("Empty SQL statement"));
257    }
258    let statement = &ast[0];
259    let mut errs = Errs::new();
260    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
261    let path =
262        std::path::Path::new(&out_dir)
263            .join("good_ormning")
264            .join(good_ormning_core::utils::json_file_name(&db_name));
265    if !path.exists() {
266        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
267        return quote!(compile_error!(#e));
268    }
269    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
270        Ok(m) => m,
271        Err(e) => {
272            let e = e.to_string();
273            return quote!(compile_error!(#e));
274        },
275    };
276    let mut field_lookup = HashMap::new();
277    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
278    let version = match versions_map.get(&version_i) {
279        Some(v) => v,
280        None => {
281            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
282            return quote!(compile_error!(#e));
283        },
284    };
285    let custom_types = version.custom_types.clone();
286    for (table_id, table) in &version.tables {
287        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
288        for (field_id, field) in &table.fields {
289            fields.insert(PgFieldRef {
290                table_id: table_id.clone(),
291                field_id: field_id.clone(),
292            }, PgFieldInfo {
293                sql_name: field.id.clone(),
294                type_: field.type_.type_.clone(),
295            });
296        }
297        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
298            sql_name: table.id.clone(),
299            fields: fields,
300        });
301    }
302    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
303    query.res_count = res_count;
304    let mut hasher = DefaultHasher::new();
305    input.sql.hash(&mut hasher);
306    let query_hash = hasher.finish();
307    let query_name = format_ident!("good_query_{}", query_hash);
308    query.name = query_name.to_string();
309    let pascal_db_name: String = db_name.to_case(Case::Pascal);
310    let db_type = if let Some(v) = input.version {
311        let name = format_ident!("Db{}{}", pascal_db_name, v);
312        quote!(dbm::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
313    } else {
314        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
315        quote!(dbm::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
316    };
317    let generated =
318        good_ormning_core::pg::query::generate::generate_query_functions(
319            &mut errs,
320            field_lookup,
321            vec![query],
322            "inline",
323            db_type,
324        );
325    let conn = &input.conn;
326    let args = &input.params;
327    quote!{
328        {
329            use ::good_ormning::runtime::GoodError;
330            use ::good_ormning::runtime::ToGoodError;
331            use ::good_ormning::runtime::pg::PgConnection;
332            #(#generated) * #query_name(#conn, #(#args,) *)
333        }
334    }
335}
336
337fn parse_and_generate_sqlite(
338    input: GoodQueryInput,
339    res_count: good_ormning_core::QueryResCount,
340) -> proc_macro2::TokenStream {
341    let db_name = get_db_info("sqlite", input.db_name.clone());
342    let dialect = sqlparser::dialect::SQLiteDialect {};
343    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
344        Ok(ast) => ast,
345        Err(e) => {
346            let e = e.to_string();
347            return quote!(compile_error!(#e));
348        },
349    };
350    if ast.is_empty() {
351        return quote!(compile_error!("Empty SQL statement"));
352    }
353    let statement = &ast[0];
354    let mut errs = Errs::new();
355    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
356    let path =
357        std::path::Path::new(&out_dir)
358            .join("good_ormning")
359            .join(good_ormning_core::utils::json_file_name(&db_name));
360    if !path.exists() {
361        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
362        return quote!(compile_error!(#e));
363    }
364    let versions_map: HashMap<usize, SqliteVersion> =
365        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
366            Ok(m) => m,
367            Err(e) => {
368                let e = e.to_string();
369                return quote!(compile_error!(#e));
370            },
371        };
372    let mut field_lookup = HashMap::new();
373    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
374    let version = match versions_map.get(&version_i) {
375        Some(v) => v,
376        None => {
377            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
378            return quote!(compile_error!(#e));
379        },
380    };
381    let custom_types = version.custom_types.clone();
382    for (table_id, table) in &version.tables {
383        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
384        for (field_id, field) in &table.fields {
385            fields.insert(SqliteFieldRef {
386                table_id: table_id.clone(),
387                field_id: field_id.clone(),
388            }, SqliteFieldInfo {
389                sql_name: field.id.clone(),
390                type_: field.type_.type_.clone(),
391            });
392        }
393        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
394            sql_name: table.id.clone(),
395            fields: fields,
396        });
397    }
398    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
399    query.res_count = res_count;
400    let mut hasher = DefaultHasher::new();
401    input.sql.hash(&mut hasher);
402    let query_hash = hasher.finish();
403    let query_name = format_ident!("good_query_{}", query_hash);
404    query.name = query_name.to_string();
405    let pascal_db_name: String = db_name.to_case(Case::Pascal);
406    let db_type = if let Some(v) = input.version {
407        let name = format_ident!("Db{}{}", pascal_db_name, v);
408        quote!(dbm::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
409    } else {
410        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
411        quote!(dbm::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
412    };
413    let generated =
414        good_ormning_core::sqlite::query::generate::generate_query_functions(
415            &mut errs,
416            field_lookup,
417            vec![query],
418            "inline",
419            db_type,
420        );
421    let conn = &input.conn;
422    let args = &input.params;
423    quote!{
424        {
425            use ::good_ormning::runtime::GoodError;
426            use ::good_ormning::runtime::ToGoodError;
427            use ::good_ormning::runtime::sqlite::SqliteConnection;
428            #(#generated) * #query_name(#conn, #(#args,) *)
429        }
430    }
431}
432
433/// See the `good_query` macro help in the readme.
434#[proc_macro]
435pub fn good_query_pg(input: TokenStream) -> TokenStream {
436    let input = parse_macro_input!(input as GoodQueryInput);
437    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
438}
439
440/// See the `good_query` macro help in the readme.
441#[proc_macro]
442pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
443    let input = parse_macro_input!(input as GoodQueryInput);
444    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
445}
446
447/// See the `good_query` macro help in the readme.
448#[proc_macro]
449pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
450    let input = parse_macro_input!(input as GoodQueryInput);
451    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
452}
453
454/// See the `good_query` macro help in the readme.
455#[proc_macro]
456pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
457    let input = parse_macro_input!(input as GoodQueryInput);
458    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
459}
460
461/// See the `good_query` macro help in the readme.
462#[proc_macro]
463pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
464    let input = parse_macro_input!(input as GoodQueryInput);
465    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
466}
467
468/// See the `good_query` macro help in the readme.
469#[proc_macro]
470pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
471    let input = parse_macro_input!(input as GoodQueryInput);
472    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
473}
474
475/// See the `good_query` macro help in the readme.
476#[proc_macro]
477pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
478    let input = parse_macro_input!(input as GoodQueryInput);
479    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
480}
481
482/// See the `good_query` macro help in the readme.
483#[proc_macro]
484pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
485    let input = parse_macro_input!(input as GoodQueryInput);
486    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
487}