1use super::{
2 column_to_column_ref, placeholder_to_placeholder_expr, scalar_value_to_literal_value,
3 PlannerError, PlannerResult,
4};
5use datafusion::logical_expr::{
6 expr::{Alias, Placeholder},
7 BinaryExpr, Expr, Operator,
8};
9use proof_of_sql::{
10 base::database::ColumnType,
11 sql::{proof_exprs::DynProofExpr, scale_cast_binary_op},
12};
13use sqlparser::ast::Ident;
14
15#[expect(
17 clippy::missing_panics_doc,
18 reason = "Output of comparisons is always boolean"
19)]
20fn binary_expr_to_proof_expr(
21 left: &Expr,
22 right: &Expr,
23 op: Operator,
24 schema: &[(Ident, ColumnType)],
25) -> PlannerResult<DynProofExpr> {
26 let left_proof_expr = expr_to_proof_expr(left, schema)?;
27 let right_proof_expr = expr_to_proof_expr(right, schema)?;
28
29 let (left_proof_expr, right_proof_expr) = match op {
30 Operator::Eq
31 | Operator::Lt
32 | Operator::Gt
33 | Operator::LtEq
34 | Operator::GtEq
35 | Operator::Plus
36 | Operator::Minus => scale_cast_binary_op(left_proof_expr, right_proof_expr)?,
37 _ => (left_proof_expr, right_proof_expr),
38 };
39
40 match op {
41 Operator::And => Ok(DynProofExpr::try_new_and(
42 left_proof_expr,
43 right_proof_expr,
44 )?),
45 Operator::Or => Ok(DynProofExpr::try_new_or(left_proof_expr, right_proof_expr)?),
46 Operator::Multiply => Ok(DynProofExpr::try_new_multiply(
47 left_proof_expr,
48 right_proof_expr,
49 )?),
50 Operator::Eq => Ok(DynProofExpr::try_new_equals(
51 left_proof_expr,
52 right_proof_expr,
53 )?),
54 Operator::Lt => Ok(DynProofExpr::try_new_inequality(
55 left_proof_expr,
56 right_proof_expr,
57 true,
58 )?),
59 Operator::Gt => Ok(DynProofExpr::try_new_inequality(
60 left_proof_expr,
61 right_proof_expr,
62 false,
63 )?),
64 Operator::LtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
65 left_proof_expr,
66 right_proof_expr,
67 false,
68 )?)
69 .expect("An inequality expression must have a boolean data type...")),
70 Operator::GtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
71 left_proof_expr,
72 right_proof_expr,
73 true,
74 )?)
75 .expect("An inequality expression must have a boolean data type...")),
76 Operator::Plus => Ok(DynProofExpr::try_new_add(
77 left_proof_expr,
78 right_proof_expr,
79 )?),
80 Operator::Minus => Ok(DynProofExpr::try_new_subtract(
81 left_proof_expr,
82 right_proof_expr,
83 )?),
84 _ => Err(PlannerError::UnsupportedBinaryOperator { op }),
86 }
87}
88
89pub fn expr_to_proof_expr(
94 expr: &Expr,
95 schema: &[(Ident, ColumnType)],
96) -> PlannerResult<DynProofExpr> {
97 match expr {
98 Expr::Alias(Alias { expr, .. }) => expr_to_proof_expr(expr, schema),
99 Expr::Column(col) => Ok(DynProofExpr::new_column(column_to_column_ref(col, schema)?)),
100 Expr::Placeholder(placeholder) => placeholder_to_placeholder_expr(placeholder),
101 Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
102 binary_expr_to_proof_expr(left, right, *op, schema)
103 }
104 Expr::Literal(val) => Ok(DynProofExpr::new_literal(scalar_value_to_literal_value(
105 val.clone(),
106 )?)),
107 Expr::Not(expr) => {
108 let proof_expr = expr_to_proof_expr(expr, schema)?;
109 Ok(DynProofExpr::try_new_not(proof_expr)?)
110 }
111 Expr::Cast(cast) => {
112 match &*cast.expr {
113 Expr::Placeholder(placeholder) if placeholder.data_type.is_none() => {
115 let typed_placeholder =
116 Placeholder::new(placeholder.id.clone(), Some(cast.data_type.clone()));
117 placeholder_to_placeholder_expr(&typed_placeholder)
118 }
119 _ => {
120 let from_expr = expr_to_proof_expr(&cast.expr, schema)?;
121 let to_type = cast.data_type.clone().try_into().map_err(|_| {
122 PlannerError::UnsupportedDataType {
123 data_type: cast.data_type.clone(),
124 }
125 })?;
126 Ok(
127 DynProofExpr::try_new_cast(from_expr.clone(), to_type).map_or_else(
128 |_| DynProofExpr::try_new_scaling_cast(from_expr, to_type),
129 Ok,
130 )?,
131 )
132 }
133 }
134 }
135 _ => Err(PlannerError::UnsupportedLogicalExpression { expr: expr.clone() }),
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::df_util::*;
143 use arrow::datatypes::DataType;
144 use core::ops::{Add, Mul, Sub};
145 use datafusion::{
146 common::ScalarValue,
147 logical_expr::{
148 expr::{Placeholder, Unnest},
149 Cast,
150 },
151 };
152 use proof_of_sql::base::{
153 database::{ColumnRef, ColumnType, LiteralValue, TableRef},
154 math::decimal::Precision,
155 };
156
157 #[expect(non_snake_case)]
158 fn COLUMN_INT() -> DynProofExpr {
159 DynProofExpr::new_column(ColumnRef::new(
160 TableRef::from_names(Some("namespace"), "table_name"),
161 "column".into(),
162 ColumnType::Int,
163 ))
164 }
165
166 #[expect(non_snake_case)]
167 fn COLUMN1_SMALLINT() -> DynProofExpr {
168 DynProofExpr::new_column(ColumnRef::new(
169 TableRef::from_names(Some("namespace"), "table_name"),
170 "column1".into(),
171 ColumnType::SmallInt,
172 ))
173 }
174
175 #[expect(non_snake_case)]
176 fn COLUMN2_BIGINT() -> DynProofExpr {
177 DynProofExpr::new_column(ColumnRef::new(
178 TableRef::from_names(Some("namespace"), "table_name"),
179 "column2".into(),
180 ColumnType::BigInt,
181 ))
182 }
183
184 #[expect(non_snake_case)]
185 fn COLUMN1_BOOLEAN() -> DynProofExpr {
186 DynProofExpr::new_column(ColumnRef::new(
187 TableRef::from_names(Some("namespace"), "table_name"),
188 "column1".into(),
189 ColumnType::Boolean,
190 ))
191 }
192
193 #[expect(non_snake_case)]
194 fn COLUMN2_BOOLEAN() -> DynProofExpr {
195 DynProofExpr::new_column(ColumnRef::new(
196 TableRef::from_names(Some("namespace"), "table_name"),
197 "column2".into(),
198 ColumnType::Boolean,
199 ))
200 }
201
202 #[expect(non_snake_case)]
203 fn COLUMN3_DECIMAL_75_5() -> DynProofExpr {
204 DynProofExpr::new_column(ColumnRef::new(
205 TableRef::from_names(Some("namespace"), "table_name"),
206 "column3".into(),
207 ColumnType::Decimal75(
208 Precision::new(75).expect("Precision is definitely valid"),
209 5,
210 ),
211 ))
212 }
213
214 #[expect(non_snake_case)]
215 fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
216 DynProofExpr::new_column(ColumnRef::new(
217 TableRef::from_names(Some("namespace"), "table_name"),
218 "column2".into(),
219 ColumnType::Decimal75(
220 Precision::new(25).expect("Precision is definitely valid"),
221 5,
222 ),
223 ))
224 }
225
226 #[test]
228 fn we_can_convert_alias_to_proof_expr() {
229 let expr = df_column("namespace.table_name", "column").alias("alias");
231 let schema = vec![("column".into(), ColumnType::Int)];
232 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
233 }
234
235 #[test]
237 fn we_can_convert_column_expr_to_proof_expr() {
238 let expr = df_column("namespace.table_name", "column");
240 let schema = vec![("column".into(), ColumnType::Int)];
241 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
242 }
243
244 #[test]
246 fn we_can_convert_comparison_binary_expr_to_proof_expr() {
247 let schema = vec![
248 ("column1".into(), ColumnType::SmallInt),
249 ("column2".into(), ColumnType::BigInt),
250 ];
251
252 let expr = df_column("namespace.table_name", "column1")
254 .eq(df_column("namespace.table_name", "column2"));
255 assert_eq!(
256 expr_to_proof_expr(&expr, &schema).unwrap(),
257 DynProofExpr::try_new_equals(COLUMN1_SMALLINT(), COLUMN2_BIGINT()).unwrap()
258 );
259
260 let expr = df_column("namespace.table_name", "column1")
262 .lt(df_column("namespace.table_name", "column2"));
263 assert_eq!(
264 expr_to_proof_expr(&expr, &schema).unwrap(),
265 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true).unwrap()
266 );
267
268 let expr = df_column("namespace.table_name", "column1")
270 .gt(df_column("namespace.table_name", "column2"));
271 assert_eq!(
272 expr_to_proof_expr(&expr, &schema).unwrap(),
273 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false).unwrap()
274 );
275
276 let expr = df_column("namespace.table_name", "column1")
278 .lt_eq(df_column("namespace.table_name", "column2"));
279 assert_eq!(
280 expr_to_proof_expr(&expr, &schema).unwrap(),
281 DynProofExpr::try_new_not(
282 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false)
283 .unwrap()
284 )
285 .unwrap()
286 );
287
288 let expr = df_column("namespace.table_name", "column1")
290 .gt_eq(df_column("namespace.table_name", "column2"));
291 assert_eq!(
292 expr_to_proof_expr(&expr, &schema).unwrap(),
293 DynProofExpr::try_new_not(
294 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true)
295 .unwrap()
296 )
297 .unwrap()
298 );
299 }
300
301 #[expect(clippy::too_many_lines)]
302 #[test]
303 fn we_can_convert_comparison_binary_expr_to_proof_expr_with_scale_cast() {
304 let schema = vec![
305 ("column1".into(), ColumnType::SmallInt),
306 (
307 "column2".into(),
308 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
309 ),
310 (
311 "column3".into(),
312 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
313 ),
314 ];
315
316 let expr = df_column("namespace.table_name", "column1")
318 .eq(df_column("namespace.table_name", "column3"));
319 assert_eq!(
320 expr_to_proof_expr(&expr, &schema).unwrap(),
321 DynProofExpr::try_new_equals(
322 DynProofExpr::try_new_scaling_cast(
323 COLUMN1_SMALLINT(),
324 ColumnType::Decimal75(
325 Precision::new(10).expect("Precision is definitely valid"),
326 5
327 )
328 )
329 .unwrap(),
330 COLUMN3_DECIMAL_75_5()
331 )
332 .unwrap()
333 );
334
335 let expr = df_column("namespace.table_name", "column1")
337 .lt(df_column("namespace.table_name", "column2"));
338 assert_eq!(
339 expr_to_proof_expr(&expr, &schema).unwrap(),
340 DynProofExpr::try_new_inequality(
341 DynProofExpr::try_new_scaling_cast(
342 COLUMN1_SMALLINT(),
343 ColumnType::Decimal75(
344 Precision::new(10).expect("Precision is definitely valid"),
345 5
346 )
347 )
348 .unwrap(),
349 COLUMN2_DECIMAL_25_5(),
350 true
351 )
352 .unwrap()
353 );
354
355 let expr = df_column("namespace.table_name", "column1")
357 .gt(df_column("namespace.table_name", "column2"));
358 assert_eq!(
359 expr_to_proof_expr(&expr, &schema).unwrap(),
360 DynProofExpr::try_new_inequality(
361 DynProofExpr::try_new_scaling_cast(
362 COLUMN1_SMALLINT(),
363 ColumnType::Decimal75(
364 Precision::new(10).expect("Precision is definitely valid"),
365 5
366 )
367 )
368 .unwrap(),
369 COLUMN2_DECIMAL_25_5(),
370 false
371 )
372 .unwrap()
373 );
374
375 let expr = df_column("namespace.table_name", "column1")
377 .lt_eq(df_column("namespace.table_name", "column2"));
378 assert_eq!(
379 expr_to_proof_expr(&expr, &schema).unwrap(),
380 DynProofExpr::try_new_not(
381 DynProofExpr::try_new_inequality(
382 DynProofExpr::try_new_scaling_cast(
383 COLUMN1_SMALLINT(),
384 ColumnType::Decimal75(
385 Precision::new(10).expect("Precision is definitely valid"),
386 5
387 )
388 )
389 .unwrap(),
390 COLUMN2_DECIMAL_25_5(),
391 false
392 )
393 .unwrap()
394 )
395 .unwrap()
396 );
397
398 let expr = df_column("namespace.table_name", "column1")
400 .gt_eq(df_column("namespace.table_name", "column2"));
401 assert_eq!(
402 expr_to_proof_expr(&expr, &schema).unwrap(),
403 DynProofExpr::try_new_not(
404 DynProofExpr::try_new_inequality(
405 DynProofExpr::try_new_scaling_cast(
406 COLUMN1_SMALLINT(),
407 ColumnType::Decimal75(
408 Precision::new(10).expect("Precision is definitely valid"),
409 5
410 )
411 )
412 .unwrap(),
413 COLUMN2_DECIMAL_25_5(),
414 true
415 )
416 .unwrap()
417 )
418 .unwrap()
419 );
420 }
421
422 #[test]
423 fn we_can_convert_arithmetic_binary_expr_to_proof_expr() {
424 let schema = vec![
425 ("column1".into(), ColumnType::SmallInt),
426 ("column2".into(), ColumnType::BigInt),
427 ];
428
429 let expr = Expr::BinaryExpr(BinaryExpr {
431 left: Box::new(df_column("namespace.table_name", "column1")),
432 right: Box::new(df_column("namespace.table_name", "column2")),
433 op: Operator::Plus,
434 });
435 assert_eq!(
436 expr_to_proof_expr(&expr, &schema).unwrap(),
437 DynProofExpr::try_new_add(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
438 );
439
440 let expr = Expr::BinaryExpr(BinaryExpr {
442 left: Box::new(df_column("namespace.table_name", "column1")),
443 right: Box::new(df_column("namespace.table_name", "column2")),
444 op: Operator::Minus,
445 });
446 assert_eq!(
447 expr_to_proof_expr(&expr, &schema).unwrap(),
448 DynProofExpr::try_new_subtract(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
449 );
450
451 let expr = Expr::BinaryExpr(BinaryExpr {
453 left: Box::new(df_column("namespace.table_name", "column1")),
454 right: Box::new(df_column("namespace.table_name", "column2")),
455 op: Operator::Multiply,
456 });
457 assert_eq!(
458 expr_to_proof_expr(&expr, &schema).unwrap(),
459 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
460 );
461 }
462
463 #[test]
464 fn we_can_convert_arithmetic_binary_expr_to_proof_expr_with_scale_cast() {
465 let schema = vec![
466 ("column1".into(), ColumnType::SmallInt),
467 (
468 "column2".into(),
469 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
470 ),
471 (
472 "column3".into(),
473 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
474 ),
475 ];
476
477 let expr = df_column("namespace.table_name", "column1")
479 .add(df_column("namespace.table_name", "column2"));
480 assert_eq!(
481 expr_to_proof_expr(&expr, &schema).unwrap(),
482 DynProofExpr::try_new_add(
483 DynProofExpr::try_new_scaling_cast(
484 COLUMN1_SMALLINT(),
485 ColumnType::Decimal75(
486 Precision::new(10).expect("Precision is definitely valid"),
487 5
488 )
489 )
490 .unwrap(),
491 COLUMN2_DECIMAL_25_5()
492 )
493 .unwrap()
494 );
495
496 let expr = df_column("namespace.table_name", "column1")
498 .sub(df_column("namespace.table_name", "column2"));
499 assert_eq!(
500 expr_to_proof_expr(&expr, &schema).unwrap(),
501 DynProofExpr::try_new_subtract(
502 DynProofExpr::try_new_scaling_cast(
503 COLUMN1_SMALLINT(),
504 ColumnType::Decimal75(
505 Precision::new(10).expect("Precision is definitely valid"),
506 5
507 )
508 )
509 .unwrap(),
510 COLUMN2_DECIMAL_25_5()
511 )
512 .unwrap()
513 );
514
515 let expr = df_column("namespace.table_name", "column1")
517 .mul(df_column("namespace.table_name", "column2"));
518 assert_eq!(
519 expr_to_proof_expr(&expr, &schema).unwrap(),
520 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_DECIMAL_25_5()).unwrap()
521 );
522 }
523
524 #[test]
525 fn we_can_convert_logical_binary_expr_to_proof_expr() {
526 let schema = vec![
527 ("column1".into(), ColumnType::Boolean),
528 ("column2".into(), ColumnType::Boolean),
529 ];
530
531 let expr = df_column("namespace.table_name", "column1")
533 .and(df_column("namespace.table_name", "column2"));
534 assert_eq!(
535 expr_to_proof_expr(&expr, &schema).unwrap(),
536 DynProofExpr::try_new_and(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
537 );
538
539 let expr = df_column("namespace.table_name", "column1")
541 .or(df_column("namespace.table_name", "column2"));
542 assert_eq!(
543 expr_to_proof_expr(&expr, &schema).unwrap(),
544 DynProofExpr::try_new_or(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
545 );
546 }
547
548 #[test]
549 fn we_cannot_convert_unsupported_binary_expr_to_proof_expr() {
550 let expr = Expr::BinaryExpr(BinaryExpr {
552 left: Box::new(df_column("namespace.table_name", "column1")),
553 right: Box::new(df_column("namespace.table_name", "column2")),
554 op: Operator::AtArrow,
555 });
556 let schema = vec![
557 ("column1".into(), ColumnType::Boolean),
558 ("column2".into(), ColumnType::Boolean),
559 ];
560 assert!(matches!(
561 expr_to_proof_expr(&expr, &schema),
562 Err(PlannerError::UnsupportedBinaryOperator { .. })
563 ));
564 }
565
566 #[test]
568 fn we_can_convert_literal_expr_to_proof_expr() {
569 let expr = Expr::Literal(ScalarValue::Int32(Some(1)));
570 assert_eq!(
571 expr_to_proof_expr(&expr, &Vec::new()).unwrap(),
572 DynProofExpr::new_literal(LiteralValue::Int(1))
573 );
574 }
575
576 #[test]
578 fn we_can_convert_not_expr_to_proof_expr() {
579 let expr = Expr::Not(Box::new(df_column("table_name", "column")));
580 let schema = vec![("column".into(), ColumnType::Boolean)];
581 assert_eq!(
582 expr_to_proof_expr(&expr, &schema).unwrap(),
583 DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
584 TableRef::from_names(None, "table_name"),
585 "column".into(),
586 ColumnType::Boolean
587 )))
588 .unwrap()
589 );
590 }
591
592 #[test]
594 fn we_can_convert_cast_expr_to_proof_expr() {
595 let expr = Expr::Cast(Cast::new(
596 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
597 DataType::Int32,
598 ));
599 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
600 assert_eq!(
601 expression,
602 DynProofExpr::try_new_cast(
603 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
604 ColumnType::Int
605 )
606 .unwrap()
607 );
608 }
609
610 #[test]
611 fn we_cannot_convert_cast_expr_to_proof_expr_when_inner_expr_to_proof_expr_fails() {
612 let expr = Expr::Cast(Cast::new(
614 Box::new(Expr::Literal(ScalarValue::UInt64(Some(100)))),
615 DataType::Int16,
616 ));
617 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
618 assert!(matches!(
619 expression,
620 PlannerError::UnsupportedDataType { data_type: _ }
621 ));
622 }
623
624 #[test]
625 fn we_cannot_convert_cast_expr_to_proof_expr_for_unsupported_datatypes() {
626 let expr = Expr::Cast(Cast::new(
628 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
629 DataType::UInt16,
630 ));
631 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
632 assert!(matches!(
633 expression,
634 PlannerError::UnsupportedDataType { data_type: _ }
635 ));
636 }
637
638 #[test]
639 fn we_cannot_convert_cast_expr_to_proof_expr_for_datatypes_for_which_casting_is_not_supported()
640 {
641 let expr = Expr::Cast(Cast::new(
643 Box::new(Expr::Literal(ScalarValue::Int16(Some(100)))),
644 DataType::Boolean,
645 ));
646 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
647 assert!(matches!(
648 expression,
649 PlannerError::AnalyzeError { source: _ }
650 ));
651 }
652
653 #[test]
655 fn we_can_convert_placeholder_to_proof_expr() {
656 let expr = Expr::Placeholder(Placeholder {
657 id: "$1".to_string(),
658 data_type: Some(DataType::Int32),
659 });
660 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
661 assert_eq!(
662 expression,
663 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
664 );
665 }
666
667 #[test]
669 fn we_can_convert_placeholder_with_data_type_specified_by_cast_to_proof_expr() {
670 let expr = Expr::Cast(Cast::new(
671 Box::new(Expr::Placeholder(Placeholder {
672 id: "$1".to_string(),
673 data_type: None,
674 })),
675 DataType::Int32,
676 ));
677 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
678 assert_eq!(
679 expression,
680 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
681 );
682 }
683
684 #[test]
686 fn we_cannot_convert_unsupported_expr_to_proof_expr() {
687 let expr = Expr::Unnest(Unnest::new(Expr::Literal(ScalarValue::Int32(Some(100)))));
688 assert!(matches!(
689 expr_to_proof_expr(&expr, &Vec::new()),
690 Err(PlannerError::UnsupportedLogicalExpression { .. })
691 ));
692 }
693
694 #[test]
695 fn we_can_get_proof_expr_for_timestamps_of_different_scale() {
696 let lhs = Expr::Literal(ScalarValue::TimestampSecond(Some(1), None));
697 let rhs = Expr::Literal(ScalarValue::TimestampNanosecond(Some(1), None));
698 binary_expr_to_proof_expr(&lhs, &rhs, Operator::Gt, &Vec::new()).unwrap();
699 }
700}