proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs

1use super::{
2    cast_expr::CastExpr, AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr,
3    LiteralExpr, MultiplyExpr, NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6    base::{
7        database::{try_cast_types, Column, ColumnRef, ColumnType, LiteralValue, Table},
8        map::{IndexMap, IndexSet},
9        proof::ProofError,
10        scalar::Scalar,
11    },
12    sql::{
13        proof::{FinalRoundBuilder, VerificationBuilder},
14        util::type_check_binary_operation,
15        AnalyzeError, AnalyzeResult,
16    },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
25#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28    /// Column
29    Column(ColumnExpr),
30    /// Provable logical AND expression
31    And(AndExpr),
32    /// Provable logical OR expression
33    Or(OrExpr),
34    /// Provable logical NOT expression
35    Not(NotExpr),
36    /// Provable CONST expression
37    Literal(LiteralExpr),
38    /// Provable AST expression for an equals expression
39    Equals(EqualsExpr),
40    /// Provable AST expression for an inequality expression
41    Inequality(InequalityExpr),
42    /// Provable numeric `+` / `-` expression
43    AddSubtract(AddSubtractExpr),
44    /// Provable numeric `*` expression
45    Multiply(MultiplyExpr),
46    /// Provable CAST expression
47    Cast(CastExpr),
48}
49impl DynProofExpr {
50    /// Create column expression
51    #[must_use]
52    pub fn new_column(column_ref: ColumnRef) -> Self {
53        Self::Column(ColumnExpr::new(column_ref))
54    }
55    /// Create logical AND expression
56    pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
57        lhs.check_data_type(ColumnType::Boolean)?;
58        rhs.check_data_type(ColumnType::Boolean)?;
59        Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
60    }
61    /// Create logical OR expression
62    pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
63        lhs.check_data_type(ColumnType::Boolean)?;
64        rhs.check_data_type(ColumnType::Boolean)?;
65        Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
66    }
67    /// Create logical NOT expression
68    pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
69        expr.check_data_type(ColumnType::Boolean)?;
70        Ok(Self::Not(NotExpr::new(Box::new(expr))))
71    }
72    /// Create CONST expression
73    #[must_use]
74    pub fn new_literal(value: LiteralValue) -> Self {
75        Self::Literal(LiteralExpr::new(value))
76    }
77    /// Create a new equals expression
78    pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
79        let lhs_datatype = lhs.data_type();
80        let rhs_datatype = rhs.data_type();
81        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
82            Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
83        } else {
84            Err(AnalyzeError::DataTypeMismatch {
85                left_type: lhs_datatype.to_string(),
86                right_type: rhs_datatype.to_string(),
87            })
88        }
89    }
90    /// Create a new inequality expression
91    pub fn try_new_inequality(
92        lhs: DynProofExpr,
93        rhs: DynProofExpr,
94        is_lt: bool,
95    ) -> AnalyzeResult<Self> {
96        let lhs_datatype = lhs.data_type();
97        let rhs_datatype = rhs.data_type();
98        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
99            Ok(Self::Inequality(InequalityExpr::new(
100                Box::new(lhs),
101                Box::new(rhs),
102                is_lt,
103            )))
104        } else {
105            Err(AnalyzeError::DataTypeMismatch {
106                left_type: lhs_datatype.to_string(),
107                right_type: rhs_datatype.to_string(),
108            })
109        }
110    }
111
112    /// Create a new add expression
113    pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
114        let lhs_datatype = lhs.data_type();
115        let rhs_datatype = rhs.data_type();
116        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
117            Ok(Self::AddSubtract(AddSubtractExpr::new(
118                Box::new(lhs),
119                Box::new(rhs),
120                false,
121            )))
122        } else {
123            Err(AnalyzeError::DataTypeMismatch {
124                left_type: lhs_datatype.to_string(),
125                right_type: rhs_datatype.to_string(),
126            })
127        }
128    }
129
130    /// Create a new subtract expression
131    pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
132        let lhs_datatype = lhs.data_type();
133        let rhs_datatype = rhs.data_type();
134        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
135            Ok(Self::AddSubtract(AddSubtractExpr::new(
136                Box::new(lhs),
137                Box::new(rhs),
138                true,
139            )))
140        } else {
141            Err(AnalyzeError::DataTypeMismatch {
142                left_type: lhs_datatype.to_string(),
143                right_type: rhs_datatype.to_string(),
144            })
145        }
146    }
147
148    /// Create a new multiply expression
149    pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
150        let lhs_datatype = lhs.data_type();
151        let rhs_datatype = rhs.data_type();
152        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
153            Ok(Self::Multiply(MultiplyExpr::new(
154                Box::new(lhs),
155                Box::new(rhs),
156            )))
157        } else {
158            Err(AnalyzeError::DataTypeMismatch {
159                left_type: lhs_datatype.to_string(),
160                right_type: rhs_datatype.to_string(),
161            })
162        }
163    }
164
165    /// Create a new cast expression
166    pub fn try_new_cast(from_column: DynProofExpr, to_datatype: ColumnType) -> AnalyzeResult<Self> {
167        let from_datatype = from_column.data_type();
168        try_cast_types(from_datatype, to_datatype)
169            .map(|()| Self::Cast(CastExpr::new(Box::new(from_column), to_datatype)))
170            .map_err(|_| AnalyzeError::DataTypeMismatch {
171                left_type: from_datatype.to_string(),
172                right_type: to_datatype.to_string(),
173            })
174    }
175
176    /// Check that the plan has the correct data type
177    fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
178        if self.data_type() == data_type {
179            Ok(())
180        } else {
181            Err(AnalyzeError::InvalidDataType {
182                actual: self.data_type(),
183                expected: data_type,
184            })
185        }
186    }
187}