Skip to main content

proof_of_sql_planner/
expr.rs

1use super::{
2    column_to_column_ref, placeholder_to_placeholder_expr, scalar_value_to_literal_value,
3    PlannerError, PlannerResult,
4};
5use datafusion::logical_expr::{
6    expr::{Alias, Cast, Placeholder},
7    BinaryExpr, Expr, Operator,
8};
9use indexmap::IndexSet;
10use proof_of_sql::{
11    base::database::ColumnType,
12    sql::{proof_exprs::DynProofExpr, scale_cast_binary_op},
13};
14use sqlparser::ast::Ident;
15
16/// Recursively extract all column identifiers referenced in an expression
17pub(crate) fn get_column_idents_from_expr(expr: &Expr) -> IndexSet<Ident> {
18    match expr {
19        Expr::Column(col) => {
20            let mut set = IndexSet::new();
21            set.insert(col.name.as_str().into());
22            set
23        }
24        Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
25            let mut left_idents = get_column_idents_from_expr(left);
26            left_idents.extend(get_column_idents_from_expr(right));
27            left_idents
28        }
29        Expr::Not(inner) => get_column_idents_from_expr(inner),
30        Expr::Alias(Alias { expr, .. }) | Expr::Cast(Cast { expr, .. }) => {
31            get_column_idents_from_expr(expr)
32        }
33        Expr::AggregateFunction(agg) => agg
34            .args
35            .iter()
36            .flat_map(get_column_idents_from_expr)
37            .collect(),
38        _ => IndexSet::new(),
39    }
40}
41
42/// Convert a [`BinaryExpr`] to [`DynProofExpr`]
43#[expect(
44    clippy::missing_panics_doc,
45    reason = "Output of comparisons is always boolean"
46)]
47fn binary_expr_to_proof_expr(
48    left: &Expr,
49    right: &Expr,
50    op: Operator,
51    schema: &[(Ident, ColumnType)],
52) -> PlannerResult<DynProofExpr> {
53    let left_proof_expr = expr_to_proof_expr(left, schema)?;
54    let right_proof_expr = expr_to_proof_expr(right, schema)?;
55
56    let (left_proof_expr, right_proof_expr) = match op {
57        Operator::Eq
58        | Operator::NotEq
59        | Operator::Lt
60        | Operator::Gt
61        | Operator::LtEq
62        | Operator::GtEq
63        | Operator::Plus
64        | Operator::Minus => scale_cast_binary_op(left_proof_expr, right_proof_expr)?,
65        _ => (left_proof_expr, right_proof_expr),
66    };
67
68    match op {
69        Operator::And => Ok(DynProofExpr::try_new_and(
70            left_proof_expr,
71            right_proof_expr,
72        )?),
73        Operator::Or => Ok(DynProofExpr::try_new_or(left_proof_expr, right_proof_expr)?),
74        Operator::Multiply => Ok(DynProofExpr::try_new_multiply(
75            left_proof_expr,
76            right_proof_expr,
77        )?),
78        Operator::Eq => Ok(DynProofExpr::try_new_equals(
79            left_proof_expr,
80            right_proof_expr,
81        )?),
82        Operator::NotEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_equals(
83            left_proof_expr,
84            right_proof_expr,
85        )?)
86        .expect("An equality expression must have a boolean data type...")),
87        Operator::Lt => Ok(DynProofExpr::try_new_inequality(
88            left_proof_expr,
89            right_proof_expr,
90            true,
91        )?),
92        Operator::Gt => Ok(DynProofExpr::try_new_inequality(
93            left_proof_expr,
94            right_proof_expr,
95            false,
96        )?),
97        Operator::LtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
98            left_proof_expr,
99            right_proof_expr,
100            false,
101        )?)
102        .expect("An inequality expression must have a boolean data type...")),
103        Operator::GtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
104            left_proof_expr,
105            right_proof_expr,
106            true,
107        )?)
108        .expect("An inequality expression must have a boolean data type...")),
109        Operator::Plus => Ok(DynProofExpr::try_new_add(
110            left_proof_expr,
111            right_proof_expr,
112        )?),
113        Operator::Minus => Ok(DynProofExpr::try_new_subtract(
114            left_proof_expr,
115            right_proof_expr,
116        )?),
117        // Any other operator is unsupported
118        _ => Err(PlannerError::UnsupportedBinaryOperator { op }),
119    }
120}
121
122/// Convert an [`datafusion::expr::Expr`] to [`DynProofExpr`]
123///
124/// # Panics
125/// The function should not panic if Proof of SQL is working correctly
126pub fn expr_to_proof_expr(
127    expr: &Expr,
128    schema: &[(Ident, ColumnType)],
129) -> PlannerResult<DynProofExpr> {
130    match expr {
131        Expr::Alias(Alias { expr, .. }) => expr_to_proof_expr(expr, schema),
132        Expr::Column(col) => Ok(DynProofExpr::new_column(column_to_column_ref(col, schema)?)),
133        Expr::Placeholder(placeholder) => placeholder_to_placeholder_expr(placeholder),
134        Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
135            binary_expr_to_proof_expr(left, right, *op, schema)
136        }
137        Expr::Literal(val) => Ok(DynProofExpr::new_literal(scalar_value_to_literal_value(
138            val.clone(),
139        )?)),
140        Expr::Not(expr) => {
141            let proof_expr = expr_to_proof_expr(expr, schema)?;
142            Ok(DynProofExpr::try_new_not(proof_expr)?)
143        }
144        Expr::Cast(cast) => {
145            match &*cast.expr {
146                // handle cases such as `$1::int`
147                Expr::Placeholder(placeholder) if placeholder.data_type.is_none() => {
148                    let typed_placeholder =
149                        Placeholder::new(placeholder.id.clone(), Some(cast.data_type.clone()));
150                    placeholder_to_placeholder_expr(&typed_placeholder)
151                }
152                _ => {
153                    let from_expr = expr_to_proof_expr(&cast.expr, schema)?;
154                    let to_type = cast.data_type.clone().try_into().map_err(|_| {
155                        PlannerError::UnsupportedDataType {
156                            data_type: cast.data_type.clone(),
157                        }
158                    })?;
159                    Ok(
160                        DynProofExpr::try_new_cast(from_expr.clone(), to_type).map_or_else(
161                            |_| DynProofExpr::try_new_scaling_cast(from_expr, to_type),
162                            Ok,
163                        )?,
164                    )
165                }
166            }
167        }
168        _ => Err(PlannerError::UnsupportedLogicalExpression {
169            expr: Box::new(expr.clone()),
170        }),
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::df_util::*;
178    use arrow::datatypes::DataType;
179    use core::ops::{Add, Mul, Sub};
180    use datafusion::{
181        catalog::TableReference,
182        common::{Column, ScalarValue},
183        logical_expr::{expr::Placeholder, Cast},
184    };
185    use proof_of_sql::base::{
186        database::{ColumnRef, ColumnType, LiteralValue, TableRef},
187        math::decimal::Precision,
188    };
189
190    #[expect(non_snake_case)]
191    fn COLUMN_INT() -> DynProofExpr {
192        DynProofExpr::new_column(ColumnRef::new(
193            TableRef::from_names(Some("namespace"), "table_name"),
194            "column".into(),
195            ColumnType::Int,
196        ))
197    }
198
199    #[expect(non_snake_case)]
200    fn COLUMN1_SMALLINT() -> DynProofExpr {
201        DynProofExpr::new_column(ColumnRef::new(
202            TableRef::from_names(Some("namespace"), "table_name"),
203            "column1".into(),
204            ColumnType::SmallInt,
205        ))
206    }
207
208    #[expect(non_snake_case)]
209    fn COLUMN2_BIGINT() -> DynProofExpr {
210        DynProofExpr::new_column(ColumnRef::new(
211            TableRef::from_names(Some("namespace"), "table_name"),
212            "column2".into(),
213            ColumnType::BigInt,
214        ))
215    }
216
217    #[expect(non_snake_case)]
218    fn COLUMN1_BOOLEAN() -> DynProofExpr {
219        DynProofExpr::new_column(ColumnRef::new(
220            TableRef::from_names(Some("namespace"), "table_name"),
221            "column1".into(),
222            ColumnType::Boolean,
223        ))
224    }
225
226    #[expect(non_snake_case)]
227    fn COLUMN2_BOOLEAN() -> DynProofExpr {
228        DynProofExpr::new_column(ColumnRef::new(
229            TableRef::from_names(Some("namespace"), "table_name"),
230            "column2".into(),
231            ColumnType::Boolean,
232        ))
233    }
234
235    #[expect(non_snake_case)]
236    fn COLUMN3_DECIMAL_75_5() -> DynProofExpr {
237        DynProofExpr::new_column(ColumnRef::new(
238            TableRef::from_names(Some("namespace"), "table_name"),
239            "column3".into(),
240            ColumnType::Decimal75(
241                Precision::new(75).expect("Precision is definitely valid"),
242                5,
243            ),
244        ))
245    }
246
247    #[expect(non_snake_case)]
248    fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
249        DynProofExpr::new_column(ColumnRef::new(
250            TableRef::from_names(Some("namespace"), "table_name"),
251            "column2".into(),
252            ColumnType::Decimal75(
253                Precision::new(25).expect("Precision is definitely valid"),
254                5,
255            ),
256        ))
257    }
258
259    // Alias
260    #[test]
261    fn we_can_convert_alias_to_proof_expr() {
262        // Column
263        let expr = df_column("namespace.table_name", "column").alias("alias");
264        let schema = vec![("column".into(), ColumnType::Int)];
265        assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
266    }
267
268    // Column
269    #[test]
270    fn we_can_convert_column_expr_to_proof_expr() {
271        // Column
272        let expr = df_column("namespace.table_name", "column");
273        let schema = vec![("column".into(), ColumnType::Int)];
274        assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
275    }
276
277    // BinaryExpr
278    #[test]
279    fn we_can_convert_comparison_binary_expr_to_proof_expr() {
280        let schema = vec![
281            ("column1".into(), ColumnType::SmallInt),
282            ("column2".into(), ColumnType::BigInt),
283        ];
284
285        // Eq
286        let expr = df_column("namespace.table_name", "column1")
287            .eq(df_column("namespace.table_name", "column2"));
288        assert_eq!(
289            expr_to_proof_expr(&expr, &schema).unwrap(),
290            DynProofExpr::try_new_equals(COLUMN1_SMALLINT(), COLUMN2_BIGINT()).unwrap()
291        );
292
293        // Lt
294        let expr = df_column("namespace.table_name", "column1")
295            .lt(df_column("namespace.table_name", "column2"));
296        assert_eq!(
297            expr_to_proof_expr(&expr, &schema).unwrap(),
298            DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true).unwrap()
299        );
300
301        // Gt
302        let expr = df_column("namespace.table_name", "column1")
303            .gt(df_column("namespace.table_name", "column2"));
304        assert_eq!(
305            expr_to_proof_expr(&expr, &schema).unwrap(),
306            DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false).unwrap()
307        );
308
309        // LtEq
310        let expr = df_column("namespace.table_name", "column1")
311            .lt_eq(df_column("namespace.table_name", "column2"));
312        assert_eq!(
313            expr_to_proof_expr(&expr, &schema).unwrap(),
314            DynProofExpr::try_new_not(
315                DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false)
316                    .unwrap()
317            )
318            .unwrap()
319        );
320
321        // GtEq
322        let expr = df_column("namespace.table_name", "column1")
323            .gt_eq(df_column("namespace.table_name", "column2"));
324        assert_eq!(
325            expr_to_proof_expr(&expr, &schema).unwrap(),
326            DynProofExpr::try_new_not(
327                DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true)
328                    .unwrap()
329            )
330            .unwrap()
331        );
332    }
333
334    #[expect(clippy::too_many_lines)]
335    #[test]
336    fn we_can_convert_comparison_binary_expr_to_proof_expr_with_scale_cast() {
337        let schema = vec![
338            ("column1".into(), ColumnType::SmallInt),
339            (
340                "column2".into(),
341                ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
342            ),
343            (
344                "column3".into(),
345                ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
346            ),
347        ];
348
349        // Eq
350        let expr = df_column("namespace.table_name", "column1")
351            .eq(df_column("namespace.table_name", "column3"));
352        assert_eq!(
353            expr_to_proof_expr(&expr, &schema).unwrap(),
354            DynProofExpr::try_new_equals(
355                DynProofExpr::try_new_scaling_cast(
356                    COLUMN1_SMALLINT(),
357                    ColumnType::Decimal75(
358                        Precision::new(10).expect("Precision is definitely valid"),
359                        5
360                    )
361                )
362                .unwrap(),
363                COLUMN3_DECIMAL_75_5()
364            )
365            .unwrap()
366        );
367
368        // Lt
369        let expr = df_column("namespace.table_name", "column1")
370            .lt(df_column("namespace.table_name", "column2"));
371        assert_eq!(
372            expr_to_proof_expr(&expr, &schema).unwrap(),
373            DynProofExpr::try_new_inequality(
374                DynProofExpr::try_new_scaling_cast(
375                    COLUMN1_SMALLINT(),
376                    ColumnType::Decimal75(
377                        Precision::new(10).expect("Precision is definitely valid"),
378                        5
379                    )
380                )
381                .unwrap(),
382                COLUMN2_DECIMAL_25_5(),
383                true
384            )
385            .unwrap()
386        );
387
388        // Gt
389        let expr = df_column("namespace.table_name", "column1")
390            .gt(df_column("namespace.table_name", "column2"));
391        assert_eq!(
392            expr_to_proof_expr(&expr, &schema).unwrap(),
393            DynProofExpr::try_new_inequality(
394                DynProofExpr::try_new_scaling_cast(
395                    COLUMN1_SMALLINT(),
396                    ColumnType::Decimal75(
397                        Precision::new(10).expect("Precision is definitely valid"),
398                        5
399                    )
400                )
401                .unwrap(),
402                COLUMN2_DECIMAL_25_5(),
403                false
404            )
405            .unwrap()
406        );
407
408        // LtEq
409        let expr = df_column("namespace.table_name", "column1")
410            .lt_eq(df_column("namespace.table_name", "column2"));
411        assert_eq!(
412            expr_to_proof_expr(&expr, &schema).unwrap(),
413            DynProofExpr::try_new_not(
414                DynProofExpr::try_new_inequality(
415                    DynProofExpr::try_new_scaling_cast(
416                        COLUMN1_SMALLINT(),
417                        ColumnType::Decimal75(
418                            Precision::new(10).expect("Precision is definitely valid"),
419                            5
420                        )
421                    )
422                    .unwrap(),
423                    COLUMN2_DECIMAL_25_5(),
424                    false
425                )
426                .unwrap()
427            )
428            .unwrap()
429        );
430
431        // GtEq
432        let expr = df_column("namespace.table_name", "column1")
433            .gt_eq(df_column("namespace.table_name", "column2"));
434        assert_eq!(
435            expr_to_proof_expr(&expr, &schema).unwrap(),
436            DynProofExpr::try_new_not(
437                DynProofExpr::try_new_inequality(
438                    DynProofExpr::try_new_scaling_cast(
439                        COLUMN1_SMALLINT(),
440                        ColumnType::Decimal75(
441                            Precision::new(10).expect("Precision is definitely valid"),
442                            5
443                        )
444                    )
445                    .unwrap(),
446                    COLUMN2_DECIMAL_25_5(),
447                    true
448                )
449                .unwrap()
450            )
451            .unwrap()
452        );
453    }
454
455    #[test]
456    fn we_can_convert_arithmetic_binary_expr_to_proof_expr() {
457        let schema = vec![
458            ("column1".into(), ColumnType::SmallInt),
459            ("column2".into(), ColumnType::BigInt),
460        ];
461
462        // Plus
463        let expr = Expr::BinaryExpr(BinaryExpr {
464            left: Box::new(df_column("namespace.table_name", "column1")),
465            right: Box::new(df_column("namespace.table_name", "column2")),
466            op: Operator::Plus,
467        });
468        assert_eq!(
469            expr_to_proof_expr(&expr, &schema).unwrap(),
470            DynProofExpr::try_new_add(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
471        );
472
473        // Minus
474        let expr = Expr::BinaryExpr(BinaryExpr {
475            left: Box::new(df_column("namespace.table_name", "column1")),
476            right: Box::new(df_column("namespace.table_name", "column2")),
477            op: Operator::Minus,
478        });
479        assert_eq!(
480            expr_to_proof_expr(&expr, &schema).unwrap(),
481            DynProofExpr::try_new_subtract(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
482        );
483
484        // Multiply
485        let expr = Expr::BinaryExpr(BinaryExpr {
486            left: Box::new(df_column("namespace.table_name", "column1")),
487            right: Box::new(df_column("namespace.table_name", "column2")),
488            op: Operator::Multiply,
489        });
490        assert_eq!(
491            expr_to_proof_expr(&expr, &schema).unwrap(),
492            DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
493        );
494    }
495
496    #[test]
497    fn we_can_convert_arithmetic_binary_expr_to_proof_expr_with_scale_cast() {
498        let schema = vec![
499            ("column1".into(), ColumnType::SmallInt),
500            (
501                "column2".into(),
502                ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
503            ),
504            (
505                "column3".into(),
506                ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
507            ),
508        ];
509
510        // Add
511        let expr = df_column("namespace.table_name", "column1")
512            .add(df_column("namespace.table_name", "column2"));
513        assert_eq!(
514            expr_to_proof_expr(&expr, &schema).unwrap(),
515            DynProofExpr::try_new_add(
516                DynProofExpr::try_new_scaling_cast(
517                    COLUMN1_SMALLINT(),
518                    ColumnType::Decimal75(
519                        Precision::new(10).expect("Precision is definitely valid"),
520                        5
521                    )
522                )
523                .unwrap(),
524                COLUMN2_DECIMAL_25_5()
525            )
526            .unwrap()
527        );
528
529        // Subtract
530        let expr = df_column("namespace.table_name", "column1")
531            .sub(df_column("namespace.table_name", "column2"));
532        assert_eq!(
533            expr_to_proof_expr(&expr, &schema).unwrap(),
534            DynProofExpr::try_new_subtract(
535                DynProofExpr::try_new_scaling_cast(
536                    COLUMN1_SMALLINT(),
537                    ColumnType::Decimal75(
538                        Precision::new(10).expect("Precision is definitely valid"),
539                        5
540                    )
541                )
542                .unwrap(),
543                COLUMN2_DECIMAL_25_5()
544            )
545            .unwrap()
546        );
547
548        // Multiply - No scale cast!
549        let expr = df_column("namespace.table_name", "column1")
550            .mul(df_column("namespace.table_name", "column2"));
551        assert_eq!(
552            expr_to_proof_expr(&expr, &schema).unwrap(),
553            DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_DECIMAL_25_5()).unwrap()
554        );
555    }
556
557    #[test]
558    fn we_can_convert_logical_binary_expr_to_proof_expr() {
559        let schema = vec![
560            ("column1".into(), ColumnType::Boolean),
561            ("column2".into(), ColumnType::Boolean),
562        ];
563
564        // And
565        let expr = df_column("namespace.table_name", "column1")
566            .and(df_column("namespace.table_name", "column2"));
567        assert_eq!(
568            expr_to_proof_expr(&expr, &schema).unwrap(),
569            DynProofExpr::try_new_and(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
570        );
571
572        // Or
573        let expr = df_column("namespace.table_name", "column1")
574            .or(df_column("namespace.table_name", "column2"));
575        assert_eq!(
576            expr_to_proof_expr(&expr, &schema).unwrap(),
577            DynProofExpr::try_new_or(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
578        );
579    }
580
581    #[test]
582    fn we_can_convert_logical_not_eq_to_proof_expr() {
583        let schema = vec![
584            ("column1".into(), ColumnType::BigInt),
585            ("column2".into(), ColumnType::BigInt),
586        ];
587
588        let expr = df_column("namespace.table_name", "column1")
589            .not_eq(df_column("namespace.table_name", "column2"));
590        assert_eq!(
591            expr_to_proof_expr(&expr, &schema).unwrap(),
592            DynProofExpr::try_new_not(
593                DynProofExpr::try_new_equals(
594                    DynProofExpr::new_column(ColumnRef::new(
595                        TableRef::from_names(Some("namespace"), "table_name"),
596                        "column1".into(),
597                        ColumnType::BigInt,
598                    )),
599                    DynProofExpr::new_column(ColumnRef::new(
600                        TableRef::from_names(Some("namespace"), "table_name"),
601                        "column2".into(),
602                        ColumnType::BigInt,
603                    ))
604                )
605                .unwrap()
606            )
607            .unwrap()
608        );
609    }
610
611    #[test]
612    fn we_cannot_convert_unsupported_binary_expr_to_proof_expr() {
613        // Unsupported binary operator
614        let expr = Expr::BinaryExpr(BinaryExpr {
615            left: Box::new(df_column("namespace.table_name", "column1")),
616            right: Box::new(df_column("namespace.table_name", "column2")),
617            op: Operator::AtArrow,
618        });
619        let schema = vec![
620            ("column1".into(), ColumnType::Boolean),
621            ("column2".into(), ColumnType::Boolean),
622        ];
623        assert!(matches!(
624            expr_to_proof_expr(&expr, &schema),
625            Err(PlannerError::UnsupportedBinaryOperator { .. })
626        ));
627    }
628
629    // Literal
630    #[test]
631    fn we_can_convert_literal_expr_to_proof_expr() {
632        let expr = Expr::Literal(ScalarValue::Int32(Some(1)));
633        assert_eq!(
634            expr_to_proof_expr(&expr, &Vec::new()).unwrap(),
635            DynProofExpr::new_literal(LiteralValue::Int(1))
636        );
637    }
638
639    // Not
640    #[test]
641    fn we_can_convert_not_expr_to_proof_expr() {
642        let expr = Expr::Not(Box::new(df_column("table_name", "column")));
643        let schema = vec![("column".into(), ColumnType::Boolean)];
644        assert_eq!(
645            expr_to_proof_expr(&expr, &schema).unwrap(),
646            DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
647                TableRef::from_names(None, "table_name"),
648                "column".into(),
649                ColumnType::Boolean
650            )))
651            .unwrap()
652        );
653    }
654
655    // Cast
656    #[test]
657    fn we_can_convert_cast_expr_to_proof_expr() {
658        let expr = Expr::Cast(Cast::new(
659            Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
660            DataType::Int32,
661        ));
662        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
663        assert_eq!(
664            expression,
665            DynProofExpr::try_new_cast(
666                DynProofExpr::new_literal(LiteralValue::Boolean(true)),
667                ColumnType::Int
668            )
669            .unwrap()
670        );
671    }
672
673    #[test]
674    fn we_cannot_convert_cast_expr_to_proof_expr_when_inner_expr_to_proof_expr_fails() {
675        // Unsupported logical expression
676        let expr = Expr::Cast(Cast::new(
677            Box::new(Expr::Literal(ScalarValue::UInt64(Some(100)))),
678            DataType::Int16,
679        ));
680        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
681        assert!(matches!(
682            expression,
683            PlannerError::UnsupportedDataType { data_type: _ }
684        ));
685    }
686
687    #[test]
688    fn we_cannot_convert_cast_expr_to_proof_expr_for_unsupported_datatypes() {
689        // Unsupported logical expression
690        let expr = Expr::Cast(Cast::new(
691            Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
692            DataType::UInt16,
693        ));
694        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
695        assert!(matches!(
696            expression,
697            PlannerError::UnsupportedDataType { data_type: _ }
698        ));
699    }
700
701    #[test]
702    fn we_cannot_convert_cast_expr_to_proof_expr_for_datatypes_for_which_casting_is_not_supported()
703    {
704        // Unsupported logical expression
705        let expr = Expr::Cast(Cast::new(
706            Box::new(Expr::Literal(ScalarValue::Int16(Some(100)))),
707            DataType::Boolean,
708        ));
709        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
710        assert!(matches!(
711            expression,
712            PlannerError::AnalyzeError { source: _ }
713        ));
714    }
715
716    // Placeholder
717    #[test]
718    fn we_can_convert_placeholder_to_proof_expr() {
719        let expr = Expr::Placeholder(Placeholder {
720            id: "$1".to_string(),
721            data_type: Some(DataType::Int32),
722        });
723        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
724        assert_eq!(
725            expression,
726            DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
727        );
728    }
729
730    // Placeholder with data type specified by cast
731    #[test]
732    fn we_can_convert_placeholder_with_data_type_specified_by_cast_to_proof_expr() {
733        let expr = Expr::Cast(Cast::new(
734            Box::new(Expr::Placeholder(Placeholder {
735                id: "$1".to_string(),
736                data_type: None,
737            })),
738            DataType::Int32,
739        ));
740        let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
741        assert_eq!(
742            expression,
743            DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
744        );
745    }
746
747    // Unsupported logical expression
748    #[test]
749    fn we_cannot_convert_unsupported_expr_to_proof_expr() {
750        let expr = Expr::OuterReferenceColumn(
751            DataType::Int32,
752            Column::new(None::<TableReference>, "column"),
753        );
754        assert!(matches!(
755            expr_to_proof_expr(&expr, &Vec::new()),
756            Err(PlannerError::UnsupportedLogicalExpression { .. })
757        ));
758    }
759
760    #[test]
761    fn we_can_get_proof_expr_for_timestamps_of_different_scale() {
762        let lhs = Expr::Literal(ScalarValue::TimestampSecond(Some(1), None));
763        let rhs = Expr::Literal(ScalarValue::TimestampNanosecond(Some(1), None));
764        binary_expr_to_proof_expr(&lhs, &rhs, Operator::Gt, &Vec::new()).unwrap();
765    }
766
767    // get_column_idents_from_expr tests
768    #[test]
769    fn we_can_extract_single_column_ident() {
770        let expr = df_column("table", "column_a");
771        let result = get_column_idents_from_expr(&expr);
772        let expected: IndexSet<Ident> = ["column_a".into()].into_iter().collect();
773        assert_eq!(result, expected);
774    }
775
776    #[test]
777    fn we_can_extract_column_idents_from_binary_expr() {
778        let expr = df_column("table", "a").add(df_column("table", "b"));
779        let result = get_column_idents_from_expr(&expr);
780        let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
781        assert_eq!(result, expected);
782    }
783
784    #[test]
785    fn we_can_extract_column_idents_from_nested_binary_expr() {
786        // (a + b) * c
787        let expr = df_column("table", "a")
788            .add(df_column("table", "b"))
789            .mul(df_column("table", "c"));
790        let result = get_column_idents_from_expr(&expr);
791        let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
792        assert_eq!(result, expected);
793    }
794
795    #[test]
796    fn we_can_extract_column_idents_from_not_expr() {
797        let expr = Expr::Not(Box::new(df_column("table", "bool_col")));
798        let result = get_column_idents_from_expr(&expr);
799        let expected: IndexSet<Ident> = ["bool_col".into()].into_iter().collect();
800        assert_eq!(result, expected);
801    }
802
803    #[test]
804    fn we_can_extract_column_idents_from_alias_expr() {
805        let expr = df_column("table", "col_x").alias("alias_name");
806        let result = get_column_idents_from_expr(&expr);
807        let expected: IndexSet<Ident> = ["col_x".into()].into_iter().collect();
808        assert_eq!(result, expected);
809    }
810
811    #[test]
812    fn we_can_extract_column_idents_from_cast_expr() {
813        let expr = Expr::Cast(Cast::new(
814            Box::new(df_column("table", "num_col")),
815            DataType::Int64,
816        ));
817        let result = get_column_idents_from_expr(&expr);
818        let expected: IndexSet<Ident> = ["num_col".into()].into_iter().collect();
819        assert_eq!(result, expected);
820    }
821
822    #[test]
823    fn we_can_extract_column_idents_from_aggregate_function() {
824        let expr = Expr::AggregateFunction(datafusion::logical_expr::expr::AggregateFunction {
825            func_def: datafusion::logical_expr::expr::AggregateFunctionDefinition::BuiltIn(
826                datafusion::physical_plan::aggregates::AggregateFunction::Sum,
827            ),
828            args: vec![df_column("table", "value")],
829            distinct: false,
830            filter: None,
831            order_by: None,
832            null_treatment: None,
833        });
834        let result = get_column_idents_from_expr(&expr);
835        let expected: IndexSet<Ident> = ["value".into()].into_iter().collect();
836        assert_eq!(result, expected);
837    }
838
839    #[test]
840    fn we_can_extract_column_idents_from_aggregate_function_with_multiple_args() {
841        let expr = Expr::AggregateFunction(datafusion::logical_expr::expr::AggregateFunction {
842            func_def: datafusion::logical_expr::expr::AggregateFunctionDefinition::BuiltIn(
843                datafusion::physical_plan::aggregates::AggregateFunction::Sum,
844            ),
845            args: vec![
846                df_column("table", "col1"),
847                df_column("table", "col2"),
848                df_column("table", "col3"),
849            ],
850            distinct: false,
851            filter: None,
852            order_by: None,
853            null_treatment: None,
854        });
855        let result = get_column_idents_from_expr(&expr);
856        let expected: IndexSet<Ident> = ["col1".into(), "col2".into(), "col3".into()]
857            .into_iter()
858            .collect();
859        assert_eq!(result, expected);
860    }
861
862    #[test]
863    fn we_can_extract_no_column_idents_from_literal() {
864        let expr = Expr::Literal(ScalarValue::Int32(Some(42)));
865        let result = get_column_idents_from_expr(&expr);
866        assert!(result.is_empty());
867    }
868
869    #[test]
870    fn we_can_extract_column_idents_from_complex_nested_expr() {
871        // NOT (a > b AND c < d)
872        let inner = df_column("table", "a")
873            .gt(df_column("table", "b"))
874            .and(df_column("table", "c").lt(df_column("table", "d")));
875        let expr = Expr::Not(Box::new(inner));
876        let result = get_column_idents_from_expr(&expr);
877        let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into(), "d".into()]
878            .into_iter()
879            .collect();
880        assert_eq!(result, expected);
881    }
882
883    #[test]
884    fn we_can_extract_column_idents_preserving_order() {
885        // IndexSet should preserve insertion order
886        let expr = df_column("table", "z")
887            .add(df_column("table", "a"))
888            .add(df_column("table", "m"));
889        let result = get_column_idents_from_expr(&expr);
890        let idents: Vec<Ident> = result.into_iter().collect();
891        assert_eq!(idents, vec!["z".into(), "a".into(), "m".into()]);
892    }
893
894    #[test]
895    fn we_can_handle_duplicate_column_references() {
896        // a + a should only have 'a' once
897        let expr = df_column("table", "a").add(df_column("table", "a"));
898        let result = get_column_idents_from_expr(&expr);
899        let expected: IndexSet<Ident> = ["a".into()].into_iter().collect();
900        assert_eq!(result, expected);
901    }
902
903    #[test]
904    fn we_can_extract_columns_from_comparison_operations() {
905        let expr = df_column("table", "price")
906            .gt(df_column("table", "threshold"))
907            .and(df_column("table", "active").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))));
908        let result = get_column_idents_from_expr(&expr);
909        let expected: IndexSet<Ident> = ["price".into(), "threshold".into(), "active".into()]
910            .into_iter()
911            .collect();
912        assert_eq!(result, expected);
913    }
914}