1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitInt,
51 LitStr,
52 Token,
53 parse::{
54 Parse,
55 ParseStream,
56 },
57 parse_macro_input,
58 },
59};
60
61mod convert;
62
63struct ParamType {
64 arr: bool,
65 opt: bool,
66 base: String,
67}
68
69struct GoodQueryInput {
70 version: Option<usize>,
71 db_name: String,
72 sql: String,
73 param_types: Vec<(Ident, ParamType)>,
74 conn: syn::Expr,
75 params: Vec<syn::Expr>,
76}
77
78impl Parse for GoodQueryInput {
79 fn parse(input: ParseStream) -> syn::Result<Self> {
80 let (db_name, version, sql) = {
81 let first: LitStr = input.parse()?;
82
83 if input.peek(Token![;]) {
84 input.parse::<Token![;]>()?;
85 ("".to_string(), None, first.value())
86 } else {
87 input.parse::<Token![,]>()?;
88
89 let lookahead = input.lookahead1();
90 if lookahead.peek(LitInt) {
91 let version_lit: LitInt = input.parse()?;
92 let version = version_lit.base10_parse::<usize>()?;
93 input.parse::<Token![,]>()?;
94 let sql_lit: LitStr = input.parse()?;
95 let sql = sql_lit.value();
96 input.parse::<Token![;]>()?;
97 (first.value(), Some(version), sql)
98 } else if lookahead.peek(LitStr) {
99 let sql_lit: LitStr = input.parse()?;
100 let sql = sql_lit.value();
101 input.parse::<Token![;]>()?;
102 (first.value(), None, sql)
103 } else {
104 return Err(lookahead.error());
105 }
106 }
107 };
108 let conn: syn::Expr = input.parse()?;
109 let mut param_types = Vec::new();
110 let mut params = Vec::new();
111 while input.peek(Token![,]) {
112 input.parse::<Token![,]>()?;
113 if input.is_empty() {
114 break;
115 }
116 let name: Ident = input.parse()?;
117 input.parse::<Token![:]>()?;
118 let mut arr = false;
119 let mut opt = false;
120 let mut base = String::new();
121 while input.peek(Ident) {
122 let id: Ident = input.parse()?;
123 if id == "arr" {
124 arr = true;
125 } else if id == "opt" {
126 opt = true;
127 } else {
128 base = id.to_string();
129 break;
130 }
131 }
132 if base.is_empty() {
133 return Err(input.error("Expected parameter type"));
134 }
135 input.parse::<Token![=]>()?;
136 let val: syn::Expr = input.parse()?;
137 param_types.push((name, ParamType {
138 arr: arr,
139 opt: opt,
140 base: base,
141 }));
142 params.push(val);
143 }
144 let mut final_sql = String::new();
145 let mut last_end = 0;
146 let bytes = sql.as_bytes();
147 let mut i = 0;
148 while i < bytes.len() {
149 if bytes[i] == b'$' {
150 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
151 final_sql.push_str(&sql[last_end .. i]);
152 i += 2;
153 let content_start = i;
154 while i < bytes.len() && bytes[i] != b'}' {
155 i += 1;
156 }
157 if i >= bytes.len() {
158 return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
159 }
160 let content = &sql[content_start .. i];
161 i += 1;
162 let mut split = None;
163 for (idx, b) in content.as_bytes().iter().enumerate() {
164 if *b == b'=' {
165 split = Some((&content[..idx], &content[idx + 1..]));
166 break;
167 }
168 }
169 let (type_str, val_str) = split.ok_or_else(|| {
170 syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
171 })?;
172 let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
173 params.push(val);
174 param_types.push((name, pt));
175 final_sql.push_str(&format!("${}", param_idx));
176 last_end = i;
177 continue;
178 }
179 }
180 i += 1;
181 }
182 final_sql.push_str(&sql[last_end..]);
183 Ok(GoodQueryInput {
184 version: version,
185 db_name: db_name,
186 sql: final_sql,
187 param_types: param_types,
188 conn: conn,
189 params: params,
190 })
191 }
192}
193
194fn parse_inline_param(
195 input: ParseStream,
196 type_str: &str,
197 val_str: &str,
198 current_params_len: usize,
199) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
200 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
201 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
202 })?;
203 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
204 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
205 })?;
206
207 use syn::parse::Parser;
208
209 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
210 let mut arr = false;
211 let mut opt = false;
212 let mut base = String::new();
213 while type_input.peek(Ident) {
214 let id: Ident = type_input.parse()?;
215 if id == "arr" {
216 arr = true;
217 } else if id == "opt" {
218 opt = true;
219 } else {
220 base = id.to_string();
221 break;
222 }
223 }
224 Ok((arr, opt, base))
225 }).parse2(type_tokens).map_err(|e| {
226 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
227 })?;
228 if base_p.is_empty() {
229 return Err(input.error("Expected base type in inline parameter"));
230 }
231 let param_idx = current_params_len + 1;
232 let name = format_ident!("p{}", param_idx);
233 Ok((param_idx, name, ParamType {
234 arr: arr_p,
235 opt: opt_p,
236 base: base_p,
237 }, val))
238}
239
240fn get_db_info(_engine: &str, provided_db_name: String) -> String {
241 provided_db_name
242}
243
244fn parse_and_generate_pg(
245 input: GoodQueryInput,
246 res_count: good_ormning_core::QueryResCount,
247) -> proc_macro2::TokenStream {
248 let db_name = get_db_info("pg", input.db_name.clone());
249 let dialect = sqlparser::dialect::PostgreSqlDialect {};
250 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
251 Ok(ast) => ast,
252 Err(e) => {
253 let e = e.to_string();
254 return quote!(compile_error!(#e));
255 },
256 };
257 if ast.is_empty() {
258 return quote!(compile_error!("Empty SQL statement"));
259 }
260 let statement = &ast[0];
261 let mut errs = Errs::new();
262 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
263 let path =
264 std::path::Path::new(&out_dir)
265 .join("good_ormning")
266 .join(good_ormning_core::utils::json_file_name(&db_name));
267 if !path.exists() {
268 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
269 return quote!(compile_error!(#e));
270 }
271 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
272 Ok(m) => m,
273 Err(e) => {
274 let e = e.to_string();
275 return quote!(compile_error!(#e));
276 },
277 };
278 let mut field_lookup = HashMap::new();
279 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
280 let version = match versions_map.get(&version_i) {
281 Some(v) => v,
282 None => {
283 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
284 return quote!(compile_error!(#e));
285 },
286 };
287 let custom_types = version.custom_types.clone();
288 for (table_id, table) in &version.tables {
289 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
290 for (field_id, field) in &table.fields {
291 fields.insert(PgFieldRef {
292 table_id: table_id.clone(),
293 field_id: field_id.clone(),
294 }, PgFieldInfo {
295 sql_name: field.id.clone(),
296 type_: field.type_.type_.clone(),
297 });
298 }
299 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
300 sql_name: table.id.clone(),
301 fields: fields,
302 });
303 }
304 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
305 query.res_count = res_count;
306 let mut hasher = DefaultHasher::new();
307 input.sql.hash(&mut hasher);
308 let query_hash = hasher.finish();
309 let query_name = format_ident!("good_query_{}", query_hash);
310 query.name = query_name.to_string();
311 let pascal_db_name: String = db_name.to_case(Case::Pascal);
312 let db_type = if let Some(v) = input.version {
313 let name = format_ident!("Db{}{}", pascal_db_name, v);
314 quote!(dbm::#name <impl ::good_ormning::runtime::pg::PgConnection>)
315 } else {
316 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
317 quote!(dbm::#name <impl ::good_ormning::runtime::pg::PgConnection>)
318 };
319 let generated =
320 good_ormning_core::pg::query::generate::generate_query_functions(
321 &mut errs,
322 field_lookup,
323 vec![query],
324 "inline",
325 db_type,
326 );
327 let conn = &input.conn;
328 let args = &input.params;
329 quote!{
330 {
331 use ::good_ormning::runtime::GoodError;
332 use ::good_ormning::runtime::ToGoodError;
333 use ::good_ormning::runtime::pg::PgConnection;
334 #(#generated) * #query_name(#conn, #(#args,) *)
335 }
336 }
337}
338
339fn parse_and_generate_sqlite(
340 input: GoodQueryInput,
341 res_count: good_ormning_core::QueryResCount,
342) -> proc_macro2::TokenStream {
343 let db_name = get_db_info("sqlite", input.db_name.clone());
344 let dialect = sqlparser::dialect::SQLiteDialect {};
345 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
346 Ok(ast) => ast,
347 Err(e) => {
348 let e = e.to_string();
349 return quote!(compile_error!(#e));
350 },
351 };
352 if ast.is_empty() {
353 return quote!(compile_error!("Empty SQL statement"));
354 }
355 let statement = &ast[0];
356 let mut errs = Errs::new();
357 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
358 let path =
359 std::path::Path::new(&out_dir)
360 .join("good_ormning")
361 .join(good_ormning_core::utils::json_file_name(&db_name));
362 if !path.exists() {
363 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
364 return quote!(compile_error!(#e));
365 }
366 let versions_map: HashMap<usize, SqliteVersion> =
367 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
368 Ok(m) => m,
369 Err(e) => {
370 let e = e.to_string();
371 return quote!(compile_error!(#e));
372 },
373 };
374 let mut field_lookup = HashMap::new();
375 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
376 let version = match versions_map.get(&version_i) {
377 Some(v) => v,
378 None => {
379 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
380 return quote!(compile_error!(#e));
381 },
382 };
383 let custom_types = version.custom_types.clone();
384 for (table_id, table) in &version.tables {
385 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
386 for (field_id, field) in &table.fields {
387 fields.insert(SqliteFieldRef {
388 table_id: table_id.clone(),
389 field_id: field_id.clone(),
390 }, SqliteFieldInfo {
391 sql_name: field.id.clone(),
392 type_: field.type_.type_.clone(),
393 });
394 }
395 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
396 sql_name: table.id.clone(),
397 fields: fields,
398 });
399 }
400 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
401 query.res_count = res_count;
402 let mut hasher = DefaultHasher::new();
403 input.sql.hash(&mut hasher);
404 let query_hash = hasher.finish();
405 let query_name = format_ident!("good_query_{}", query_hash);
406 query.name = query_name.to_string();
407 let pascal_db_name: String = db_name.to_case(Case::Pascal);
408 let db_type = if let Some(v) = input.version {
409 let name = format_ident!("Db{}{}", pascal_db_name, v);
410 quote!(dbm::#name <impl ::good_ormning::runtime::sqlite::SqliteConnection>)
411 } else {
412 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
413 quote!(dbm::#name <impl ::good_ormning::runtime::sqlite::SqliteConnection>)
414 };
415 let generated =
416 good_ormning_core::sqlite::query::generate::generate_query_functions(
417 &mut errs,
418 field_lookup,
419 vec![query],
420 "inline",
421 db_type,
422 );
423 let conn = &input.conn;
424 let args = &input.params;
425 quote!{
426 {
427 use ::good_ormning::runtime::GoodError;
428 use ::good_ormning::runtime::ToGoodError;
429 use ::good_ormning::runtime::sqlite::SqliteConnection;
430 #(#generated) * #query_name(#conn, #(#args,) *)
431 }
432 }
433}
434
435#[proc_macro]
437pub fn good_query_pg(input: TokenStream) -> TokenStream {
438 let input = parse_macro_input!(input as GoodQueryInput);
439 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
440}
441
442#[proc_macro]
444pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
445 let input = parse_macro_input!(input as GoodQueryInput);
446 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
447}
448
449#[proc_macro]
451pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
452 let input = parse_macro_input!(input as GoodQueryInput);
453 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
454}
455
456#[proc_macro]
458pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
459 let input = parse_macro_input!(input as GoodQueryInput);
460 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
461}
462
463#[proc_macro]
465pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
466 let input = parse_macro_input!(input as GoodQueryInput);
467 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
468}
469
470#[proc_macro]
472pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
473 let input = parse_macro_input!(input as GoodQueryInput);
474 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
475}
476
477#[proc_macro]
479pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
480 let input = parse_macro_input!(input as GoodQueryInput);
481 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
482}
483
484#[proc_macro]
486pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
487 let input = parse_macro_input!(input as GoodQueryInput);
488 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
489}