proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs1use super::{
2 AddExpr, AndExpr, CastExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
3 NotExpr, OrExpr, PlaceholderExpr, ProofExpr, ScalingCastExpr, SubtractExpr,
4};
5use crate::{
6 base::{
7 database::{try_cast_types, Column, ColumnRef, ColumnType, LiteralValue, Table},
8 map::{IndexMap, IndexSet},
9 proof::{PlaceholderResult, ProofError},
10 scalar::Scalar,
11 },
12 sql::{
13 proof::{FinalRoundBuilder, VerificationBuilder},
14 util::try_binary_operation_type,
15 AnalyzeError, AnalyzeResult,
16 },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28 Column(ColumnExpr),
30 And(AndExpr),
32 Or(OrExpr),
34 Not(NotExpr),
36 Literal(LiteralExpr),
38 Placeholder(PlaceholderExpr),
40 Equals(EqualsExpr),
42 Inequality(InequalityExpr),
44 Add(AddExpr),
46 Subtract(SubtractExpr),
48 Multiply(MultiplyExpr),
50 Cast(CastExpr),
52 ScalingCast(ScalingCastExpr),
54}
55impl DynProofExpr {
56 #[must_use]
58 pub fn new_column(column_ref: ColumnRef) -> Self {
59 Self::Column(ColumnExpr::new(column_ref))
60 }
61 pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
63 lhs.check_data_type(ColumnType::Boolean)?;
64 rhs.check_data_type(ColumnType::Boolean)?;
65 Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
66 }
67 pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
69 lhs.check_data_type(ColumnType::Boolean)?;
70 rhs.check_data_type(ColumnType::Boolean)?;
71 Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
72 }
73 pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
75 expr.check_data_type(ColumnType::Boolean)?;
76 Ok(Self::Not(NotExpr::new(Box::new(expr))))
77 }
78 #[must_use]
80 pub fn new_literal(value: LiteralValue) -> Self {
81 Self::Literal(LiteralExpr::new(value))
82 }
83 pub fn try_new_placeholder(id: usize, column_type: ColumnType) -> AnalyzeResult<Self> {
85 Ok(Self::Placeholder(PlaceholderExpr::try_new(
86 id,
87 column_type,
88 )?))
89 }
90 pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
92 let lhs_datatype = lhs.data_type();
93 let rhs_datatype = rhs.data_type();
94 if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Eq).is_some() {
95 Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
96 } else {
97 Err(AnalyzeError::DataTypeMismatch {
98 left_type: lhs_datatype.to_string(),
99 right_type: rhs_datatype.to_string(),
100 })
101 }
102 }
103 pub fn try_new_inequality(
105 lhs: DynProofExpr,
106 rhs: DynProofExpr,
107 is_lt: bool,
108 ) -> AnalyzeResult<Self> {
109 let lhs_datatype = lhs.data_type();
110 let rhs_datatype = rhs.data_type();
111 if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Lt).is_some() {
112 Ok(Self::Inequality(InequalityExpr::new(
113 Box::new(lhs),
114 Box::new(rhs),
115 is_lt,
116 )))
117 } else {
118 Err(AnalyzeError::DataTypeMismatch {
119 left_type: lhs_datatype.to_string(),
120 right_type: rhs_datatype.to_string(),
121 })
122 }
123 }
124
125 pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
127 let lhs_datatype = lhs.data_type();
128 let rhs_datatype = rhs.data_type();
129 if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Plus).is_some() {
130 Ok(Self::Add(AddExpr::new(Box::new(lhs), Box::new(rhs))))
131 } else {
132 Err(AnalyzeError::DataTypeMismatch {
133 left_type: lhs_datatype.to_string(),
134 right_type: rhs_datatype.to_string(),
135 })
136 }
137 }
138
139 pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
141 let lhs_datatype = lhs.data_type();
142 let rhs_datatype = rhs.data_type();
143 if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Minus).is_some() {
144 Ok(Self::Subtract(SubtractExpr::new(
145 Box::new(lhs),
146 Box::new(rhs),
147 )))
148 } else {
149 Err(AnalyzeError::DataTypeMismatch {
150 left_type: lhs_datatype.to_string(),
151 right_type: rhs_datatype.to_string(),
152 })
153 }
154 }
155
156 pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
158 let lhs_datatype = lhs.data_type();
159 let rhs_datatype = rhs.data_type();
160 if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply)
161 .is_some()
162 {
163 Ok(Self::Multiply(MultiplyExpr::new(
164 Box::new(lhs),
165 Box::new(rhs),
166 )))
167 } else {
168 Err(AnalyzeError::DataTypeMismatch {
169 left_type: lhs_datatype.to_string(),
170 right_type: rhs_datatype.to_string(),
171 })
172 }
173 }
174
175 pub fn try_new_cast(from_column: DynProofExpr, to_datatype: ColumnType) -> AnalyzeResult<Self> {
177 let from_datatype = from_column.data_type();
178 try_cast_types(from_datatype, to_datatype)
179 .map(|()| Self::Cast(CastExpr::new(Box::new(from_column), to_datatype)))
180 .map_err(|_| AnalyzeError::DataTypeMismatch {
181 left_type: from_datatype.to_string(),
182 right_type: to_datatype.to_string(),
183 })
184 }
185
186 pub fn try_new_scaling_cast(
188 from_expr: DynProofExpr,
189 to_datatype: ColumnType,
190 ) -> AnalyzeResult<Self> {
191 let from_datatype = from_expr.data_type();
192 ScalingCastExpr::try_new(Box::new(from_expr), to_datatype)
193 .map(DynProofExpr::ScalingCast)
194 .map_err(|_| AnalyzeError::DataTypeMismatch {
195 left_type: from_datatype.to_string(),
196 right_type: to_datatype.to_string(),
197 })
198 }
199
200 fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
202 if self.data_type() == data_type {
203 Ok(())
204 } else {
205 Err(AnalyzeError::InvalidDataType {
206 actual: self.data_type(),
207 expected: data_type,
208 })
209 }
210 }
211}