Skip to main content

good_ormning_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    good_ormning_core::{
7        pg::{
8            Version as PgVersion,
9            query::utils::{
10                PgFieldInfo,
11                PgTableInfo,
12            },
13            schema::{
14                field::FieldRef as PgFieldRef,
15                table::TableRef as PgTableRef,
16            },
17        },
18        sqlite::{
19            Version as SqliteVersion,
20            query::utils::{
21                SqliteFieldInfo,
22                SqliteTableInfo,
23            },
24            schema::{
25                field::FieldRef as SqliteFieldRef,
26                table::TableRef as SqliteTableRef,
27            },
28        },
29        utils::Errs,
30    },
31    proc_macro::TokenStream,
32    quote::{
33        format_ident,
34        quote,
35    },
36    std::{
37        collections::{
38            HashMap,
39            hash_map::DefaultHasher,
40        },
41        env,
42        fs,
43        hash::{
44            Hash,
45            Hasher,
46        },
47    },
48    syn::{
49        Ident,
50        LitStr,
51        Token,
52        parse::{
53            Parse,
54            ParseStream,
55        },
56        parse_macro_input,
57    },
58};
59
60mod convert;
61
62struct ParamType {
63    arr: bool,
64    opt: bool,
65    base: String,
66}
67
68struct GoodQueryInput {
69    version: Option<usize>,
70    db_name: String,
71    sql: String,
72    param_types: Vec<(Ident, ParamType)>,
73    conn: syn::Expr,
74    params: Vec<syn::Expr>,
75}
76
77impl Parse for GoodQueryInput {
78    fn parse(input: ParseStream) -> syn::Result<Self> {
79        let mut literals = Vec::new();
80        while !input.peek(Token![;]) && !input.is_empty() {
81            literals.push(input.parse::<LitStr>()?);
82            if input.peek(Token![,]) {
83                input.parse::<Token![,]>()?;
84            } else if !input.peek(Token![;]) {
85                return Err(input.error("Expected , or ; after literal"));
86            }
87        }
88        input.parse::<Token![;]>()?;
89        let (db_name, version, sql) = match literals.len() {
90            1 => ("".to_string(), None, literals[0].value()),
91            2 => (literals[0].value(), None, literals[1].value()),
92            3 => {
93                let v_str = literals[1].value();
94                let v = v_str.parse::<usize>().map_err(|_| {
95                    syn::Error::new(literals[1].span(), "Version must be a positive integer")
96                })?;
97                (literals[0].value(), Some(v), literals[2].value())
98            },
99            _ => {
100                return Err(
101                    input.error(
102                        "Expected 1, 2, or 3 literals before semicolon (sql, [db_name, sql], or [db_name, version, sql])",
103                    ),
104                );
105            },
106        };
107        let conn: syn::Expr = input.parse()?;
108        let mut param_types = Vec::new();
109        let mut params = Vec::new();
110        while input.peek(Token![,]) {
111            input.parse::<Token![,]>()?;
112            if input.is_empty() {
113                break;
114            }
115            let name: Ident = input.parse()?;
116            input.parse::<Token![:]>()?;
117            let mut arr = false;
118            let mut opt = false;
119            let mut base = String::new();
120            while input.peek(Ident) {
121                let id: Ident = input.parse()?;
122                if id == "arr" {
123                    arr = true;
124                } else if id == "opt" {
125                    opt = true;
126                } else {
127                    base = id.to_string();
128                    break;
129                }
130            }
131            if base.is_empty() {
132                return Err(input.error("Expected parameter type"));
133            }
134            input.parse::<Token![=]>()?;
135            let val: syn::Expr = input.parse()?;
136            param_types.push((name, ParamType {
137                arr: arr,
138                opt: opt,
139                base: base,
140            }));
141            params.push(val);
142        }
143        let mut final_sql = String::new();
144        let mut last_end = 0;
145        while let Some(start) = sql[last_end..].find("${") {
146            final_sql.push_str(&sql[last_end .. last_end + start]);
147            let start = last_end + start;
148            if let Some(end) = sql[start..].find('}') {
149                let end = start + end;
150                let content = &sql[start + 2 .. end];
151                let mut split = None;
152                let bytes = content.as_bytes();
153                let mut i = 0;
154                while i < bytes.len() {
155                    if bytes[i] == b'=' {
156                        split = Some((&content[..i], &content[i + 1..]));
157                        break;
158                    }
159                    i += 1;
160                }
161                let (type_str, val_str) =
162                    split.ok_or_else(|| input.error("Invalid inline parameter format. Expected ${type = value}"))?;
163                let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
164                    syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
165                })?;
166                let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
167                    syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
168                })?;
169
170                use syn::parse::Parser;
171
172                let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
173                    let mut arr = false;
174                    let mut opt = false;
175                    let mut base = String::new();
176                    while type_input.peek(Ident) {
177                        let id: Ident = type_input.parse()?;
178                        if id == "arr" {
179                            arr = true;
180                        } else if id == "opt" {
181                            opt = true;
182                        } else {
183                            base = id.to_string();
184                            break;
185                        }
186                    }
187                    Ok((arr, opt, base))
188                }).parse2(type_tokens).map_err(|e| {
189                    syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
190                })?;
191                if base_p.is_empty() {
192                    return Err(input.error("Expected base type in inline parameter"));
193                }
194                let param_idx = params.len() + 1;
195                let name = format_ident!("p{}", param_idx);
196                params.push(val);
197                param_types.push((name, ParamType {
198                    arr: arr_p,
199                    opt: opt_p,
200                    base: base_p,
201                }));
202                final_sql.push_str(&format!("${}", param_idx));
203                last_end = end + 1;
204            } else {
205                return Err(input.error("Unclosed inline parameter ${"));
206            }
207        }
208        final_sql.push_str(&sql[last_end..]);
209        Ok(GoodQueryInput {
210            version: version,
211            db_name: db_name,
212            sql: final_sql,
213            param_types: param_types,
214            conn: conn,
215            params: params,
216        })
217    }
218}
219
220fn get_db_info(_engine: &str, provided_db_name: String) -> String {
221    provided_db_name
222}
223
224fn parse_and_generate_pg(
225    input: GoodQueryInput,
226    res_count: good_ormning_core::QueryResCount,
227) -> proc_macro2::TokenStream {
228    let db_name = get_db_info("pg", input.db_name.clone());
229    let dialect = sqlparser::dialect::PostgreSqlDialect {};
230    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
231        Ok(ast) => ast,
232        Err(e) => {
233            let e = e.to_string();
234            return quote!(compile_error!(#e));
235        },
236    };
237    if ast.is_empty() {
238        return quote!(compile_error!("Empty SQL statement"));
239    }
240    let statement = &ast[0];
241    let mut errs = Errs::new();
242    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
243    let path =
244        std::path::Path::new(&out_dir)
245            .join("good_ormning")
246            .join(good_ormning_core::utils::json_file_name(&db_name));
247    if !path.exists() {
248        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
249        return quote!(compile_error!(#e));
250    }
251    let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
252        Ok(m) => m,
253        Err(e) => {
254            let e = e.to_string();
255            return quote!(compile_error!(#e));
256        },
257    };
258    let mut field_lookup = HashMap::new();
259    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
260    let version = match versions_map.get(&version_i) {
261        Some(v) => v,
262        None => {
263            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
264            return quote!(compile_error!(#e));
265        },
266    };
267    let custom_types = version.custom_types.clone();
268    for (table_id, table) in &version.tables {
269        let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
270        for (field_id, field) in &table.fields {
271            fields.insert(PgFieldRef {
272                table_id: table_id.clone(),
273                field_id: field_id.clone(),
274            }, PgFieldInfo {
275                sql_name: field.id.clone(),
276                type_: field.type_.type_.clone(),
277            });
278        }
279        field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
280            sql_name: table.id.clone(),
281            fields: fields,
282        });
283    }
284    let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
285    query.res_count = res_count;
286    let mut hasher = DefaultHasher::new();
287    input.sql.hash(&mut hasher);
288    let query_hash = hasher.finish();
289    let query_name = format_ident!("good_query_{}", query_hash);
290    query.name = query_name.to_string();
291    let pascal_db_name: String = db_name.to_case(Case::Pascal);
292    let db_type = if let Some(v) = input.version {
293        let name = format_ident!("Db{}{}", pascal_db_name, v);
294        quote!(dbm::#name <'_ >)
295    } else {
296        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
297        quote!(dbm::#name <'_ >)
298    };
299    let generated =
300        good_ormning_core::pg::query::generate::generate_query_functions(
301            &mut errs,
302            field_lookup,
303            vec![query],
304            "inline",
305            db_type,
306        );
307    let conn = &input.conn;
308    let args = &input.params;
309    quote!{
310        {
311            use ::good_ormning::runtime::GoodError;
312            use ::good_ormning::runtime::ToGoodError;
313            use ::good_ormning::runtime::pg::PgConnection;
314            #(#generated) * #query_name(& mut #conn, #(#args,) *)
315        }
316    }
317}
318
319fn parse_and_generate_sqlite(
320    input: GoodQueryInput,
321    res_count: good_ormning_core::QueryResCount,
322) -> proc_macro2::TokenStream {
323    let db_name = get_db_info("sqlite", input.db_name.clone());
324    let dialect = sqlparser::dialect::SQLiteDialect {};
325    let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
326        Ok(ast) => ast,
327        Err(e) => {
328            let e = e.to_string();
329            return quote!(compile_error!(#e));
330        },
331    };
332    if ast.is_empty() {
333        return quote!(compile_error!("Empty SQL statement"));
334    }
335    let statement = &ast[0];
336    let mut errs = Errs::new();
337    let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
338    let path =
339        std::path::Path::new(&out_dir)
340            .join("good_ormning")
341            .join(good_ormning_core::utils::json_file_name(&db_name));
342    if !path.exists() {
343        let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
344        return quote!(compile_error!(#e));
345    }
346    let versions_map: HashMap<usize, SqliteVersion> =
347        match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
348            Ok(m) => m,
349            Err(e) => {
350                let e = e.to_string();
351                return quote!(compile_error!(#e));
352            },
353        };
354    let mut field_lookup = HashMap::new();
355    let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
356    let version = match versions_map.get(&version_i) {
357        Some(v) => v,
358        None => {
359            let e = format!("Version {} not found in schema for db {}", version_i, db_name);
360            return quote!(compile_error!(#e));
361        },
362    };
363    let custom_types = version.custom_types.clone();
364    for (table_id, table) in &version.tables {
365        let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
366        for (field_id, field) in &table.fields {
367            fields.insert(SqliteFieldRef {
368                table_id: table_id.clone(),
369                field_id: field_id.clone(),
370            }, SqliteFieldInfo {
371                sql_name: field.id.clone(),
372                type_: field.type_.type_.clone(),
373            });
374        }
375        field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
376            sql_name: table.id.clone(),
377            fields: fields,
378        });
379    }
380    let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
381    query.res_count = res_count;
382    let mut hasher = DefaultHasher::new();
383    input.sql.hash(&mut hasher);
384    let query_hash = hasher.finish();
385    let query_name = format_ident!("good_query_{}", query_hash);
386    query.name = query_name.to_string();
387    let pascal_db_name: String = db_name.to_case(Case::Pascal);
388    let db_type = if let Some(v) = input.version {
389        let name = format_ident!("Db{}{}", pascal_db_name, v);
390        quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
391    } else {
392        let name = format_ident!("Db{}{}", pascal_db_name, version_i);
393        quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
394    };
395    let generated =
396        good_ormning_core::sqlite::query::generate::generate_query_functions(
397            &mut errs,
398            field_lookup,
399            vec![query],
400            "inline",
401            db_type,
402        );
403    let conn = &input.conn;
404    let args = &input.params;
405    quote!{
406        {
407            use ::good_ormning::runtime::GoodError;
408            use ::good_ormning::runtime::ToGoodError;
409            use ::good_ormning::runtime::sqlite::SqliteConnection;
410            #(#generated) * #query_name(& mut #conn, #(#args,) *)
411        }
412    }
413}
414
415/// See the `good_query` macro help in the readme.
416#[proc_macro]
417pub fn good_query_pg(input: TokenStream) -> TokenStream {
418    let input = parse_macro_input!(input as GoodQueryInput);
419    parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
420}
421
422/// See the `good_query` macro help in the readme.
423#[proc_macro]
424pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
425    let input = parse_macro_input!(input as GoodQueryInput);
426    parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
427}
428
429/// See the `good_query` macro help in the readme.
430#[proc_macro]
431pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
432    let input = parse_macro_input!(input as GoodQueryInput);
433    parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
434}
435
436/// See the `good_query` macro help in the readme.
437#[proc_macro]
438pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
439    let input = parse_macro_input!(input as GoodQueryInput);
440    parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
441}
442
443/// See the `good_query` macro help in the readme.
444#[proc_macro]
445pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
446    let input = parse_macro_input!(input as GoodQueryInput);
447    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
448}
449
450/// See the `good_query` macro help in the readme.
451#[proc_macro]
452pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
453    let input = parse_macro_input!(input as GoodQueryInput);
454    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
455}
456
457/// See the `good_query` macro help in the readme.
458#[proc_macro]
459pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
460    let input = parse_macro_input!(input as GoodQueryInput);
461    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
462}
463
464/// See the `good_query` macro help in the readme.
465#[proc_macro]
466pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
467    let input = parse_macro_input!(input as GoodQueryInput);
468    parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
469}