mik_sql_macros/
lib.rs

1//! Proc-macros for mik-sql - SQL query builder with Mongo-style filters.
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{
7    Expr, LitBool, LitFloat, LitInt, LitStr, Result, Token, braced, bracketed,
8    ext::IdentExt,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    punctuated::Punctuated,
12    token,
13};
14
15// ============================================================================
16// SQL CRUD MACROS - Query builder with JSON-like syntax
17// ============================================================================
18
19/// SQL dialect for query generation.
20#[derive(Clone, Copy, Default)]
21enum SqlDialect {
22    #[default]
23    Postgres,
24    Sqlite,
25}
26
27impl SqlDialect {
28    /// Parse dialect from identifier, returns None if not a dialect keyword.
29    fn from_ident(ident: &syn::Ident) -> Option<Self> {
30        match ident.to_string().as_str() {
31            "postgres" | "pg" => Some(SqlDialect::Postgres),
32            "sqlite" => Some(SqlDialect::Sqlite),
33            _ => None,
34        }
35    }
36
37    /// Generate the query builder constructor call.
38    fn builder_tokens(self, table: &str) -> TokenStream2 {
39        match self {
40            SqlDialect::Postgres => quote! { ::mik_sql::postgres(#table) },
41            SqlDialect::Sqlite => quote! { ::mik_sql::sqlite(#table) },
42        }
43    }
44
45    /// Generate the insert builder constructor call.
46    fn insert_tokens(self, table: &str) -> TokenStream2 {
47        match self {
48            SqlDialect::Postgres => quote! { ::mik_sql::insert(#table) },
49            SqlDialect::Sqlite => quote! { ::mik_sql::insert_sqlite(#table) },
50        }
51    }
52
53    /// Generate the update builder constructor call.
54    fn update_tokens(self, table: &str) -> TokenStream2 {
55        match self {
56            SqlDialect::Postgres => quote! { ::mik_sql::update(#table) },
57            SqlDialect::Sqlite => quote! { ::mik_sql::update_sqlite(#table) },
58        }
59    }
60
61    /// Generate the delete builder constructor call.
62    fn delete_tokens(self, table: &str) -> TokenStream2 {
63        match self {
64            SqlDialect::Postgres => quote! { ::mik_sql::delete(#table) },
65            SqlDialect::Sqlite => quote! { ::mik_sql::delete_sqlite(#table) },
66        }
67    }
68}
69
70/// Parse optional dialect prefix: `sql_read!(sqlite, users { ... })`
71fn parse_optional_dialect(input: ParseStream) -> Result<SqlDialect> {
72    let fork = input.fork();
73    if let Ok(ident) = fork.parse::<syn::Ident>()
74        && let Some(dialect) = SqlDialect::from_ident(&ident)
75        && fork.peek(Token![,])
76    {
77        input.parse::<syn::Ident>()?;
78        input.parse::<Token![,]>()?;
79        return Ok(dialect);
80    }
81    Ok(SqlDialect::default())
82}
83
84/// Input for the [`sql_read!`] macro.
85struct SqlInput {
86    dialect: SqlDialect,
87    table: syn::Ident,
88    select_fields: Vec<syn::Ident>,
89    computed: Vec<SqlCompute>,
90    aggregates: Vec<SqlAggregate>,
91    filter_expr: Option<SqlFilterExpr>,
92    group_by: Vec<syn::Ident>,
93    having: Option<SqlFilterExpr>,
94    sorts: Vec<SqlSort>,
95    dynamic_sort: Option<Expr>,
96    allow_sort: Vec<syn::Ident>,
97    merge_filters: Option<Expr>,
98    allow_fields: Vec<syn::Ident>,
99    deny_ops: Vec<syn::Ident>,
100    max_depth: Option<Expr>,
101    page: Option<Expr>,
102    limit: Option<Expr>,
103    offset: Option<Expr>,
104    after: Option<Expr>,
105    before: Option<Expr>,
106}
107
108/// An aggregation in the sql! macro.
109struct SqlAggregate {
110    func: SqlAggregateFunc,
111    field: Option<syn::Ident>,
112    alias: Option<syn::Ident>,
113}
114
115/// Aggregation functions.
116#[derive(Clone, Copy)]
117enum SqlAggregateFunc {
118    Count,
119    CountDistinct,
120    Sum,
121    Avg,
122    Min,
123    Max,
124}
125
126/// A filter expression - can be simple or compound.
127enum SqlFilterExpr {
128    Simple(SqlFilter),
129    Compound {
130        op: SqlLogicalOp,
131        filters: Vec<SqlFilterExpr>,
132    },
133}
134
135/// Logical operators for compound filters.
136#[derive(Clone, Copy)]
137enum SqlLogicalOp {
138    And,
139    Or,
140    Not,
141}
142
143/// A filter condition in the sql! macro.
144struct SqlFilter {
145    field: syn::Ident,
146    op: SqlOperator,
147    value: SqlValue,
148}
149
150/// SQL operators.
151#[derive(Clone)]
152enum SqlOperator {
153    Eq,
154    Ne,
155    Gt,
156    Gte,
157    Lt,
158    Lte,
159    In,
160    NotIn,
161    Like,
162    ILike,
163    Regex,
164    StartsWith,
165    EndsWith,
166    Contains,
167    Between,
168}
169
170/// Value in a filter.
171enum SqlValue {
172    Null,
173    Bool(bool),
174    Int(LitInt),
175    Float(LitFloat),
176    String(LitStr),
177    Array(Vec<SqlValue>),
178    IntHint(Expr),
179    StrHint(Expr),
180    FloatHint(Expr),
181    BoolHint(Expr),
182    Expr(Expr),
183}
184
185/// Sort field with direction.
186struct SqlSort {
187    field: syn::Ident,
188    desc: bool,
189}
190
191/// A computed field in the sql! macro.
192struct SqlCompute {
193    alias: syn::Ident,
194    expr: SqlComputeExpr,
195}
196
197/// A compute expression (arithmetic, function call, or literal).
198enum SqlComputeExpr {
199    Column(syn::Ident),
200    LitStr(LitStr),
201    LitInt(LitInt),
202    LitFloat(LitFloat),
203    BinOp {
204        left: Box<SqlComputeExpr>,
205        op: SqlComputeBinOp,
206        right: Box<SqlComputeExpr>,
207    },
208    Func {
209        name: SqlComputeFunc,
210        args: Vec<SqlComputeExpr>,
211    },
212    Paren(Box<SqlComputeExpr>),
213}
214
215/// Binary operators for compute expressions.
216#[derive(Clone, Copy)]
217enum SqlComputeBinOp {
218    Add,
219    Sub,
220    Mul,
221    Div,
222}
223
224/// Whitelisted compute functions.
225#[derive(Clone, Copy)]
226enum SqlComputeFunc {
227    Concat,
228    Coalesce,
229    Upper,
230    Lower,
231    Round,
232    Abs,
233    Length,
234}
235
236impl Parse for SqlInput {
237    // Complex DSL parsing requires seeing full flow in one place
238    #[allow(clippy::too_many_lines)]
239    fn parse(input: ParseStream) -> Result<Self> {
240        let dialect = parse_optional_dialect(input)?;
241        let table: syn::Ident = input.parse().map_err(|e| {
242            syn::Error::new(
243                e.span(),
244                format!(
245                    "Expected table name.\n\
246                     Usage: sql_read!(table_name {{ ... }}) or sql_read!(sqlite, table_name {{ ... }})\n\
247                     Original error: {e}"
248                ),
249            )
250        })?;
251
252        let content;
253        braced!(content in input);
254
255        let mut select_fields = Vec::new();
256        let mut computed = Vec::new();
257        let mut aggregates = Vec::new();
258        let mut filter_expr = None;
259        let mut group_by = Vec::new();
260        let mut having = None;
261        let mut sorts = Vec::new();
262        let mut dynamic_sort = None;
263        let mut allow_sort = Vec::new();
264        let mut merge_filters = None;
265        let mut allow_fields = Vec::new();
266        let mut deny_ops = Vec::new();
267        let mut max_depth = None;
268        let mut page = None;
269        let mut limit = None;
270        let mut offset = None;
271        let mut after = None;
272        let mut before = None;
273
274        while !content.is_empty() {
275            let key: syn::Ident = content.parse()?;
276            content.parse::<Token![:]>()?;
277
278            match key.to_string().as_str() {
279                "select" => {
280                    let fields_content;
281                    bracketed!(fields_content in content);
282                    let fields: Punctuated<syn::Ident, Token![,]> =
283                        fields_content.parse_terminated(syn::Ident::parse, Token![,])?;
284                    select_fields = fields.into_iter().collect();
285                },
286                "compute" => {
287                    let compute_content;
288                    braced!(compute_content in content);
289                    computed = parse_compute_fields(&compute_content)?;
290                },
291                "aggregate" | "agg" => {
292                    let agg_content;
293                    braced!(agg_content in content);
294                    aggregates = parse_aggregates(&agg_content)?;
295                },
296                "filter" => {
297                    let filter_content;
298                    braced!(filter_content in content);
299                    filter_expr = Some(parse_filter_block(&filter_content)?);
300                },
301                "group_by" | "groupBy" => {
302                    let group_content;
303                    bracketed!(group_content in content);
304                    let fields: Punctuated<syn::Ident, Token![,]> =
305                        group_content.parse_terminated(syn::Ident::parse, Token![,])?;
306                    group_by = fields.into_iter().collect();
307                },
308                "having" => {
309                    let having_content;
310                    braced!(having_content in content);
311                    having = Some(parse_filter_block(&having_content)?);
312                },
313                "order" => {
314                    if content.peek(token::Bracket) {
315                        let order_content;
316                        bracketed!(order_content in content);
317                        let sort_items: Punctuated<SqlSort, Token![,]> =
318                            order_content.parse_terminated(SqlSort::parse, Token![,])?;
319                        sorts = sort_items.into_iter().collect();
320                    } else if content.peek(Token![-]) {
321                        let sort: SqlSort = content.parse()?;
322                        sorts.push(sort);
323                    } else if content.peek(syn::Ident)
324                        && !content.peek2(Token![,])
325                        && !content.peek2(token::Brace)
326                    {
327                        let fork = content.fork();
328                        let ident: syn::Ident = fork.parse()?;
329                        if fork.peek(Token![,]) && fork.peek2(syn::Ident) {
330                            fork.parse::<Token![,]>().ok();
331                            if let Ok(_next_ident) = fork.parse::<syn::Ident>() {
332                                if fork.peek(Token![:]) {
333                                    dynamic_sort = Some(syn::Expr::Path(syn::ExprPath {
334                                        attrs: vec![],
335                                        qself: None,
336                                        path: ident.clone().into(),
337                                    }));
338                                    content.parse::<syn::Ident>()?;
339                                } else {
340                                    let sort: SqlSort = content.parse()?;
341                                    sorts.push(sort);
342                                }
343                            } else {
344                                let sort: SqlSort = content.parse()?;
345                                sorts.push(sort);
346                            }
347                        } else if fork.is_empty()
348                            || (fork.peek(Token![,]) && !fork.peek2(syn::Ident))
349                        {
350                            dynamic_sort = Some(content.parse()?);
351                        } else {
352                            let sort: SqlSort = content.parse()?;
353                            sorts.push(sort);
354                        }
355                    } else {
356                        dynamic_sort = Some(content.parse()?);
357                    }
358                },
359                "allow_sort" | "allowSort" => {
360                    let allow_content;
361                    bracketed!(allow_content in content);
362                    let fields: Punctuated<syn::Ident, Token![,]> =
363                        allow_content.parse_terminated(syn::Ident::parse, Token![,])?;
364                    allow_sort = fields.into_iter().collect();
365                },
366                "merge" => {
367                    merge_filters = Some(content.parse()?);
368                },
369                "allow" => {
370                    let allow_content;
371                    bracketed!(allow_content in content);
372                    let fields: Punctuated<syn::Ident, Token![,]> =
373                        allow_content.parse_terminated(syn::Ident::parse, Token![,])?;
374                    allow_fields = fields.into_iter().collect();
375                },
376                "deny_ops" | "denyOps" => {
377                    let deny_content;
378                    bracketed!(deny_content in content);
379                    let mut ops = Vec::new();
380                    while !deny_content.is_empty() {
381                        deny_content.parse::<Token![$]>()?;
382                        let op: syn::Ident = deny_content.call(syn::Ident::parse_any)?;
383                        ops.push(op);
384                        if deny_content.peek(Token![,]) {
385                            deny_content.parse::<Token![,]>()?;
386                        }
387                    }
388                    deny_ops = ops;
389                },
390                "max_depth" | "maxDepth" => {
391                    max_depth = Some(content.parse()?);
392                },
393                "page" => {
394                    page = Some(content.parse()?);
395                },
396                "limit" => {
397                    limit = Some(content.parse()?);
398                },
399                "offset" => {
400                    offset = Some(content.parse()?);
401                },
402                "after" => {
403                    after = Some(content.parse()?);
404                },
405                "before" => {
406                    before = Some(content.parse()?);
407                },
408                other => {
409                    return Err(syn::Error::new(
410                        key.span(),
411                        format!(
412                            "Unknown option '{other}'. Valid options: select, compute, aggregate, filter, merge, allow, deny_ops, max_depth, group_by, having, order, allow_sort, page, limit, offset, after, before"
413                        ),
414                    ));
415                },
416            }
417
418            if content.peek(Token![,]) {
419                content.parse::<Token![,]>()?;
420            }
421        }
422
423        Ok(SqlInput {
424            dialect,
425            table,
426            select_fields,
427            computed,
428            aggregates,
429            filter_expr,
430            group_by,
431            having,
432            sorts,
433            dynamic_sort,
434            allow_sort,
435            merge_filters,
436            allow_fields,
437            deny_ops,
438            max_depth,
439            page,
440            limit,
441            offset,
442            after,
443            before,
444        })
445    }
446}
447
448/// Parse the aggregate block.
449fn parse_aggregates(input: ParseStream) -> Result<Vec<SqlAggregate>> {
450    let mut aggregates = Vec::new();
451
452    while !input.is_empty() {
453        let func_name: syn::Ident = input.parse()?;
454        input.parse::<Token![:]>()?;
455
456        let func_str = func_name.to_string();
457        let (func, field, alias) = match func_str.as_str() {
458            "count" => {
459                if input.peek(Token![*]) {
460                    input.parse::<Token![*]>()?;
461                    (
462                        SqlAggregateFunc::Count,
463                        None,
464                        Some(syn::Ident::new("count", func_name.span())),
465                    )
466                } else {
467                    let field: syn::Ident = input.parse()?;
468                    (SqlAggregateFunc::Count, Some(field), None)
469                }
470            },
471            "count_distinct" | "countDistinct" => {
472                let field: syn::Ident = input.parse()?;
473                (SqlAggregateFunc::CountDistinct, Some(field), None)
474            },
475            "sum" => {
476                let field: syn::Ident = input.parse()?;
477                (SqlAggregateFunc::Sum, Some(field), None)
478            },
479            "avg" => {
480                let field: syn::Ident = input.parse()?;
481                (SqlAggregateFunc::Avg, Some(field), None)
482            },
483            "min" => {
484                let field: syn::Ident = input.parse()?;
485                (SqlAggregateFunc::Min, Some(field), None)
486            },
487            "max" => {
488                let field: syn::Ident = input.parse()?;
489                (SqlAggregateFunc::Max, Some(field), None)
490            },
491            other => {
492                return Err(syn::Error::new(
493                    func_name.span(),
494                    format!(
495                        "Unknown aggregate function '{other}'. Valid: count, count_distinct, sum, avg, min, max"
496                    ),
497                ));
498            },
499        };
500
501        aggregates.push(SqlAggregate { func, field, alias });
502
503        if input.peek(Token![,]) {
504            input.parse::<Token![,]>()?;
505        }
506    }
507
508    Ok(aggregates)
509}
510
511/// Parse the compute block.
512fn parse_compute_fields(input: ParseStream) -> Result<Vec<SqlCompute>> {
513    let mut computed = Vec::new();
514
515    while !input.is_empty() {
516        let alias: syn::Ident = input.parse()?;
517        input.parse::<Token![:]>()?;
518        let expr = parse_compute_expr(input)?;
519        computed.push(SqlCompute { alias, expr });
520
521        if input.peek(Token![,]) {
522            input.parse::<Token![,]>()?;
523        }
524    }
525
526    Ok(computed)
527}
528
529fn parse_compute_expr(input: ParseStream) -> Result<SqlComputeExpr> {
530    parse_compute_additive(input)
531}
532
533fn parse_compute_additive(input: ParseStream) -> Result<SqlComputeExpr> {
534    let mut left = parse_compute_multiplicative(input)?;
535
536    while input.peek(Token![+]) || input.peek(Token![-]) {
537        let op = if input.peek(Token![+]) {
538            input.parse::<Token![+]>()?;
539            SqlComputeBinOp::Add
540        } else {
541            input.parse::<Token![-]>()?;
542            SqlComputeBinOp::Sub
543        };
544
545        let right = parse_compute_multiplicative(input)?;
546        left = SqlComputeExpr::BinOp {
547            left: Box::new(left),
548            op,
549            right: Box::new(right),
550        };
551    }
552
553    Ok(left)
554}
555
556fn parse_compute_multiplicative(input: ParseStream) -> Result<SqlComputeExpr> {
557    let mut left = parse_compute_primary(input)?;
558
559    while input.peek(Token![*]) || input.peek(Token![/]) {
560        let op = if input.peek(Token![*]) {
561            input.parse::<Token![*]>()?;
562            SqlComputeBinOp::Mul
563        } else {
564            input.parse::<Token![/]>()?;
565            SqlComputeBinOp::Div
566        };
567
568        let right = parse_compute_primary(input)?;
569        left = SqlComputeExpr::BinOp {
570            left: Box::new(left),
571            op,
572            right: Box::new(right),
573        };
574    }
575
576    Ok(left)
577}
578
579fn parse_compute_primary(input: ParseStream) -> Result<SqlComputeExpr> {
580    if input.peek(token::Paren) {
581        let content;
582        syn::parenthesized!(content in input);
583        let inner = parse_compute_expr(&content)?;
584        return Ok(SqlComputeExpr::Paren(Box::new(inner)));
585    }
586
587    if input.peek(LitStr) {
588        return Ok(SqlComputeExpr::LitStr(input.parse()?));
589    }
590
591    if input.peek(LitFloat) {
592        return Ok(SqlComputeExpr::LitFloat(input.parse()?));
593    }
594
595    if input.peek(LitInt) {
596        return Ok(SqlComputeExpr::LitInt(input.parse()?));
597    }
598
599    if input.peek(syn::Ident) {
600        let ident: syn::Ident = input.parse()?;
601
602        if input.peek(token::Paren) {
603            let func_name = ident.to_string();
604            let func = match func_name.as_str() {
605                "concat" => SqlComputeFunc::Concat,
606                "coalesce" => SqlComputeFunc::Coalesce,
607                "upper" => SqlComputeFunc::Upper,
608                "lower" => SqlComputeFunc::Lower,
609                "round" => SqlComputeFunc::Round,
610                "abs" => SqlComputeFunc::Abs,
611                "length" | "len" => SqlComputeFunc::Length,
612                other => {
613                    return Err(syn::Error::new(
614                        ident.span(),
615                        format!(
616                            "Unknown compute function '{other}'. Valid: concat, coalesce, upper, lower, round, abs, length"
617                        ),
618                    ));
619                },
620            };
621
622            let args_content;
623            syn::parenthesized!(args_content in input);
624            let args: Punctuated<SqlComputeExpr, Token![,]> =
625                args_content.parse_terminated(parse_compute_expr, Token![,])?;
626
627            return Ok(SqlComputeExpr::Func {
628                name: func,
629                args: args.into_iter().collect(),
630            });
631        }
632
633        return Ok(SqlComputeExpr::Column(ident));
634    }
635
636    Err(syn::Error::new(
637        input.span(),
638        "Expected a compute expression: column, literal, function call, or (expression)",
639    ))
640}
641
642fn parse_filter_block(input: ParseStream) -> Result<SqlFilterExpr> {
643    let mut simple_filters = Vec::new();
644
645    while !input.is_empty() {
646        if input.peek(Token![$]) {
647            input.parse::<Token![$]>()?;
648            let op_name: syn::Ident = input.call(syn::Ident::parse_any)?;
649            input.parse::<Token![:]>()?;
650
651            let logical_op = match op_name.to_string().as_str() {
652                "and" => SqlLogicalOp::And,
653                "or" => SqlLogicalOp::Or,
654                "not" => SqlLogicalOp::Not,
655                other => {
656                    return Err(syn::Error::new(
657                        op_name.span(),
658                        format!("Unknown logical operator '${other}'. Valid: $and, $or, $not"),
659                    ));
660                },
661            };
662
663            let filters = parse_filter_array(input)?;
664
665            if !simple_filters.is_empty() {
666                let mut all_filters: Vec<SqlFilterExpr> = simple_filters
667                    .into_iter()
668                    .map(SqlFilterExpr::Simple)
669                    .collect();
670                all_filters.push(SqlFilterExpr::Compound {
671                    op: logical_op,
672                    filters,
673                });
674                return Ok(SqlFilterExpr::Compound {
675                    op: SqlLogicalOp::And,
676                    filters: all_filters,
677                });
678            }
679
680            if input.peek(Token![,]) {
681                input.parse::<Token![,]>()?;
682            }
683
684            if !input.is_empty() {
685                let remaining = parse_filter_block(input)?;
686                return Ok(SqlFilterExpr::Compound {
687                    op: SqlLogicalOp::And,
688                    filters: vec![
689                        SqlFilterExpr::Compound {
690                            op: logical_op,
691                            filters,
692                        },
693                        remaining,
694                    ],
695                });
696            }
697
698            return Ok(SqlFilterExpr::Compound {
699                op: logical_op,
700                filters,
701            });
702        }
703
704        let filter = parse_sql_filter(input)?;
705        simple_filters.push(filter);
706
707        if input.peek(Token![,]) {
708            input.parse::<Token![,]>()?;
709        }
710    }
711
712    match simple_filters.len() {
713        0 => Err(syn::Error::new(input.span(), "Empty filter block")),
714        1 => Ok(SqlFilterExpr::Simple(simple_filters.remove(0))),
715        _ => Ok(SqlFilterExpr::Compound {
716            op: SqlLogicalOp::And,
717            filters: simple_filters
718                .into_iter()
719                .map(SqlFilterExpr::Simple)
720                .collect(),
721        }),
722    }
723}
724
725fn parse_filter_array(input: ParseStream) -> Result<Vec<SqlFilterExpr>> {
726    let content;
727    bracketed!(content in input);
728
729    let mut filters = Vec::new();
730    while !content.is_empty() {
731        let filter_content;
732        braced!(filter_content in content);
733        let filter_expr = parse_filter_block(&filter_content)?;
734        filters.push(filter_expr);
735
736        if content.peek(Token![,]) {
737            content.parse::<Token![,]>()?;
738        }
739    }
740
741    Ok(filters)
742}
743
744fn parse_sql_filter(input: ParseStream) -> Result<SqlFilter> {
745    let field: syn::Ident = input.parse()?;
746    input.parse::<Token![:]>()?;
747
748    if input.peek(token::Brace) {
749        let op_content;
750        braced!(op_content in input);
751
752        op_content.parse::<Token![$]>()?;
753        let op_name: syn::Ident = op_content.call(syn::Ident::parse_any)?;
754        op_content.parse::<Token![:]>()?;
755
756        let op = match op_name.to_string().as_str() {
757            "eq" => SqlOperator::Eq,
758            "ne" => SqlOperator::Ne,
759            "gt" => SqlOperator::Gt,
760            "gte" => SqlOperator::Gte,
761            "lt" => SqlOperator::Lt,
762            "lte" => SqlOperator::Lte,
763            "in" => SqlOperator::In,
764            "nin" => SqlOperator::NotIn,
765            "like" => SqlOperator::Like,
766            "ilike" => SqlOperator::ILike,
767            "regex" => SqlOperator::Regex,
768            "startsWith" | "starts_with" => SqlOperator::StartsWith,
769            "endsWith" | "ends_with" => SqlOperator::EndsWith,
770            "contains" => SqlOperator::Contains,
771            "between" => SqlOperator::Between,
772            other => {
773                return Err(syn::Error::new(
774                    op_name.span(),
775                    format!(
776                        "Unknown operator '${other}'. Valid operators: $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin, $like, $ilike, $regex, $startsWith, $endsWith, $contains, $between"
777                    ),
778                ));
779            },
780        };
781
782        let value = parse_sql_value(&op_content)?;
783        Ok(SqlFilter { field, op, value })
784    } else {
785        let value = parse_sql_value(input)?;
786        Ok(SqlFilter {
787            field,
788            op: SqlOperator::Eq,
789            value,
790        })
791    }
792}
793
794fn parse_sql_value(input: ParseStream) -> Result<SqlValue> {
795    let lookahead = input.lookahead1();
796
797    if lookahead.peek(token::Bracket) {
798        let content;
799        bracketed!(content in input);
800        let elements: Punctuated<SqlValue, Token![,]> =
801            content.parse_terminated(|inner| parse_sql_value(inner), Token![,])?;
802        Ok(SqlValue::Array(elements.into_iter().collect()))
803    } else if lookahead.peek(LitStr) {
804        Ok(SqlValue::String(input.parse()?))
805    } else if lookahead.peek(LitInt) {
806        Ok(SqlValue::Int(input.parse()?))
807    } else if lookahead.peek(LitFloat) {
808        Ok(SqlValue::Float(input.parse()?))
809    } else if lookahead.peek(LitBool) {
810        let lit: LitBool = input.parse()?;
811        Ok(SqlValue::Bool(lit.value))
812    } else if input.peek(syn::Ident) && input.peek2(token::Paren) {
813        let fork = input.fork();
814        let ident: syn::Ident = fork.parse()?;
815        match ident.to_string().as_str() {
816            "int" => {
817                input.parse::<syn::Ident>()?;
818                let content;
819                syn::parenthesized!(content in input);
820                Ok(SqlValue::IntHint(content.parse()?))
821            },
822            "str" => {
823                input.parse::<syn::Ident>()?;
824                let content;
825                syn::parenthesized!(content in input);
826                Ok(SqlValue::StrHint(content.parse()?))
827            },
828            "float" => {
829                input.parse::<syn::Ident>()?;
830                let content;
831                syn::parenthesized!(content in input);
832                Ok(SqlValue::FloatHint(content.parse()?))
833            },
834            "bool" => {
835                input.parse::<syn::Ident>()?;
836                let content;
837                syn::parenthesized!(content in input);
838                Ok(SqlValue::BoolHint(content.parse()?))
839            },
840            _ => Ok(SqlValue::Expr(input.parse()?)),
841        }
842    } else if input.peek(syn::Ident) {
843        let fork = input.fork();
844        let ident: syn::Ident = fork.parse()?;
845        match ident.to_string().as_str() {
846            "null" => {
847                input.parse::<syn::Ident>()?;
848                Ok(SqlValue::Null)
849            },
850            "true" => {
851                input.parse::<syn::Ident>()?;
852                Ok(SqlValue::Bool(true))
853            },
854            "false" => {
855                input.parse::<syn::Ident>()?;
856                Ok(SqlValue::Bool(false))
857            },
858            _ => Ok(SqlValue::Expr(input.parse()?)),
859        }
860    } else {
861        Err(syn::Error::new(
862            input.span(),
863            "Expected a value: string, number, boolean, null, array, or type hint (int(), str(), etc.)",
864        ))
865    }
866}
867
868impl Parse for SqlSort {
869    fn parse(input: ParseStream) -> Result<Self> {
870        let desc = if input.peek(Token![-]) {
871            input.parse::<Token![-]>()?;
872            true
873        } else {
874            false
875        };
876        let field: syn::Ident = input.parse()?;
877        Ok(SqlSort { field, desc })
878    }
879}
880
881fn sql_value_to_tokens(value: &SqlValue) -> TokenStream2 {
882    match value {
883        SqlValue::Null => quote! { ::mik_sql::Value::Null },
884        SqlValue::Bool(b) => quote! { ::mik_sql::Value::Bool(#b) },
885        SqlValue::Int(i) => quote! { ::mik_sql::Value::Int(#i as i64) },
886        SqlValue::Float(f) => quote! { ::mik_sql::Value::Float(#f as f64) },
887        SqlValue::String(s) => quote! { ::mik_sql::Value::String(#s.to_string()) },
888        SqlValue::Array(arr) => {
889            let elements: Vec<_> = arr.iter().map(sql_value_to_tokens).collect();
890            quote! { ::mik_sql::Value::Array(vec![#(#elements),*]) }
891        },
892        SqlValue::IntHint(e) => quote! { ::mik_sql::Value::Int(#e as i64) },
893        SqlValue::StrHint(e) | SqlValue::Expr(e) => {
894            quote! { ::mik_sql::Value::String((#e).to_string()) }
895        },
896        SqlValue::FloatHint(e) => quote! { ::mik_sql::Value::Float(#e as f64) },
897        SqlValue::BoolHint(e) => quote! { ::mik_sql::Value::Bool(#e) },
898    }
899}
900
901fn sql_operator_to_tokens(op: &SqlOperator) -> TokenStream2 {
902    match op {
903        SqlOperator::Eq => quote! { ::mik_sql::Operator::Eq },
904        SqlOperator::Ne => quote! { ::mik_sql::Operator::Ne },
905        SqlOperator::Gt => quote! { ::mik_sql::Operator::Gt },
906        SqlOperator::Gte => quote! { ::mik_sql::Operator::Gte },
907        SqlOperator::Lt => quote! { ::mik_sql::Operator::Lt },
908        SqlOperator::Lte => quote! { ::mik_sql::Operator::Lte },
909        SqlOperator::In => quote! { ::mik_sql::Operator::In },
910        SqlOperator::NotIn => quote! { ::mik_sql::Operator::NotIn },
911        SqlOperator::Like => quote! { ::mik_sql::Operator::Like },
912        SqlOperator::ILike => quote! { ::mik_sql::Operator::ILike },
913        SqlOperator::Regex => quote! { ::mik_sql::Operator::Regex },
914        SqlOperator::StartsWith => quote! { ::mik_sql::Operator::StartsWith },
915        SqlOperator::EndsWith => quote! { ::mik_sql::Operator::EndsWith },
916        SqlOperator::Contains => quote! { ::mik_sql::Operator::Contains },
917        SqlOperator::Between => quote! { ::mik_sql::Operator::Between },
918    }
919}
920
921/// Build a SELECT query using the query builder (CRUD: Read).
922#[proc_macro]
923#[allow(clippy::too_many_lines)] // Query building has many options to handle
924pub fn sql_read(input: TokenStream) -> TokenStream {
925    let SqlInput {
926        dialect,
927        table,
928        select_fields,
929        computed,
930        aggregates,
931        filter_expr,
932        group_by,
933        having,
934        sorts,
935        dynamic_sort,
936        allow_sort,
937        merge_filters,
938        allow_fields,
939        deny_ops,
940        max_depth,
941        page,
942        limit,
943        offset,
944        after,
945        before,
946    } = parse_macro_input!(input as SqlInput);
947
948    let (sorts, dynamic_sort) = if let Some(ref expr) = dynamic_sort {
949        if allow_sort.is_empty() {
950            if let syn::Expr::Path(syn::ExprPath { path, .. }) = expr {
951                if path.segments.len() == 1 && path.segments[0].arguments.is_empty() {
952                    let field_name = path.segments[0].ident.clone();
953                    let mut new_sorts = sorts;
954                    new_sorts.push(SqlSort {
955                        field: field_name,
956                        desc: false,
957                    });
958                    (new_sorts, None)
959                } else {
960                    (sorts, dynamic_sort)
961                }
962            } else {
963                (sorts, dynamic_sort)
964            }
965        } else {
966            (sorts, dynamic_sort)
967        }
968    } else {
969        (sorts, dynamic_sort)
970    };
971
972    let table_str = table.to_string();
973
974    let fields_chain = if select_fields.is_empty() {
975        quote! {}
976    } else {
977        let field_strs: Vec<String> = select_fields
978            .iter()
979            .map(std::string::ToString::to_string)
980            .collect();
981        quote! { .fields(&[#(#field_strs),*]) }
982    };
983
984    let computed_chain: Vec<TokenStream2> = computed
985        .iter()
986        .map(|c| {
987            let alias = c.alias.to_string();
988            let expr_sql = compute_expr_to_sql(&c.expr);
989            quote! { .computed(#alias, #expr_sql) }
990        })
991        .collect();
992
993    let aggregate_chain: Vec<TokenStream2> = aggregates
994        .iter()
995        .map(|agg| {
996            let agg_tokens = sql_aggregate_to_tokens(agg);
997            quote! { .aggregate(#agg_tokens) }
998        })
999        .collect();
1000
1001    let filter_chain = if let Some(expr) = filter_expr {
1002        let expr_tokens = sql_filter_expr_to_tokens(&expr);
1003        quote! { .filter_expr(#expr_tokens) }
1004    } else {
1005        quote! {}
1006    };
1007
1008    let group_by_chain = if group_by.is_empty() {
1009        quote! {}
1010    } else {
1011        let field_strs: Vec<String> = group_by
1012            .iter()
1013            .map(std::string::ToString::to_string)
1014            .collect();
1015        quote! { .group_by(&[#(#field_strs),*]) }
1016    };
1017
1018    let having_chain = if let Some(expr) = having {
1019        let expr_tokens = sql_filter_expr_to_tokens(&expr);
1020        quote! { .having(#expr_tokens) }
1021    } else {
1022        quote! {}
1023    };
1024
1025    let sort_chain: Vec<TokenStream2> = sorts
1026        .iter()
1027        .map(|s| {
1028            let field_str = s.field.to_string();
1029            let dir = if s.desc {
1030                quote! { ::mik_sql::SortDir::Desc }
1031            } else {
1032                quote! { ::mik_sql::SortDir::Asc }
1033            };
1034            quote! { .sort(#field_str, #dir) }
1035        })
1036        .collect();
1037
1038    let dynamic_sort_setup = if let Some(ref sort_expr) = dynamic_sort {
1039        let allow_strs: Vec<String> = allow_sort
1040            .iter()
1041            .map(std::string::ToString::to_string)
1042            .collect();
1043        if allow_strs.is_empty() {
1044            quote! {
1045                let __dynamic_sorts = ::mik_sql::SortField::parse_sort_string(
1046                    &#sort_expr,
1047                    &[]
1048                ).map_err(|e| e)?;
1049            }
1050        } else {
1051            quote! {
1052                let __dynamic_sorts = ::mik_sql::SortField::parse_sort_string(
1053                    &#sort_expr,
1054                    &[#(#allow_strs),*]
1055                ).map_err(|e| e)?;
1056            }
1057        }
1058    } else {
1059        quote! {}
1060    };
1061
1062    let dynamic_sort_chain = if dynamic_sort.is_some() {
1063        quote! { .sorts(&__dynamic_sorts) }
1064    } else {
1065        quote! {}
1066    };
1067
1068    let (merge_setup, merge_chain) = if let Some(ref merge_expr) = merge_filters {
1069        let allow_strs: Vec<String> = allow_fields
1070            .iter()
1071            .map(std::string::ToString::to_string)
1072            .collect();
1073        let deny_op_tokens: Vec<TokenStream2> = deny_ops
1074            .iter()
1075            .map(|op| {
1076                let op_str = op.to_string();
1077                match op_str.as_str() {
1078                    "ne" => quote! { ::mik_sql::Operator::Ne },
1079                    "gt" => quote! { ::mik_sql::Operator::Gt },
1080                    "gte" => quote! { ::mik_sql::Operator::Gte },
1081                    "lt" => quote! { ::mik_sql::Operator::Lt },
1082                    "lte" => quote! { ::mik_sql::Operator::Lte },
1083                    "in" => quote! { ::mik_sql::Operator::In },
1084                    "nin" | "notIn" => quote! { ::mik_sql::Operator::NotIn },
1085                    "like" => quote! { ::mik_sql::Operator::Like },
1086                    "ilike" => quote! { ::mik_sql::Operator::ILike },
1087                    "regex" => quote! { ::mik_sql::Operator::Regex },
1088                    "startsWith" | "starts_with" => quote! { ::mik_sql::Operator::StartsWith },
1089                    "endsWith" | "ends_with" => quote! { ::mik_sql::Operator::EndsWith },
1090                    "contains" => quote! { ::mik_sql::Operator::Contains },
1091                    "between" => quote! { ::mik_sql::Operator::Between },
1092                    // "eq" and unknown operators default to Eq
1093                    _ => quote! { ::mik_sql::Operator::Eq },
1094                }
1095            })
1096            .collect();
1097
1098        let max_depth_val = max_depth
1099            .map(|d| quote! { #d as usize })
1100            .unwrap_or(quote! { 5 });
1101
1102        let setup = quote! {
1103            let __validator = ::mik_sql::FilterValidator::new()
1104                .allow_fields(&[#(#allow_strs),*])
1105                .deny_operators(&[#(#deny_op_tokens),*])
1106                .max_depth(#max_depth_val);
1107
1108            for __user_filter in &#merge_expr {
1109                __validator.validate(__user_filter).map_err(|e| e.to_string())?;
1110            }
1111        };
1112
1113        let chain = quote! {
1114            for __f in &#merge_expr {
1115                __builder = __builder.filter(__f.field.clone(), __f.op, __f.value.clone());
1116            }
1117        };
1118
1119        (setup, chain)
1120    } else {
1121        (quote! {}, quote! {})
1122    };
1123
1124    let needs_result = dynamic_sort.is_some() || merge_filters.is_some();
1125
1126    let pagination_chain = match (page, limit, offset) {
1127        (Some(p), Some(l), None) => quote! { .page(#p as u32, #l as u32) },
1128        (None, Some(l), Some(o)) => quote! { .limit_offset(#l as u32, #o as u32) },
1129        (None, Some(l), None) => quote! { .limit_offset(#l as u32, 0) },
1130        _ => quote! {},
1131    };
1132
1133    let after_chain = if let Some(ref expr) = after {
1134        quote! { .after_cursor(#expr) }
1135    } else {
1136        quote! {}
1137    };
1138
1139    let before_chain = if let Some(ref expr) = before {
1140        quote! { .before_cursor(#expr) }
1141    } else {
1142        quote! {}
1143    };
1144
1145    let builder_constructor = dialect.builder_tokens(&table_str);
1146
1147    let tokens = if needs_result {
1148        quote! {
1149            (|| -> ::std::result::Result<(String, Vec<::mik_sql::Value>), String> {
1150                #dynamic_sort_setup
1151                #merge_setup
1152
1153                let mut __builder = #builder_constructor
1154                    #fields_chain
1155                    #(#computed_chain)*
1156                    #(#aggregate_chain)*
1157                    #filter_chain;
1158
1159                #merge_chain
1160
1161                let __sql_result = __builder
1162                    #group_by_chain
1163                    #having_chain
1164                    #(#sort_chain)*
1165                    #dynamic_sort_chain
1166                    #after_chain
1167                    #before_chain
1168                    #pagination_chain
1169                    .build();
1170
1171                Ok((__sql_result.sql, __sql_result.params))
1172            })()
1173        }
1174    } else {
1175        quote! {
1176            {
1177                let __sql_result = #builder_constructor
1178                    #fields_chain
1179                    #(#computed_chain)*
1180                    #(#aggregate_chain)*
1181                    #filter_chain
1182                    #group_by_chain
1183                    #having_chain
1184                    #(#sort_chain)*
1185                    #after_chain
1186                    #before_chain
1187                    #pagination_chain
1188                    .build();
1189                (__sql_result.sql, __sql_result.params)
1190            }
1191        }
1192    };
1193
1194    TokenStream::from(tokens)
1195}
1196
1197fn sql_aggregate_to_tokens(agg: &SqlAggregate) -> TokenStream2 {
1198    let field_str = agg.field.as_ref().map(std::string::ToString::to_string);
1199    let alias_str = agg.alias.as_ref().map(std::string::ToString::to_string);
1200
1201    let base = match (&agg.func, &field_str) {
1202        (SqlAggregateFunc::Count, Some(f)) => quote! { ::mik_sql::Aggregate::count_field(#f) },
1203        (SqlAggregateFunc::CountDistinct, Some(f)) => {
1204            quote! { ::mik_sql::Aggregate::count_distinct(#f) }
1205        },
1206        (SqlAggregateFunc::Sum, Some(f)) => quote! { ::mik_sql::Aggregate::sum(#f) },
1207        (SqlAggregateFunc::Avg, Some(f)) => quote! { ::mik_sql::Aggregate::avg(#f) },
1208        (SqlAggregateFunc::Min, Some(f)) => quote! { ::mik_sql::Aggregate::min(#f) },
1209        (SqlAggregateFunc::Max, Some(f)) => quote! { ::mik_sql::Aggregate::max(#f) },
1210        // Count without field, or missing required field - default to count()
1211        _ => quote! { ::mik_sql::Aggregate::count() },
1212    };
1213
1214    if let Some(alias) = alias_str {
1215        quote! { #base.as_alias(#alias) }
1216    } else {
1217        base
1218    }
1219}
1220
1221fn sql_filter_expr_to_tokens(expr: &SqlFilterExpr) -> TokenStream2 {
1222    match expr {
1223        SqlFilterExpr::Simple(filter) => {
1224            let field_str = filter.field.to_string();
1225            let op = sql_operator_to_tokens(&filter.op);
1226            let value = sql_value_to_tokens(&filter.value);
1227            quote! { ::mik_sql::simple(#field_str, #op, #value) }
1228        },
1229        SqlFilterExpr::Compound { op, filters } => {
1230            let filter_tokens: Vec<TokenStream2> =
1231                filters.iter().map(sql_filter_expr_to_tokens).collect();
1232
1233            match op {
1234                SqlLogicalOp::And => quote! { ::mik_sql::and(vec![#(#filter_tokens),*]) },
1235                SqlLogicalOp::Or => quote! { ::mik_sql::or(vec![#(#filter_tokens),*]) },
1236                SqlLogicalOp::Not => {
1237                    let inner = filter_tokens.into_iter().next().unwrap_or_default();
1238                    quote! { ::mik_sql::not(#inner) }
1239                },
1240            }
1241        },
1242    }
1243}
1244
1245fn compute_expr_to_sql(expr: &SqlComputeExpr) -> String {
1246    match expr {
1247        SqlComputeExpr::Column(ident) => ident.to_string(),
1248        SqlComputeExpr::LitStr(lit) => {
1249            let s = lit.value();
1250            format!("'{}'", s.replace('\'', "''"))
1251        },
1252        SqlComputeExpr::LitInt(lit) => lit.to_string(),
1253        SqlComputeExpr::LitFloat(lit) => lit.to_string(),
1254        SqlComputeExpr::BinOp { left, op, right } => {
1255            let left_sql = compute_expr_to_sql(left);
1256            let right_sql = compute_expr_to_sql(right);
1257            let op_str = match op {
1258                SqlComputeBinOp::Add => "+",
1259                SqlComputeBinOp::Sub => "-",
1260                SqlComputeBinOp::Mul => "*",
1261                SqlComputeBinOp::Div => "/",
1262            };
1263            format!("{left_sql} {op_str} {right_sql}")
1264        },
1265        SqlComputeExpr::Func { name, args } => {
1266            let args_sql: Vec<String> = args.iter().map(compute_expr_to_sql).collect();
1267            match name {
1268                SqlComputeFunc::Concat => args_sql.join(" || "),
1269                SqlComputeFunc::Coalesce => format!("COALESCE({})", args_sql.join(", ")),
1270                SqlComputeFunc::Upper => format!("UPPER({})", args_sql.join(", ")),
1271                SqlComputeFunc::Lower => format!("LOWER({})", args_sql.join(", ")),
1272                SqlComputeFunc::Round => format!("ROUND({})", args_sql.join(", ")),
1273                SqlComputeFunc::Abs => format!("ABS({})", args_sql.join(", ")),
1274                SqlComputeFunc::Length => format!("LENGTH({})", args_sql.join(", ")),
1275            }
1276        },
1277        SqlComputeExpr::Paren(inner) => format!("({})", compute_expr_to_sql(inner)),
1278    }
1279}
1280
1281/// Collect field values from a list for batched loading.
1282#[proc_macro]
1283pub fn ids(input: TokenStream) -> TokenStream {
1284    let input = parse_macro_input!(input as IdsInput);
1285
1286    let list = &input.list;
1287    let field = &input.field;
1288
1289    let tokens = quote! {
1290        #list.iter().map(|__item| __item.#field.clone()).collect::<Vec<_>>()
1291    };
1292
1293    TokenStream::from(tokens)
1294}
1295
1296struct IdsInput {
1297    list: Expr,
1298    field: syn::Ident,
1299}
1300
1301impl Parse for IdsInput {
1302    fn parse(input: ParseStream) -> Result<Self> {
1303        let list: Expr = input.parse()?;
1304
1305        let field = if input.peek(Token![,]) {
1306            input.parse::<Token![,]>()?;
1307            input.parse()?
1308        } else {
1309            syn::Ident::new("id", proc_macro2::Span::call_site())
1310        };
1311
1312        Ok(IdsInput { list, field })
1313    }
1314}
1315
1316/// Build an INSERT query using object-like syntax.
1317#[proc_macro]
1318pub fn sql_create(input: TokenStream) -> TokenStream {
1319    let InsertInput {
1320        dialect,
1321        table,
1322        columns,
1323        returning,
1324    } = parse_macro_input!(input as InsertInput);
1325
1326    let table_str = table.to_string();
1327    let builder_constructor = dialect.insert_tokens(&table_str);
1328
1329    let col_strs: Vec<String> = columns.iter().map(|(c, _)| c.to_string()).collect();
1330
1331    let value_tokens: Vec<TokenStream2> = columns
1332        .iter()
1333        .map(|(_, v)| sql_value_to_tokens(v))
1334        .collect();
1335
1336    let returning_chain = if returning.is_empty() {
1337        quote! {}
1338    } else {
1339        let ret_strs: Vec<String> = returning
1340            .iter()
1341            .map(std::string::ToString::to_string)
1342            .collect();
1343        quote! { .returning(&[#(#ret_strs),*]) }
1344    };
1345
1346    let tokens = quote! {
1347        {
1348            let __result = #builder_constructor
1349                .columns(&[#(#col_strs),*])
1350                .values(vec![#(#value_tokens),*])
1351                #returning_chain
1352                .build();
1353            (__result.sql, __result.params)
1354        }
1355    };
1356
1357    TokenStream::from(tokens)
1358}
1359
1360struct InsertInput {
1361    dialect: SqlDialect,
1362    table: syn::Ident,
1363    columns: Vec<(syn::Ident, SqlValue)>,
1364    returning: Vec<syn::Ident>,
1365}
1366
1367impl Parse for InsertInput {
1368    fn parse(input: ParseStream) -> Result<Self> {
1369        let dialect = parse_optional_dialect(input)?;
1370        let table: syn::Ident = input.parse()?;
1371
1372        let content;
1373        braced!(content in input);
1374
1375        let mut columns = Vec::new();
1376        let mut returning = Vec::new();
1377
1378        while !content.is_empty() {
1379            let key: syn::Ident = content.parse()?;
1380            content.parse::<Token![:]>()?;
1381
1382            if key.to_string().as_str() == "returning" {
1383                let ret_content;
1384                bracketed!(ret_content in content);
1385                let fields: Punctuated<syn::Ident, Token![,]> =
1386                    ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1387                returning = fields.into_iter().collect();
1388            } else {
1389                let value = parse_sql_value(&content)?;
1390                columns.push((key, value));
1391            }
1392
1393            if content.peek(Token![,]) {
1394                content.parse::<Token![,]>()?;
1395            }
1396        }
1397
1398        Ok(InsertInput {
1399            dialect,
1400            table,
1401            columns,
1402            returning,
1403        })
1404    }
1405}
1406
1407/// Build an UPDATE query using object-like syntax.
1408#[proc_macro]
1409pub fn sql_update(input: TokenStream) -> TokenStream {
1410    let UpdateInput {
1411        dialect,
1412        table,
1413        sets,
1414        where_expr,
1415        returning,
1416    } = parse_macro_input!(input as UpdateInput);
1417
1418    let table_str = table.to_string();
1419    let builder_constructor = dialect.update_tokens(&table_str);
1420
1421    let set_chain: Vec<TokenStream2> = sets
1422        .iter()
1423        .map(|(col, val)| {
1424            let col_str = col.to_string();
1425            let val_tokens = sql_value_to_tokens(val);
1426            quote! { .set(#col_str, #val_tokens) }
1427        })
1428        .collect();
1429
1430    let filter_chain = if let Some(expr) = where_expr {
1431        let expr_tokens = sql_filter_expr_to_tokens(&expr);
1432        quote! { .filter_expr(#expr_tokens) }
1433    } else {
1434        quote! {}
1435    };
1436
1437    let returning_chain = if returning.is_empty() {
1438        quote! {}
1439    } else {
1440        let ret_strs: Vec<String> = returning
1441            .iter()
1442            .map(std::string::ToString::to_string)
1443            .collect();
1444        quote! { .returning(&[#(#ret_strs),*]) }
1445    };
1446
1447    let tokens = quote! {
1448        {
1449            let __result = #builder_constructor
1450                #(#set_chain)*
1451                #filter_chain
1452                #returning_chain
1453                .build();
1454            (__result.sql, __result.params)
1455        }
1456    };
1457
1458    TokenStream::from(tokens)
1459}
1460
1461struct UpdateInput {
1462    dialect: SqlDialect,
1463    table: syn::Ident,
1464    sets: Vec<(syn::Ident, SqlValue)>,
1465    where_expr: Option<SqlFilterExpr>,
1466    returning: Vec<syn::Ident>,
1467}
1468
1469impl Parse for UpdateInput {
1470    fn parse(input: ParseStream) -> Result<Self> {
1471        let dialect = parse_optional_dialect(input)?;
1472        let table: syn::Ident = input.parse()?;
1473
1474        let content;
1475        braced!(content in input);
1476
1477        let mut sets = Vec::new();
1478        let mut where_expr = None;
1479        let mut returning = Vec::new();
1480
1481        while !content.is_empty() {
1482            let key: syn::Ident = content.parse()?;
1483            content.parse::<Token![:]>()?;
1484
1485            match key.to_string().as_str() {
1486                "set" => {
1487                    let set_content;
1488                    braced!(set_content in content);
1489                    sets = parse_column_values(&set_content)?;
1490                },
1491                "where" | "filter" => {
1492                    let where_content;
1493                    braced!(where_content in content);
1494                    where_expr = Some(parse_filter_block(&where_content)?);
1495                },
1496                "returning" => {
1497                    let ret_content;
1498                    bracketed!(ret_content in content);
1499                    let fields: Punctuated<syn::Ident, Token![,]> =
1500                        ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1501                    returning = fields.into_iter().collect();
1502                },
1503                _ => {
1504                    return Err(syn::Error::new(
1505                        key.span(),
1506                        format!("Unknown option '{key}'. Expected 'set', 'where', or 'returning'"),
1507                    ));
1508                },
1509            }
1510
1511            if content.peek(Token![,]) {
1512                content.parse::<Token![,]>()?;
1513            }
1514        }
1515
1516        Ok(UpdateInput {
1517            dialect,
1518            table,
1519            sets,
1520            where_expr,
1521            returning,
1522        })
1523    }
1524}
1525
1526/// Build a DELETE query using object-like syntax.
1527#[proc_macro]
1528pub fn sql_delete(input: TokenStream) -> TokenStream {
1529    let DeleteInput {
1530        dialect,
1531        table,
1532        where_expr,
1533        returning,
1534    } = parse_macro_input!(input as DeleteInput);
1535
1536    let table_str = table.to_string();
1537    let builder_constructor = dialect.delete_tokens(&table_str);
1538
1539    let filter_chain = if let Some(expr) = where_expr {
1540        let expr_tokens = sql_filter_expr_to_tokens(&expr);
1541        quote! { .filter_expr(#expr_tokens) }
1542    } else {
1543        quote! {}
1544    };
1545
1546    let returning_chain = if returning.is_empty() {
1547        quote! {}
1548    } else {
1549        let ret_strs: Vec<String> = returning
1550            .iter()
1551            .map(std::string::ToString::to_string)
1552            .collect();
1553        quote! { .returning(&[#(#ret_strs),*]) }
1554    };
1555
1556    let tokens = quote! {
1557        {
1558            let __result = #builder_constructor
1559                #filter_chain
1560                #returning_chain
1561                .build();
1562            (__result.sql, __result.params)
1563        }
1564    };
1565
1566    TokenStream::from(tokens)
1567}
1568
1569struct DeleteInput {
1570    dialect: SqlDialect,
1571    table: syn::Ident,
1572    where_expr: Option<SqlFilterExpr>,
1573    returning: Vec<syn::Ident>,
1574}
1575
1576impl Parse for DeleteInput {
1577    fn parse(input: ParseStream) -> Result<Self> {
1578        let dialect = parse_optional_dialect(input)?;
1579        let table: syn::Ident = input.parse()?;
1580
1581        let content;
1582        braced!(content in input);
1583
1584        let mut where_expr = None;
1585        let mut returning = Vec::new();
1586
1587        while !content.is_empty() {
1588            let key: syn::Ident = content.parse()?;
1589            content.parse::<Token![:]>()?;
1590
1591            match key.to_string().as_str() {
1592                "where" | "filter" => {
1593                    let where_content;
1594                    braced!(where_content in content);
1595                    where_expr = Some(parse_filter_block(&where_content)?);
1596                },
1597                "returning" => {
1598                    let ret_content;
1599                    bracketed!(ret_content in content);
1600                    let fields: Punctuated<syn::Ident, Token![,]> =
1601                        ret_content.parse_terminated(syn::Ident::parse, Token![,])?;
1602                    returning = fields.into_iter().collect();
1603                },
1604                _ => {
1605                    return Err(syn::Error::new(
1606                        key.span(),
1607                        format!("Unknown option '{key}'. Expected 'where' or 'returning'"),
1608                    ));
1609                },
1610            }
1611
1612            if content.peek(Token![,]) {
1613                content.parse::<Token![,]>()?;
1614            }
1615        }
1616
1617        Ok(DeleteInput {
1618            dialect,
1619            table,
1620            where_expr,
1621            returning,
1622        })
1623    }
1624}
1625
1626fn parse_column_values(input: ParseStream) -> Result<Vec<(syn::Ident, SqlValue)>> {
1627    let mut result = Vec::new();
1628
1629    while !input.is_empty() {
1630        let key: syn::Ident = input.parse()?;
1631        input.parse::<Token![:]>()?;
1632        let value = parse_sql_value(input)?;
1633        result.push((key, value));
1634
1635        if input.peek(Token![,]) {
1636            input.parse::<Token![,]>()?;
1637        }
1638    }
1639
1640    Ok(result)
1641}