1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitInt,
51 LitStr,
52 Token,
53 parse::{
54 Parse,
55 ParseStream,
56 },
57 parse_macro_input,
58 },
59};
60
61mod convert;
62
63struct ParamType {
64 arr: bool,
65 opt: bool,
66 base: String,
67}
68
69struct GoodQueryInput {
70 db_mod: Ident,
71 version: Option<usize>,
72 db_name: String,
73 sql: String,
74 param_types: Vec<(Ident, ParamType)>,
75 conn: syn::Expr,
76 params: Vec<syn::Expr>,
77}
78
79impl Parse for GoodQueryInput {
80 fn parse(input: ParseStream) -> syn::Result<Self> {
81 let db_mod: Ident = input.parse()?;
82 input.parse::<Token![,]>()?;
83 let (db_name, version, sql) = {
84 let first: LitStr = input.parse()?;
85 if input.peek(Token![;]) {
86 input.parse::<Token![;]>()?;
87 ("".to_string(), None, first.value())
88 } else {
89 input.parse::<Token![,]>()?;
90 let lookahead = input.lookahead1();
91 if lookahead.peek(LitInt) {
92 let version_lit: LitInt = input.parse()?;
93 let version = version_lit.base10_parse::<usize>()?;
94 input.parse::<Token![,]>()?;
95 let sql_lit: LitStr = input.parse()?;
96 let sql = sql_lit.value();
97 input.parse::<Token![;]>()?;
98 (first.value(), Some(version), sql)
99 } else if lookahead.peek(LitStr) {
100 let sql_lit: LitStr = input.parse()?;
101 let sql = sql_lit.value();
102 input.parse::<Token![;]>()?;
103 (first.value(), None, sql)
104 } else {
105 return Err(lookahead.error());
106 }
107 }
108 };
109 let conn: syn::Expr = input.parse()?;
110 let mut param_types = Vec::new();
111 let mut params = Vec::new();
112 while input.peek(Token![,]) {
113 input.parse::<Token![,]>()?;
114 if input.is_empty() {
115 break;
116 }
117 let name: Ident = input.parse()?;
118 input.parse::<Token![:]>()?;
119 let mut arr = false;
120 let mut opt = false;
121 let mut base = String::new();
122 while input.peek(Ident) {
123 let id: Ident = input.parse()?;
124 if id == "arr" {
125 arr = true;
126 } else if id == "opt" {
127 opt = true;
128 } else {
129 base = id.to_string();
130 break;
131 }
132 }
133 if base.is_empty() {
134 return Err(input.error("Expected parameter type"));
135 }
136 input.parse::<Token![=]>()?;
137 let val: syn::Expr = input.parse()?;
138 param_types.push((name, ParamType {
139 arr: arr,
140 opt: opt,
141 base: base,
142 }));
143 params.push(val);
144 }
145 let mut final_sql = String::new();
146 let mut last_end = 0;
147 let bytes = sql.as_bytes();
148 let mut i = 0;
149 while i < bytes.len() {
150 if bytes[i] == b'$' {
151 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
152 final_sql.push_str(&sql[last_end .. i]);
153 i += 2;
154 let content_start = i;
155 while i < bytes.len() && bytes[i] != b'}' {
156 i += 1;
157 }
158 if i >= bytes.len() {
159 return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
160 }
161 let content = &sql[content_start .. i];
162 i += 1;
163 let mut split = None;
164 for (idx, b) in content.as_bytes().iter().enumerate() {
165 if *b == b'=' {
166 split = Some((&content[..idx], &content[idx + 1..]));
167 break;
168 }
169 }
170 let (type_str, val_str) = split.ok_or_else(|| {
171 syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
172 })?;
173 let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
174 params.push(val);
175 param_types.push((name, pt));
176 final_sql.push_str(&format!("${}", param_idx));
177 last_end = i;
178 continue;
179 }
180 }
181 i += 1;
182 }
183 final_sql.push_str(&sql[last_end..]);
184 Ok(GoodQueryInput {
185 db_mod: db_mod,
186 version: version,
187 db_name: db_name,
188 sql: final_sql,
189 param_types: param_types,
190 conn: conn,
191 params: params,
192 })
193 }
194}
195
196fn parse_inline_param(
197 input: ParseStream,
198 type_str: &str,
199 val_str: &str,
200 current_params_len: usize,
201) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
202 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
203 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
204 })?;
205 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
206 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
207 })?;
208
209 use syn::parse::Parser;
210
211 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
212 let mut arr = false;
213 let mut opt = false;
214 let mut base = String::new();
215 while type_input.peek(Ident) {
216 let id: Ident = type_input.parse()?;
217 if id == "arr" {
218 arr = true;
219 } else if id == "opt" {
220 opt = true;
221 } else {
222 base = id.to_string();
223 break;
224 }
225 }
226 Ok((arr, opt, base))
227 }).parse2(type_tokens).map_err(|e| {
228 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
229 })?;
230 if base_p.is_empty() {
231 return Err(input.error("Expected base type in inline parameter"));
232 }
233 let param_idx = current_params_len + 1;
234 let name = format_ident!("p{}", param_idx);
235 Ok((param_idx, name, ParamType {
236 arr: arr_p,
237 opt: opt_p,
238 base: base_p,
239 }, val))
240}
241
242fn get_db_info(_engine: &str, provided_db_name: String) -> String {
243 provided_db_name
244}
245
246fn parse_and_generate_pg(
247 input: GoodQueryInput,
248 res_count: good_ormning_core::QueryResCount,
249) -> proc_macro2::TokenStream {
250 let db_name = get_db_info("pg", input.db_name.clone());
251 let dialect = sqlparser::dialect::PostgreSqlDialect {};
252 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
253 Ok(ast) => ast,
254 Err(e) => {
255 let e = e.to_string();
256 return quote!(compile_error!(#e));
257 },
258 };
259 if ast.is_empty() {
260 return quote!(compile_error!("Empty SQL statement"));
261 }
262 let statement = &ast[0];
263 let mut errs = Errs::new();
264 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
265 let path =
266 std::path::Path::new(&out_dir)
267 .join("good_ormning")
268 .join(good_ormning_core::utils::json_file_name(&db_name));
269 if !path.exists() {
270 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
271 return quote!(compile_error!(#e));
272 }
273 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
274 Ok(m) => m,
275 Err(e) => {
276 let e = e.to_string();
277 return quote!(compile_error!(#e));
278 },
279 };
280 let mut field_lookup = HashMap::new();
281 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
282 let version = match versions_map.get(&version_i) {
283 Some(v) => v,
284 None => {
285 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
286 return quote!(compile_error!(#e));
287 },
288 };
289 let custom_types = version.custom_types.clone();
290 for (table_id, table) in &version.tables {
291 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
292 for (field_id, field) in &table.fields {
293 fields.insert(PgFieldRef {
294 table_id: table_id.clone(),
295 field_id: field_id.clone(),
296 }, PgFieldInfo {
297 sql_name: field.id.clone(),
298 type_: field.type_.type_.clone(),
299 });
300 }
301 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
302 sql_name: table.id.clone(),
303 fields: fields,
304 });
305 }
306 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
307 query.res_count = res_count;
308 let mut hasher = DefaultHasher::new();
309 input.sql.hash(&mut hasher);
310 let query_hash = hasher.finish();
311 let query_name = format_ident!("good_query_{}", query_hash);
312 query.name = query_name.to_string();
313 let pascal_db_name: String = db_name.to_case(Case::Pascal);
314 let db_mod = &input.db_mod;
315 let db_type = if let Some(v) = input.version {
316 let name = format_ident!("Db{}{}", pascal_db_name, v);
317 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
318 } else {
319 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
320 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
321 };
322 let generated =
323 good_ormning_core::pg::query::generate::generate_query_functions(
324 &mut errs,
325 field_lookup,
326 vec![query],
327 "inline",
328 db_type,
329 );
330 let conn = &input.conn;
331 let args = &input.params;
332 let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
333 let db_mod_str = input.db_mod.to_string();
334 let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
335 quote!{
336 {
337 const _:() = {
338 if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
339 #[allow(unconditional_panic)]
340 let _ = ["Database name mismatch"][1];
341 }
342 };
343 use ::good_ormning::runtime::GoodError;
344 use ::good_ormning::runtime::ToGoodError;
345 use ::good_ormning::runtime::pg::PgConnection;
346 #(#generated) * #query_name(#conn, #(#args,) *)
347 }
348 }
349}
350
351fn parse_and_generate_sqlite(
352 input: GoodQueryInput,
353 res_count: good_ormning_core::QueryResCount,
354) -> proc_macro2::TokenStream {
355 let db_name = get_db_info("sqlite", input.db_name.clone());
356 let dialect = sqlparser::dialect::SQLiteDialect {};
357 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
358 Ok(ast) => ast,
359 Err(e) => {
360 let e = e.to_string();
361 return quote!(compile_error!(#e));
362 },
363 };
364 if ast.is_empty() {
365 return quote!(compile_error!("Empty SQL statement"));
366 }
367 let statement = &ast[0];
368 let mut errs = Errs::new();
369 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
370 let path =
371 std::path::Path::new(&out_dir)
372 .join("good_ormning")
373 .join(good_ormning_core::utils::json_file_name(&db_name));
374 if !path.exists() {
375 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
376 return quote!(compile_error!(#e));
377 }
378 let versions_map: HashMap<usize, SqliteVersion> =
379 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
380 Ok(m) => m,
381 Err(e) => {
382 let e = e.to_string();
383 return quote!(compile_error!(#e));
384 },
385 };
386 let mut field_lookup = HashMap::new();
387 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
388 let version = match versions_map.get(&version_i) {
389 Some(v) => v,
390 None => {
391 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
392 return quote!(compile_error!(#e));
393 },
394 };
395 let custom_types = version.custom_types.clone();
396 for (table_id, table) in &version.tables {
397 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
398 for (field_id, field) in &table.fields {
399 fields.insert(SqliteFieldRef {
400 table_id: table_id.clone(),
401 field_id: field_id.clone(),
402 }, SqliteFieldInfo {
403 sql_name: field.id.clone(),
404 type_: field.type_.type_.clone(),
405 });
406 }
407 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
408 sql_name: table.id.clone(),
409 fields: fields,
410 });
411 }
412 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
413 query.res_count = res_count;
414 let mut hasher = DefaultHasher::new();
415 input.sql.hash(&mut hasher);
416 let query_hash = hasher.finish();
417 let query_name = format_ident!("good_query_{}", query_hash);
418 query.name = query_name.to_string();
419 let pascal_db_name: String = db_name.to_case(Case::Pascal);
420 let db_mod = &input.db_mod;
421 let db_type = if let Some(v) = input.version {
422 let name = format_ident!("Db{}{}", pascal_db_name, v);
423 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
424 } else {
425 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
426 quote!(#db_mod::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
427 };
428 let generated =
429 good_ormning_core::sqlite::query::generate::generate_query_functions(
430 &mut errs,
431 field_lookup,
432 vec![query],
433 "inline",
434 db_type,
435 );
436 let conn = &input.conn;
437 let args = &input.params;
438 let db_name_lit = LitStr::new(&db_name, input.db_mod.span());
439 let db_mod_str = input.db_mod.to_string();
440 let _db_mod_lit = LitStr::new(&db_mod_str, input.db_mod.span());
441 quote!{
442 {
443 const _:() = {
444 if !:: good_ormning:: runtime:: utils:: str_eq(#db_mod:: DB_NAME, #db_name_lit) {
445 #[allow(unconditional_panic)]
446 let _ = ["Database name mismatch"][1];
447 }
448 };
449 use ::good_ormning::runtime::GoodError;
450 use ::good_ormning::runtime::ToGoodError;
451 use ::good_ormning::runtime::sqlite::SqliteConnection;
452 #(#generated) * #query_name(#conn, #(#args,) *)
453 }
454 }
455}
456
457#[proc_macro]
459pub fn good_query_pg(input: TokenStream) -> TokenStream {
460 let input = parse_macro_input!(input as GoodQueryInput);
461 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
462}
463
464#[proc_macro]
466pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
467 let input = parse_macro_input!(input as GoodQueryInput);
468 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
469}
470
471#[proc_macro]
473pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
474 let input = parse_macro_input!(input as GoodQueryInput);
475 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
476}
477
478#[proc_macro]
480pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
481 let input = parse_macro_input!(input as GoodQueryInput);
482 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
483}
484
485#[proc_macro]
487pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
488 let input = parse_macro_input!(input as GoodQueryInput);
489 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
490}
491
492#[proc_macro]
494pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
495 let input = parse_macro_input!(input as GoodQueryInput);
496 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
497}
498
499#[proc_macro]
501pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
502 let input = parse_macro_input!(input as GoodQueryInput);
503 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
504}
505
506#[proc_macro]
508pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
509 let input = parse_macro_input!(input as GoodQueryInput);
510 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
511}