1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitInt,
51 LitStr,
52 Token,
53 parse::{
54 Parse,
55 ParseStream,
56 },
57 parse_macro_input,
58 },
59};
60
61mod convert;
62
63struct ParamType {
64 arr: bool,
65 opt: bool,
66 base: String,
67}
68
69struct GoodQueryInput {
70 db_mod: Ident,
71 version: Option<usize>,
72 db_name: String,
73 sql: String,
74 param_types: Vec<(Ident, ParamType)>,
75 conn: syn::Expr,
76 params: Vec<syn::Expr>,
77}
78
79impl Parse for GoodQueryInput {
80 fn parse(input: ParseStream) -> syn::Result<Self> {
81 let db_mod: Ident = input.parse()?;
82 input.parse::<Token![,]>()?;
83 let (db_name, version, sql) = {
84 if input.peek(LitInt) {
85 let version_lit: LitInt = input.parse()?;
86 let version = version_lit.base10_parse::<usize>()?;
87 input.parse::<Token![,]>()?;
88 let sql_lit: LitStr = input.parse()?;
89 let sql = sql_lit.value();
90 input.parse::<Token![;]>()?;
91 ("".to_string(), Some(version), sql)
92 } else {
93 let first: LitStr = input.parse()?;
94 if input.peek(Token![;]) {
95 input.parse::<Token![;]>()?;
96 ("".to_string(), None, first.value())
97 } else {
98 input.parse::<Token![,]>()?;
99 let lookahead = input.lookahead1();
100 if lookahead.peek(LitInt) {
101 let version_lit: LitInt = input.parse()?;
102 let version = version_lit.base10_parse::<usize>()?;
103 input.parse::<Token![,]>()?;
104 let sql_lit: LitStr = input.parse()?;
105 let sql = sql_lit.value();
106 input.parse::<Token![;]>()?;
107 (first.value(), Some(version), sql)
108 } else if lookahead.peek(LitStr) {
109 let sql_lit: LitStr = input.parse()?;
110 let sql = sql_lit.value();
111 input.parse::<Token![;]>()?;
112 (first.value(), None, sql)
113 } else {
114 return Err(lookahead.error());
115 }
116 }
117 }
118 };
119 let conn: syn::Expr = input.parse()?;
120 let mut param_types = Vec::new();
121 let mut params = Vec::new();
122 while input.peek(Token![,]) {
123 input.parse::<Token![,]>()?;
124 if input.is_empty() {
125 break;
126 }
127 let name: Ident = input.parse()?;
128 input.parse::<Token![:]>()?;
129 let mut arr = false;
130 let mut opt = false;
131 let mut base = String::new();
132 while input.peek(Ident) {
133 let id: Ident = input.parse()?;
134 if id == "arr" {
135 arr = true;
136 } else if id == "opt" {
137 opt = true;
138 } else {
139 base = id.to_string();
140 break;
141 }
142 }
143 if base.is_empty() {
144 return Err(input.error("Expected parameter type"));
145 }
146 input.parse::<Token![=]>()?;
147 let val: syn::Expr = input.parse()?;
148 param_types.push((name, ParamType {
149 arr: arr,
150 opt: opt,
151 base: base,
152 }));
153 params.push(val);
154 }
155 let mut final_sql = String::new();
156 let mut last_end = 0;
157 let bytes = sql.as_bytes();
158 let mut i = 0;
159 while i < bytes.len() {
160 if bytes[i] == b'$' {
161 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
162 final_sql.push_str(&sql[last_end .. i]);
163 i += 2;
164 let content_start = i;
165 while i < bytes.len() && bytes[i] != b'}' {
166 i += 1;
167 }
168 if i >= bytes.len() {
169 return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
170 }
171 let content = &sql[content_start .. i];
172 i += 1;
173 let mut split = None;
174 for (idx, b) in content.as_bytes().iter().enumerate() {
175 if *b == b'=' {
176 split = Some((&content[..idx], &content[idx + 1..]));
177 break;
178 }
179 }
180 let (type_str, val_str) = split.ok_or_else(|| {
181 syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
182 })?;
183 let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
184 params.push(val);
185 param_types.push((name, pt));
186 final_sql.push_str(&format!("${}", param_idx));
187 last_end = i;
188 continue;
189 }
190 }
191 i += 1;
192 }
193 final_sql.push_str(&sql[last_end..]);
194 Ok(GoodQueryInput {
195 db_mod: db_mod,
196 version: version,
197 db_name: db_name,
198 sql: final_sql,
199 param_types: param_types,
200 conn: conn,
201 params: params,
202 })
203 }
204}
205
206fn parse_inline_param(
207 input: ParseStream,
208 type_str: &str,
209 val_str: &str,
210 current_params_len: usize,
211) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
212 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
213 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
214 })?;
215 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
216 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
217 })?;
218
219 use syn::parse::Parser;
220
221 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
222 let mut arr = false;
223 let mut opt = false;
224 let mut base = String::new();
225 while type_input.peek(Ident) {
226 let id: Ident = type_input.parse()?;
227 if id == "arr" {
228 arr = true;
229 } else if id == "opt" {
230 opt = true;
231 } else {
232 base = id.to_string();
233 break;
234 }
235 }
236 Ok((arr, opt, base))
237 }).parse2(type_tokens).map_err(|e| {
238 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
239 })?;
240 if base_p.is_empty() {
241 return Err(input.error("Expected base type in inline parameter"));
242 }
243 let param_idx = current_params_len + 1;
244 let name = format_ident!("p{}", param_idx);
245 Ok((param_idx, name, ParamType {
246 arr: arr_p,
247 opt: opt_p,
248 base: base_p,
249 }, val))
250}
251
252fn get_db_info(_engine: &str, provided_db_name: String) -> String {
253 provided_db_name
254}
255
256fn parse_and_generate_pg(
257 input: GoodQueryInput,
258 res_count: good_ormning_core::QueryResCount,
259) -> proc_macro2::TokenStream {
260 let db_name = get_db_info("pg", input.db_name.clone());
261 let dialect = sqlparser::dialect::PostgreSqlDialect {};
262 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
263 Ok(ast) => ast,
264 Err(e) => {
265 let e = e.to_string();
266 return quote!(compile_error!(#e));
267 },
268 };
269 if ast.is_empty() {
270 return quote!(compile_error!("Empty SQL statement"));
271 }
272 let statement = &ast[0];
273 let mut errs = Errs::new();
274 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
275 let path =
276 std::path::Path::new(&out_dir)
277 .join("good_ormning")
278 .join(good_ormning_core::utils::json_file_name(&db_name));
279 if !path.exists() {
280 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
281 return quote!(compile_error!(#e));
282 }
283 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
284 Ok(m) => m,
285 Err(e) => {
286 let e = e.to_string();
287 return quote!(compile_error!(#e));
288 },
289 };
290 let mut field_lookup = HashMap::new();
291 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
292 let version = match versions_map.get(&version_i) {
293 Some(v) => v,
294 None => {
295 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
296 return quote!(compile_error!(#e));
297 },
298 };
299 let custom_types = version.custom_types.clone();
300 for (table_id, table) in &version.tables {
301 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
302 for (field_id, field) in &table.fields {
303 fields.insert(PgFieldRef {
304 table_id: table_id.clone(),
305 field_id: field_id.clone(),
306 }, PgFieldInfo {
307 sql_name: field.id.clone(),
308 type_: field.type_.type_.clone(),
309 });
310 }
311 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
312 sql_name: table.id.clone(),
313 fields: fields,
314 });
315 }
316 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
317 query.res_count = res_count;
318 let mut hasher = DefaultHasher::new();
319 input.sql.hash(&mut hasher);
320 let query_hash = hasher.finish();
321 let query_name = format_ident!("good_query_{}", query_hash);
322 query.name = query_name.to_string();
323 let pascal_db_name: String = db_name.to_case(Case::Pascal);
324 let db_mod = &input.db_mod;
325 let db_type = if let Some(v) = input.version {
326 let name = format_ident!("Db{}{}", pascal_db_name, v);
327 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
328 } else {
329 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
330 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
331 };
332 let generated =
333 good_ormning_core::pg::query::generate::generate_query_functions(
334 &mut errs,
335 field_lookup,
336 vec![query],
337 "inline",
338 db_type,
339 );
340 let conn = &input.conn;
341 let args = &input.params;
342 let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
343 let db_mod_str = input.db_mod.to_string();
344 let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
345 quote!{
346 {
347 const _:() = {
348 if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
349 #[allow(unconditional_panic)]
350 let _ = ["Database name mismatch"][1];
351 }
352 };
353 use ::good_ormning::runtime::GoodError;
354 use ::good_ormning::runtime::ToGoodError;
355 use ::good_ormning::runtime::pg::PgConnection;
356 #(#generated) * #query_name(#conn, #(#args,) *)
357 }
358 }
359}
360
361fn parse_and_generate_sqlite(
362 input: GoodQueryInput,
363 res_count: good_ormning_core::QueryResCount,
364) -> proc_macro2::TokenStream {
365 let db_name = get_db_info("sqlite", input.db_name.clone());
366 let dialect = sqlparser::dialect::SQLiteDialect {};
367 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
368 Ok(ast) => ast,
369 Err(e) => {
370 let e = e.to_string();
371 return quote!(compile_error!(#e));
372 },
373 };
374 if ast.is_empty() {
375 return quote!(compile_error!("Empty SQL statement"));
376 }
377 let statement = &ast[0];
378 let mut errs = Errs::new();
379 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
380 let path =
381 std::path::Path::new(&out_dir)
382 .join("good_ormning")
383 .join(good_ormning_core::utils::json_file_name(&db_name));
384 if !path.exists() {
385 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
386 return quote!(compile_error!(#e));
387 }
388 let versions_map: HashMap<usize, SqliteVersion> =
389 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
390 Ok(m) => m,
391 Err(e) => {
392 let e = e.to_string();
393 return quote!(compile_error!(#e));
394 },
395 };
396 let mut field_lookup = HashMap::new();
397 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
398 let version = match versions_map.get(&version_i) {
399 Some(v) => v,
400 None => {
401 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
402 return quote!(compile_error!(#e));
403 },
404 };
405 let custom_types = version.custom_types.clone();
406 for (table_id, table) in &version.tables {
407 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
408 for (field_id, field) in &table.fields {
409 fields.insert(SqliteFieldRef {
410 table_id: table_id.clone(),
411 field_id: field_id.clone(),
412 }, SqliteFieldInfo {
413 sql_name: field.id.clone(),
414 type_: field.type_.type_.clone(),
415 });
416 }
417 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
418 sql_name: table.id.clone(),
419 fields: fields,
420 });
421 }
422 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
423 query.res_count = res_count;
424 let mut hasher = DefaultHasher::new();
425 input.sql.hash(&mut hasher);
426 let query_hash = hasher.finish();
427 let query_name = format_ident!("good_query_{}", query_hash);
428 query.name = query_name.to_string();
429 let pascal_db_name: String = db_name.to_case(Case::Pascal);
430 let db_mod = &input.db_mod;
431 let db_type = if let Some(v) = input.version {
432 let name = format_ident!("Db{}{}", pascal_db_name, v);
433 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
434 } else {
435 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
436 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
437 };
438 let generated =
439 good_ormning_core::sqlite::query::generate::generate_query_functions(
440 &mut errs,
441 field_lookup,
442 vec![query],
443 "inline",
444 db_type,
445 );
446 let conn = &input.conn;
447 let args = &input.params;
448 let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
449 let db_mod_str = input.db_mod.to_string();
450 let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
451 quote!{
452 {
453 const _:() = {
454 if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
455 #[allow(unconditional_panic)]
456 let _ = ["Database name mismatch"][1];
457 }
458 };
459 use ::good_ormning::runtime::GoodError;
460 use ::good_ormning::runtime::ToGoodError;
461 use ::good_ormning::runtime::sqlite::SqliteConnection;
462 #(#generated) * #query_name(#conn, #(#args,) *)
463 }
464 }
465}
466
467#[proc_macro]
469pub fn good_query_pg(input: TokenStream) -> TokenStream {
470 let input = parse_macro_input!(input as GoodQueryInput);
471 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
472}
473
474#[proc_macro]
476pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
477 let input = parse_macro_input!(input as GoodQueryInput);
478 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
479}
480
481#[proc_macro]
483pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
484 let input = parse_macro_input!(input as GoodQueryInput);
485 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
486}
487
488#[proc_macro]
490pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
491 let input = parse_macro_input!(input as GoodQueryInput);
492 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
493}
494
495#[proc_macro]
497pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
498 let input = parse_macro_input!(input as GoodQueryInput);
499 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
500}
501
502#[proc_macro]
504pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
505 let input = parse_macro_input!(input as GoodQueryInput);
506 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
507}
508
509#[proc_macro]
511pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
512 let input = parse_macro_input!(input as GoodQueryInput);
513 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
514}
515
516#[proc_macro]
518pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
519 let input = parse_macro_input!(input as GoodQueryInput);
520 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
521}