1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitStr,
51 Token,
52 parse::{
53 Parse,
54 ParseStream,
55 },
56 parse_macro_input,
57 },
58};
59
60mod convert;
61
62struct ParamType {
63 arr: bool,
64 opt: bool,
65 base: String,
66}
67
68struct GoodQueryInput {
69 version: Option<usize>,
70 db_name: String,
71 sql: String,
72 param_types: Vec<(Ident, ParamType)>,
73 conn: syn::Expr,
74 params: Vec<syn::Expr>,
75}
76
77impl Parse for GoodQueryInput {
78 fn parse(input: ParseStream) -> syn::Result<Self> {
79 let mut literals = Vec::new();
80 while !input.peek(Token![;]) && !input.is_empty() {
81 literals.push(input.parse::<LitStr>()?);
82 if input.peek(Token![,]) {
83 input.parse::<Token![,]>()?;
84 } else if !input.peek(Token![;]) {
85 return Err(input.error("Expected , or ; after literal"));
86 }
87 }
88 input.parse::<Token![;]>()?;
89 let (db_name, version, sql) = match literals.len() {
90 1 => ("".to_string(), None, literals[0].value()),
91 2 => (literals[0].value(), None, literals[1].value()),
92 3 => {
93 let v_str = literals[1].value();
94 let v = v_str.parse::<usize>().map_err(|_| {
95 syn::Error::new(literals[1].span(), "Version must be a positive integer")
96 })?;
97 (literals[0].value(), Some(v), literals[2].value())
98 },
99 _ => {
100 return Err(
101 input.error(
102 "Expected 1, 2, or 3 literals before semicolon (sql, [db_name, sql], or [db_name, version, sql])",
103 ),
104 );
105 },
106 };
107 let conn: syn::Expr = input.parse()?;
108 let mut param_types = Vec::new();
109 let mut params = Vec::new();
110 while input.peek(Token![,]) {
111 input.parse::<Token![,]>()?;
112 if input.is_empty() {
113 break;
114 }
115 let name: Ident = input.parse()?;
116 input.parse::<Token![:]>()?;
117 let mut arr = false;
118 let mut opt = false;
119 let mut base = String::new();
120 while input.peek(Ident) {
121 let id: Ident = input.parse()?;
122 if id == "arr" {
123 arr = true;
124 } else if id == "opt" {
125 opt = true;
126 } else {
127 base = id.to_string();
128 break;
129 }
130 }
131 if base.is_empty() {
132 return Err(input.error("Expected parameter type"));
133 }
134 input.parse::<Token![=]>()?;
135 let val: syn::Expr = input.parse()?;
136 param_types.push((name, ParamType {
137 arr: arr,
138 opt: opt,
139 base: base,
140 }));
141 params.push(val);
142 }
143 let mut final_sql = String::new();
144 let mut last_end = 0;
145 let bytes = sql.as_bytes();
146 let mut i = 0;
147 while i < bytes.len() {
148 if bytes[i] == b'$' {
149 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
150 final_sql.push_str(&sql[last_end .. i]);
151 i += 2;
152 let content_start = i;
153 while i < bytes.len() && bytes[i] != b'}' {
154 i += 1;
155 }
156 if i >= bytes.len() {
157 return Err(syn::Error::new(input.span(), "Unclosed inline parameter ${"));
158 }
159 let content = &sql[content_start .. i];
160 i += 1;
161 let mut split = None;
162 for (idx, b) in content.as_bytes().iter().enumerate() {
163 if *b == b'=' {
164 split = Some((&content[..idx], &content[idx + 1..]));
165 break;
166 }
167 }
168 let (type_str, val_str) = split.ok_or_else(|| {
169 syn::Error::new(input.span(), "Invalid inline parameter format. Expected ${type = value}")
170 })?;
171 let (param_idx, name, pt, val) = parse_inline_param(input, type_str, val_str, params.len())?;
172 params.push(val);
173 param_types.push((name, pt));
174 final_sql.push_str(&format!("${}", param_idx));
175 last_end = i;
176 continue;
177 }
178 }
179 i += 1;
180 }
181 final_sql.push_str(&sql[last_end..]);
182 Ok(GoodQueryInput {
183 version: version,
184 db_name: db_name,
185 sql: final_sql,
186 param_types: param_types,
187 conn: conn,
188 params: params,
189 })
190 }
191}
192
193fn parse_inline_param(
194 input: ParseStream,
195 type_str: &str,
196 val_str: &str,
197 current_params_len: usize,
198) -> syn::Result<(usize, Ident, ParamType, syn::Expr)> {
199 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
200 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
201 })?;
202 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
203 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
204 })?;
205
206 use syn::parse::Parser;
207
208 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
209 let mut arr = false;
210 let mut opt = false;
211 let mut base = String::new();
212 while type_input.peek(Ident) {
213 let id: Ident = type_input.parse()?;
214 if id == "arr" {
215 arr = true;
216 } else if id == "opt" {
217 opt = true;
218 } else {
219 base = id.to_string();
220 break;
221 }
222 }
223 Ok((arr, opt, base))
224 }).parse2(type_tokens).map_err(|e| {
225 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
226 })?;
227 if base_p.is_empty() {
228 return Err(input.error("Expected base type in inline parameter"));
229 }
230 let param_idx = current_params_len + 1;
231 let name = format_ident!("p{}", param_idx);
232 Ok((param_idx, name, ParamType {
233 arr: arr_p,
234 opt: opt_p,
235 base: base_p,
236 }, val))
237}
238
239fn get_db_info(_engine: &str, provided_db_name: String) -> String {
240 provided_db_name
241}
242
243fn parse_and_generate_pg(
244 input: GoodQueryInput,
245 res_count: good_ormning_core::QueryResCount,
246) -> proc_macro2::TokenStream {
247 let db_name = get_db_info("pg", input.db_name.clone());
248 let dialect = sqlparser::dialect::PostgreSqlDialect {};
249 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
250 Ok(ast) => ast,
251 Err(e) => {
252 let e = e.to_string();
253 return quote!(compile_error!(#e));
254 },
255 };
256 if ast.is_empty() {
257 return quote!(compile_error!("Empty SQL statement"));
258 }
259 let statement = &ast[0];
260 let mut errs = Errs::new();
261 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
262 let path =
263 std::path::Path::new(&out_dir)
264 .join("good_ormning")
265 .join(good_ormning_core::utils::json_file_name(&db_name));
266 if !path.exists() {
267 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
268 return quote!(compile_error!(#e));
269 }
270 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
271 Ok(m) => m,
272 Err(e) => {
273 let e = e.to_string();
274 return quote!(compile_error!(#e));
275 },
276 };
277 let mut field_lookup = HashMap::new();
278 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
279 let version = match versions_map.get(&version_i) {
280 Some(v) => v,
281 None => {
282 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
283 return quote!(compile_error!(#e));
284 },
285 };
286 let custom_types = version.custom_types.clone();
287 for (table_id, table) in &version.tables {
288 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
289 for (field_id, field) in &table.fields {
290 fields.insert(PgFieldRef {
291 table_id: table_id.clone(),
292 field_id: field_id.clone(),
293 }, PgFieldInfo {
294 sql_name: field.id.clone(),
295 type_: field.type_.type_.clone(),
296 });
297 }
298 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
299 sql_name: table.id.clone(),
300 fields: fields,
301 });
302 }
303 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
304 query.res_count = res_count;
305 let mut hasher = DefaultHasher::new();
306 input.sql.hash(&mut hasher);
307 let query_hash = hasher.finish();
308 let query_name = format_ident!("good_query_{}", query_hash);
309 query.name = query_name.to_string();
310 let pascal_db_name: String = db_name.to_case(Case::Pascal);
311 let db_type = if let Some(v) = input.version {
312 let name = format_ident!("Db{}{}", pascal_db_name, v);
313 quote!(dbm::#name <'_ >)
314 } else {
315 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
316 quote!(dbm::#name <'_ >)
317 };
318 let generated =
319 good_ormning_core::pg::query::generate::generate_query_functions(
320 &mut errs,
321 field_lookup,
322 vec![query],
323 "inline",
324 db_type,
325 );
326 let conn = &input.conn;
327 let args = &input.params;
328 quote!{
329 {
330 use ::good_ormning::runtime::GoodError;
331 use ::good_ormning::runtime::ToGoodError;
332 use ::good_ormning::runtime::pg::PgConnection;
333 #(#generated) * #query_name(& mut #conn, #(#args,) *)
334 }
335 }
336}
337
338fn parse_and_generate_sqlite(
339 input: GoodQueryInput,
340 res_count: good_ormning_core::QueryResCount,
341) -> proc_macro2::TokenStream {
342 let db_name = get_db_info("sqlite", input.db_name.clone());
343 let dialect = sqlparser::dialect::SQLiteDialect {};
344 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
345 Ok(ast) => ast,
346 Err(e) => {
347 let e = e.to_string();
348 return quote!(compile_error!(#e));
349 },
350 };
351 if ast.is_empty() {
352 return quote!(compile_error!("Empty SQL statement"));
353 }
354 let statement = &ast[0];
355 let mut errs = Errs::new();
356 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
357 let path =
358 std::path::Path::new(&out_dir)
359 .join("good_ormning")
360 .join(good_ormning_core::utils::json_file_name(&db_name));
361 if !path.exists() {
362 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
363 return quote!(compile_error!(#e));
364 }
365 let versions_map: HashMap<usize, SqliteVersion> =
366 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
367 Ok(m) => m,
368 Err(e) => {
369 let e = e.to_string();
370 return quote!(compile_error!(#e));
371 },
372 };
373 let mut field_lookup = HashMap::new();
374 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
375 let version = match versions_map.get(&version_i) {
376 Some(v) => v,
377 None => {
378 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
379 return quote!(compile_error!(#e));
380 },
381 };
382 let custom_types = version.custom_types.clone();
383 for (table_id, table) in &version.tables {
384 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
385 for (field_id, field) in &table.fields {
386 fields.insert(SqliteFieldRef {
387 table_id: table_id.clone(),
388 field_id: field_id.clone(),
389 }, SqliteFieldInfo {
390 sql_name: field.id.clone(),
391 type_: field.type_.type_.clone(),
392 });
393 }
394 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
395 sql_name: table.id.clone(),
396 fields: fields,
397 });
398 }
399 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
400 query.res_count = res_count;
401 let mut hasher = DefaultHasher::new();
402 input.sql.hash(&mut hasher);
403 let query_hash = hasher.finish();
404 let query_name = format_ident!("good_query_{}", query_hash);
405 query.name = query_name.to_string();
406 let pascal_db_name: String = db_name.to_case(Case::Pascal);
407 let db_type = if let Some(v) = input.version {
408 let name = format_ident!("Db{}{}", pascal_db_name, v);
409 quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
410 } else {
411 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
412 quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
413 };
414 let generated =
415 good_ormning_core::sqlite::query::generate::generate_query_functions(
416 &mut errs,
417 field_lookup,
418 vec![query],
419 "inline",
420 db_type,
421 );
422 let conn = &input.conn;
423 let args = &input.params;
424 quote!{
425 {
426 use ::good_ormning::runtime::GoodError;
427 use ::good_ormning::runtime::ToGoodError;
428 use ::good_ormning::runtime::sqlite::SqliteConnection;
429 #(#generated) * #query_name(& mut #conn, #(#args,) *)
430 }
431 }
432}
433
434#[proc_macro]
436pub fn good_query_pg(input: TokenStream) -> TokenStream {
437 let input = parse_macro_input!(input as GoodQueryInput);
438 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
439}
440
441#[proc_macro]
443pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
444 let input = parse_macro_input!(input as GoodQueryInput);
445 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
446}
447
448#[proc_macro]
450pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
451 let input = parse_macro_input!(input as GoodQueryInput);
452 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
453}
454
455#[proc_macro]
457pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
458 let input = parse_macro_input!(input as GoodQueryInput);
459 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
460}
461
462#[proc_macro]
464pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
465 let input = parse_macro_input!(input as GoodQueryInput);
466 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
467}
468
469#[proc_macro]
471pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
472 let input = parse_macro_input!(input as GoodQueryInput);
473 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
474}
475
476#[proc_macro]
478pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
479 let input = parse_macro_input!(input as GoodQueryInput);
480 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
481}
482
483#[proc_macro]
485pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
486 let input = parse_macro_input!(input as GoodQueryInput);
487 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
488}