1use {
2 convert_case::{
3 Case,
4 Casing,
5 },
6 good_ormning_core::{
7 pg::{
8 Version as PgVersion,
9 query::utils::{
10 PgFieldInfo,
11 PgTableInfo,
12 },
13 schema::{
14 field::FieldRef as PgFieldRef,
15 table::TableRef as PgTableRef,
16 },
17 },
18 sqlite::{
19 Version as SqliteVersion,
20 query::utils::{
21 SqliteFieldInfo,
22 SqliteTableInfo,
23 },
24 schema::{
25 field::FieldRef as SqliteFieldRef,
26 table::TableRef as SqliteTableRef,
27 },
28 },
29 utils::Errs,
30 },
31 proc_macro::TokenStream,
32 quote::{
33 format_ident,
34 quote,
35 },
36 std::{
37 collections::{
38 HashMap,
39 hash_map::DefaultHasher,
40 },
41 env,
42 fs,
43 hash::{
44 Hash,
45 Hasher,
46 },
47 },
48 syn::{
49 Ident,
50 LitStr,
51 Token,
52 parse::{
53 Parse,
54 ParseStream,
55 },
56 parse_macro_input,
57 },
58};
59
60mod convert;
61
62struct ParamType {
63 arr: bool,
64 opt: bool,
65 base: String,
66}
67
68struct GoodQueryInput {
69 version: Option<usize>,
70 db_name: String,
71 sql: String,
72 param_types: Vec<(Ident, ParamType)>,
73 conn: syn::Expr,
74 params: Vec<syn::Expr>,
75}
76
77impl Parse for GoodQueryInput {
78 fn parse(input: ParseStream) -> syn::Result<Self> {
79 let mut literals = Vec::new();
80 while !input.peek(Token![;]) && !input.is_empty() {
81 literals.push(input.parse::<LitStr>()?);
82 if input.peek(Token![,]) {
83 input.parse::<Token![,]>()?;
84 } else if !input.peek(Token![;]) {
85 return Err(input.error("Expected , or ; after literal"));
86 }
87 }
88 input.parse::<Token![;]>()?;
89 let (db_name, version, sql) = match literals.len() {
90 1 => ("".to_string(), None, literals[0].value()),
91 2 => (literals[0].value(), None, literals[1].value()),
92 3 => {
93 let v_str = literals[1].value();
94 let v = v_str.parse::<usize>().map_err(|_| {
95 syn::Error::new(literals[1].span(), "Version must be a positive integer")
96 })?;
97 (literals[0].value(), Some(v), literals[2].value())
98 },
99 _ => {
100 return Err(
101 input.error(
102 "Expected 1, 2, or 3 literals before semicolon (sql, [db_name, sql], or [db_name, version, sql])",
103 ),
104 );
105 },
106 };
107 let conn: syn::Expr = input.parse()?;
108 let mut param_types = Vec::new();
109 let mut params = Vec::new();
110 while input.peek(Token![,]) {
111 input.parse::<Token![,]>()?;
112 if input.is_empty() {
113 break;
114 }
115 let name: Ident = input.parse()?;
116 input.parse::<Token![:]>()?;
117 let mut arr = false;
118 let mut opt = false;
119 let mut base = String::new();
120 while input.peek(Ident) {
121 let id: Ident = input.parse()?;
122 if id == "arr" {
123 arr = true;
124 } else if id == "opt" {
125 opt = true;
126 } else {
127 base = id.to_string();
128 break;
129 }
130 }
131 if base.is_empty() {
132 return Err(input.error("Expected parameter type"));
133 }
134 input.parse::<Token![=]>()?;
135 let val: syn::Expr = input.parse()?;
136 param_types.push((name, ParamType {
137 arr: arr,
138 opt: opt,
139 base: base,
140 }));
141 params.push(val);
142 }
143 let mut final_sql = String::new();
144 let mut last_end = 0;
145 while let Some(start) = sql[last_end..].find("${") {
146 final_sql.push_str(&sql[last_end .. last_end + start]);
147 let start = last_end + start;
148 if let Some(end) = sql[start..].find('}') {
149 let end = start + end;
150 let content = &sql[start + 2 .. end];
151 let mut split = None;
152 let bytes = content.as_bytes();
153 let mut i = 0;
154 while i < bytes.len() {
155 if bytes[i] == b'=' {
156 split = Some((&content[..i], &content[i + 1..]));
157 break;
158 }
159 i += 1;
160 }
161 let (type_str, val_str) =
162 split.ok_or_else(|| input.error("Invalid inline parameter format. Expected ${type = value}"))?;
163 let val: syn::Expr = syn::parse_str(val_str).map_err(|e| {
164 syn::Error::new(input.span(), format!("Failed to parse inline parameter value: {}", e))
165 })?;
166 let type_tokens: proc_macro2::TokenStream = type_str.parse().map_err(|e| {
167 syn::Error::new(input.span(), format!("Failed to parse inline parameter type tokens: {}", e))
168 })?;
169
170 use syn::parse::Parser;
171
172 let (arr_p, opt_p, base_p) = (|type_input: ParseStream| {
173 let mut arr = false;
174 let mut opt = false;
175 let mut base = String::new();
176 while type_input.peek(Ident) {
177 let id: Ident = type_input.parse()?;
178 if id == "arr" {
179 arr = true;
180 } else if id == "opt" {
181 opt = true;
182 } else {
183 base = id.to_string();
184 break;
185 }
186 }
187 Ok((arr, opt, base))
188 }).parse2(type_tokens).map_err(|e| {
189 syn::Error::new(input.span(), format!("Failed to parse inline parameter type: {}", e))
190 })?;
191 if base_p.is_empty() {
192 return Err(input.error("Expected base type in inline parameter"));
193 }
194 let param_idx = params.len() + 1;
195 let name = format_ident!("p{}", param_idx);
196 params.push(val);
197 param_types.push((name, ParamType {
198 arr: arr_p,
199 opt: opt_p,
200 base: base_p,
201 }));
202 final_sql.push_str(&format!("${}", param_idx));
203 last_end = end + 1;
204 } else {
205 return Err(input.error("Unclosed inline parameter ${"));
206 }
207 }
208 final_sql.push_str(&sql[last_end..]);
209 Ok(GoodQueryInput {
210 version: version,
211 db_name: db_name,
212 sql: final_sql,
213 param_types: param_types,
214 conn: conn,
215 params: params,
216 })
217 }
218}
219
220fn get_db_info(_engine: &str, provided_db_name: String) -> String {
221 provided_db_name
222}
223
224fn parse_and_generate_pg(
225 input: GoodQueryInput,
226 res_count: good_ormning_core::QueryResCount,
227) -> proc_macro2::TokenStream {
228 let db_name = get_db_info("pg", input.db_name.clone());
229 let dialect = sqlparser::dialect::PostgreSqlDialect {};
230 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
231 Ok(ast) => ast,
232 Err(e) => {
233 let e = e.to_string();
234 return quote!(compile_error!(#e));
235 },
236 };
237 if ast.is_empty() {
238 return quote!(compile_error!("Empty SQL statement"));
239 }
240 let statement = &ast[0];
241 let mut errs = Errs::new();
242 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
243 let path =
244 std::path::Path::new(&out_dir)
245 .join("good_ormning")
246 .join(good_ormning_core::utils::json_file_name(&db_name));
247 if !path.exists() {
248 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
249 return quote!(compile_error!(#e));
250 }
251 let versions_map: HashMap<usize, PgVersion> = match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
252 Ok(m) => m,
253 Err(e) => {
254 let e = e.to_string();
255 return quote!(compile_error!(#e));
256 },
257 };
258 let mut field_lookup = HashMap::new();
259 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
260 let version = match versions_map.get(&version_i) {
261 Some(v) => v,
262 None => {
263 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
264 return quote!(compile_error!(#e));
265 },
266 };
267 let custom_types = version.custom_types.clone();
268 for (table_id, table) in &version.tables {
269 let mut fields: HashMap<PgFieldRef, PgFieldInfo> = HashMap::new();
270 for (field_id, field) in &table.fields {
271 fields.insert(PgFieldRef {
272 table_id: table_id.clone(),
273 field_id: field_id.clone(),
274 }, PgFieldInfo {
275 sql_name: field.id.clone(),
276 type_: field.type_.type_.clone(),
277 });
278 }
279 field_lookup.insert(PgTableRef(table_id.clone()), PgTableInfo {
280 sql_name: table.id.clone(),
281 fields: fields,
282 });
283 }
284 let mut query = crate::convert::pg::convert_query(&input, statement, &custom_types, &field_lookup);
285 query.res_count = res_count;
286 let mut hasher = DefaultHasher::new();
287 input.sql.hash(&mut hasher);
288 let query_hash = hasher.finish();
289 let query_name = format_ident!("good_query_{}", query_hash);
290 query.name = query_name.to_string();
291 let pascal_db_name: String = db_name.to_case(Case::Pascal);
292 let db_type = if let Some(v) = input.version {
293 let name = format_ident!("Db{}{}", pascal_db_name, v);
294 quote!(dbm::#name <'_ >)
295 } else {
296 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
297 quote!(dbm::#name <'_ >)
298 };
299 let generated =
300 good_ormning_core::pg::query::generate::generate_query_functions(
301 &mut errs,
302 field_lookup,
303 vec![query],
304 "inline",
305 db_type,
306 );
307 let conn = &input.conn;
308 let args = &input.params;
309 quote!{
310 {
311 use ::good_ormning::runtime::GoodError;
312 use ::good_ormning::runtime::ToGoodError;
313 use ::good_ormning::runtime::pg::PgConnection;
314 #(#generated) * #query_name(& mut #conn, #(#args,) *)
315 }
316 }
317}
318
319fn parse_and_generate_sqlite(
320 input: GoodQueryInput,
321 res_count: good_ormning_core::QueryResCount,
322) -> proc_macro2::TokenStream {
323 let db_name = get_db_info("sqlite", input.db_name.clone());
324 let dialect = sqlparser::dialect::SQLiteDialect {};
325 let ast = match sqlparser::parser::Parser::parse_sql(&dialect, &input.sql) {
326 Ok(ast) => ast,
327 Err(e) => {
328 let e = e.to_string();
329 return quote!(compile_error!(#e));
330 },
331 };
332 if ast.is_empty() {
333 return quote!(compile_error!("Empty SQL statement"));
334 }
335 let statement = &ast[0];
336 let mut errs = Errs::new();
337 let out_dir = env::var("OUT_DIR").unwrap_or_else(|_| ".".to_string());
338 let path =
339 std::path::Path::new(&out_dir)
340 .join("good_ormning")
341 .join(good_ormning_core::utils::json_file_name(&db_name));
342 if !path.exists() {
343 let e = format!("Schema file not found at {:?}. Did you run the build script?", path.to_string_lossy());
344 return quote!(compile_error!(#e));
345 }
346 let versions_map: HashMap<usize, SqliteVersion> =
347 match serde_json::from_str(&fs::read_to_string(&path).unwrap()) {
348 Ok(m) => m,
349 Err(e) => {
350 let e = e.to_string();
351 return quote!(compile_error!(#e));
352 },
353 };
354 let mut field_lookup = HashMap::new();
355 let version_i = input.version.unwrap_or_else(|| versions_map.keys().max().copied().unwrap_or(0));
356 let version = match versions_map.get(&version_i) {
357 Some(v) => v,
358 None => {
359 let e = format!("Version {} not found in schema for db {}", version_i, db_name);
360 return quote!(compile_error!(#e));
361 },
362 };
363 let custom_types = version.custom_types.clone();
364 for (table_id, table) in &version.tables {
365 let mut fields: HashMap<SqliteFieldRef, SqliteFieldInfo> = HashMap::new();
366 for (field_id, field) in &table.fields {
367 fields.insert(SqliteFieldRef {
368 table_id: table_id.clone(),
369 field_id: field_id.clone(),
370 }, SqliteFieldInfo {
371 sql_name: field.id.clone(),
372 type_: field.type_.type_.clone(),
373 });
374 }
375 field_lookup.insert(SqliteTableRef(table_id.clone()), SqliteTableInfo {
376 sql_name: table.id.clone(),
377 fields: fields,
378 });
379 }
380 let mut query = crate::convert::sqlite::convert_query(&input, statement, &custom_types, &field_lookup);
381 query.res_count = res_count;
382 let mut hasher = DefaultHasher::new();
383 input.sql.hash(&mut hasher);
384 let query_hash = hasher.finish();
385 let query_name = format_ident!("good_query_{}", query_hash);
386 query.name = query_name.to_string();
387 let pascal_db_name: String = db_name.to_case(Case::Pascal);
388 let db_type = if let Some(v) = input.version {
389 let name = format_ident!("Db{}{}", pascal_db_name, v);
390 quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
391 } else {
392 let name = format_ident!("Db{}{}", pascal_db_name, version_i);
393 quote!(dbm::#name <'_, impl:: good_ormning:: runtime:: sqlite:: SqliteConnection >)
394 };
395 let generated =
396 good_ormning_core::sqlite::query::generate::generate_query_functions(
397 &mut errs,
398 field_lookup,
399 vec![query],
400 "inline",
401 db_type,
402 );
403 let conn = &input.conn;
404 let args = &input.params;
405 quote!{
406 {
407 use ::good_ormning::runtime::GoodError;
408 use ::good_ormning::runtime::ToGoodError;
409 use ::good_ormning::runtime::sqlite::SqliteConnection;
410 #(#generated) * #query_name(& mut #conn, #(#args,) *)
411 }
412 }
413}
414
415#[proc_macro]
417pub fn good_query_pg(input: TokenStream) -> TokenStream {
418 let input = parse_macro_input!(input as GoodQueryInput);
419 parse_and_generate_pg(input, good_ormning_core::QueryResCount::None).into()
420}
421
422#[proc_macro]
424pub fn good_query_one_pg(input: TokenStream) -> TokenStream {
425 let input = parse_macro_input!(input as GoodQueryInput);
426 parse_and_generate_pg(input, good_ormning_core::QueryResCount::One).into()
427}
428
429#[proc_macro]
431pub fn good_query_opt_pg(input: TokenStream) -> TokenStream {
432 let input = parse_macro_input!(input as GoodQueryInput);
433 parse_and_generate_pg(input, good_ormning_core::QueryResCount::MaybeOne).into()
434}
435
436#[proc_macro]
438pub fn good_query_many_pg(input: TokenStream) -> TokenStream {
439 let input = parse_macro_input!(input as GoodQueryInput);
440 parse_and_generate_pg(input, good_ormning_core::QueryResCount::Many).into()
441}
442
443#[proc_macro]
445pub fn good_query_sqlite(input: TokenStream) -> TokenStream {
446 let input = parse_macro_input!(input as GoodQueryInput);
447 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::None).into()
448}
449
450#[proc_macro]
452pub fn good_query_one_sqlite(input: TokenStream) -> TokenStream {
453 let input = parse_macro_input!(input as GoodQueryInput);
454 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::One).into()
455}
456
457#[proc_macro]
459pub fn good_query_opt_sqlite(input: TokenStream) -> TokenStream {
460 let input = parse_macro_input!(input as GoodQueryInput);
461 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::MaybeOne).into()
462}
463
464#[proc_macro]
466pub fn good_query_many_sqlite(input: TokenStream) -> TokenStream {
467 let input = parse_macro_input!(input as GoodQueryInput);
468 parse_and_generate_sqlite(input, good_ormning_core::QueryResCount::Many).into()
469}