proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs1use super::{
2 cast_expr::CastExpr, AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr,
3 LiteralExpr, MultiplyExpr, NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6 base::{
7 database::{try_cast_types, Column, ColumnRef, ColumnType, LiteralValue, Table},
8 map::{IndexMap, IndexSet},
9 proof::ProofError,
10 scalar::Scalar,
11 },
12 sql::{
13 proof::{FinalRoundBuilder, VerificationBuilder},
14 util::type_check_binary_operation,
15 AnalyzeError, AnalyzeResult,
16 },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28 Column(ColumnExpr),
30 And(AndExpr),
32 Or(OrExpr),
34 Not(NotExpr),
36 Literal(LiteralExpr),
38 Equals(EqualsExpr),
40 Inequality(InequalityExpr),
42 AddSubtract(AddSubtractExpr),
44 Multiply(MultiplyExpr),
46 Cast(CastExpr),
48}
49impl DynProofExpr {
50 #[must_use]
52 pub fn new_column(column_ref: ColumnRef) -> Self {
53 Self::Column(ColumnExpr::new(column_ref))
54 }
55 pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
57 lhs.check_data_type(ColumnType::Boolean)?;
58 rhs.check_data_type(ColumnType::Boolean)?;
59 Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
60 }
61 pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
63 lhs.check_data_type(ColumnType::Boolean)?;
64 rhs.check_data_type(ColumnType::Boolean)?;
65 Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
66 }
67 pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
69 expr.check_data_type(ColumnType::Boolean)?;
70 Ok(Self::Not(NotExpr::new(Box::new(expr))))
71 }
72 #[must_use]
74 pub fn new_literal(value: LiteralValue) -> Self {
75 Self::Literal(LiteralExpr::new(value))
76 }
77 pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
79 let lhs_datatype = lhs.data_type();
80 let rhs_datatype = rhs.data_type();
81 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
82 Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
83 } else {
84 Err(AnalyzeError::DataTypeMismatch {
85 left_type: lhs_datatype.to_string(),
86 right_type: rhs_datatype.to_string(),
87 })
88 }
89 }
90 pub fn try_new_inequality(
92 lhs: DynProofExpr,
93 rhs: DynProofExpr,
94 is_lt: bool,
95 ) -> AnalyzeResult<Self> {
96 let lhs_datatype = lhs.data_type();
97 let rhs_datatype = rhs.data_type();
98 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
99 Ok(Self::Inequality(InequalityExpr::new(
100 Box::new(lhs),
101 Box::new(rhs),
102 is_lt,
103 )))
104 } else {
105 Err(AnalyzeError::DataTypeMismatch {
106 left_type: lhs_datatype.to_string(),
107 right_type: rhs_datatype.to_string(),
108 })
109 }
110 }
111
112 pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
114 let lhs_datatype = lhs.data_type();
115 let rhs_datatype = rhs.data_type();
116 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
117 Ok(Self::AddSubtract(AddSubtractExpr::new(
118 Box::new(lhs),
119 Box::new(rhs),
120 false,
121 )))
122 } else {
123 Err(AnalyzeError::DataTypeMismatch {
124 left_type: lhs_datatype.to_string(),
125 right_type: rhs_datatype.to_string(),
126 })
127 }
128 }
129
130 pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
132 let lhs_datatype = lhs.data_type();
133 let rhs_datatype = rhs.data_type();
134 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
135 Ok(Self::AddSubtract(AddSubtractExpr::new(
136 Box::new(lhs),
137 Box::new(rhs),
138 true,
139 )))
140 } else {
141 Err(AnalyzeError::DataTypeMismatch {
142 left_type: lhs_datatype.to_string(),
143 right_type: rhs_datatype.to_string(),
144 })
145 }
146 }
147
148 pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
150 let lhs_datatype = lhs.data_type();
151 let rhs_datatype = rhs.data_type();
152 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
153 Ok(Self::Multiply(MultiplyExpr::new(
154 Box::new(lhs),
155 Box::new(rhs),
156 )))
157 } else {
158 Err(AnalyzeError::DataTypeMismatch {
159 left_type: lhs_datatype.to_string(),
160 right_type: rhs_datatype.to_string(),
161 })
162 }
163 }
164
165 pub fn try_new_cast(from_column: DynProofExpr, to_datatype: ColumnType) -> AnalyzeResult<Self> {
167 let from_datatype = from_column.data_type();
168 try_cast_types(from_datatype, to_datatype)
169 .map(|()| Self::Cast(CastExpr::new(Box::new(from_column), to_datatype)))
170 .map_err(|_| AnalyzeError::DataTypeMismatch {
171 left_type: from_datatype.to_string(),
172 right_type: to_datatype.to_string(),
173 })
174 }
175
176 fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
178 if self.data_type() == data_type {
179 Ok(())
180 } else {
181 Err(AnalyzeError::InvalidDataType {
182 actual: self.data_type(),
183 expected: data_type,
184 })
185 }
186 }
187}