1use super::{
2 column_to_column_ref, placeholder_to_placeholder_expr, scalar_value_to_literal_value,
3 PlannerError, PlannerResult,
4};
5use datafusion::logical_expr::{
6 expr::{Alias, Placeholder},
7 BinaryExpr, Expr, Operator,
8};
9use proof_of_sql::{
10 base::database::ColumnType,
11 sql::{proof_exprs::DynProofExpr, scale_cast_binary_op},
12};
13use sqlparser::ast::Ident;
14
15#[expect(
17 clippy::missing_panics_doc,
18 reason = "Output of comparisons is always boolean"
19)]
20fn binary_expr_to_proof_expr(
21 left: &Expr,
22 right: &Expr,
23 op: Operator,
24 schema: &[(Ident, ColumnType)],
25) -> PlannerResult<DynProofExpr> {
26 let left_proof_expr = expr_to_proof_expr(left, schema)?;
27 let right_proof_expr = expr_to_proof_expr(right, schema)?;
28
29 let (left_proof_expr, right_proof_expr) = match op {
30 Operator::Eq
31 | Operator::Lt
32 | Operator::Gt
33 | Operator::LtEq
34 | Operator::GtEq
35 | Operator::Plus
36 | Operator::Minus => scale_cast_binary_op(left_proof_expr, right_proof_expr)?,
37 _ => (left_proof_expr, right_proof_expr),
38 };
39
40 match op {
41 Operator::And => Ok(DynProofExpr::try_new_and(
42 left_proof_expr,
43 right_proof_expr,
44 )?),
45 Operator::Or => Ok(DynProofExpr::try_new_or(left_proof_expr, right_proof_expr)?),
46 Operator::Multiply => Ok(DynProofExpr::try_new_multiply(
47 left_proof_expr,
48 right_proof_expr,
49 )?),
50 Operator::Eq => Ok(DynProofExpr::try_new_equals(
51 left_proof_expr,
52 right_proof_expr,
53 )?),
54 Operator::Lt => Ok(DynProofExpr::try_new_inequality(
55 left_proof_expr,
56 right_proof_expr,
57 true,
58 )?),
59 Operator::Gt => Ok(DynProofExpr::try_new_inequality(
60 left_proof_expr,
61 right_proof_expr,
62 false,
63 )?),
64 Operator::LtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
65 left_proof_expr,
66 right_proof_expr,
67 false,
68 )?)
69 .expect("An inequality expression must have a boolean data type...")),
70 Operator::GtEq => Ok(DynProofExpr::try_new_not(DynProofExpr::try_new_inequality(
71 left_proof_expr,
72 right_proof_expr,
73 true,
74 )?)
75 .expect("An inequality expression must have a boolean data type...")),
76 Operator::Plus => Ok(DynProofExpr::try_new_add(
77 left_proof_expr,
78 right_proof_expr,
79 )?),
80 Operator::Minus => Ok(DynProofExpr::try_new_subtract(
81 left_proof_expr,
82 right_proof_expr,
83 )?),
84 _ => Err(PlannerError::UnsupportedBinaryOperator { op }),
86 }
87}
88
89pub fn expr_to_proof_expr(
94 expr: &Expr,
95 schema: &[(Ident, ColumnType)],
96) -> PlannerResult<DynProofExpr> {
97 match expr {
98 Expr::Alias(Alias { expr, .. }) => expr_to_proof_expr(expr, schema),
99 Expr::Column(col) => Ok(DynProofExpr::new_column(column_to_column_ref(col, schema)?)),
100 Expr::Placeholder(placeholder) => placeholder_to_placeholder_expr(placeholder),
101 Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
102 binary_expr_to_proof_expr(left, right, *op, schema)
103 }
104 Expr::Literal(val) => Ok(DynProofExpr::new_literal(scalar_value_to_literal_value(
105 val.clone(),
106 )?)),
107 Expr::Not(expr) => {
108 let proof_expr = expr_to_proof_expr(expr, schema)?;
109 Ok(DynProofExpr::try_new_not(proof_expr)?)
110 }
111 Expr::Cast(cast) => {
112 match &*cast.expr {
113 Expr::Placeholder(placeholder) if placeholder.data_type.is_none() => {
115 let typed_placeholder =
116 Placeholder::new(placeholder.id.clone(), Some(cast.data_type.clone()));
117 placeholder_to_placeholder_expr(&typed_placeholder)
118 }
119 _ => {
120 let from_expr = expr_to_proof_expr(&cast.expr, schema)?;
121 let to_type = cast.data_type.clone().try_into().map_err(|_| {
122 PlannerError::UnsupportedDataType {
123 data_type: cast.data_type.clone(),
124 }
125 })?;
126 Ok(DynProofExpr::try_new_cast(from_expr, to_type)?)
127 }
128 }
129 }
130 _ => Err(PlannerError::UnsupportedLogicalExpression { expr: expr.clone() }),
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::df_util::*;
138 use arrow::datatypes::DataType;
139 use core::ops::{Add, Mul, Sub};
140 use datafusion::{
141 common::ScalarValue,
142 logical_expr::{
143 expr::{Placeholder, Unnest},
144 Cast,
145 },
146 };
147 use proof_of_sql::base::{
148 database::{ColumnRef, ColumnType, LiteralValue, TableRef},
149 math::decimal::Precision,
150 };
151
152 #[expect(non_snake_case)]
153 fn COLUMN_INT() -> DynProofExpr {
154 DynProofExpr::new_column(ColumnRef::new(
155 TableRef::from_names(Some("namespace"), "table_name"),
156 "column".into(),
157 ColumnType::Int,
158 ))
159 }
160
161 #[expect(non_snake_case)]
162 fn COLUMN1_SMALLINT() -> DynProofExpr {
163 DynProofExpr::new_column(ColumnRef::new(
164 TableRef::from_names(Some("namespace"), "table_name"),
165 "column1".into(),
166 ColumnType::SmallInt,
167 ))
168 }
169
170 #[expect(non_snake_case)]
171 fn COLUMN2_BIGINT() -> DynProofExpr {
172 DynProofExpr::new_column(ColumnRef::new(
173 TableRef::from_names(Some("namespace"), "table_name"),
174 "column2".into(),
175 ColumnType::BigInt,
176 ))
177 }
178
179 #[expect(non_snake_case)]
180 fn COLUMN1_BOOLEAN() -> DynProofExpr {
181 DynProofExpr::new_column(ColumnRef::new(
182 TableRef::from_names(Some("namespace"), "table_name"),
183 "column1".into(),
184 ColumnType::Boolean,
185 ))
186 }
187
188 #[expect(non_snake_case)]
189 fn COLUMN2_BOOLEAN() -> DynProofExpr {
190 DynProofExpr::new_column(ColumnRef::new(
191 TableRef::from_names(Some("namespace"), "table_name"),
192 "column2".into(),
193 ColumnType::Boolean,
194 ))
195 }
196
197 #[expect(non_snake_case)]
198 fn COLUMN3_DECIMAL_75_5() -> DynProofExpr {
199 DynProofExpr::new_column(ColumnRef::new(
200 TableRef::from_names(Some("namespace"), "table_name"),
201 "column3".into(),
202 ColumnType::Decimal75(
203 Precision::new(75).expect("Precision is definitely valid"),
204 5,
205 ),
206 ))
207 }
208
209 #[expect(non_snake_case)]
210 fn COLUMN2_DECIMAL_25_5() -> DynProofExpr {
211 DynProofExpr::new_column(ColumnRef::new(
212 TableRef::from_names(Some("namespace"), "table_name"),
213 "column2".into(),
214 ColumnType::Decimal75(
215 Precision::new(25).expect("Precision is definitely valid"),
216 5,
217 ),
218 ))
219 }
220
221 #[test]
223 fn we_can_convert_alias_to_proof_expr() {
224 let expr = df_column("namespace.table_name", "column").alias("alias");
226 let schema = vec![("column".into(), ColumnType::Int)];
227 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
228 }
229
230 #[test]
232 fn we_can_convert_column_expr_to_proof_expr() {
233 let expr = df_column("namespace.table_name", "column");
235 let schema = vec![("column".into(), ColumnType::Int)];
236 assert_eq!(expr_to_proof_expr(&expr, &schema).unwrap(), COLUMN_INT());
237 }
238
239 #[test]
241 fn we_can_convert_comparison_binary_expr_to_proof_expr() {
242 let schema = vec![
243 ("column1".into(), ColumnType::SmallInt),
244 ("column2".into(), ColumnType::BigInt),
245 ];
246
247 let expr = df_column("namespace.table_name", "column1")
249 .eq(df_column("namespace.table_name", "column2"));
250 assert_eq!(
251 expr_to_proof_expr(&expr, &schema).unwrap(),
252 DynProofExpr::try_new_equals(COLUMN1_SMALLINT(), COLUMN2_BIGINT()).unwrap()
253 );
254
255 let expr = df_column("namespace.table_name", "column1")
257 .lt(df_column("namespace.table_name", "column2"));
258 assert_eq!(
259 expr_to_proof_expr(&expr, &schema).unwrap(),
260 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true).unwrap()
261 );
262
263 let expr = df_column("namespace.table_name", "column1")
265 .gt(df_column("namespace.table_name", "column2"));
266 assert_eq!(
267 expr_to_proof_expr(&expr, &schema).unwrap(),
268 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false).unwrap()
269 );
270
271 let expr = df_column("namespace.table_name", "column1")
273 .lt_eq(df_column("namespace.table_name", "column2"));
274 assert_eq!(
275 expr_to_proof_expr(&expr, &schema).unwrap(),
276 DynProofExpr::try_new_not(
277 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), false)
278 .unwrap()
279 )
280 .unwrap()
281 );
282
283 let expr = df_column("namespace.table_name", "column1")
285 .gt_eq(df_column("namespace.table_name", "column2"));
286 assert_eq!(
287 expr_to_proof_expr(&expr, &schema).unwrap(),
288 DynProofExpr::try_new_not(
289 DynProofExpr::try_new_inequality(COLUMN1_SMALLINT(), COLUMN2_BIGINT(), true)
290 .unwrap()
291 )
292 .unwrap()
293 );
294 }
295
296 #[expect(clippy::too_many_lines)]
297 #[test]
298 fn we_can_convert_comparison_binary_expr_to_proof_expr_with_scale_cast() {
299 let schema = vec![
300 ("column1".into(), ColumnType::SmallInt),
301 (
302 "column2".into(),
303 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
304 ),
305 (
306 "column3".into(),
307 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
308 ),
309 ];
310
311 let expr = df_column("namespace.table_name", "column1")
313 .eq(df_column("namespace.table_name", "column3"));
314 assert_eq!(
315 expr_to_proof_expr(&expr, &schema).unwrap(),
316 DynProofExpr::try_new_equals(
317 DynProofExpr::try_new_scaling_cast(
318 COLUMN1_SMALLINT(),
319 ColumnType::Decimal75(
320 Precision::new(10).expect("Precision is definitely valid"),
321 5
322 )
323 )
324 .unwrap(),
325 COLUMN3_DECIMAL_75_5()
326 )
327 .unwrap()
328 );
329
330 let expr = df_column("namespace.table_name", "column1")
332 .lt(df_column("namespace.table_name", "column2"));
333 assert_eq!(
334 expr_to_proof_expr(&expr, &schema).unwrap(),
335 DynProofExpr::try_new_inequality(
336 DynProofExpr::try_new_scaling_cast(
337 COLUMN1_SMALLINT(),
338 ColumnType::Decimal75(
339 Precision::new(10).expect("Precision is definitely valid"),
340 5
341 )
342 )
343 .unwrap(),
344 COLUMN2_DECIMAL_25_5(),
345 true
346 )
347 .unwrap()
348 );
349
350 let expr = df_column("namespace.table_name", "column1")
352 .gt(df_column("namespace.table_name", "column2"));
353 assert_eq!(
354 expr_to_proof_expr(&expr, &schema).unwrap(),
355 DynProofExpr::try_new_inequality(
356 DynProofExpr::try_new_scaling_cast(
357 COLUMN1_SMALLINT(),
358 ColumnType::Decimal75(
359 Precision::new(10).expect("Precision is definitely valid"),
360 5
361 )
362 )
363 .unwrap(),
364 COLUMN2_DECIMAL_25_5(),
365 false
366 )
367 .unwrap()
368 );
369
370 let expr = df_column("namespace.table_name", "column1")
372 .lt_eq(df_column("namespace.table_name", "column2"));
373 assert_eq!(
374 expr_to_proof_expr(&expr, &schema).unwrap(),
375 DynProofExpr::try_new_not(
376 DynProofExpr::try_new_inequality(
377 DynProofExpr::try_new_scaling_cast(
378 COLUMN1_SMALLINT(),
379 ColumnType::Decimal75(
380 Precision::new(10).expect("Precision is definitely valid"),
381 5
382 )
383 )
384 .unwrap(),
385 COLUMN2_DECIMAL_25_5(),
386 false
387 )
388 .unwrap()
389 )
390 .unwrap()
391 );
392
393 let expr = df_column("namespace.table_name", "column1")
395 .gt_eq(df_column("namespace.table_name", "column2"));
396 assert_eq!(
397 expr_to_proof_expr(&expr, &schema).unwrap(),
398 DynProofExpr::try_new_not(
399 DynProofExpr::try_new_inequality(
400 DynProofExpr::try_new_scaling_cast(
401 COLUMN1_SMALLINT(),
402 ColumnType::Decimal75(
403 Precision::new(10).expect("Precision is definitely valid"),
404 5
405 )
406 )
407 .unwrap(),
408 COLUMN2_DECIMAL_25_5(),
409 true
410 )
411 .unwrap()
412 )
413 .unwrap()
414 );
415 }
416
417 #[test]
418 fn we_can_convert_arithmetic_binary_expr_to_proof_expr() {
419 let schema = vec![
420 ("column1".into(), ColumnType::SmallInt),
421 ("column2".into(), ColumnType::BigInt),
422 ];
423
424 let expr = Expr::BinaryExpr(BinaryExpr {
426 left: Box::new(df_column("namespace.table_name", "column1")),
427 right: Box::new(df_column("namespace.table_name", "column2")),
428 op: Operator::Plus,
429 });
430 assert_eq!(
431 expr_to_proof_expr(&expr, &schema).unwrap(),
432 DynProofExpr::try_new_add(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
433 );
434
435 let expr = Expr::BinaryExpr(BinaryExpr {
437 left: Box::new(df_column("namespace.table_name", "column1")),
438 right: Box::new(df_column("namespace.table_name", "column2")),
439 op: Operator::Minus,
440 });
441 assert_eq!(
442 expr_to_proof_expr(&expr, &schema).unwrap(),
443 DynProofExpr::try_new_subtract(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
444 );
445
446 let expr = Expr::BinaryExpr(BinaryExpr {
448 left: Box::new(df_column("namespace.table_name", "column1")),
449 right: Box::new(df_column("namespace.table_name", "column2")),
450 op: Operator::Multiply,
451 });
452 assert_eq!(
453 expr_to_proof_expr(&expr, &schema).unwrap(),
454 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_BIGINT(),).unwrap()
455 );
456 }
457
458 #[test]
459 fn we_can_convert_arithmetic_binary_expr_to_proof_expr_with_scale_cast() {
460 let schema = vec![
461 ("column1".into(), ColumnType::SmallInt),
462 (
463 "column2".into(),
464 ColumnType::Decimal75(Precision::new(25).unwrap(), 5),
465 ),
466 (
467 "column3".into(),
468 ColumnType::Decimal75(Precision::new(75).unwrap(), 5),
469 ),
470 ];
471
472 let expr = df_column("namespace.table_name", "column1")
474 .add(df_column("namespace.table_name", "column2"));
475 assert_eq!(
476 expr_to_proof_expr(&expr, &schema).unwrap(),
477 DynProofExpr::try_new_add(
478 DynProofExpr::try_new_scaling_cast(
479 COLUMN1_SMALLINT(),
480 ColumnType::Decimal75(
481 Precision::new(10).expect("Precision is definitely valid"),
482 5
483 )
484 )
485 .unwrap(),
486 COLUMN2_DECIMAL_25_5()
487 )
488 .unwrap()
489 );
490
491 let expr = df_column("namespace.table_name", "column1")
493 .sub(df_column("namespace.table_name", "column2"));
494 assert_eq!(
495 expr_to_proof_expr(&expr, &schema).unwrap(),
496 DynProofExpr::try_new_subtract(
497 DynProofExpr::try_new_scaling_cast(
498 COLUMN1_SMALLINT(),
499 ColumnType::Decimal75(
500 Precision::new(10).expect("Precision is definitely valid"),
501 5
502 )
503 )
504 .unwrap(),
505 COLUMN2_DECIMAL_25_5()
506 )
507 .unwrap()
508 );
509
510 let expr = df_column("namespace.table_name", "column1")
512 .mul(df_column("namespace.table_name", "column2"));
513 assert_eq!(
514 expr_to_proof_expr(&expr, &schema).unwrap(),
515 DynProofExpr::try_new_multiply(COLUMN1_SMALLINT(), COLUMN2_DECIMAL_25_5()).unwrap()
516 );
517 }
518
519 #[test]
520 fn we_can_convert_logical_binary_expr_to_proof_expr() {
521 let schema = vec![
522 ("column1".into(), ColumnType::Boolean),
523 ("column2".into(), ColumnType::Boolean),
524 ];
525
526 let expr = df_column("namespace.table_name", "column1")
528 .and(df_column("namespace.table_name", "column2"));
529 assert_eq!(
530 expr_to_proof_expr(&expr, &schema).unwrap(),
531 DynProofExpr::try_new_and(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
532 );
533
534 let expr = df_column("namespace.table_name", "column1")
536 .or(df_column("namespace.table_name", "column2"));
537 assert_eq!(
538 expr_to_proof_expr(&expr, &schema).unwrap(),
539 DynProofExpr::try_new_or(COLUMN1_BOOLEAN(), COLUMN2_BOOLEAN()).unwrap()
540 );
541 }
542
543 #[test]
544 fn we_cannot_convert_unsupported_binary_expr_to_proof_expr() {
545 let expr = Expr::BinaryExpr(BinaryExpr {
547 left: Box::new(df_column("namespace.table_name", "column1")),
548 right: Box::new(df_column("namespace.table_name", "column2")),
549 op: Operator::AtArrow,
550 });
551 let schema = vec![
552 ("column1".into(), ColumnType::Boolean),
553 ("column2".into(), ColumnType::Boolean),
554 ];
555 assert!(matches!(
556 expr_to_proof_expr(&expr, &schema),
557 Err(PlannerError::UnsupportedBinaryOperator { .. })
558 ));
559 }
560
561 #[test]
563 fn we_can_convert_literal_expr_to_proof_expr() {
564 let expr = Expr::Literal(ScalarValue::Int32(Some(1)));
565 assert_eq!(
566 expr_to_proof_expr(&expr, &Vec::new()).unwrap(),
567 DynProofExpr::new_literal(LiteralValue::Int(1))
568 );
569 }
570
571 #[test]
573 fn we_can_convert_not_expr_to_proof_expr() {
574 let expr = Expr::Not(Box::new(df_column("table_name", "column")));
575 let schema = vec![("column".into(), ColumnType::Boolean)];
576 assert_eq!(
577 expr_to_proof_expr(&expr, &schema).unwrap(),
578 DynProofExpr::try_new_not(DynProofExpr::new_column(ColumnRef::new(
579 TableRef::from_names(None, "table_name"),
580 "column".into(),
581 ColumnType::Boolean
582 )))
583 .unwrap()
584 );
585 }
586
587 #[test]
589 fn we_can_convert_cast_expr_to_proof_expr() {
590 let expr = Expr::Cast(Cast::new(
591 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
592 DataType::Int32,
593 ));
594 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
595 assert_eq!(
596 expression,
597 DynProofExpr::try_new_cast(
598 DynProofExpr::new_literal(LiteralValue::Boolean(true)),
599 ColumnType::Int
600 )
601 .unwrap()
602 );
603 }
604
605 #[test]
606 fn we_cannot_convert_cast_expr_to_proof_expr_when_inner_expr_to_proof_expr_fails() {
607 let expr = Expr::Cast(Cast::new(
609 Box::new(Expr::Literal(ScalarValue::UInt64(Some(100)))),
610 DataType::Int16,
611 ));
612 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
613 assert!(matches!(
614 expression,
615 PlannerError::UnsupportedDataType { data_type: _ }
616 ));
617 }
618
619 #[test]
620 fn we_cannot_convert_cast_expr_to_proof_expr_for_unsupported_datatypes() {
621 let expr = Expr::Cast(Cast::new(
623 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)))),
624 DataType::UInt16,
625 ));
626 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
627 assert!(matches!(
628 expression,
629 PlannerError::UnsupportedDataType { data_type: _ }
630 ));
631 }
632
633 #[test]
634 fn we_cannot_convert_cast_expr_to_proof_expr_for_datatypes_for_which_casting_is_not_supported()
635 {
636 let expr = Expr::Cast(Cast::new(
638 Box::new(Expr::Literal(ScalarValue::Int16(Some(100)))),
639 DataType::Boolean,
640 ));
641 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap_err();
642 assert!(matches!(
643 expression,
644 PlannerError::AnalyzeError { source: _ }
645 ));
646 }
647
648 #[test]
650 fn we_can_convert_placeholder_to_proof_expr() {
651 let expr = Expr::Placeholder(Placeholder {
652 id: "$1".to_string(),
653 data_type: Some(DataType::Int32),
654 });
655 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
656 assert_eq!(
657 expression,
658 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
659 );
660 }
661
662 #[test]
664 fn we_can_convert_placeholder_with_data_type_specified_by_cast_to_proof_expr() {
665 let expr = Expr::Cast(Cast::new(
666 Box::new(Expr::Placeholder(Placeholder {
667 id: "$1".to_string(),
668 data_type: None,
669 })),
670 DataType::Int32,
671 ));
672 let expression = expr_to_proof_expr(&expr, &Vec::new()).unwrap();
673 assert_eq!(
674 expression,
675 DynProofExpr::try_new_placeholder(1, ColumnType::Int).unwrap()
676 );
677 }
678
679 #[test]
681 fn we_cannot_convert_unsupported_expr_to_proof_expr() {
682 let expr = Expr::Unnest(Unnest::new(Expr::Literal(ScalarValue::Int32(Some(100)))));
683 assert!(matches!(
684 expr_to_proof_expr(&expr, &Vec::new()),
685 Err(PlannerError::UnsupportedLogicalExpression { .. })
686 ));
687 }
688
689 #[test]
690 fn we_can_get_proof_expr_for_timestamps_of_different_scale() {
691 let lhs = Expr::Literal(ScalarValue::TimestampSecond(Some(1), None));
692 let rhs = Expr::Literal(ScalarValue::TimestampNanosecond(Some(1), None));
693 binary_expr_to_proof_expr(&lhs, &rhs, Operator::Gt, &Vec::new()).unwrap();
694 }
695}