1use super::{
2 column_to_column_ref, placeholder_to_placeholder_expr, scalar_value_to_literal_value,
3 PlannerError, PlannerResult,
4};
5use datafusion::logical_expr::{
6 expr::{Alias, Cast, Placeholder},
7 BinaryExpr, Expr, Operator,
8};
9use indexmap::IndexSet;
10use proof_of_sql::{
11 base::database::ColumnType,
12 sql::{proof_exprs::DynProofExpr, scale_cast_binary_op},
13};
14use sqlparser::ast::Ident;
15
16pub(crate) fn get_column_idents_from_expr(expr: &Expr) -> IndexSet<Ident> {
18 match expr {
19 Expr::Column(col) => {
20 let mut set = IndexSet::new();
21 set.insert(col.name.as_str().into());
22 set
23 }
24 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
25 let mut left_idents = get_column_idents_from_expr(left);
26 left_idents.extend(get_column_idents_from_expr(right));
27 left_idents
28 }
29 Expr::Not(inner) => get_column_idents_from_expr(inner),
30 Expr::Alias(Alias { expr, .. }) | Expr::Cast(Cast { expr, .. }) => {
31 get_column_idents_from_expr(expr)
32 }
33 Expr::AggregateFunction(agg) => agg
34 .args
35 .iter()
36 .flat_map(get_column_idents_from_expr)
37 .collect(),
38 _ => IndexSet::new(),
39 }
40}
41
42#[expect(
44 clippy::missing_panics_doc,
45 reason = "Output of comparisons is always boolean"
46)]
47fn binary_expr_to_proof_expr(
48 left: &Expr,
49 right: &Expr,
50 op: Operator,
51 schema: &[(Ident, ColumnType)],
52) -> PlannerResult<DynProofExpr> {
53 let left_proof_expr = expr_to_proof_expr(left, schema)?;
54 let right_proof_expr = expr_to_proof_expr(right, schema)?;
55
56 let (left_proof_expr, right_proof_expr) = match op {
57 Operator::Eq
58 | Operator::NotEq
59 | Operator::Lt
60 | Operator::Gt
61 | Operator::LtEq
62 | Operator::GtEq
63 | Operator::Plus
64 | Operator::Minus => scale_cast_binary_op(left_proof_expr, right_proof_expr)?,
65 _ => (left_proof_expr, right_proof_expr),
66 };
67
68 match op {
69 Operator::And => Ok(DynProofExpr::try_new_and(
70 left_proof_expr,
71 right_proof_expr,
72 )?),
73 Operator::Or => Ok(DynProofExpr::try_new_or(left_proof_expr, right_proof_expr)?),
74 Operator::Multiply => Ok(DynProofExpr::try_new_multiply(
75 left_proof_expr,
76 right_proof_expr,
77 )?),
78 Operator::Eq => Ok(DynProofExpr::try_new_equals(
79 left_proof_expr,
80 right_proof_expr,
81 )?),
82 Operator::NotEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_equals(
83 left_proof_expr,
84 right_proof_expr,
85 )?)
86 .expect("An equality expression must have a boolean data type...")),
87 Operator::Lt => Ok(DynProofExpr::try_new_inequality(
88 left_proof_expr,
89 right_proof_expr,
90 true,
91 )?),
92 Operator::Gt => Ok(DynProofExpr::try_new_inequality(
93 left_proof_expr,
94 right_proof_expr,
95 false,
96 )?),
97 Operator::LtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
98 left_proof_expr,
99 right_proof_expr,
100 false,
101 )?)
102 .expect("An inequality expression must have a boolean data type...")),
103 Operator::GtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
104 left_proof_expr,
105 right_proof_expr,
106 true,
107 )?)
108 .expect("An inequality expression must have a boolean data type...")),
109 Operator::Plus => Ok(DynProofExpr::try_new_add(
110 left_proof_expr,
111 right_proof_expr,
112 )?),
113 Operator::Minus => Ok(DynProofExpr::try_new_subtract(
114 left_proof_expr,
115 right_proof_expr,
116 )?),
117 _ => Err(PlannerError::UnsupportedBinaryOperator { op }),
119 }
120}
121
122pub fn expr_to_proof_expr(
127 expr: &Expr,
128 schema: &[(Ident, ColumnType)],
129) -> PlannerResult<DynProofExpr> {
130 match expr {
131 Expr::Alias(Alias { expr, .. }) => expr_to_proof_expr(expr, schema),
132 Expr::Column(col) => Ok(DynProofExpr::new_column(column_to_column_ref(col, schema)?)),
133 Expr::Placeholder(placeholder) => placeholder_to_placeholder_expr(placeholder),
134 Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
135 binary_expr_to_proof_expr(left, right, *op, schema)
136 }
137 Expr::Literal(val) => Ok(DynProofExpr::new_literal(scalar_value_to_literal_value(
138 val.clone(),
139 )?)),
140 Expr::Not(expr) => {
141 let proof_expr = expr_to_proof_expr(expr, schema)?;
142 Ok(DynProofExpr::try_new_not(proof_expr)?)
143 }
144 Expr::Cast(cast) => {
145 match &*cast.expr {
146 Expr::Placeholder(placeholder) if placeholder.data_type.is_none() => {
148 let typed_placeholder =
149 Placeholder::new(placeholder.id.clone(), Some(cast.data_type.clone()));
150 placeholder_to_placeholder_expr(&typed_placeholder)
151 }
152 _ => {
153 let from_expr = expr_to_proof_expr(&cast.expr, schema)?;
154 let to_type = cast.data_type.clone().try_into().map_err(|_| {
155 PlannerError::UnsupportedDataType {
156 data_type: cast.data_type.clone(),
157 }
158 })?;
159 Ok(
160 DynProofExpr::try_new_cast(from_expr.clone(), to_type).map_or_else(
161 |_| DynProofExpr::try_new_scaling_cast(from_expr, to_type),
162 Ok,
163 )?,
164 )
165 }
166 }
167 }
168 _ => Err(PlannerError::UnsupportedLogicalExpression {
169 expr: Box::new(expr.clone()),
170 }),
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::df_util::*;
178 use arrow::datatypes::DataType;
179 use core::ops::{Add, Mul, Sub};
180 use datafusion::{
181 catalog::TableReference,
182 common::{Column, ScalarValue},
183 logical_expr::{expr::Placeholder, Cast},
184 };
185 use proof_of_sql::base::{
186 database::{ColumnRef, ColumnType, LiteralValue, TableRef},
187 math::decimal::Precision,
188 };
189
190 #[expect(non_snake_case)]
191 fn COLUMN_INT() -> DynProofExpr {
192 DynProofExpr::new_column(ColumnRef::new(
193 TableRef::from_names(Some("namespace"), "table_name"),
194 "column".into(),
195 ColumnType::Int,
196 ))
197 }
198
199 #[expect(non_snake_case)]
200 fn COLUMN1_SMALLINT() -> DynProofExpr {
201 DynProofExpr::new_column(ColumnRef::new(
202 TableRef::from_names(Some("namespace"), "table_name"),
203 "column1".into(),
204 ColumnType::SmallInt,
205 ))
206 }
207
208 #[expect(non_snake_case)]
209 fn COLUMN2_BIGINT() -> DynProofExpr {
210 DynProofExpr::new_column(ColumnRef::new(
211 TableRef::from_names(Some("namespace"), "table_name"),
212 "column2".into(),
213 ColumnType::BigInt,
214 ))
215 }
216
217 #[expect(non_snake_case)]
218 fn COLUMN1_BOOLEAN() -> DynProofExpr {
219 DynProofExpr::new_column(ColumnRef::new(
220 TableRef::from_names(Some("namespace"), "table_name"),
221 "column1".into(),
222 ColumnType::Boolean,
223 ))
224 }
225
226 #[expect(non_snake_case)]
227 fn COLUMN2_BOOLEAN() -> DynProofExpr {
228 DynProofExpr::new_column(ColumnRef::new(
229 TableRef::from_names(Some("namespace"), "table_name"),
230 "column2".into(),
231 ColumnType::Boolean,
232 ))
233 }
234
235 #[expect(non_snake_case)]
236 fn COLUMN3_DECIMAL_75_5() -> DynProofExpr {
237 DynProofExpr::new_column(ColumnRef::new(
238 TableRef::from_names(Some("namespace"), "table_name"),
239 "column3".into(),
240 ColumnType::Decimal75(
241 Precision::new(75).expect("Precision is definitely valid"),
242 5,
243 ),
244 ))
245 }
246
247 #[expect(non_snake_case)]
248 fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
249 DynProofExpr::new_column(ColumnRef::new(
250 TableRef::from_names(Some("namespace"), "table_name"),
251 "column2".into(),
252 ColumnType::Decimal75(
253 Precision::new(25).expect("Precision is definitely valid"),
254 5,
255 ),
256 ))
257 }
258
259 #[test]
261 fn we_can_convert_alias_to_proof_expr() {
262 let expr = df_column("namespace.table_name", "column").alias("alias");
264 let schema = vec![("column".into(), ColumnType::Int)];
265 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
266 }
267
268 #[test]
270 fn we_can_convert_column_expr_to_proof_expr() {
271 let expr = df_column("namespace.table_name", "column");
273 let schema = vec![("column".into(), ColumnType::Int)];
274 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
275 }
276
277 #[test]
279 fn we_can_convert_comparison_binary_expr_to_proof_expr() {
280 let schema = vec![
281 ("column1".into(), ColumnType::SmallInt),
282 ("column2".into(), ColumnType::BigInt),
283 ];
284
285 let expr = df_column("namespace.table_name", "column1")
287 .eq(df_column("namespace.table_name", "column2"));
288 assert_eq!(
289 expr_to_proof_expr(&expr, &schema).unwrap(),
290 DynProofExpr::try_new_equals(COLUMN1_SMALLINT(), COLUMN2_BIGINT()).unwrap()
291 );
292
293 let expr = df_column("namespace.table_name", "column1")
295 .lt(df_column("namespace.table_name", "column2"));
296 assert_eq!(
297 expr_to_proof_expr(&expr, &schema).unwrap(),
298 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true).unwrap()
299 );
300
301 let expr = df_column("namespace.table_name", "column1")
303 .gt(df_column("namespace.table_name", "column2"));
304 assert_eq!(
305 expr_to_proof_expr(&expr, &schema).unwrap(),
306 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false).unwrap()
307 );
308
309 let expr = df_column("namespace.table_name", "column1")
311 .lt_eq(df_column("namespace.table_name", "column2"));
312 assert_eq!(
313 expr_to_proof_expr(&expr, &schema).unwrap(),
314 DynProofExpr::try_new_not(
315 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false)
316 .unwrap()
317 )
318 .unwrap()
319 );
320
321 let expr = df_column("namespace.table_name", "column1")
323 .gt_eq(df_column("namespace.table_name", "column2"));
324 assert_eq!(
325 expr_to_proof_expr(&expr, &schema).unwrap(),
326 DynProofExpr::try_new_not(
327 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true)
328 .unwrap()
329 )
330 .unwrap()
331 );
332 }
333
334 #[expect(clippy::too_many_lines)]
335 #[test]
336 fn we_can_convert_comparison_binary_expr_to_proof_expr_with_scale_cast() {
337 let schema = vec![
338 ("column1".into(), ColumnType::SmallInt),
339 (
340 "column2".into(),
341 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
342 ),
343 (
344 "column3".into(),
345 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
346 ),
347 ];
348
349 let expr = df_column("namespace.table_name", "column1")
351 .eq(df_column("namespace.table_name", "column3"));
352 assert_eq!(
353 expr_to_proof_expr(&expr, &schema).unwrap(),
354 DynProofExpr::try_new_equals(
355 DynProofExpr::try_new_scaling_cast(
356 COLUMN1_SMALLINT(),
357 ColumnType::Decimal75(
358 Precision::new(10).expect("Precision is definitely valid"),
359 5
360 )
361 )
362 .unwrap(),
363 COLUMN3_DECIMAL_75_5()
364 )
365 .unwrap()
366 );
367
368 let expr = df_column("namespace.table_name", "column1")
370 .lt(df_column("namespace.table_name", "column2"));
371 assert_eq!(
372 expr_to_proof_expr(&expr, &schema).unwrap(),
373 DynProofExpr::try_new_inequality(
374 DynProofExpr::try_new_scaling_cast(
375 COLUMN1_SMALLINT(),
376 ColumnType::Decimal75(
377 Precision::new(10).expect("Precision is definitely valid"),
378 5
379 )
380 )
381 .unwrap(),
382 COLUMN2_DECIMAL_25_5(),
383 true
384 )
385 .unwrap()
386 );
387
388 let expr = df_column("namespace.table_name", "column1")
390 .gt(df_column("namespace.table_name", "column2"));
391 assert_eq!(
392 expr_to_proof_expr(&expr, &schema).unwrap(),
393 DynProofExpr::try_new_inequality(
394 DynProofExpr::try_new_scaling_cast(
395 COLUMN1_SMALLINT(),
396 ColumnType::Decimal75(
397 Precision::new(10).expect("Precision is definitely valid"),
398 5
399 )
400 )
401 .unwrap(),
402 COLUMN2_DECIMAL_25_5(),
403 false
404 )
405 .unwrap()
406 );
407
408 let expr = df_column("namespace.table_name", "column1")
410 .lt_eq(df_column("namespace.table_name", "column2"));
411 assert_eq!(
412 expr_to_proof_expr(&expr, &schema).unwrap(),
413 DynProofExpr::try_new_not(
414 DynProofExpr::try_new_inequality(
415 DynProofExpr::try_new_scaling_cast(
416 COLUMN1_SMALLINT(),
417 ColumnType::Decimal75(
418 Precision::new(10).expect("Precision is definitely valid"),
419 5
420 )
421 )
422 .unwrap(),
423 COLUMN2_DECIMAL_25_5(),
424 false
425 )
426 .unwrap()
427 )
428 .unwrap()
429 );
430
431 let expr = df_column("namespace.table_name", "column1")
433 .gt_eq(df_column("namespace.table_name", "column2"));
434 assert_eq!(
435 expr_to_proof_expr(&expr, &schema).unwrap(),
436 DynProofExpr::try_new_not(
437 DynProofExpr::try_new_inequality(
438 DynProofExpr::try_new_scaling_cast(
439 COLUMN1_SMALLINT(),
440 ColumnType::Decimal75(
441 Precision::new(10).expect("Precision is definitely valid"),
442 5
443 )
444 )
445 .unwrap(),
446 COLUMN2_DECIMAL_25_5(),
447 true
448 )
449 .unwrap()
450 )
451 .unwrap()
452 );
453 }
454
455 #[test]
456 fn we_can_convert_arithmetic_binary_expr_to_proof_expr() {
457 let schema = vec![
458 ("column1".into(), ColumnType::SmallInt),
459 ("column2".into(), ColumnType::BigInt),
460 ];
461
462 let expr = Expr::BinaryExpr(BinaryExpr {
464 left: Box::new(df_column("namespace.table_name", "column1")),
465 right: Box::new(df_column("namespace.table_name", "column2")),
466 op: Operator::Plus,
467 });
468 assert_eq!(
469 expr_to_proof_expr(&expr, &schema).unwrap(),
470 DynProofExpr::try_new_add(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
471 );
472
473 let expr = Expr::BinaryExpr(BinaryExpr {
475 left: Box::new(df_column("namespace.table_name", "column1")),
476 right: Box::new(df_column("namespace.table_name", "column2")),
477 op: Operator::Minus,
478 });
479 assert_eq!(
480 expr_to_proof_expr(&expr, &schema).unwrap(),
481 DynProofExpr::try_new_subtract(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
482 );
483
484 let expr = Expr::BinaryExpr(BinaryExpr {
486 left: Box::new(df_column("namespace.table_name", "column1")),
487 right: Box::new(df_column("namespace.table_name", "column2")),
488 op: Operator::Multiply,
489 });
490 assert_eq!(
491 expr_to_proof_expr(&expr, &schema).unwrap(),
492 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
493 );
494 }
495
496 #[test]
497 fn we_can_convert_arithmetic_binary_expr_to_proof_expr_with_scale_cast() {
498 let schema = vec![
499 ("column1".into(), ColumnType::SmallInt),
500 (
501 "column2".into(),
502 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
503 ),
504 (
505 "column3".into(),
506 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
507 ),
508 ];
509
510 let expr = df_column("namespace.table_name", "column1")
512 .add(df_column("namespace.table_name", "column2"));
513 assert_eq!(
514 expr_to_proof_expr(&expr, &schema).unwrap(),
515 DynProofExpr::try_new_add(
516 DynProofExpr::try_new_scaling_cast(
517 COLUMN1_SMALLINT(),
518 ColumnType::Decimal75(
519 Precision::new(10).expect("Precision is definitely valid"),
520 5
521 )
522 )
523 .unwrap(),
524 COLUMN2_DECIMAL_25_5()
525 )
526 .unwrap()
527 );
528
529 let expr = df_column("namespace.table_name", "column1")
531 .sub(df_column("namespace.table_name", "column2"));
532 assert_eq!(
533 expr_to_proof_expr(&expr, &schema).unwrap(),
534 DynProofExpr::try_new_subtract(
535 DynProofExpr::try_new_scaling_cast(
536 COLUMN1_SMALLINT(),
537 ColumnType::Decimal75(
538 Precision::new(10).expect("Precision is definitely valid"),
539 5
540 )
541 )
542 .unwrap(),
543 COLUMN2_DECIMAL_25_5()
544 )
545 .unwrap()
546 );
547
548 let expr = df_column("namespace.table_name", "column1")
550 .mul(df_column("namespace.table_name", "column2"));
551 assert_eq!(
552 expr_to_proof_expr(&expr, &schema).unwrap(),
553 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_DECIMAL_25_5()).unwrap()
554 );
555 }
556
557 #[test]
558 fn we_can_convert_logical_binary_expr_to_proof_expr() {
559 let schema = vec![
560 ("column1".into(), ColumnType::Boolean),
561 ("column2".into(), ColumnType::Boolean),
562 ];
563
564 let expr = df_column("namespace.table_name", "column1")
566 .and(df_column("namespace.table_name", "column2"));
567 assert_eq!(
568 expr_to_proof_expr(&expr, &schema).unwrap(),
569 DynProofExpr::try_new_and(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
570 );
571
572 let expr = df_column("namespace.table_name", "column1")
574 .or(df_column("namespace.table_name", "column2"));
575 assert_eq!(
576 expr_to_proof_expr(&expr, &schema).unwrap(),
577 DynProofExpr::try_new_or(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
578 );
579 }
580
581 #[test]
582 fn we_can_convert_logical_not_eq_to_proof_expr() {
583 let schema = vec![
584 ("column1".into(), ColumnType::BigInt),
585 ("column2".into(), ColumnType::BigInt),
586 ];
587
588 let expr = df_column("namespace.table_name", "column1")
589 .not_eq(df_column("namespace.table_name", "column2"));
590 assert_eq!(
591 expr_to_proof_expr(&expr, &schema).unwrap(),
592 DynProofExpr::try_new_not(
593 DynProofExpr::try_new_equals(
594 DynProofExpr::new_column(ColumnRef::new(
595 TableRef::from_names(Some("namespace"), "table_name"),
596 "column1".into(),
597 ColumnType::BigInt,
598 )),
599 DynProofExpr::new_column(ColumnRef::new(
600 TableRef::from_names(Some("namespace"), "table_name"),
601 "column2".into(),
602 ColumnType::BigInt,
603 ))
604 )
605 .unwrap()
606 )
607 .unwrap()
608 );
609 }
610
611 #[test]
612 fn we_cannot_convert_unsupported_binary_expr_to_proof_expr() {
613 let expr = Expr::BinaryExpr(BinaryExpr {
615 left: Box::new(df_column("namespace.table_name", "column1")),
616 right: Box::new(df_column("namespace.table_name", "column2")),
617 op: Operator::AtArrow,
618 });
619 let schema = vec![
620 ("column1".into(), ColumnType::Boolean),
621 ("column2".into(), ColumnType::Boolean),
622 ];
623 assert!(matches!(
624 expr_to_proof_expr(&expr, &schema),
625 Err(PlannerError::UnsupportedBinaryOperator { .. })
626 ));
627 }
628
629 #[test]
631 fn we_can_convert_literal_expr_to_proof_expr() {
632 let expr = Expr::Literal(ScalarValue::Int32(Some(1)));
633 assert_eq!(
634 expr_to_proof_expr(&expr, &Vec::new()).unwrap(),
635 DynProofExpr::new_literal(LiteralValue::Int(1))
636 );
637 }
638
639 #[test]
641 fn we_can_convert_not_expr_to_proof_expr() {
642 let expr = Expr::Not(Box::new(df_column("table_name", "column")));
643 let schema = vec![("column".into(), ColumnType::Boolean)];
644 assert_eq!(
645 expr_to_proof_expr(&expr, &schema).unwrap(),
646 DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
647 TableRef::from_names(None, "table_name"),
648 "column".into(),
649 ColumnType::Boolean
650 )))
651 .unwrap()
652 );
653 }
654
655 #[test]
657 fn we_can_convert_cast_expr_to_proof_expr() {
658 let expr = Expr::Cast(Cast::new(
659 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
660 DataType::Int32,
661 ));
662 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
663 assert_eq!(
664 expression,
665 DynProofExpr::try_new_cast(
666 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
667 ColumnType::Int
668 )
669 .unwrap()
670 );
671 }
672
673 #[test]
674 fn we_cannot_convert_cast_expr_to_proof_expr_when_inner_expr_to_proof_expr_fails() {
675 let expr = Expr::Cast(Cast::new(
677 Box::new(Expr::Literal(ScalarValue::UInt64(Some(100)))),
678 DataType::Int16,
679 ));
680 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
681 assert!(matches!(
682 expression,
683 PlannerError::UnsupportedDataType { data_type: _ }
684 ));
685 }
686
687 #[test]
688 fn we_cannot_convert_cast_expr_to_proof_expr_for_unsupported_datatypes() {
689 let expr = Expr::Cast(Cast::new(
691 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
692 DataType::UInt16,
693 ));
694 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
695 assert!(matches!(
696 expression,
697 PlannerError::UnsupportedDataType { data_type: _ }
698 ));
699 }
700
701 #[test]
702 fn we_cannot_convert_cast_expr_to_proof_expr_for_datatypes_for_which_casting_is_not_supported()
703 {
704 let expr = Expr::Cast(Cast::new(
706 Box::new(Expr::Literal(ScalarValue::Int16(Some(100)))),
707 DataType::Boolean,
708 ));
709 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
710 assert!(matches!(
711 expression,
712 PlannerError::AnalyzeError { source: _ }
713 ));
714 }
715
716 #[test]
718 fn we_can_convert_placeholder_to_proof_expr() {
719 let expr = Expr::Placeholder(Placeholder {
720 id: "$1".to_string(),
721 data_type: Some(DataType::Int32),
722 });
723 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
724 assert_eq!(
725 expression,
726 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
727 );
728 }
729
730 #[test]
732 fn we_can_convert_placeholder_with_data_type_specified_by_cast_to_proof_expr() {
733 let expr = Expr::Cast(Cast::new(
734 Box::new(Expr::Placeholder(Placeholder {
735 id: "$1".to_string(),
736 data_type: None,
737 })),
738 DataType::Int32,
739 ));
740 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
741 assert_eq!(
742 expression,
743 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
744 );
745 }
746
747 #[test]
749 fn we_cannot_convert_unsupported_expr_to_proof_expr() {
750 let expr = Expr::OuterReferenceColumn(
751 DataType::Int32,
752 Column::new(None::<TableReference>, "column"),
753 );
754 assert!(matches!(
755 expr_to_proof_expr(&expr, &Vec::new()),
756 Err(PlannerError::UnsupportedLogicalExpression { .. })
757 ));
758 }
759
760 #[test]
761 fn we_can_get_proof_expr_for_timestamps_of_different_scale() {
762 let lhs = Expr::Literal(ScalarValue::TimestampSecond(Some(1), None));
763 let rhs = Expr::Literal(ScalarValue::TimestampNanosecond(Some(1), None));
764 binary_expr_to_proof_expr(&lhs, &rhs, Operator::Gt, &Vec::new()).unwrap();
765 }
766
767 #[test]
769 fn we_can_extract_single_column_ident() {
770 let expr = df_column("table", "column_a");
771 let result = get_column_idents_from_expr(&expr);
772 let expected: IndexSet<Ident> = ["column_a".into()].into_iter().collect();
773 assert_eq!(result, expected);
774 }
775
776 #[test]
777 fn we_can_extract_column_idents_from_binary_expr() {
778 let expr = df_column("table", "a").add(df_column("table", "b"));
779 let result = get_column_idents_from_expr(&expr);
780 let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
781 assert_eq!(result, expected);
782 }
783
784 #[test]
785 fn we_can_extract_column_idents_from_nested_binary_expr() {
786 let expr = df_column("table", "a")
788 .add(df_column("table", "b"))
789 .mul(df_column("table", "c"));
790 let result = get_column_idents_from_expr(&expr);
791 let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
792 assert_eq!(result, expected);
793 }
794
795 #[test]
796 fn we_can_extract_column_idents_from_not_expr() {
797 let expr = Expr::Not(Box::new(df_column("table", "bool_col")));
798 let result = get_column_idents_from_expr(&expr);
799 let expected: IndexSet<Ident> = ["bool_col".into()].into_iter().collect();
800 assert_eq!(result, expected);
801 }
802
803 #[test]
804 fn we_can_extract_column_idents_from_alias_expr() {
805 let expr = df_column("table", "col_x").alias("alias_name");
806 let result = get_column_idents_from_expr(&expr);
807 let expected: IndexSet<Ident> = ["col_x".into()].into_iter().collect();
808 assert_eq!(result, expected);
809 }
810
811 #[test]
812 fn we_can_extract_column_idents_from_cast_expr() {
813 let expr = Expr::Cast(Cast::new(
814 Box::new(df_column("table", "num_col")),
815 DataType::Int64,
816 ));
817 let result = get_column_idents_from_expr(&expr);
818 let expected: IndexSet<Ident> = ["num_col".into()].into_iter().collect();
819 assert_eq!(result, expected);
820 }
821
822 #[test]
823 fn we_can_extract_column_idents_from_aggregate_function() {
824 let expr = Expr::AggregateFunction(datafusion::logical_expr::expr::AggregateFunction {
825 func_def: datafusion::logical_expr::expr::AggregateFunctionDefinition::BuiltIn(
826 datafusion::physical_plan::aggregates::AggregateFunction::Sum,
827 ),
828 args: vec![df_column("table", "value")],
829 distinct: false,
830 filter: None,
831 order_by: None,
832 null_treatment: None,
833 });
834 let result = get_column_idents_from_expr(&expr);
835 let expected: IndexSet<Ident> = ["value".into()].into_iter().collect();
836 assert_eq!(result, expected);
837 }
838
839 #[test]
840 fn we_can_extract_column_idents_from_aggregate_function_with_multiple_args() {
841 let expr = Expr::AggregateFunction(datafusion::logical_expr::expr::AggregateFunction {
842 func_def: datafusion::logical_expr::expr::AggregateFunctionDefinition::BuiltIn(
843 datafusion::physical_plan::aggregates::AggregateFunction::Sum,
844 ),
845 args: vec![
846 df_column("table", "col1"),
847 df_column("table", "col2"),
848 df_column("table", "col3"),
849 ],
850 distinct: false,
851 filter: None,
852 order_by: None,
853 null_treatment: None,
854 });
855 let result = get_column_idents_from_expr(&expr);
856 let expected: IndexSet<Ident> = ["col1".into(), "col2".into(), "col3".into()]
857 .into_iter()
858 .collect();
859 assert_eq!(result, expected);
860 }
861
862 #[test]
863 fn we_can_extract_no_column_idents_from_literal() {
864 let expr = Expr::Literal(ScalarValue::Int32(Some(42)));
865 let result = get_column_idents_from_expr(&expr);
866 assert!(result.is_empty());
867 }
868
869 #[test]
870 fn we_can_extract_column_idents_from_complex_nested_expr() {
871 let inner = df_column("table", "a")
873 .gt(df_column("table", "b"))
874 .and(df_column("table", "c").lt(df_column("table", "d")));
875 let expr = Expr::Not(Box::new(inner));
876 let result = get_column_idents_from_expr(&expr);
877 let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into(), "d".into()]
878 .into_iter()
879 .collect();
880 assert_eq!(result, expected);
881 }
882
883 #[test]
884 fn we_can_extract_column_idents_preserving_order() {
885 let expr = df_column("table", "z")
887 .add(df_column("table", "a"))
888 .add(df_column("table", "m"));
889 let result = get_column_idents_from_expr(&expr);
890 let idents: Vec<Ident> = result.into_iter().collect();
891 assert_eq!(idents, vec!["z".into(), "a".into(), "m".into()]);
892 }
893
894 #[test]
895 fn we_can_handle_duplicate_column_references() {
896 let expr = df_column("table", "a").add(df_column("table", "a"));
898 let result = get_column_idents_from_expr(&expr);
899 let expected: IndexSet<Ident> = ["a".into()].into_iter().collect();
900 assert_eq!(result, expected);
901 }
902
903 #[test]
904 fn we_can_extract_columns_from_comparison_operations() {
905 let expr = df_column("table", "price")
906 .gt(df_column("table", "threshold"))
907 .and(df_column("table", "active").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))));
908 let result = get_column_idents_from_expr(&expr);
909 let expected: IndexSet<Ident> = ["price".into(), "threshold".into(), "active".into()]
910 .into_iter()
911 .collect();
912 assert_eq!(result, expected);
913 }
914}