proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs1use super::{
2 AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
3 NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6 base::{
7 database::{Column, ColumnRef, ColumnType, LiteralValue, Table},
8 map::{IndexMap, IndexSet},
9 proof::ProofError,
10 scalar::Scalar,
11 },
12 sql::{
13 proof::{FinalRoundBuilder, VerificationBuilder},
14 util::type_check_binary_operation,
15 AnalyzeError, AnalyzeResult,
16 },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28 Column(ColumnExpr),
30 And(AndExpr),
32 Or(OrExpr),
34 Not(NotExpr),
36 Literal(LiteralExpr),
38 Equals(EqualsExpr),
40 Inequality(InequalityExpr),
42 AddSubtract(AddSubtractExpr),
44 Multiply(MultiplyExpr),
46}
47impl DynProofExpr {
48 #[must_use]
50 pub fn new_column(column_ref: ColumnRef) -> Self {
51 Self::Column(ColumnExpr::new(column_ref))
52 }
53 pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
55 lhs.check_data_type(ColumnType::Boolean)?;
56 rhs.check_data_type(ColumnType::Boolean)?;
57 Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
58 }
59 pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
61 lhs.check_data_type(ColumnType::Boolean)?;
62 rhs.check_data_type(ColumnType::Boolean)?;
63 Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
64 }
65 pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
67 expr.check_data_type(ColumnType::Boolean)?;
68 Ok(Self::Not(NotExpr::new(Box::new(expr))))
69 }
70 #[must_use]
72 pub fn new_literal(value: LiteralValue) -> Self {
73 Self::Literal(LiteralExpr::new(value))
74 }
75 pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
77 let lhs_datatype = lhs.data_type();
78 let rhs_datatype = rhs.data_type();
79 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
80 Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
81 } else {
82 Err(AnalyzeError::DataTypeMismatch {
83 left_type: lhs_datatype.to_string(),
84 right_type: rhs_datatype.to_string(),
85 })
86 }
87 }
88 pub fn try_new_inequality(
90 lhs: DynProofExpr,
91 rhs: DynProofExpr,
92 is_lt: bool,
93 ) -> AnalyzeResult<Self> {
94 let lhs_datatype = lhs.data_type();
95 let rhs_datatype = rhs.data_type();
96 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
97 Ok(Self::Inequality(InequalityExpr::new(
98 Box::new(lhs),
99 Box::new(rhs),
100 is_lt,
101 )))
102 } else {
103 Err(AnalyzeError::DataTypeMismatch {
104 left_type: lhs_datatype.to_string(),
105 right_type: rhs_datatype.to_string(),
106 })
107 }
108 }
109
110 pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
112 let lhs_datatype = lhs.data_type();
113 let rhs_datatype = rhs.data_type();
114 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
115 Ok(Self::AddSubtract(AddSubtractExpr::new(
116 Box::new(lhs),
117 Box::new(rhs),
118 false,
119 )))
120 } else {
121 Err(AnalyzeError::DataTypeMismatch {
122 left_type: lhs_datatype.to_string(),
123 right_type: rhs_datatype.to_string(),
124 })
125 }
126 }
127
128 pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
130 let lhs_datatype = lhs.data_type();
131 let rhs_datatype = rhs.data_type();
132 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
133 Ok(Self::AddSubtract(AddSubtractExpr::new(
134 Box::new(lhs),
135 Box::new(rhs),
136 true,
137 )))
138 } else {
139 Err(AnalyzeError::DataTypeMismatch {
140 left_type: lhs_datatype.to_string(),
141 right_type: rhs_datatype.to_string(),
142 })
143 }
144 }
145
146 pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
148 let lhs_datatype = lhs.data_type();
149 let rhs_datatype = rhs.data_type();
150 if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
151 Ok(Self::Multiply(MultiplyExpr::new(
152 Box::new(lhs),
153 Box::new(rhs),
154 )))
155 } else {
156 Err(AnalyzeError::DataTypeMismatch {
157 left_type: lhs_datatype.to_string(),
158 right_type: rhs_datatype.to_string(),
159 })
160 }
161 }
162
163 fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
165 if self.data_type() == data_type {
166 Ok(())
167 } else {
168 Err(AnalyzeError::InvalidDataType {
169 actual: self.data_type(),
170 expected: data_type,
171 })
172 }
173 }
174}