1use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{
7 Expr, LitBool, LitFloat, LitInt, LitStr, Result, Token, braced, bracketed,
8 ext::IdentExt,
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 punctuated::Punctuated,
12 token,
13};
14
15#[derive(Clone, Copy, Default)]
21enum SqlDialect {
22 #[default]
23 Postgres,
24 Sqlite,
25}
26
27impl SqlDialect {
28 fn from_ident(ident: &syn::Ident) -> Option<Self> {
30 match ident.to_string().as_str() {
31 "postgres" | "pg" => Some(SqlDialect::Postgres),
32 "sqlite" => Some(SqlDialect::Sqlite),
33 _ => None,
34 }
35 }
36
37 fn builder_tokens(self, table: &str) -> TokenStream2 {
39 match self {
40 SqlDialect::Postgres => quote! { ::mik_sql::postgres(#table) },
41 SqlDialect::Sqlite => quote! { ::mik_sql::sqlite(#table) },
42 }
43 }
44
45 fn insert_tokens(self, table: &str) -> TokenStream2 {
47 match self {
48 SqlDialect::Postgres => quote! { ::mik_sql::insert(#table) },
49 SqlDialect::Sqlite => quote! { ::mik_sql::insert_sqlite(#table) },
50 }
51 }
52
53 fn update_tokens(self, table: &str) -> TokenStream2 {
55 match self {
56 SqlDialect::Postgres => quote! { ::mik_sql::update(#table) },
57 SqlDialect::Sqlite => quote! { ::mik_sql::update_sqlite(#table) },
58 }
59 }
60
61 fn delete_tokens(self, table: &str) -> TokenStream2 {
63 match self {
64 SqlDialect::Postgres => quote! { ::mik_sql::delete(#table) },
65 SqlDialect::Sqlite => quote! { ::mik_sql::delete_sqlite(#table) },
66 }
67 }
68}
69
70fn parse_optional_dialect(input: ParseStream) -> Result<SqlDialect> {
72 let fork = input.fork();
73 if let Ok(ident) = fork.parse::<syn::Ident>()
74 && let Some(dialect) = SqlDialect::from_ident(&ident)
75 && fork.peek(Token![,])
76 {
77 input.parse::<syn::Ident>()?;
78 input.parse::<Token![,]>()?;
79 return Ok(dialect);
80 }
81 Ok(SqlDialect::default())
82}
83
84struct SqlInput {
86 dialect: SqlDialect,
87 table: syn::Ident,
88 select_fields: Vec<syn::Ident>,
89 computed: Vec<SqlCompute>,
90 aggregates: Vec<SqlAggregate>,
91 filter_expr: Option<SqlFilterExpr>,
92 group_by: Vec<syn::Ident>,
93 having: Option<SqlFilterExpr>,
94 sorts: Vec<SqlSort>,
95 dynamic_sort: Option<Expr>,
96 allow_sort: Vec<syn::Ident>,
97 merge_filters: Option<Expr>,
98 allow_fields: Vec<syn::Ident>,
99 deny_ops: Vec<syn::Ident>,
100 max_depth: Option<Expr>,
101 page: Option<Expr>,
102 limit: Option<Expr>,
103 offset: Option<Expr>,
104 after: Option<Expr>,
105 before: Option<Expr>,
106}
107
108struct SqlAggregate {
110 func: SqlAggregateFunc,
111 field: Option<syn::Ident>,
112 alias: Option<syn::Ident>,
113}
114
115#[derive(Clone, Copy)]
117enum SqlAggregateFunc {
118 Count,
119 CountDistinct,
120 Sum,
121 Avg,
122 Min,
123 Max,
124}
125
126enum SqlFilterExpr {
128 Simple(SqlFilter),
129 Compound {
130 op: SqlLogicalOp,
131 filters: Vec<SqlFilterExpr>,
132 },
133}
134
135#[derive(Clone, Copy)]
137enum SqlLogicalOp {
138 And,
139 Or,
140 Not,
141}
142
143struct SqlFilter {
145 field: syn::Ident,
146 op: SqlOperator,
147 value: SqlValue,
148}
149
150#[derive(Clone)]
152enum SqlOperator {
153 Eq,
154 Ne,
155 Gt,
156 Gte,
157 Lt,
158 Lte,
159 In,
160 NotIn,
161 Like,
162 ILike,
163 Regex,
164 StartsWith,
165 EndsWith,
166 Contains,
167 Between,
168}
169
170enum SqlValue {
172 Null,
173 Bool(bool),
174 Int(LitInt),
175 Float(LitFloat),
176 String(LitStr),
177 Array(Vec<SqlValue>),
178 IntHint(Expr),
179 StrHint(Expr),
180 FloatHint(Expr),
181 BoolHint(Expr),
182 Expr(Expr),
183}
184
185struct SqlSort {
187 field: syn::Ident,
188 desc: bool,
189}
190
191struct SqlCompute {
193 alias: syn::Ident,
194 expr: SqlComputeExpr,
195}
196
197enum SqlComputeExpr {
199 Column(syn::Ident),
200 LitStr(LitStr),
201 LitInt(LitInt),
202 LitFloat(LitFloat),
203 BinOp {
204 left: Box<SqlComputeExpr>,
205 op: SqlComputeBinOp,
206 right: Box<SqlComputeExpr>,
207 },
208 Func {
209 name: SqlComputeFunc,
210 args: Vec<SqlComputeExpr>,
211 },
212 Paren(Box<SqlComputeExpr>),
213}
214
215#[derive(Clone, Copy)]
217enum SqlComputeBinOp {
218 Add,
219 Sub,
220 Mul,
221 Div,
222}
223
224#[derive(Clone, Copy)]
226enum SqlComputeFunc {
227 Concat,
228 Coalesce,
229 Upper,
230 Lower,
231 Round,
232 Abs,
233 Length,
234}
235
236impl Parse for SqlInput {
237 #[allow(clippy::too_many_lines)]
239 fn parse(input: ParseStream) -> Result<Self> {
240 let dialect = parse_optional_dialect(input)?;
241 let table: syn::Ident = input.parse().map_err(|e| {
242 syn::Error::new(
243 e.span(),
244 format!(
245 "Expected table name.\n\
246 Usage: sql_read!(table_name {{ ... }}) or sql_read!(sqlite, table_name {{ ... }})\n\
247 Original error: {e}"
248 ),
249 )
250 })?;
251
252 let content;
253 braced!(content in input);
254
255 let mut select_fields = Vec::new();
256 let mut computed = Vec::new();
257 let mut aggregates = Vec::new();
258 let mut filter_expr = None;
259 let mut group_by = Vec::new();
260 let mut having = None;
261 let mut sorts = Vec::new();
262 let mut dynamic_sort = None;
263 let mut allow_sort = Vec::new();
264 let mut merge_filters = None;
265 let mut allow_fields = Vec::new();
266 let mut deny_ops = Vec::new();
267 let mut max_depth = None;
268 let mut page = None;
269 let mut limit = None;
270 let mut offset = None;
271 let mut after = None;
272 let mut before = None;
273
274 while !content.is_empty() {
275 let key: syn::Ident = content.parse()?;
276 content.parse::<Token![:]>()?;
277
278 match key.to_string().as_str() {
279 "select" => {
280 let fields_content;
281 bracketed!(fields_content in content);
282 let fields: Punctuated<syn::Ident, Token![,]> =
283 fields_content.parse_terminated(syn::Ident::parse, Token![,])?;
284 select_fields = fields.into_iter().collect();
285 },
286 "compute" => {
287 let compute_content;
288 braced!(compute_content in content);
289 computed = parse_compute_fields(&compute_content)?;
290 },
291 "aggregate" | "agg" => {
292 let agg_content;
293 braced!(agg_content in content);
294 aggregates = parse_aggregates(&agg_content)?;
295 },
296 "filter" => {
297 let filter_content;
298 braced!(filter_content in content);
299 filter_expr = Some(parse_filter_block(&filter_content)?);
300 },
301 "group_by" | "groupBy" => {
302 let group_content;
303 bracketed!(group_content in content);
304 let fields: Punctuated<syn::Ident, Token![,]> =
305 group_content.parse_terminated(syn::Ident::parse, Token![,])?;
306 group_by = fields.into_iter().collect();
307 },
308 "having" => {
309 let having_content;
310 braced!(having_content in content);
311 having = Some(parse_filter_block(&having_content)?);
312 },
313 "order" => {
314 if content.peek(token::Bracket) {
315 let order_content;
316 bracketed!(order_content in content);
317 let sort_items: Punctuated<SqlSort, Token![,]> =
318 order_content.parse_terminated(SqlSort::parse, Token![,])?;
319 sorts = sort_items.into_iter().collect();
320 } else if content.peek(Token![-]) {
321 let sort: SqlSort = content.parse()?;
322 sorts.push(sort);
323 } else if content.peek(syn::Ident)
324 && !content.peek2(Token![,])
325 && !content.peek2(token::Brace)
326 {
327 let fork = content.fork();
328 let ident: syn::Ident = fork.parse()?;
329 if fork.peek(Token![,]) && fork.peek2(syn::Ident) {
330 fork.parse::<Token![,]>().ok();
331 if let Ok(_next_ident) = fork.parse::<syn::Ident>() {
332 if fork.peek(Token![:]) {
333 dynamic_sort = Some(syn::Expr::Path(syn::ExprPath {
334 attrs: vec![],
335 qself: None,
336 path: ident.clone().into(),
337 }));
338 content.parse::<syn::Ident>()?;
339 } else {
340 let sort: SqlSort = content.parse()?;
341 sorts.push(sort);
342 }
343 } else {
344 let sort: SqlSort = content.parse()?;
345 sorts.push(sort);
346 }
347 } else if fork.is_empty()
348 || (fork.peek(Token![,]) && !fork.peek2(syn::Ident))
349 {
350 dynamic_sort = Some(content.parse()?);
351 } else {
352 let sort: SqlSort = content.parse()?;
353 sorts.push(sort);
354 }
355 } else {
356 dynamic_sort = Some(content.parse()?);
357 }
358 },
359 "allow_sort" | "allowSort" => {
360 let allow_content;
361 bracketed!(allow_content in content);
362 let fields: Punctuated<syn::Ident, Token![,]> =
363 allow_content.parse_terminated(syn::Ident::parse, Token![,])?;
364 allow_sort = fields.into_iter().collect();
365 },
366 "merge" => {
367 merge_filters = Some(content.parse()?);
368 },
369 "allow" => {
370 let allow_content;
371 bracketed!(allow_content in content);
372 let fields: Punctuated<syn::Ident, Token![,]> =
373 allow_content.parse_terminated(syn::Ident::parse, Token![,])?;
374 allow_fields = fields.into_iter().collect();
375 },
376 "deny_ops" | "denyOps" => {
377 let deny_content;
378 bracketed!(deny_content in content);
379 let mut ops = Vec::new();
380 while !deny_content.is_empty() {
381 deny_content.parse::<Token![$]>()?;
382 let op: syn::Ident = deny_content.call(syn::Ident::parse_any)?;
383 ops.push(op);
384 if deny_content.peek(Token![,]) {
385 deny_content.parse::<Token![,]>()?;
386 }
387 }
388 deny_ops = ops;
389 },
390 "max_depth" | "maxDepth" => {
391 max_depth = Some(content.parse()?);
392 },
393 "page" => {
394 page = Some(content.parse()?);
395 },
396 "limit" => {
397 limit = Some(content.parse()?);
398 },
399 "offset" => {
400 offset = Some(content.parse()?);
401 },
402 "after" => {
403 after = Some(content.parse()?);
404 },
405 "before" => {
406 before = Some(content.parse()?);
407 },
408 other => {
409 return Err(syn::Error::new(
410 key.span(),
411 format!(
412 "Unknown option '{other}'. Valid options: select, compute, aggregate, filter, merge, allow, deny_ops, max_depth, group_by, having, order, allow_sort, page, limit, offset, after, before"
413 ),
414 ));
415 },
416 }
417
418 if content.peek(Token![,]) {
419 content.parse::<Token![,]>()?;
420 }
421 }
422
423 Ok(SqlInput {
424 dialect,
425 table,
426 select_fields,
427 computed,
428 aggregates,
429 filter_expr,
430 group_by,
431 having,
432 sorts,
433 dynamic_sort,
434 allow_sort,
435 merge_filters,
436 allow_fields,
437 deny_ops,
438 max_depth,
439 page,
440 limit,
441 offset,
442 after,
443 before,
444 })
445 }
446}
447
448fn parse_aggregates(input: ParseStream) -> Result<Vec<SqlAggregate>> {
450 let mut aggregates = Vec::new();
451
452 while !input.is_empty() {
453 let func_name: syn::Ident = input.parse()?;
454 input.parse::<Token![:]>()?;
455
456 let func_str = func_name.to_string();
457 let (func, field, alias) = match func_str.as_str() {
458 "count" => {
459 if input.peek(Token![*]) {
460 input.parse::<Token![*]>()?;
461 (
462 SqlAggregateFunc::Count,
463 None,
464 Some(syn::Ident::new("count", func_name.span())),
465 )
466 } else {
467 let field: syn::Ident = input.parse()?;
468 (SqlAggregateFunc::Count, Some(field), None)
469 }
470 },
471 "count_distinct" | "countDistinct" => {
472 let field: syn::Ident = input.parse()?;
473 (SqlAggregateFunc::CountDistinct, Some(field), None)
474 },
475 "sum" => {
476 let field: syn::Ident = input.parse()?;
477 (SqlAggregateFunc::Sum, Some(field), None)
478 },
479 "avg" => {
480 let field: syn::Ident = input.parse()?;
481 (SqlAggregateFunc::Avg, Some(field), None)
482 },
483 "min" => {
484 let field: syn::Ident = input.parse()?;
485 (SqlAggregateFunc::Min, Some(field), None)
486 },
487 "max" => {
488 let field: syn::Ident = input.parse()?;
489 (SqlAggregateFunc::Max, Some(field), None)
490 },
491 other => {
492 return Err(syn::Error::new(
493 func_name.span(),
494 format!(
495 "Unknown aggregate function '{other}'. Valid: count, count_distinct, sum, avg, min, max"
496 ),
497 ));
498 },
499 };
500
501 aggregates.push(SqlAggregate { func, field, alias });
502
503 if input.peek(Token![,]) {
504 input.parse::<Token![,]>()?;
505 }
506 }
507
508 Ok(aggregates)
509}
510
511fn parse_compute_fields(input: ParseStream) -> Result<Vec<SqlCompute>> {
513 let mut computed = Vec::new();
514
515 while !input.is_empty() {
516 let alias: syn::Ident = input.parse()?;
517 input.parse::<Token![:]>()?;
518 let expr = parse_compute_expr(input)?;
519 computed.push(SqlCompute { alias, expr });
520
521 if input.peek(Token![,]) {
522 input.parse::<Token![,]>()?;
523 }
524 }
525
526 Ok(computed)
527}
528
529fn parse_compute_expr(input: ParseStream) -> Result<SqlComputeExpr> {
530 parse_compute_additive(input)
531}
532
533fn parse_compute_additive(input: ParseStream) -> Result<SqlComputeExpr> {
534 let mut left = parse_compute_multiplicative(input)?;
535
536 while input.peek(Token![+]) || input.peek(Token![-]) {
537 let op = if input.peek(Token![+]) {
538 input.parse::<Token![+]>()?;
539 SqlComputeBinOp::Add
540 } else {
541 input.parse::<Token![-]>()?;
542 SqlComputeBinOp::Sub
543 };
544
545 let right = parse_compute_multiplicative(input)?;
546 left = SqlComputeExpr::BinOp {
547 left: Box::new(left),
548 op,
549 right: Box::new(right),
550 };
551 }
552
553 Ok(left)
554}
555
556fn parse_compute_multiplicative(input: ParseStream) -> Result<SqlComputeExpr> {
557 let mut left = parse_compute_primary(input)?;
558
559 while input.peek(Token![*]) || input.peek(Token![/]) {
560 let op = if input.peek(Token![*]) {
561 input.parse::<Token![*]>()?;
562 SqlComputeBinOp::Mul
563 } else {
564 input.parse::<Token![/]>()?;
565 SqlComputeBinOp::Div
566 };
567
568 let right = parse_compute_primary(input)?;
569 left = SqlComputeExpr::BinOp {
570 left: Box::new(left),
571 op,
572 right: Box::new(right),
573 };
574 }
575
576 Ok(left)
577}
578
579fn parse_compute_primary(input: ParseStream) -> Result<SqlComputeExpr> {
580 if input.peek(token::Paren) {
581 let content;
582 syn::parenthesized!(content in input);
583 let inner = parse_compute_expr(&content)?;
584 return Ok(SqlComputeExpr::Paren(Box::new(inner)));
585 }
586
587 if input.peek(LitStr) {
588 return Ok(SqlComputeExpr::LitStr(input.parse()?));
589 }
590
591 if input.peek(LitFloat) {
592 return Ok(SqlComputeExpr::LitFloat(input.parse()?));
593 }
594
595 if input.peek(LitInt) {
596 return Ok(SqlComputeExpr::LitInt(input.parse()?));
597 }
598
599 if input.peek(syn::Ident) {
600 let ident: syn::Ident = input.parse()?;
601
602 if input.peek(token::Paren) {
603 let func_name = ident.to_string();
604 let func = match func_name.as_str() {
605 "concat" => SqlComputeFunc::Concat,
606 "coalesce" => SqlComputeFunc::Coalesce,
607 "upper" => SqlComputeFunc::Upper,
608 "lower" => SqlComputeFunc::Lower,
609 "round" => SqlComputeFunc::Round,
610 "abs" => SqlComputeFunc::Abs,
611 "length" | "len" => SqlComputeFunc::Length,
612 other => {
613 return Err(syn::Error::new(
614 ident.span(),
615 format!(
616 "Unknown compute function '{other}'. Valid: concat, coalesce, upper, lower, round, abs, length"
617 ),
618 ));
619 },
620 };
621
622 let args_content;
623 syn::parenthesized!(args_content in input);
624 let args: Punctuated<SqlComputeExpr, Token![,]> =
625 args_content.parse_terminated(parse_compute_expr, Token![,])?;
626
627 return Ok(SqlComputeExpr::Func {
628 name: func,
629 args: args.into_iter().collect(),
630 });
631 }
632
633 return Ok(SqlComputeExpr::Column(ident));
634 }
635
636 Err(syn::Error::new(
637 input.span(),
638 "Expected a compute expression: column, literal, function call, or (expression)",
639 ))
640}
641
642fn parse_filter_block(input: ParseStream) -> Result<SqlFilterExpr> {
643 let mut simple_filters = Vec::new();
644
645 while !input.is_empty() {
646 if input.peek(Token![$]) {
647 input.parse::<Token![$]>()?;
648 let op_name: syn::Ident = input.call(syn::Ident::parse_any)?;
649 input.parse::<Token![:]>()?;
650
651 let logical_op = match op_name.to_string().as_str() {
652 "and" => SqlLogicalOp::And,
653 "or" => SqlLogicalOp::Or,
654 "not" => SqlLogicalOp::Not,
655 other => {
656 return Err(syn::Error::new(
657 op_name.span(),
658 format!("Unknown logical operator '${other}'. Valid: $and, $or, $not"),
659 ));
660 },
661 };
662
663 let filters = parse_filter_array(input)?;
664
665 if !simple_filters.is_empty() {
666 let mut all_filters: Vec<SqlFilterExpr> = simple_filters
667 .into_iter()
668 .map(SqlFilterExpr::Simple)
669 .collect();
670 all_filters.push(SqlFilterExpr::Compound {
671 op: logical_op,
672 filters,
673 });
674 return Ok(SqlFilterExpr::Compound {
675 op: SqlLogicalOp::And,
676 filters: all_filters,
677 });
678 }
679
680 if input.peek(Token![,]) {
681 input.parse::<Token![,]>()?;
682 }
683
684 if !input.is_empty() {
685 let remaining = parse_filter_block(input)?;
686 return Ok(SqlFilterExpr::Compound {
687 op: SqlLogicalOp::And,
688 filters: vec![
689 SqlFilterExpr::Compound {
690 op: logical_op,
691 filters,
692 },
693 remaining,
694 ],
695 });
696 }
697
698 return Ok(SqlFilterExpr::Compound {
699 op: logical_op,
700 filters,
701 });
702 }
703
704 let filter = parse_sql_filter(input)?;
705 simple_filters.push(filter);
706
707 if input.peek(Token![,]) {
708 input.parse::<Token![,]>()?;
709 }
710 }
711
712 match simple_filters.len() {
713 0 => Err(syn::Error::new(input.span(), "Empty filter block")),
714 1 => Ok(SqlFilterExpr::Simple(simple_filters.remove(0))),
715 _ => Ok(SqlFilterExpr::Compound {
716 op: SqlLogicalOp::And,
717 filters: simple_filters
718 .into_iter()
719 .map(SqlFilterExpr::Simple)
720 .collect(),
721 }),
722 }
723}
724
725fn parse_filter_array(input: ParseStream) -> Result<Vec<SqlFilterExpr>> {
726 let content;
727 bracketed!(content in input);
728
729 let mut filters = Vec::new();
730 while !content.is_empty() {
731 let filter_content;
732 braced!(filter_content in content);
733 let filter_expr = parse_filter_block(&filter_content)?;
734 filters.push(filter_expr);
735
736 if content.peek(Token![,]) {
737 content.parse::<Token![,]>()?;
738 }
739 }
740
741 Ok(filters)
742}
743
744fn parse_sql_filter(input: ParseStream) -> Result<SqlFilter> {
745 let field: syn::Ident = input.parse()?;
746 input.parse::<Token![:]>()?;
747
748 if input.peek(token::Brace) {
749 let op_content;
750 braced!(op_content in input);
751
752 op_content.parse::<Token![$]>()?;
753 let op_name: syn::Ident = op_content.call(syn::Ident::parse_any)?;
754 op_content.parse::<Token![:]>()?;
755
756 let op = match op_name.to_string().as_str() {
757 "eq" => SqlOperator::Eq,
758 "ne" => SqlOperator::Ne,
759 "gt" => SqlOperator::Gt,
760 "gte" => SqlOperator::Gte,
761 "lt" => SqlOperator::Lt,
762 "lte" => SqlOperator::Lte,
763 "in" => SqlOperator::In,
764 "nin" => SqlOperator::NotIn,
765 "like" => SqlOperator::Like,
766 "ilike" => SqlOperator::ILike,
767 "regex" => SqlOperator::Regex,
768 "startsWith" | "starts_with" => SqlOperator::StartsWith,
769 "endsWith" | "ends_with" => SqlOperator::EndsWith,
770 "contains" => SqlOperator::Contains,
771 "between" => SqlOperator::Between,
772 other => {
773 return Err(syn::Error::new(
774 op_name.span(),
775 format!(
776 "Unknown operator '${other}'. Valid operators: $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin, $like, $ilike, $regex, $startsWith, $endsWith, $contains, $between"
777 ),
778 ));
779 },
780 };
781
782 let value = parse_sql_value(&op_content)?;
783 Ok(SqlFilter { field, op, value })
784 } else {
785 let value = parse_sql_value(input)?;
786 Ok(SqlFilter {
787 field,
788 op: SqlOperator::Eq,
789 value,
790 })
791 }
792}
793
794fn parse_sql_value(input: ParseStream) -> Result<SqlValue> {
795 let lookahead = input.lookahead1();
796
797 if lookahead.peek(token::Bracket) {
798 let content;
799 bracketed!(content in input);
800 let elements: Punctuated<SqlValue, Token![,]> =
801 content.parse_terminated(|inner| parse_sql_value(inner), Token![,])?;
802 Ok(SqlValue::Array(elements.into_iter().collect()))
803 } else if lookahead.peek(LitStr) {
804 Ok(SqlValue::String(input.parse()?))
805 } else if lookahead.peek(LitInt) {
806 Ok(SqlValue::Int(input.parse()?))
807 } else if lookahead.peek(LitFloat) {
808 Ok(SqlValue::Float(input.parse()?))
809 } else if lookahead.peek(LitBool) {
810 let lit: LitBool = input.parse()?;
811 Ok(SqlValue::Bool(lit.value))
812 } else if input.peek(syn::Ident) && input.peek2(token::Paren) {
813 let fork = input.fork();
814 let ident: syn::Ident = fork.parse()?;
815 match ident.to_string().as_str() {
816 "int" => {
817 input.parse::<syn::Ident>()?;
818 let content;
819 syn::parenthesized!(content in input);
820 Ok(SqlValue::IntHint(content.parse()?))
821 },
822 "str" => {
823 input.parse::<syn::Ident>()?;
824 let content;
825 syn::parenthesized!(content in input);
826 Ok(SqlValue::StrHint(content.parse()?))
827 },
828 "float" => {
829 input.parse::<syn::Ident>()?;
830 let content;
831 syn::parenthesized!(content in input);
832 Ok(SqlValue::FloatHint(content.parse()?))
833 },
834 "bool" => {
835 input.parse::<syn::Ident>()?;
836 let content;
837 syn::parenthesized!(content in input);
838 Ok(SqlValue::BoolHint(content.parse()?))
839 },
840 _ => Ok(SqlValue::Expr(input.parse()?)),
841 }
842 } else if input.peek(syn::Ident) {
843 let fork = input.fork();
844 let ident: syn::Ident = fork.parse()?;
845 match ident.to_string().as_str() {
846 "null" => {
847 input.parse::<syn::Ident>()?;
848 Ok(SqlValue::Null)
849 },
850 "true" => {
851 input.parse::<syn::Ident>()?;
852 Ok(SqlValue::Bool(true))
853 },
854 "false" => {
855 input.parse::<syn::Ident>()?;
856 Ok(SqlValue::Bool(false))
857 },
858 _ => Ok(SqlValue::Expr(input.parse()?)),
859 }
860 } else {
861 Err(syn::Error::new(
862 input.span(),
863 "Expected a value: string, number, boolean, null, array, or type hint (int(), str(), etc.)",
864 ))
865 }
866}
867
868impl Parse for SqlSort {
869 fn parse(input: ParseStream) -> Result<Self> {
870 let desc = if input.peek(Token![-]) {
871 input.parse::<Token![-]>()?;
872 true
873 } else {
874 false
875 };
876 let field: syn::Ident = input.parse()?;
877 Ok(SqlSort { field, desc })
878 }
879}
880
881fn sql_value_to_tokens(value: &SqlValue) -> TokenStream2 {
882 match value {
883 SqlValue::Null => quote! { ::mik_sql::Value::Null },
884 SqlValue::Bool(b) => quote! { ::mik_sql::Value::Bool(#b) },
885 SqlValue::Int(i) => quote! { ::mik_sql::Value::Int(#i as i64) },
886 SqlValue::Float(f) => quote! { ::mik_sql::Value::Float(#f as f64) },
887 SqlValue::String(s) => quote! { ::mik_sql::Value::String(#s.to_string()) },
888 SqlValue::Array(arr) => {
889 let elements: Vec<_> = arr.iter().map(sql_value_to_tokens).collect();
890 quote! { ::mik_sql::Value::Array(vec![#(#elements),*]) }
891 },
892 SqlValue::IntHint(e) => quote! { ::mik_sql::Value::Int(#e as i64) },
893 SqlValue::StrHint(e) | SqlValue::Expr(e) => {
894 quote! { ::mik_sql::Value::String((#e).to_string()) }
895 },
896 SqlValue::FloatHint(e) => quote! { ::mik_sql::Value::Float(#e as f64) },
897 SqlValue::BoolHint(e) => quote! { ::mik_sql::Value::Bool(#e) },
898 }
899}
900
901fn sql_operator_to_tokens(op: &SqlOperator) -> TokenStream2 {
902 match op {
903 SqlOperator::Eq => quote! { ::mik_sql::Operator::Eq },
904 SqlOperator::Ne => quote! { ::mik_sql::Operator::Ne },
905 SqlOperator::Gt => quote! { ::mik_sql::Operator::Gt },
906 SqlOperator::Gte => quote! { ::mik_sql::Operator::Gte },
907 SqlOperator::Lt => quote! { ::mik_sql::Operator::Lt },
908 SqlOperator::Lte => quote! { ::mik_sql::Operator::Lte },
909 SqlOperator::In => quote! { ::mik_sql::Operator::In },
910 SqlOperator::NotIn => quote! { ::mik_sql::Operator::NotIn },
911 SqlOperator::Like => quote! { ::mik_sql::Operator::Like },
912 SqlOperator::ILike => quote! { ::mik_sql::Operator::ILike },
913 SqlOperator::Regex => quote! { ::mik_sql::Operator::Regex },
914 SqlOperator::StartsWith => quote! { ::mik_sql::Operator::StartsWith },
915 SqlOperator::EndsWith => quote! { ::mik_sql::Operator::EndsWith },
916 SqlOperator::Contains => quote! { ::mik_sql::Operator::Contains },
917 SqlOperator::Between => quote! { ::mik_sql::Operator::Between },
918 }
919}
920
921#[proc_macro]
923#[allow(clippy::too_many_lines)] pub fn sql_read(input: TokenStream) -> TokenStream {
925 let SqlInput {
926 dialect,
927 table,
928 select_fields,
929 computed,
930 aggregates,
931 filter_expr,
932 group_by,
933 having,
934 sorts,
935 dynamic_sort,
936 allow_sort,
937 merge_filters,
938 allow_fields,
939 deny_ops,
940 max_depth,
941 page,
942 limit,
943 offset,
944 after,
945 before,
946 } = parse_macro_input!(input as SqlInput);
947
948 let (sorts, dynamic_sort) = if let Some(ref expr) = dynamic_sort {
949 if allow_sort.is_empty() {
950 if let syn::Expr::Path(syn::ExprPath { path, .. }) = expr {
951 if path.segments.len() == 1 && path.segments[0].arguments.is_empty() {
952 let field_name = path.segments[0].ident.clone();
953 let mut new_sorts = sorts;
954 new_sorts.push(SqlSort {
955 field: field_name,
956 desc: false,
957 });
958 (new_sorts, None)
959 } else {
960 (sorts, dynamic_sort)
961 }
962 } else {
963 (sorts, dynamic_sort)
964 }
965 } else {
966 (sorts, dynamic_sort)
967 }
968 } else {
969 (sorts, dynamic_sort)
970 };
971
972 let table_str = table.to_string();
973
974 let fields_chain = if select_fields.is_empty() {
975 quote! {}
976 } else {
977 let field_strs: Vec<String> = select_fields
978 .iter()
979 .map(std::string::ToString::to_string)
980 .collect();
981 quote! { .fields(&[#(#field_strs),*]) }
982 };
983
984 let computed_chain: Vec<TokenStream2> = computed
985 .iter()
986 .map(|c| {
987 let alias = c.alias.to_string();
988 let expr_sql = compute_expr_to_sql(&c.expr);
989 quote! { .computed(#alias, #expr_sql) }
990 })
991 .collect();
992
993 let aggregate_chain: Vec<TokenStream2> = aggregates
994 .iter()
995 .map(|agg| {
996 let agg_tokens = sql_aggregate_to_tokens(agg);
997 quote! { .aggregate(#agg_tokens) }
998 })
999 .collect();
1000
1001 let filter_chain = if let Some(expr) = filter_expr {
1002 let expr_tokens = sql_filter_expr_to_tokens(&expr);
1003 quote! { .filter_expr(#expr_tokens) }
1004 } else {
1005 quote! {}
1006 };
1007
1008 let group_by_chain = if group_by.is_empty() {
1009 quote! {}
1010 } else {
1011 let field_strs: Vec<String> = group_by
1012 .iter()
1013 .map(std::string::ToString::to_string)
1014 .collect();
1015 quote! { .group_by(&[#(#field_strs),*]) }
1016 };
1017
1018 let having_chain = if let Some(expr) = having {
1019 let expr_tokens = sql_filter_expr_to_tokens(&expr);
1020 quote! { .having(#expr_tokens) }
1021 } else {
1022 quote! {}
1023 };
1024
1025 let sort_chain: Vec<TokenStream2> = sorts
1026 .iter()
1027 .map(|s| {
1028 let field_str = s.field.to_string();
1029 let dir = if s.desc {
1030 quote! { ::mik_sql::SortDir::Desc }
1031 } else {
1032 quote! { ::mik_sql::SortDir::Asc }
1033 };
1034 quote! { .sort(#field_str, #dir) }
1035 })
1036 .collect();
1037
1038 let dynamic_sort_setup = if let Some(ref sort_expr) = dynamic_sort {
1039 let allow_strs: Vec<String> = allow_sort
1040 .iter()
1041 .map(std::string::ToString::to_string)
1042 .collect();
1043 if allow_strs.is_empty() {
1044 quote! {
1045 let __dynamic_sorts = ::mik_sql::SortField::parse_sort_string(
1046 &#sort_expr,
1047 &[]
1048 ).map_err(|e| e)?;
1049 }
1050 } else {
1051 quote! {
1052 let __dynamic_sorts = ::mik_sql::SortField::parse_sort_string(
1053 &#sort_expr,
1054 &[#(#allow_strs),*]
1055 ).map_err(|e| e)?;
1056 }
1057 }
1058 } else {
1059 quote! {}
1060 };
1061
1062 let dynamic_sort_chain = if dynamic_sort.is_some() {
1063 quote! { .sorts(&__dynamic_sorts) }
1064 } else {
1065 quote! {}
1066 };
1067
1068 let (merge_setup, merge_chain) = if let Some(ref merge_expr) = merge_filters {
1069 let allow_strs: Vec<String> = allow_fields
1070 .iter()
1071 .map(std::string::ToString::to_string)
1072 .collect();
1073 let deny_op_tokens: Vec<TokenStream2> = deny_ops
1074 .iter()
1075 .map(|op| {
1076 let op_str = op.to_string();
1077 match op_str.as_str() {
1078 "ne" => quote! { ::mik_sql::Operator::Ne },
1079 "gt" => quote! { ::mik_sql::Operator::Gt },
1080 "gte" => quote! { ::mik_sql::Operator::Gte },
1081 "lt" => quote! { ::mik_sql::Operator::Lt },
1082 "lte" => quote! { ::mik_sql::Operator::Lte },
1083 "in" => quote! { ::mik_sql::Operator::In },
1084 "nin" | "notIn" => quote! { ::mik_sql::Operator::NotIn },
1085 "like" => quote! { ::mik_sql::Operator::Like },
1086 "ilike" => quote! { ::mik_sql::Operator::ILike },
1087 "regex" => quote! { ::mik_sql::Operator::Regex },
1088 "startsWith" | "starts_with" => quote! { ::mik_sql::Operator::StartsWith },
1089 "endsWith" | "ends_with" => quote! { ::mik_sql::Operator::EndsWith },
1090 "contains" => quote! { ::mik_sql::Operator::Contains },
1091 "between" => quote! { ::mik_sql::Operator::Between },
1092 _ => quote! { ::mik_sql::Operator::Eq },
1094 }
1095 })
1096 .collect();
1097
1098 let max_depth_val = max_depth
1099 .map(|d| quote! { #d as usize })
1100 .unwrap_or(quote! { 5 });
1101
1102 let setup = quote! {
1103 let __validator = ::mik_sql::FilterValidator::new()
1104 .allow_fields(&[#(#allow_strs),*])
1105 .deny_operators(&[#(#deny_op_tokens),*])
1106 .max_depth(#max_depth_val);
1107
1108 for __user_filter in &#merge_expr {
1109 __validator.validate(__user_filter).map_err(|e| e.to_string())?;
1110 }
1111 };
1112
1113 let chain = quote! {
1114 for __f in &#merge_expr {
1115 __builder = __builder.filter(__f.field.clone(), __f.op, __f.value.clone());
1116 }
1117 };
1118
1119 (setup, chain)
1120 } else {
1121 (quote! {}, quote! {})
1122 };
1123
1124 let needs_result = dynamic_sort.is_some() || merge_filters.is_some();
1125
1126 let pagination_chain = match (page, limit, offset) {
1127 (Some(p), Some(l), None) => quote! { .page(#p as u32, #l as u32) },
1128 (None, Some(l), Some(o)) => quote! { .limit_offset(#l as u32, #o as u32) },
1129 (None, Some(l), None) => quote! { .limit_offset(#l as u32, 0) },
1130 _ => quote! {},
1131 };
1132
1133 let after_chain = if let Some(ref expr) = after {
1134 quote! { .after_cursor(#expr) }
1135 } else {
1136 quote! {}
1137 };
1138
1139 let before_chain = if let Some(ref expr) = before {
1140 quote! { .before_cursor(#expr) }
1141 } else {
1142 quote! {}
1143 };
1144
1145 let builder_constructor = dialect.builder_tokens(&table_str);
1146
1147 let tokens = if needs_result {
1148 quote! {
1149 (|| -> ::std::result::Result<(String, Vec<::mik_sql::Value>), String> {
1150 #dynamic_sort_setup
1151 #merge_setup
1152
1153 let mut __builder = #builder_constructor
1154 #fields_chain
1155 #(#computed_chain)*
1156 #(#aggregate_chain)*
1157 #filter_chain;
1158
1159 #merge_chain
1160
1161 let __sql_result = __builder
1162 #group_by_chain
1163 #having_chain
1164 #(#sort_chain)*
1165 #dynamic_sort_chain
1166 #after_chain
1167 #before_chain
1168 #pagination_chain
1169 .build();
1170
1171 Ok((__sql_result.sql, __sql_result.params))
1172 })()
1173 }
1174 } else {
1175 quote! {
1176 {
1177 let __sql_result = #builder_constructor
1178 #fields_chain
1179 #(#computed_chain)*
1180 #(#aggregate_chain)*
1181 #filter_chain
1182 #group_by_chain
1183 #having_chain
1184 #(#sort_chain)*
1185 #after_chain
1186 #before_chain
1187 #pagination_chain
1188 .build();
1189 (__sql_result.sql, __sql_result.params)
1190 }
1191 }
1192 };
1193
1194 TokenStream::from(tokens)
1195}
1196
1197fn sql_aggregate_to_tokens(agg: &SqlAggregate) -> TokenStream2 {
1198 let field_str = agg.field.as_ref().map(std::string::ToString::to_string);
1199 let alias_str = agg.alias.as_ref().map(std::string::ToString::to_string);
1200
1201 let base = match (&agg.func, &field_str) {
1202 (SqlAggregateFunc::Count, Some(f)) => quote! { ::mik_sql::Aggregate::count_field(#f) },
1203 (SqlAggregateFunc::CountDistinct, Some(f)) => {
1204 quote! { ::mik_sql::Aggregate::count_distinct(#f) }
1205 },
1206 (SqlAggregateFunc::Sum, Some(f)) => quote! { ::mik_sql::Aggregate::sum(#f) },
1207 (SqlAggregateFunc::Avg, Some(f)) => quote! { ::mik_sql::Aggregate::avg(#f) },
1208 (SqlAggregateFunc::Min, Some(f)) => quote! { ::mik_sql::Aggregate::min(#f) },
1209 (SqlAggregateFunc::Max, Some(f)) => quote! { ::mik_sql::Aggregate::max(#f) },
1210 _ => quote! { ::mik_sql::Aggregate::count() },
1212 };
1213
1214 if let Some(alias) = alias_str {
1215 quote! { #base.as_alias(#alias) }
1216 } else {
1217 base
1218 }
1219}
1220
1221fn sql_filter_expr_to_tokens(expr: &SqlFilterExpr) -> TokenStream2 {
1222 match expr {
1223 SqlFilterExpr::Simple(filter) => {
1224 let field_str = filter.field.to_string();
1225 let op = sql_operator_to_tokens(&filter.op);
1226 let value = sql_value_to_tokens(&filter.value);
1227 quote! { ::mik_sql::simple(#field_str, #op, #value) }
1228 },
1229 SqlFilterExpr::Compound { op, filters } => {
1230 let filter_tokens: Vec<TokenStream2> =
1231 filters.iter().map(sql_filter_expr_to_tokens).collect();
1232
1233 match op {
1234 SqlLogicalOp::And => quote! { ::mik_sql::and(vec![#(#filter_tokens),*]) },
1235 SqlLogicalOp::Or => quote! { ::mik_sql::or(vec![#(#filter_tokens),*]) },
1236 SqlLogicalOp::Not => {
1237 let inner = filter_tokens.into_iter().next().unwrap_or_default();
1238 quote! { ::mik_sql::not(#inner) }
1239 },
1240 }
1241 },
1242 }
1243}
1244
1245fn compute_expr_to_sql(expr: &SqlComputeExpr) -> String {
1246 match expr {
1247 SqlComputeExpr::Column(ident) => ident.to_string(),
1248 SqlComputeExpr::LitStr(lit) => {
1249 let s = lit.value();
1250 format!("'{}'", s.replace('\'', "''"))
1251 },
1252 SqlComputeExpr::LitInt(lit) => lit.to_string(),
1253 SqlComputeExpr::LitFloat(lit) => lit.to_string(),
1254 SqlComputeExpr::BinOp { left, op, right } => {
1255 let left_sql = compute_expr_to_sql(left);
1256 let right_sql = compute_expr_to_sql(right);
1257 let op_str = match op {
1258 SqlComputeBinOp::Add => "+",
1259 SqlComputeBinOp::Sub => "-",
1260 SqlComputeBinOp::Mul => "*",
1261 SqlComputeBinOp::Div => "/",
1262 };
1263 format!("{left_sql} {op_str} {right_sql}")
1264 },
1265 SqlComputeExpr::Func { name, args } => {
1266 let args_sql: Vec<String> = args.iter().map(compute_expr_to_sql).collect();
1267 match name {
1268 SqlComputeFunc::Concat => args_sql.join(" || "),
1269 SqlComputeFunc::Coalesce => format!("COALESCE({})", args_sql.join(", ")),
1270 SqlComputeFunc::Upper => format!("UPPER({})", args_sql.join(", ")),
1271 SqlComputeFunc::Lower => format!("LOWER({})", args_sql.join(", ")),
1272 SqlComputeFunc::Round => format!("ROUND({})", args_sql.join(", ")),
1273 SqlComputeFunc::Abs => format!("ABS({})", args_sql.join(", ")),
1274 SqlComputeFunc::Length => format!("LENGTH({})", args_sql.join(", ")),
1275 }
1276 },
1277 SqlComputeExpr::Paren(inner) => format!("({})", compute_expr_to_sql(inner)),
1278 }
1279}
1280
1281#[proc_macro]
1283pub fn ids(input: TokenStream) -> TokenStream {
1284 let input = parse_macro_input!(input as IdsInput);
1285
1286 let list = &input.list;
1287 let field = &input.field;
1288
1289 let tokens = quote! {
1290 #list.iter().map(|__item| __item.#field.clone()).collect::<Vec<_>>()
1291 };
1292
1293 TokenStream::from(tokens)
1294}
1295
1296struct IdsInput {
1297 list: Expr,
1298 field: syn::Ident,
1299}
1300
1301impl Parse for IdsInput {
1302 fn parse(input: ParseStream) -> Result<Self> {
1303 let list: Expr = input.parse()?;
1304
1305 let field = if input.peek(Token![,]) {
1306 input.parse::<Token![,]>()?;
1307 input.parse()?
1308 } else {
1309 syn::Ident::new("id", proc_macro2::Span::call_site())
1310 };
1311
1312 Ok(IdsInput { list, field })
1313 }
1314}
1315
1316#[proc_macro]
1318pub fn sql_create(input: TokenStream) -> TokenStream {
1319 let InsertInput {
1320 dialect,
1321 table,
1322 columns,
1323 returning,
1324 } = parse_macro_input!(input as InsertInput);
1325
1326 let table_str = table.to_string();
1327 let builder_constructor = dialect.insert_tokens(&table_str);
1328
1329 let col_strs: Vec<String> = columns.iter().map(|(c, _)| c.to_string()).collect();
1330
1331 let value_tokens: Vec<TokenStream2> = columns
1332 .iter()
1333 .map(|(_, v)| sql_value_to_tokens(v))
1334 .collect();
1335
1336 let returning_chain = if returning.is_empty() {
1337 quote! {}
1338 } else {
1339 let ret_strs: Vec<String> = returning
1340 .iter()
1341 .map(std::string::ToString::to_string)
1342 .collect();
1343 quote! { .returning(&[#(#ret_strs),*]) }
1344 };
1345
1346 let tokens = quote! {
1347 {
1348 let __result = #builder_constructor
1349 .columns(&[#(#col_strs),*])
1350 .values(vec![#(#value_tokens),*])
1351 #returning_chain
1352 .build();
1353 (__result.sql, __result.params)
1354 }
1355 };
1356
1357 TokenStream::from(tokens)
1358}
1359
1360struct InsertInput {
1361 dialect: SqlDialect,
1362 table: syn::Ident,
1363 columns: Vec<(syn::Ident, SqlValue)>,
1364 returning: Vec<syn::Ident>,
1365}
1366
1367impl Parse for InsertInput {
1368 fn parse(input: ParseStream) -> Result<Self> {
1369 let dialect = parse_optional_dialect(input)?;
1370 let table: syn::Ident = input.parse()?;
1371
1372 let content;
1373 braced!(content in input);
1374
1375 let mut columns = Vec::new();
1376 let mut returning = Vec::new();
1377
1378 while !content.is_empty() {
1379 let key: syn::Ident = content.parse()?;
1380 content.parse::<Token![:]>()?;
1381
1382 if key.to_string().as_str() == "returning" {
1383 let ret_content;
1384 bracketed!(ret_content in content);
1385 let fields: Punctuated<syn::Ident, Token![,]> =
1386 ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1387 returning = fields.into_iter().collect();
1388 } else {
1389 let value = parse_sql_value(&content)?;
1390 columns.push((key, value));
1391 }
1392
1393 if content.peek(Token![,]) {
1394 content.parse::<Token![,]>()?;
1395 }
1396 }
1397
1398 Ok(InsertInput {
1399 dialect,
1400 table,
1401 columns,
1402 returning,
1403 })
1404 }
1405}
1406
1407#[proc_macro]
1409pub fn sql_update(input: TokenStream) -> TokenStream {
1410 let UpdateInput {
1411 dialect,
1412 table,
1413 sets,
1414 where_expr,
1415 returning,
1416 } = parse_macro_input!(input as UpdateInput);
1417
1418 let table_str = table.to_string();
1419 let builder_constructor = dialect.update_tokens(&table_str);
1420
1421 let set_chain: Vec<TokenStream2> = sets
1422 .iter()
1423 .map(|(col, val)| {
1424 let col_str = col.to_string();
1425 let val_tokens = sql_value_to_tokens(val);
1426 quote! { .set(#col_str, #val_tokens) }
1427 })
1428 .collect();
1429
1430 let filter_chain = if let Some(expr) = where_expr {
1431 let expr_tokens = sql_filter_expr_to_tokens(&expr);
1432 quote! { .filter_expr(#expr_tokens) }
1433 } else {
1434 quote! {}
1435 };
1436
1437 let returning_chain = if returning.is_empty() {
1438 quote! {}
1439 } else {
1440 let ret_strs: Vec<String> = returning
1441 .iter()
1442 .map(std::string::ToString::to_string)
1443 .collect();
1444 quote! { .returning(&[#(#ret_strs),*]) }
1445 };
1446
1447 let tokens = quote! {
1448 {
1449 let __result = #builder_constructor
1450 #(#set_chain)*
1451 #filter_chain
1452 #returning_chain
1453 .build();
1454 (__result.sql, __result.params)
1455 }
1456 };
1457
1458 TokenStream::from(tokens)
1459}
1460
1461struct UpdateInput {
1462 dialect: SqlDialect,
1463 table: syn::Ident,
1464 sets: Vec<(syn::Ident, SqlValue)>,
1465 where_expr: Option<SqlFilterExpr>,
1466 returning: Vec<syn::Ident>,
1467}
1468
1469impl Parse for UpdateInput {
1470 fn parse(input: ParseStream) -> Result<Self> {
1471 let dialect = parse_optional_dialect(input)?;
1472 let table: syn::Ident = input.parse()?;
1473
1474 let content;
1475 braced!(content in input);
1476
1477 let mut sets = Vec::new();
1478 let mut where_expr = None;
1479 let mut returning = Vec::new();
1480
1481 while !content.is_empty() {
1482 let key: syn::Ident = content.parse()?;
1483 content.parse::<Token![:]>()?;
1484
1485 match key.to_string().as_str() {
1486 "set" => {
1487 let set_content;
1488 braced!(set_content in content);
1489 sets = parse_column_values(&set_content)?;
1490 },
1491 "where" | "filter" => {
1492 let where_content;
1493 braced!(where_content in content);
1494 where_expr = Some(parse_filter_block(&where_content)?);
1495 },
1496 "returning" => {
1497 let ret_content;
1498 bracketed!(ret_content in content);
1499 let fields: Punctuated<syn::Ident, Token![,]> =
1500 ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1501 returning = fields.into_iter().collect();
1502 },
1503 _ => {
1504 return Err(syn::Error::new(
1505 key.span(),
1506 format!("Unknown option '{key}'. Expected 'set', 'where', or 'returning'"),
1507 ));
1508 },
1509 }
1510
1511 if content.peek(Token![,]) {
1512 content.parse::<Token![,]>()?;
1513 }
1514 }
1515
1516 Ok(UpdateInput {
1517 dialect,
1518 table,
1519 sets,
1520 where_expr,
1521 returning,
1522 })
1523 }
1524}
1525
1526#[proc_macro]
1528pub fn sql_delete(input: TokenStream) -> TokenStream {
1529 let DeleteInput {
1530 dialect,
1531 table,
1532 where_expr,
1533 returning,
1534 } = parse_macro_input!(input as DeleteInput);
1535
1536 let table_str = table.to_string();
1537 let builder_constructor = dialect.delete_tokens(&table_str);
1538
1539 let filter_chain = if let Some(expr) = where_expr {
1540 let expr_tokens = sql_filter_expr_to_tokens(&expr);
1541 quote! { .filter_expr(#expr_tokens) }
1542 } else {
1543 quote! {}
1544 };
1545
1546 let returning_chain = if returning.is_empty() {
1547 quote! {}
1548 } else {
1549 let ret_strs: Vec<String> = returning
1550 .iter()
1551 .map(std::string::ToString::to_string)
1552 .collect();
1553 quote! { .returning(&[#(#ret_strs),*]) }
1554 };
1555
1556 let tokens = quote! {
1557 {
1558 let __result = #builder_constructor
1559 #filter_chain
1560 #returning_chain
1561 .build();
1562 (__result.sql, __result.params)
1563 }
1564 };
1565
1566 TokenStream::from(tokens)
1567}
1568
1569struct DeleteInput {
1570 dialect: SqlDialect,
1571 table: syn::Ident,
1572 where_expr: Option<SqlFilterExpr>,
1573 returning: Vec<syn::Ident>,
1574}
1575
1576impl Parse for DeleteInput {
1577 fn parse(input: ParseStream) -> Result<Self> {
1578 let dialect = parse_optional_dialect(input)?;
1579 let table: syn::Ident = input.parse()?;
1580
1581 let content;
1582 braced!(content in input);
1583
1584 let mut where_expr = None;
1585 let mut returning = Vec::new();
1586
1587 while !content.is_empty() {
1588 let key: syn::Ident = content.parse()?;
1589 content.parse::<Token![:]>()?;
1590
1591 match key.to_string().as_str() {
1592 "where" | "filter" => {
1593 let where_content;
1594 braced!(where_content in content);
1595 where_expr = Some(parse_filter_block(&where_content)?);
1596 },
1597 "returning" => {
1598 let ret_content;
1599 bracketed!(ret_content in content);
1600 let fields: Punctuated<syn::Ident, Token![,]> =
1601 ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1602 returning = fields.into_iter().collect();
1603 },
1604 _ => {
1605 return Err(syn::Error::new(
1606 key.span(),
1607 format!("Unknown option '{key}'. Expected 'where' or 'returning'"),
1608 ));
1609 },
1610 }
1611
1612 if content.peek(Token![,]) {
1613 content.parse::<Token![,]>()?;
1614 }
1615 }
1616
1617 Ok(DeleteInput {
1618 dialect,
1619 table,
1620 where_expr,
1621 returning,
1622 })
1623 }
1624}
1625
1626fn parse_column_values(input: ParseStream) -> Result<Vec<(syn::Ident, SqlValue)>> {
1627 let mut result = Vec::new();
1628
1629 while !input.is_empty() {
1630 let key: syn::Ident = input.parse()?;
1631 input.parse::<Token![:]>()?;
1632 let value = parse_sql_value(input)?;
1633 result.push((key, value));
1634
1635 if input.peek(Token![,]) {
1636 input.parse::<Token![,]>()?;
1637 }
1638 }
1639
1640 Ok(result)
1641}