proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs1use super::{
2 AddSubtractExpr, AggregateExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr,
3 MultiplyExpr, NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6 base::{
7 database::{Column, ColumnRef, ColumnType, LiteralValue, Table},
8 map::{IndexMap, IndexSet},
9 proof::ProofError,
10 scalar::Scalar,
11 },
12 sql::{
13 proof::{FinalRoundBuilder, VerificationBuilder},
14 util::type_check_binary_operation,
15 AnalyzeError, AnalyzeResult,
16 },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use proof_of_sql_parser::intermediate_ast::AggregationOperator;
22use serde::{Deserialize, Serialize};
23use sqlparser::ast::BinaryOperator;
24
25#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
27#[enum_dispatch::enum_dispatch]
28pub enum DynProofExpr {
29 Column(ColumnExpr),
31 And(AndExpr),
33 Or(OrExpr),
35 Not(NotExpr),
37 Literal(LiteralExpr),
39 Equals(EqualsExpr),
41 Inequality(InequalityExpr),
43 AddSubtract(AddSubtractExpr),
45 Multiply(MultiplyExpr),
47 Aggregate(AggregateExpr),
49}
50impl DynProofExpr {
51 #[must_use]
53 pub fn new_column(column_ref: ColumnRef) -> Self {
54 Self::Column(ColumnExpr::new(column_ref))
55 }
56 pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
58 lhs.check_data_type(ColumnType::Boolean)?;
59 rhs.check_data_type(ColumnType::Boolean)?;
60 Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
61 }
62 pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
64 lhs.check_data_type(ColumnType::Boolean)?;
65 rhs.check_data_type(ColumnType::Boolean)?;
66 Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
67 }
68 pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
70 expr.check_data_type(ColumnType::Boolean)?;
71 Ok(Self::Not(NotExpr::new(Box::new(expr))))
72 }
73 #[must_use]
75 pub fn new_literal(value: LiteralValue) -> Self {
76 Self::Literal(LiteralExpr::new(value))
77 }
78 pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
80 let lhs_datatype = lhs.data_type();
81 let rhs_datatype = rhs.data_type();
82 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
83 Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
84 } else {
85 Err(AnalyzeError::DataTypeMismatch {
86 left_type: lhs_datatype.to_string(),
87 right_type: rhs_datatype.to_string(),
88 })
89 }
90 }
91 pub fn try_new_inequality(
93 lhs: DynProofExpr,
94 rhs: DynProofExpr,
95 is_lt: bool,
96 ) -> AnalyzeResult<Self> {
97 let lhs_datatype = lhs.data_type();
98 let rhs_datatype = rhs.data_type();
99 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
100 Ok(Self::Inequality(InequalityExpr::new(
101 Box::new(lhs),
102 Box::new(rhs),
103 is_lt,
104 )))
105 } else {
106 Err(AnalyzeError::DataTypeMismatch {
107 left_type: lhs_datatype.to_string(),
108 right_type: rhs_datatype.to_string(),
109 })
110 }
111 }
112
113 pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
115 let lhs_datatype = lhs.data_type();
116 let rhs_datatype = rhs.data_type();
117 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
118 Ok(Self::AddSubtract(AddSubtractExpr::new(
119 Box::new(lhs),
120 Box::new(rhs),
121 false,
122 )))
123 } else {
124 Err(AnalyzeError::DataTypeMismatch {
125 left_type: lhs_datatype.to_string(),
126 right_type: rhs_datatype.to_string(),
127 })
128 }
129 }
130
131 pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
133 let lhs_datatype = lhs.data_type();
134 let rhs_datatype = rhs.data_type();
135 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
136 Ok(Self::AddSubtract(AddSubtractExpr::new(
137 Box::new(lhs),
138 Box::new(rhs),
139 true,
140 )))
141 } else {
142 Err(AnalyzeError::DataTypeMismatch {
143 left_type: lhs_datatype.to_string(),
144 right_type: rhs_datatype.to_string(),
145 })
146 }
147 }
148
149 pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
151 let lhs_datatype = lhs.data_type();
152 let rhs_datatype = rhs.data_type();
153 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
154 Ok(Self::Multiply(MultiplyExpr::new(
155 Box::new(lhs),
156 Box::new(rhs),
157 )))
158 } else {
159 Err(AnalyzeError::DataTypeMismatch {
160 left_type: lhs_datatype.to_string(),
161 right_type: rhs_datatype.to_string(),
162 })
163 }
164 }
165
166 #[must_use]
168 pub fn new_aggregate(op: AggregationOperator, expr: DynProofExpr) -> Self {
169 Self::Aggregate(AggregateExpr::new(op, Box::new(expr)))
170 }
171
172 fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
174 if self.data_type() == data_type {
175 Ok(())
176 } else {
177 Err(AnalyzeError::InvalidDataType {
178 actual: self.data_type(),
179 expected: data_type,
180 })
181 }
182 }
183}