Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitStr,
51        Token,
52        parse::{
53            Parse,
54            ParseStream,
55        },
56        parse_macro_input,
57    },
58};
59
60mod convert;
61
62struct ParamType {
63    arr: bool,
64    opt: bool,
65    base: String,
66}
67
68struct GoodQueryInput {
69    version: Option<usize>,
70    db_name: String,
71    sql: String,
72    param_types: Vec<(Ident, ParamType)>,
73    conn: syn::Expr,
74    params: Vec<syn::Expr>,
75}
76
77impl Parse for GoodQueryInput {
78    fn parse(input: ParseStream) -> syn::Result<Self> {
79        let mut literals = Vec::new();
80        while !input.peek(Token![;]) && !input.is_empty() {
81            literals.push(input.parse::<LitStr>()?);
82            if input.peek(Token![,]) {
83                input.parse::<Token![,]>()?;
84            } else if !input.peek(Token![;]) {
85                return Err(input.error("Expected , or ; after literal"));
86            }
87        }
88        input.parse::<Token![;]>()?;
89        let (db_name, version, sql) = match literals.len() {
90            1 => ("".to_string(), None, literals[0].value()),
91            2 => (literals[0].value(), None, literals[1].value()),
92            3 => {
93                let v_str = literals[1].value();
94                let v = v_str.parse::<usize>().map_err(|_| {
95                    syn::Error::new(literals[1].span(), "Version must be a positive integer")
96                })?;
97                (literals[0].value(), Some(v), literals[2].value())
98            },
99            _ => {
100                return Err(
101                    input.error(
102                        "Expected 1, 2, or 3 literals before semicolon (sql, [db_name, sql], or [db_name, version, sql])",
103                    ),
104                );
105            },
106        };
107        let conn: syn::Expr = input.parse()?;
108        let mut param_types = Vec::new();
109        let mut params = Vec::new();
110        while input.peek(Token![,]) {
111            input.parse::<Token![,]>()?;
112            if input.is_empty() {
113                break;
114            }
115            let name: Ident = input.parse()?;
116            input.parse::<Token![:]>()?;
117            let mut arr = false;
118            let mut opt = false;
119            let mut base = String::new();
120            while input.peek(Ident) {
121                let id: Ident = input.parse()?;
122                if id == "arr" {
123                    arr = true;
124                } else if id == "opt" {
125                    opt = true;
126                } else {
127                    base = id.to_string();
128                    break;
129                }
130            }
131            if base.is_empty() {
132                return Err(input.error("Expected parameter type"));
133            }
134            input.parse::<Token![=]>()?;
135            let val: syn::Expr = input.parse()?;
136            param_types.push((name, ParamType {
137                arr: arr,
138                opt: opt,
139                base: base,
140            }));
141            params.push(val);
142        }
143        let mut final_sql = String::new();
144        let mut last_end = 0;
145        let bytes = sql.as_bytes();
146        let mut i = 0;
147        while i < bytes.len() {
148            if bytes[i] == b'$' {
149                if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
150                    final_sql.push_str(&sql[last_end .. i]);
151                    i += 2;
152                    let content_start = i;
153                    while i < bytes.len() && bytes[i] != b'}' {
154                        i += 1;
155                    }
156                    if i >= bytes.len() {
157                        return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
158                    }
159                    let content = &sql[content_start .. i];
160                    i += 1;
161                    let mut split = None;
162                    for (idx, b) in content.as_bytes().iter().enumerate() {
163                        if *b == b'=' {
164                            split = Some((&content[..idx], &content[idx + 1..]));
165                            break;
166                        }
167                    }
168                    let (type_str, val_str) = split.ok_or_else(|| {
169                        syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
170                    })?;
171                    let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
172                    params.push(val);
173                    param_types.push((name, pt));
174                    final_sql.push_str(&format!("${}", param_idx));
175                    last_end = i;
176                    continue;
177                }
178            }
179            i += 1;
180        }
181        final_sql.push_str(&sql[last_end..]);
182        Ok(GoodQueryInput {
183            version: version,
184            db_name: db_name,
185            sql: final_sql,
186            param_types: param_types,
187            conn: conn,
188            params: params,
189        })
190    }
191}
192
193fn parse_inline_param(
194    input: ParseStream,
195    type_str: &str,
196    val_str: &str,
197    current_params_len: usize,
198) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
199    let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
200        syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
201    })?;
202    let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
203        syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
204    })?;
205
206    use syn::parse::Parser;
207
208    let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
209        let mut arr = false;
210        let mut opt = false;
211        let mut base = String::new();
212        while type_input.peek(Ident) {
213            let id: Ident = type_input.parse()?;
214            if id == "arr" {
215                arr = true;
216            } else if id == "opt" {
217                opt = true;
218            } else {
219                base = id.to_string();
220                break;
221            }
222        }
223        Ok((arr, opt, base))
224    }).parse2(type_tokens).map_err(|e| {
225        syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
226    })?;
227    if base_p.is_empty() {
228        return Err(input.error("Expected base type in inline parameter"));
229    }
230    let param_idx = current_params_len + 1;
231    let name = format_ident!("p{}", param_idx);
232    Ok((param_idx, name, ParamType {
233        arr: arr_p,
234        opt: opt_p,
235        base: base_p,
236    }, val))
237}
238
239fn get_db_info(_engine: &str, provided_db_name: String) -> String {
240    provided_db_name
241}
242
243fn parse_and_generate_pg(
244    input: GoodQueryInput,
245    res_count: good_ormning_core::QueryResCount,
246) -> proc_macro2::TokenStream {
247    let db_name = get_db_info("pg", input.db_name.clone());
248    let dialect = sqlparser::dialect::PostgreSqlDialect {};
249    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
250        Ok(ast) => ast,
251        Err(e) => {
252            let e = e.to_string();
253            return quote!(compile_error!(#e));
254        },
255    };
256    if ast.is_empty() {
257        return quote!(compile_error!("Empty SQL statement"));
258    }
259    let statement = &ast[0];
260    let mut errs = Errs::new();
261    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
262    let path =
263        std::path::Path::new(&out_dir)
264            .join("good_ormning")
265            .join(good_ormning_core::utils::json_file_name(&db_name));
266    if !path.exists() {
267        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
268        return quote!(compile_error!(#e));
269    }
270    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
271        Ok(m) => m,
272        Err(e) => {
273            let e = e.to_string();
274            return quote!(compile_error!(#e));
275        },
276    };
277    let mut field_lookup = HashMap::new();
278    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
279    let version = match versions_map.get(&version_i) {
280        Some(v) => v,
281        None => {
282            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
283            return quote!(compile_error!(#e));
284        },
285    };
286    let custom_types = version.custom_types.clone();
287    for (table_id, table) in &version.tables {
288        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
289        for (field_id, field) in &table.fields {
290            fields.insert(PgFieldRef {
291                table_id: table_id.clone(),
292                field_id: field_id.clone(),
293            }, PgFieldInfo {
294                sql_name: field.id.clone(),
295                type_: field.type_.type_.clone(),
296            });
297        }
298        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
299            sql_name: table.id.clone(),
300            fields: fields,
301        });
302    }
303    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
304    query.res_count = res_count;
305    let mut hasher = DefaultHasher::new();
306    input.sql.hash(&mut hasher);
307    let query_hash = hasher.finish();
308    let query_name = format_ident!("good_query_{}", query_hash);
309    query.name = query_name.to_string();
310    let pascal_db_name: String = db_name.to_case(Case::Pascal);
311    let db_type = if let Some(v) = input.version {
312        let name = format_ident!("Db{}{}", pascal_db_name, v);
313        quote!(dbm::#name <'_ >)
314    } else {
315        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
316        quote!(dbm::#name <'_ >)
317    };
318    let generated =
319        good_ormning_core::pg::query::generate::generate_query_functions(
320            &mut errs,
321            field_lookup,
322            vec![query],
323            "inline",
324            db_type,
325        );
326    let conn = &input.conn;
327    let args = &input.params;
328    quote!{
329        {
330            use ::good_ormning::runtime::GoodError;
331            use ::good_ormning::runtime::ToGoodError;
332            use ::good_ormning::runtime::pg::PgConnection;
333            #(#generated) * #query_name(& mut #conn, #(#args,) *)
334        }
335    }
336}
337
338fn parse_and_generate_sqlite(
339    input: GoodQueryInput,
340    res_count: good_ormning_core::QueryResCount,
341) -> proc_macro2::TokenStream {
342    let db_name = get_db_info("sqlite", input.db_name.clone());
343    let dialect = sqlparser::dialect::SQLiteDialect {};
344    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
345        Ok(ast) => ast,
346        Err(e) => {
347            let e = e.to_string();
348            return quote!(compile_error!(#e));
349        },
350    };
351    if ast.is_empty() {
352        return quote!(compile_error!("Empty SQL statement"));
353    }
354    let statement = &ast[0];
355    let mut errs = Errs::new();
356    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
357    let path =
358        std::path::Path::new(&out_dir)
359            .join("good_ormning")
360            .join(good_ormning_core::utils::json_file_name(&db_name));
361    if !path.exists() {
362        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
363        return quote!(compile_error!(#e));
364    }
365    let versions_map: HashMap<usize, SqliteVersion> =
366        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
367            Ok(m) => m,
368            Err(e) => {
369                let e = e.to_string();
370                return quote!(compile_error!(#e));
371            },
372        };
373    let mut field_lookup = HashMap::new();
374    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
375    let version = match versions_map.get(&version_i) {
376        Some(v) => v,
377        None => {
378            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
379            return quote!(compile_error!(#e));
380        },
381    };
382    let custom_types = version.custom_types.clone();
383    for (table_id, table) in &version.tables {
384        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
385        for (field_id, field) in &table.fields {
386            fields.insert(SqliteFieldRef {
387                table_id: table_id.clone(),
388                field_id: field_id.clone(),
389            }, SqliteFieldInfo {
390                sql_name: field.id.clone(),
391                type_: field.type_.type_.clone(),
392            });
393        }
394        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
395            sql_name: table.id.clone(),
396            fields: fields,
397        });
398    }
399    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
400    query.res_count = res_count;
401    let mut hasher = DefaultHasher::new();
402    input.sql.hash(&mut hasher);
403    let query_hash = hasher.finish();
404    let query_name = format_ident!("good_query_{}", query_hash);
405    query.name = query_name.to_string();
406    let pascal_db_name: String = db_name.to_case(Case::Pascal);
407    let db_type = if let Some(v) = input.version {
408        let name = format_ident!("Db{}{}", pascal_db_name, v);
409        quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
410    } else {
411        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
412        quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
413    };
414    let generated =
415        good_ormning_core::sqlite::query::generate::generate_query_functions(
416            &mut errs,
417            field_lookup,
418            vec![query],
419            "inline",
420            db_type,
421        );
422    let conn = &input.conn;
423    let args = &input.params;
424    quote!{
425        {
426            use ::good_ormning::runtime::GoodError;
427            use ::good_ormning::runtime::ToGoodError;
428            use ::good_ormning::runtime::sqlite::SqliteConnection;
429            #(#generated) * #query_name(& mut #conn, #(#args,) *)
430        }
431    }
432}
433
434/// See the `good_query` macro help in the readme.
435#[proc_macro]
436pub fn good_query_pg(input: TokenStream) -> TokenStream {
437    let input = parse_macro_input!(input as GoodQueryInput);
438    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
439}
440
441/// See the `good_query` macro help in the readme.
442#[proc_macro]
443pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
444    let input = parse_macro_input!(input as GoodQueryInput);
445    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
446}
447
448/// See the `good_query` macro help in the readme.
449#[proc_macro]
450pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
451    let input = parse_macro_input!(input as GoodQueryInput);
452    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
453}
454
455/// See the `good_query` macro help in the readme.
456#[proc_macro]
457pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
458    let input = parse_macro_input!(input as GoodQueryInput);
459    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
460}
461
462/// See the `good_query` macro help in the readme.
463#[proc_macro]
464pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
465    let input = parse_macro_input!(input as GoodQueryInput);
466    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
467}
468
469/// See the `good_query` macro help in the readme.
470#[proc_macro]
471pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
472    let input = parse_macro_input!(input as GoodQueryInput);
473    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
474}
475
476/// See the `good_query` macro help in the readme.
477#[proc_macro]
478pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
479    let input = parse_macro_input!(input as GoodQueryInput);
480    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
481}
482
483/// See the `good_query` macro help in the readme.
484#[proc_macro]
485pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
486    let input = parse_macro_input!(input as GoodQueryInput);
487    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
488}