1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitInt,
51 LitStr,
52 Token,
53 parse::{
54 Parse,
55 ParseStream,
56 },
57 parse_macro_input,
58 },
59};
60
61mod convert;
62
63struct ParamType {
64 arr: bool,
65 opt: bool,
66 base: String,
67}
68
69struct GoodQueryInput {
70 version: Option<usize>,
71 db_name: String,
72 sql: String,
73 param_types: Vec<(Ident, ParamType)>,
74 conn: syn::Expr,
75 params: Vec<syn::Expr>,
76}
77
78impl Parse for GoodQueryInput {
79 fn parse(input: ParseStream) -> syn::Result<Self> {
80 let (db_name, version, sql) = {
81 let first: LitStr = input.parse()?;
82 if input.peek(Token![;]) {
83 input.parse::<Token![;]>()?;
84 ("".to_string(), None, first.value())
85 } else {
86 input.parse::<Token![,]>()?;
87 let lookahead = input.lookahead1();
88 if lookahead.peek(LitInt) {
89 let version_lit: LitInt = input.parse()?;
90 let version = version_lit.base10_parse::<usize>()?;
91 input.parse::<Token![,]>()?;
92 let sql_lit: LitStr = input.parse()?;
93 let sql = sql_lit.value();
94 input.parse::<Token![;]>()?;
95 (first.value(), Some(version), sql)
96 } else if lookahead.peek(LitStr) {
97 let sql_lit: LitStr = input.parse()?;
98 let sql = sql_lit.value();
99 input.parse::<Token![;]>()?;
100 (first.value(), None, sql)
101 } else {
102 return Err(lookahead.error());
103 }
104 }
105 };
106 let conn: syn::Expr = input.parse()?;
107 let mut param_types = Vec::new();
108 let mut params = Vec::new();
109 while input.peek(Token![,]) {
110 input.parse::<Token![,]>()?;
111 if input.is_empty() {
112 break;
113 }
114 let name: Ident = input.parse()?;
115 input.parse::<Token![:]>()?;
116 let mut arr = false;
117 let mut opt = false;
118 let mut base = String::new();
119 while input.peek(Ident) {
120 let id: Ident = input.parse()?;
121 if id == "arr" {
122 arr = true;
123 } else if id == "opt" {
124 opt = true;
125 } else {
126 base = id.to_string();
127 break;
128 }
129 }
130 if base.is_empty() {
131 return Err(input.error("Expected parameter type"));
132 }
133 input.parse::<Token![=]>()?;
134 let val: syn::Expr = input.parse()?;
135 param_types.push((name, ParamType {
136 arr: arr,
137 opt: opt,
138 base: base,
139 }));
140 params.push(val);
141 }
142 let mut final_sql = String::new();
143 let mut last_end = 0;
144 let bytes = sql.as_bytes();
145 let mut i = 0;
146 while i < bytes.len() {
147 if bytes[i] == b'$' {
148 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
149 final_sql.push_str(&sql[last_end .. i]);
150 i += 2;
151 let content_start = i;
152 while i < bytes.len() && bytes[i] != b'}' {
153 i += 1;
154 }
155 if i >= bytes.len() {
156 return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
157 }
158 let content = &sql[content_start .. i];
159 i += 1;
160 let mut split = None;
161 for (idx, b) in content.as_bytes().iter().enumerate() {
162 if *b == b'=' {
163 split = Some((&content[..idx], &content[idx + 1..]));
164 break;
165 }
166 }
167 let (type_str, val_str) = split.ok_or_else(|| {
168 syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
169 })?;
170 let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
171 params.push(val);
172 param_types.push((name, pt));
173 final_sql.push_str(&format!("${}", param_idx));
174 last_end = i;
175 continue;
176 }
177 }
178 i += 1;
179 }
180 final_sql.push_str(&sql[last_end..]);
181 Ok(GoodQueryInput {
182 version: version,
183 db_name: db_name,
184 sql: final_sql,
185 param_types: param_types,
186 conn: conn,
187 params: params,
188 })
189 }
190}
191
192fn parse_inline_param(
193 input: ParseStream,
194 type_str: &str,
195 val_str: &str,
196 current_params_len: usize,
197) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
198 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
199 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
200 })?;
201 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
202 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
203 })?;
204
205 use syn::parse::Parser;
206
207 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
208 let mut arr = false;
209 let mut opt = false;
210 let mut base = String::new();
211 while type_input.peek(Ident) {
212 let id: Ident = type_input.parse()?;
213 if id == "arr" {
214 arr = true;
215 } else if id == "opt" {
216 opt = true;
217 } else {
218 base = id.to_string();
219 break;
220 }
221 }
222 Ok((arr, opt, base))
223 }).parse2(type_tokens).map_err(|e| {
224 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
225 })?;
226 if base_p.is_empty() {
227 return Err(input.error("Expected base type in inline parameter"));
228 }
229 let param_idx = current_params_len + 1;
230 let name = format_ident!("p{}", param_idx);
231 Ok((param_idx, name, ParamType {
232 arr: arr_p,
233 opt: opt_p,
234 base: base_p,
235 }, val))
236}
237
238fn get_db_info(_engine: &str, provided_db_name: String) -> String {
239 provided_db_name
240}
241
242fn parse_and_generate_pg(
243 input: GoodQueryInput,
244 res_count: good_ormning_core::QueryResCount,
245) -> proc_macro2::TokenStream {
246 let db_name = get_db_info("pg", input.db_name.clone());
247 let dialect = sqlparser::dialect::PostgreSqlDialect {};
248 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
249 Ok(ast) => ast,
250 Err(e) => {
251 let e = e.to_string();
252 return quote!(compile_error!(#e));
253 },
254 };
255 if ast.is_empty() {
256 return quote!(compile_error!("Empty SQL statement"));
257 }
258 let statement = &ast[0];
259 let mut errs = Errs::new();
260 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
261 let path =
262 std::path::Path::new(&out_dir)
263 .join("good_ormning")
264 .join(good_ormning_core::utils::json_file_name(&db_name));
265 if !path.exists() {
266 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
267 return quote!(compile_error!(#e));
268 }
269 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
270 Ok(m) => m,
271 Err(e) => {
272 let e = e.to_string();
273 return quote!(compile_error!(#e));
274 },
275 };
276 let mut field_lookup = HashMap::new();
277 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
278 let version = match versions_map.get(&version_i) {
279 Some(v) => v,
280 None => {
281 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
282 return quote!(compile_error!(#e));
283 },
284 };
285 let custom_types = version.custom_types.clone();
286 for (table_id, table) in &version.tables {
287 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
288 for (field_id, field) in &table.fields {
289 fields.insert(PgFieldRef {
290 table_id: table_id.clone(),
291 field_id: field_id.clone(),
292 }, PgFieldInfo {
293 sql_name: field.id.clone(),
294 type_: field.type_.type_.clone(),
295 });
296 }
297 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
298 sql_name: table.id.clone(),
299 fields: fields,
300 });
301 }
302 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
303 query.res_count = res_count;
304 let mut hasher = DefaultHasher::new();
305 input.sql.hash(&mut hasher);
306 let query_hash = hasher.finish();
307 let query_name = format_ident!("good_query_{}", query_hash);
308 query.name = query_name.to_string();
309 let pascal_db_name: String = db_name.to_case(Case::Pascal);
310 let db_type = if let Some(v) = input.version {
311 let name = format_ident!("Db{}{}", pascal_db_name, v);
312 quote!(dbm::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
313 } else {
314 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
315 quote!(dbm::#name < impl:: good_ormning:: runtime:: pg:: PgConnection >)
316 };
317 let generated =
318 good_ormning_core::pg::query::generate::generate_query_functions(
319 &mut errs,
320 field_lookup,
321 vec![query],
322 "inline",
323 db_type,
324 );
325 let conn = &input.conn;
326 let args = &input.params;
327 quote!{
328 {
329 use ::good_ormning::runtime::GoodError;
330 use ::good_ormning::runtime::ToGoodError;
331 use ::good_ormning::runtime::pg::PgConnection;
332 #(#generated) * #query_name(#conn, #(#args,) *)
333 }
334 }
335}
336
337fn parse_and_generate_sqlite(
338 input: GoodQueryInput,
339 res_count: good_ormning_core::QueryResCount,
340) -> proc_macro2::TokenStream {
341 let db_name = get_db_info("sqlite", input.db_name.clone());
342 let dialect = sqlparser::dialect::SQLiteDialect {};
343 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
344 Ok(ast) => ast,
345 Err(e) => {
346 let e = e.to_string();
347 return quote!(compile_error!(#e));
348 },
349 };
350 if ast.is_empty() {
351 return quote!(compile_error!("Empty SQL statement"));
352 }
353 let statement = &ast[0];
354 let mut errs = Errs::new();
355 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
356 let path =
357 std::path::Path::new(&out_dir)
358 .join("good_ormning")
359 .join(good_ormning_core::utils::json_file_name(&db_name));
360 if !path.exists() {
361 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
362 return quote!(compile_error!(#e));
363 }
364 let versions_map: HashMap<usize, SqliteVersion> =
365 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
366 Ok(m) => m,
367 Err(e) => {
368 let e = e.to_string();
369 return quote!(compile_error!(#e));
370 },
371 };
372 let mut field_lookup = HashMap::new();
373 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
374 let version = match versions_map.get(&version_i) {
375 Some(v) => v,
376 None => {
377 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
378 return quote!(compile_error!(#e));
379 },
380 };
381 let custom_types = version.custom_types.clone();
382 for (table_id, table) in &version.tables {
383 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
384 for (field_id, field) in &table.fields {
385 fields.insert(SqliteFieldRef {
386 table_id: table_id.clone(),
387 field_id: field_id.clone(),
388 }, SqliteFieldInfo {
389 sql_name: field.id.clone(),
390 type_: field.type_.type_.clone(),
391 });
392 }
393 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
394 sql_name: table.id.clone(),
395 fields: fields,
396 });
397 }
398 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
399 query.res_count = res_count;
400 let mut hasher = DefaultHasher::new();
401 input.sql.hash(&mut hasher);
402 let query_hash = hasher.finish();
403 let query_name = format_ident!("good_query_{}", query_hash);
404 query.name = query_name.to_string();
405 let pascal_db_name: String = db_name.to_case(Case::Pascal);
406 let db_type = if let Some(v) = input.version {
407 let name = format_ident!("Db{}{}", pascal_db_name, v);
408 quote!(dbm::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
409 } else {
410 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
411 quote!(dbm::#name < impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
412 };
413 let generated =
414 good_ormning_core::sqlite::query::generate::generate_query_functions(
415 &mut errs,
416 field_lookup,
417 vec![query],
418 "inline",
419 db_type,
420 );
421 let conn = &input.conn;
422 let args = &input.params;
423 quote!{
424 {
425 use ::good_ormning::runtime::GoodError;
426 use ::good_ormning::runtime::ToGoodError;
427 use ::good_ormning::runtime::sqlite::SqliteConnection;
428 #(#generated) * #query_name(#conn, #(#args,) *)
429 }
430 }
431}
432
433#[proc_macro]
435pub fn good_query_pg(input: TokenStream) -> TokenStream {
436 let input = parse_macro_input!(input as GoodQueryInput);
437 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
438}
439
440#[proc_macro]
442pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
443 let input = parse_macro_input!(input as GoodQueryInput);
444 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
445}
446
447#[proc_macro]
449pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
450 let input = parse_macro_input!(input as GoodQueryInput);
451 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
452}
453
454#[proc_macro]
456pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
457 let input = parse_macro_input!(input as GoodQueryInput);
458 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
459}
460
461#[proc_macro]
463pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
464 let input = parse_macro_input!(input as GoodQueryInput);
465 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
466}
467
468#[proc_macro]
470pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
471 let input = parse_macro_input!(input as GoodQueryInput);
472 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
473}
474
475#[proc_macro]
477pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
478 let input = parse_macro_input!(input as GoodQueryInput);
479 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
480}
481
482#[proc_macro]
484pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
485 let input = parse_macro_input!(input as GoodQueryInput);
486 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
487}