1use super::{
2 column_to_column_ref, placeholder_to_placeholder_expr, scalar_value_to_literal_value,
3 PlannerError, PlannerResult,
4};
5use datafusion::logical_expr::{
6 expr::{Alias, Placeholder},
7 BinaryExpr, Expr, Operator,
8};
9use proof_of_sql::{
10 base::database::ColumnType,
11 sql::{proof_exprs::DynProofExpr, scale_cast_binary_op},
12};
13use sqlparser::ast::Ident;
14
15#[expect(
17 clippy::missing_panics_doc,
18 reason = "Output of comparisons is always boolean"
19)]
20fn binary_expr_to_proof_expr(
21 left: &Expr,
22 right: &Expr,
23 op: Operator,
24 schema: &[(Ident, ColumnType)],
25) -> PlannerResult<DynProofExpr> {
26 let left_proof_expr = expr_to_proof_expr(left, schema)?;
27 let right_proof_expr = expr_to_proof_expr(right, schema)?;
28
29 let (left_proof_expr, right_proof_expr) = match op {
30 Operator::Eq
31 | Operator::Lt
32 | Operator::Gt
33 | Operator::LtEq
34 | Operator::GtEq
35 | Operator::Plus
36 | Operator::Minus => scale_cast_binary_op(left_proof_expr, right_proof_expr)?,
37 _ => (left_proof_expr, right_proof_expr),
38 };
39
40 match op {
41 Operator::And => Ok(DynProofExpr::try_new_and(
42 left_proof_expr,
43 right_proof_expr,
44 )?),
45 Operator::Or => Ok(DynProofExpr::try_new_or(left_proof_expr, right_proof_expr)?),
46 Operator::Multiply => Ok(DynProofExpr::try_new_multiply(
47 left_proof_expr,
48 right_proof_expr,
49 )?),
50 Operator::Eq => Ok(DynProofExpr::try_new_equals(
51 left_proof_expr,
52 right_proof_expr,
53 )?),
54 Operator::Lt => Ok(DynProofExpr::try_new_inequality(
55 left_proof_expr,
56 right_proof_expr,
57 true,
58 )?),
59 Operator::Gt => Ok(DynProofExpr::try_new_inequality(
60 left_proof_expr,
61 right_proof_expr,
62 false,
63 )?),
64 Operator::LtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
65 left_proof_expr,
66 right_proof_expr,
67 false,
68 )?)
69 .expect("An inequality expression must have a boolean data type...")),
70 Operator::GtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
71 left_proof_expr,
72 right_proof_expr,
73 true,
74 )?)
75 .expect("An inequality expression must have a boolean data type...")),
76 Operator::Plus => Ok(DynProofExpr::try_new_add(
77 left_proof_expr,
78 right_proof_expr,
79 )?),
80 Operator::Minus => Ok(DynProofExpr::try_new_subtract(
81 left_proof_expr,
82 right_proof_expr,
83 )?),
84 _ => Err(PlannerError::UnsupportedBinaryOperator { op }),
86 }
87}
88
89pub fn expr_to_proof_expr(
94 expr: &Expr,
95 schema: &[(Ident, ColumnType)],
96) -> PlannerResult<DynProofExpr> {
97 match expr {
98 Expr::Alias(Alias { expr, .. }) => expr_to_proof_expr(expr, schema),
99 Expr::Column(col) => Ok(DynProofExpr::new_column(column_to_column_ref(col, schema)?)),
100 Expr::Placeholder(placeholder) => placeholder_to_placeholder_expr(placeholder),
101 Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
102 binary_expr_to_proof_expr(left, right, *op, schema)
103 }
104 Expr::Literal(val) => Ok(DynProofExpr::new_literal(scalar_value_to_literal_value(
105 val.clone(),
106 )?)),
107 Expr::Not(expr) => {
108 let proof_expr = expr_to_proof_expr(expr, schema)?;
109 Ok(DynProofExpr::try_new_not(proof_expr)?)
110 }
111 Expr::Cast(cast) => {
112 match &*cast.expr {
113 Expr::Placeholder(placeholder) if placeholder.data_type.is_none() => {
115 let typed_placeholder =
116 Placeholder::new(placeholder.id.clone(), Some(cast.data_type.clone()));
117 placeholder_to_placeholder_expr(&typed_placeholder)
118 }
119 _ => {
120 let from_expr = expr_to_proof_expr(&cast.expr, schema)?;
121 let to_type = cast.data_type.clone().try_into().map_err(|_| {
122 PlannerError::UnsupportedDataType {
123 data_type: cast.data_type.clone(),
124 }
125 })?;
126 Ok(
127 DynProofExpr::try_new_cast(from_expr.clone(), to_type).map_or_else(
128 |_| DynProofExpr::try_new_scaling_cast(from_expr, to_type),
129 Ok,
130 )?,
131 )
132 }
133 }
134 }
135 _ => Err(PlannerError::UnsupportedLogicalExpression {
136 expr: Box::new(expr.clone()),
137 }),
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::df_util::*;
145 use arrow::datatypes::DataType;
146 use core::ops::{Add, Mul, Sub};
147 use datafusion::{
148 common::ScalarValue,
149 logical_expr::{
150 expr::{Placeholder, Unnest},
151 Cast,
152 },
153 };
154 use proof_of_sql::base::{
155 database::{ColumnRef, ColumnType, LiteralValue, TableRef},
156 math::decimal::Precision,
157 };
158
159 #[expect(non_snake_case)]
160 fn COLUMN_INT() -> DynProofExpr {
161 DynProofExpr::new_column(ColumnRef::new(
162 TableRef::from_names(Some("namespace"), "table_name"),
163 "column".into(),
164 ColumnType::Int,
165 ))
166 }
167
168 #[expect(non_snake_case)]
169 fn COLUMN1_SMALLINT() -> DynProofExpr {
170 DynProofExpr::new_column(ColumnRef::new(
171 TableRef::from_names(Some("namespace"), "table_name"),
172 "column1".into(),
173 ColumnType::SmallInt,
174 ))
175 }
176
177 #[expect(non_snake_case)]
178 fn COLUMN2_BIGINT() -> DynProofExpr {
179 DynProofExpr::new_column(ColumnRef::new(
180 TableRef::from_names(Some("namespace"), "table_name"),
181 "column2".into(),
182 ColumnType::BigInt,
183 ))
184 }
185
186 #[expect(non_snake_case)]
187 fn COLUMN1_BOOLEAN() -> DynProofExpr {
188 DynProofExpr::new_column(ColumnRef::new(
189 TableRef::from_names(Some("namespace"), "table_name"),
190 "column1".into(),
191 ColumnType::Boolean,
192 ))
193 }
194
195 #[expect(non_snake_case)]
196 fn COLUMN2_BOOLEAN() -> DynProofExpr {
197 DynProofExpr::new_column(ColumnRef::new(
198 TableRef::from_names(Some("namespace"), "table_name"),
199 "column2".into(),
200 ColumnType::Boolean,
201 ))
202 }
203
204 #[expect(non_snake_case)]
205 fn COLUMN3_DECIMAL_75_5() -> DynProofExpr {
206 DynProofExpr::new_column(ColumnRef::new(
207 TableRef::from_names(Some("namespace"), "table_name"),
208 "column3".into(),
209 ColumnType::Decimal75(
210 Precision::new(75).expect("Precision is definitely valid"),
211 5,
212 ),
213 ))
214 }
215
216 #[expect(non_snake_case)]
217 fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
218 DynProofExpr::new_column(ColumnRef::new(
219 TableRef::from_names(Some("namespace"), "table_name"),
220 "column2".into(),
221 ColumnType::Decimal75(
222 Precision::new(25).expect("Precision is definitely valid"),
223 5,
224 ),
225 ))
226 }
227
228 #[test]
230 fn we_can_convert_alias_to_proof_expr() {
231 let expr = df_column("namespace.table_name", "column").alias("alias");
233 let schema = vec![("column".into(), ColumnType::Int)];
234 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
235 }
236
237 #[test]
239 fn we_can_convert_column_expr_to_proof_expr() {
240 let expr = df_column("namespace.table_name", "column");
242 let schema = vec![("column".into(), ColumnType::Int)];
243 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
244 }
245
246 #[test]
248 fn we_can_convert_comparison_binary_expr_to_proof_expr() {
249 let schema = vec![
250 ("column1".into(), ColumnType::SmallInt),
251 ("column2".into(), ColumnType::BigInt),
252 ];
253
254 let expr = df_column("namespace.table_name", "column1")
256 .eq(df_column("namespace.table_name", "column2"));
257 assert_eq!(
258 expr_to_proof_expr(&expr, &schema).unwrap(),
259 DynProofExpr::try_new_equals(COLUMN1_SMALLINT(), COLUMN2_BIGINT()).unwrap()
260 );
261
262 let expr = df_column("namespace.table_name", "column1")
264 .lt(df_column("namespace.table_name", "column2"));
265 assert_eq!(
266 expr_to_proof_expr(&expr, &schema).unwrap(),
267 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true).unwrap()
268 );
269
270 let expr = df_column("namespace.table_name", "column1")
272 .gt(df_column("namespace.table_name", "column2"));
273 assert_eq!(
274 expr_to_proof_expr(&expr, &schema).unwrap(),
275 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false).unwrap()
276 );
277
278 let expr = df_column("namespace.table_name", "column1")
280 .lt_eq(df_column("namespace.table_name", "column2"));
281 assert_eq!(
282 expr_to_proof_expr(&expr, &schema).unwrap(),
283 DynProofExpr::try_new_not(
284 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false)
285 .unwrap()
286 )
287 .unwrap()
288 );
289
290 let expr = df_column("namespace.table_name", "column1")
292 .gt_eq(df_column("namespace.table_name", "column2"));
293 assert_eq!(
294 expr_to_proof_expr(&expr, &schema).unwrap(),
295 DynProofExpr::try_new_not(
296 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true)
297 .unwrap()
298 )
299 .unwrap()
300 );
301 }
302
303 #[expect(clippy::too_many_lines)]
304 #[test]
305 fn we_can_convert_comparison_binary_expr_to_proof_expr_with_scale_cast() {
306 let schema = vec![
307 ("column1".into(), ColumnType::SmallInt),
308 (
309 "column2".into(),
310 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
311 ),
312 (
313 "column3".into(),
314 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
315 ),
316 ];
317
318 let expr = df_column("namespace.table_name", "column1")
320 .eq(df_column("namespace.table_name", "column3"));
321 assert_eq!(
322 expr_to_proof_expr(&expr, &schema).unwrap(),
323 DynProofExpr::try_new_equals(
324 DynProofExpr::try_new_scaling_cast(
325 COLUMN1_SMALLINT(),
326 ColumnType::Decimal75(
327 Precision::new(10).expect("Precision is definitely valid"),
328 5
329 )
330 )
331 .unwrap(),
332 COLUMN3_DECIMAL_75_5()
333 )
334 .unwrap()
335 );
336
337 let expr = df_column("namespace.table_name", "column1")
339 .lt(df_column("namespace.table_name", "column2"));
340 assert_eq!(
341 expr_to_proof_expr(&expr, &schema).unwrap(),
342 DynProofExpr::try_new_inequality(
343 DynProofExpr::try_new_scaling_cast(
344 COLUMN1_SMALLINT(),
345 ColumnType::Decimal75(
346 Precision::new(10).expect("Precision is definitely valid"),
347 5
348 )
349 )
350 .unwrap(),
351 COLUMN2_DECIMAL_25_5(),
352 true
353 )
354 .unwrap()
355 );
356
357 let expr = df_column("namespace.table_name", "column1")
359 .gt(df_column("namespace.table_name", "column2"));
360 assert_eq!(
361 expr_to_proof_expr(&expr, &schema).unwrap(),
362 DynProofExpr::try_new_inequality(
363 DynProofExpr::try_new_scaling_cast(
364 COLUMN1_SMALLINT(),
365 ColumnType::Decimal75(
366 Precision::new(10).expect("Precision is definitely valid"),
367 5
368 )
369 )
370 .unwrap(),
371 COLUMN2_DECIMAL_25_5(),
372 false
373 )
374 .unwrap()
375 );
376
377 let expr = df_column("namespace.table_name", "column1")
379 .lt_eq(df_column("namespace.table_name", "column2"));
380 assert_eq!(
381 expr_to_proof_expr(&expr, &schema).unwrap(),
382 DynProofExpr::try_new_not(
383 DynProofExpr::try_new_inequality(
384 DynProofExpr::try_new_scaling_cast(
385 COLUMN1_SMALLINT(),
386 ColumnType::Decimal75(
387 Precision::new(10).expect("Precision is definitely valid"),
388 5
389 )
390 )
391 .unwrap(),
392 COLUMN2_DECIMAL_25_5(),
393 false
394 )
395 .unwrap()
396 )
397 .unwrap()
398 );
399
400 let expr = df_column("namespace.table_name", "column1")
402 .gt_eq(df_column("namespace.table_name", "column2"));
403 assert_eq!(
404 expr_to_proof_expr(&expr, &schema).unwrap(),
405 DynProofExpr::try_new_not(
406 DynProofExpr::try_new_inequality(
407 DynProofExpr::try_new_scaling_cast(
408 COLUMN1_SMALLINT(),
409 ColumnType::Decimal75(
410 Precision::new(10).expect("Precision is definitely valid"),
411 5
412 )
413 )
414 .unwrap(),
415 COLUMN2_DECIMAL_25_5(),
416 true
417 )
418 .unwrap()
419 )
420 .unwrap()
421 );
422 }
423
424 #[test]
425 fn we_can_convert_arithmetic_binary_expr_to_proof_expr() {
426 let schema = vec![
427 ("column1".into(), ColumnType::SmallInt),
428 ("column2".into(), ColumnType::BigInt),
429 ];
430
431 let expr = Expr::BinaryExpr(BinaryExpr {
433 left: Box::new(df_column("namespace.table_name", "column1")),
434 right: Box::new(df_column("namespace.table_name", "column2")),
435 op: Operator::Plus,
436 });
437 assert_eq!(
438 expr_to_proof_expr(&expr, &schema).unwrap(),
439 DynProofExpr::try_new_add(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
440 );
441
442 let expr = Expr::BinaryExpr(BinaryExpr {
444 left: Box::new(df_column("namespace.table_name", "column1")),
445 right: Box::new(df_column("namespace.table_name", "column2")),
446 op: Operator::Minus,
447 });
448 assert_eq!(
449 expr_to_proof_expr(&expr, &schema).unwrap(),
450 DynProofExpr::try_new_subtract(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
451 );
452
453 let expr = Expr::BinaryExpr(BinaryExpr {
455 left: Box::new(df_column("namespace.table_name", "column1")),
456 right: Box::new(df_column("namespace.table_name", "column2")),
457 op: Operator::Multiply,
458 });
459 assert_eq!(
460 expr_to_proof_expr(&expr, &schema).unwrap(),
461 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
462 );
463 }
464
465 #[test]
466 fn we_can_convert_arithmetic_binary_expr_to_proof_expr_with_scale_cast() {
467 let schema = vec![
468 ("column1".into(), ColumnType::SmallInt),
469 (
470 "column2".into(),
471 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
472 ),
473 (
474 "column3".into(),
475 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
476 ),
477 ];
478
479 let expr = df_column("namespace.table_name", "column1")
481 .add(df_column("namespace.table_name", "column2"));
482 assert_eq!(
483 expr_to_proof_expr(&expr, &schema).unwrap(),
484 DynProofExpr::try_new_add(
485 DynProofExpr::try_new_scaling_cast(
486 COLUMN1_SMALLINT(),
487 ColumnType::Decimal75(
488 Precision::new(10).expect("Precision is definitely valid"),
489 5
490 )
491 )
492 .unwrap(),
493 COLUMN2_DECIMAL_25_5()
494 )
495 .unwrap()
496 );
497
498 let expr = df_column("namespace.table_name", "column1")
500 .sub(df_column("namespace.table_name", "column2"));
501 assert_eq!(
502 expr_to_proof_expr(&expr, &schema).unwrap(),
503 DynProofExpr::try_new_subtract(
504 DynProofExpr::try_new_scaling_cast(
505 COLUMN1_SMALLINT(),
506 ColumnType::Decimal75(
507 Precision::new(10).expect("Precision is definitely valid"),
508 5
509 )
510 )
511 .unwrap(),
512 COLUMN2_DECIMAL_25_5()
513 )
514 .unwrap()
515 );
516
517 let expr = df_column("namespace.table_name", "column1")
519 .mul(df_column("namespace.table_name", "column2"));
520 assert_eq!(
521 expr_to_proof_expr(&expr, &schema).unwrap(),
522 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_DECIMAL_25_5()).unwrap()
523 );
524 }
525
526 #[test]
527 fn we_can_convert_logical_binary_expr_to_proof_expr() {
528 let schema = vec![
529 ("column1".into(), ColumnType::Boolean),
530 ("column2".into(), ColumnType::Boolean),
531 ];
532
533 let expr = df_column("namespace.table_name", "column1")
535 .and(df_column("namespace.table_name", "column2"));
536 assert_eq!(
537 expr_to_proof_expr(&expr, &schema).unwrap(),
538 DynProofExpr::try_new_and(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
539 );
540
541 let expr = df_column("namespace.table_name", "column1")
543 .or(df_column("namespace.table_name", "column2"));
544 assert_eq!(
545 expr_to_proof_expr(&expr, &schema).unwrap(),
546 DynProofExpr::try_new_or(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
547 );
548 }
549
550 #[test]
551 fn we_cannot_convert_unsupported_binary_expr_to_proof_expr() {
552 let expr = Expr::BinaryExpr(BinaryExpr {
554 left: Box::new(df_column("namespace.table_name", "column1")),
555 right: Box::new(df_column("namespace.table_name", "column2")),
556 op: Operator::AtArrow,
557 });
558 let schema = vec![
559 ("column1".into(), ColumnType::Boolean),
560 ("column2".into(), ColumnType::Boolean),
561 ];
562 assert!(matches!(
563 expr_to_proof_expr(&expr, &schema),
564 Err(PlannerError::UnsupportedBinaryOperator { .. })
565 ));
566 }
567
568 #[test]
570 fn we_can_convert_literal_expr_to_proof_expr() {
571 let expr = Expr::Literal(ScalarValue::Int32(Some(1)));
572 assert_eq!(
573 expr_to_proof_expr(&expr, &Vec::new()).unwrap(),
574 DynProofExpr::new_literal(LiteralValue::Int(1))
575 );
576 }
577
578 #[test]
580 fn we_can_convert_not_expr_to_proof_expr() {
581 let expr = Expr::Not(Box::new(df_column("table_name", "column")));
582 let schema = vec![("column".into(), ColumnType::Boolean)];
583 assert_eq!(
584 expr_to_proof_expr(&expr, &schema).unwrap(),
585 DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
586 TableRef::from_names(None, "table_name"),
587 "column".into(),
588 ColumnType::Boolean
589 )))
590 .unwrap()
591 );
592 }
593
594 #[test]
596 fn we_can_convert_cast_expr_to_proof_expr() {
597 let expr = Expr::Cast(Cast::new(
598 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
599 DataType::Int32,
600 ));
601 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
602 assert_eq!(
603 expression,
604 DynProofExpr::try_new_cast(
605 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
606 ColumnType::Int
607 )
608 .unwrap()
609 );
610 }
611
612 #[test]
613 fn we_cannot_convert_cast_expr_to_proof_expr_when_inner_expr_to_proof_expr_fails() {
614 let expr = Expr::Cast(Cast::new(
616 Box::new(Expr::Literal(ScalarValue::UInt64(Some(100)))),
617 DataType::Int16,
618 ));
619 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
620 assert!(matches!(
621 expression,
622 PlannerError::UnsupportedDataType { data_type: _ }
623 ));
624 }
625
626 #[test]
627 fn we_cannot_convert_cast_expr_to_proof_expr_for_unsupported_datatypes() {
628 let expr = Expr::Cast(Cast::new(
630 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
631 DataType::UInt16,
632 ));
633 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
634 assert!(matches!(
635 expression,
636 PlannerError::UnsupportedDataType { data_type: _ }
637 ));
638 }
639
640 #[test]
641 fn we_cannot_convert_cast_expr_to_proof_expr_for_datatypes_for_which_casting_is_not_supported()
642 {
643 let expr = Expr::Cast(Cast::new(
645 Box::new(Expr::Literal(ScalarValue::Int16(Some(100)))),
646 DataType::Boolean,
647 ));
648 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
649 assert!(matches!(
650 expression,
651 PlannerError::AnalyzeError { source: _ }
652 ));
653 }
654
655 #[test]
657 fn we_can_convert_placeholder_to_proof_expr() {
658 let expr = Expr::Placeholder(Placeholder {
659 id: "$1".to_string(),
660 data_type: Some(DataType::Int32),
661 });
662 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
663 assert_eq!(
664 expression,
665 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
666 );
667 }
668
669 #[test]
671 fn we_can_convert_placeholder_with_data_type_specified_by_cast_to_proof_expr() {
672 let expr = Expr::Cast(Cast::new(
673 Box::new(Expr::Placeholder(Placeholder {
674 id: "$1".to_string(),
675 data_type: None,
676 })),
677 DataType::Int32,
678 ));
679 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
680 assert_eq!(
681 expression,
682 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
683 );
684 }
685
686 #[test]
688 fn we_cannot_convert_unsupported_expr_to_proof_expr() {
689 let expr = Expr::Unnest(Unnest::new(Expr::Literal(ScalarValue::Int32(Some(100)))));
690 assert!(matches!(
691 expr_to_proof_expr(&expr, &Vec::new()),
692 Err(PlannerError::UnsupportedLogicalExpression { .. })
693 ));
694 }
695
696 #[test]
697 fn we_can_get_proof_expr_for_timestamps_of_different_scale() {
698 let lhs = Expr::Literal(ScalarValue::TimestampSecond(Some(1), None));
699 let rhs = Expr::Literal(ScalarValue::TimestampNanosecond(Some(1), None));
700 binary_expr_to_proof_expr(&lhs, &rhs, Operator::Gt, &Vec::new()).unwrap();
701 }
702}