proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs

1use super::{
2    cast_expr::CastExpr, AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr,
3    LiteralExpr, MultiplyExpr, NotExpr, OrExpr, PlaceholderExpr, ProofExpr,
4};
5use crate::{
6    base::{
7        database::{try_cast_types, Column, ColumnRef, ColumnType, LiteralValue, Table},
8        map::{IndexMap, IndexSet},
9        proof::{PlaceholderResult, ProofError},
10        scalar::Scalar,
11    },
12    sql::{
13        proof::{FinalRoundBuilder, VerificationBuilder},
14        util::type_check_binary_operation,
15        AnalyzeError, AnalyzeResult,
16    },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
25#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28    /// Column
29    Column(ColumnExpr),
30    /// Provable logical AND expression
31    And(AndExpr),
32    /// Provable logical OR expression
33    Or(OrExpr),
34    /// Provable logical NOT expression
35    Not(NotExpr),
36    /// Provable CONST expression
37    Literal(LiteralExpr),
38    /// Provable placeholder expression
39    Placeholder(PlaceholderExpr),
40    /// Provable AST expression for an equals expression
41    Equals(EqualsExpr),
42    /// Provable AST expression for an inequality expression
43    Inequality(InequalityExpr),
44    /// Provable numeric `+` / `-` expression
45    AddSubtract(AddSubtractExpr),
46    /// Provable numeric `*` expression
47    Multiply(MultiplyExpr),
48    /// Provable CAST expression
49    Cast(CastExpr),
50}
51impl DynProofExpr {
52    /// Create column expression
53    #[must_use]
54    pub fn new_column(column_ref: ColumnRef) -> Self {
55        Self::Column(ColumnExpr::new(column_ref))
56    }
57    /// Create logical AND expression
58    pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
59        lhs.check_data_type(ColumnType::Boolean)?;
60        rhs.check_data_type(ColumnType::Boolean)?;
61        Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
62    }
63    /// Create logical OR expression
64    pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
65        lhs.check_data_type(ColumnType::Boolean)?;
66        rhs.check_data_type(ColumnType::Boolean)?;
67        Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
68    }
69    /// Create logical NOT expression
70    pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
71        expr.check_data_type(ColumnType::Boolean)?;
72        Ok(Self::Not(NotExpr::new(Box::new(expr))))
73    }
74    /// Create CONST expression
75    #[must_use]
76    pub fn new_literal(value: LiteralValue) -> Self {
77        Self::Literal(LiteralExpr::new(value))
78    }
79    /// Create placeholder expression
80    pub fn try_new_placeholder(id: usize, column_type: ColumnType) -> AnalyzeResult<Self> {
81        Ok(Self::Placeholder(PlaceholderExpr::try_new(
82            id,
83            column_type,
84        )?))
85    }
86    /// Create a new equals expression
87    pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
88        let lhs_datatype = lhs.data_type();
89        let rhs_datatype = rhs.data_type();
90        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
91            Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
92        } else {
93            Err(AnalyzeError::DataTypeMismatch {
94                left_type: lhs_datatype.to_string(),
95                right_type: rhs_datatype.to_string(),
96            })
97        }
98    }
99    /// Create a new inequality expression
100    pub fn try_new_inequality(
101        lhs: DynProofExpr,
102        rhs: DynProofExpr,
103        is_lt: bool,
104    ) -> AnalyzeResult<Self> {
105        let lhs_datatype = lhs.data_type();
106        let rhs_datatype = rhs.data_type();
107        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
108            Ok(Self::Inequality(InequalityExpr::new(
109                Box::new(lhs),
110                Box::new(rhs),
111                is_lt,
112            )))
113        } else {
114            Err(AnalyzeError::DataTypeMismatch {
115                left_type: lhs_datatype.to_string(),
116                right_type: rhs_datatype.to_string(),
117            })
118        }
119    }
120
121    /// Create a new add expression
122    pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
123        let lhs_datatype = lhs.data_type();
124        let rhs_datatype = rhs.data_type();
125        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
126            Ok(Self::AddSubtract(AddSubtractExpr::new(
127                Box::new(lhs),
128                Box::new(rhs),
129                false,
130            )))
131        } else {
132            Err(AnalyzeError::DataTypeMismatch {
133                left_type: lhs_datatype.to_string(),
134                right_type: rhs_datatype.to_string(),
135            })
136        }
137    }
138
139    /// Create a new subtract expression
140    pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
141        let lhs_datatype = lhs.data_type();
142        let rhs_datatype = rhs.data_type();
143        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
144            Ok(Self::AddSubtract(AddSubtractExpr::new(
145                Box::new(lhs),
146                Box::new(rhs),
147                true,
148            )))
149        } else {
150            Err(AnalyzeError::DataTypeMismatch {
151                left_type: lhs_datatype.to_string(),
152                right_type: rhs_datatype.to_string(),
153            })
154        }
155    }
156
157    /// Create a new multiply expression
158    pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
159        let lhs_datatype = lhs.data_type();
160        let rhs_datatype = rhs.data_type();
161        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
162            Ok(Self::Multiply(MultiplyExpr::new(
163                Box::new(lhs),
164                Box::new(rhs),
165            )))
166        } else {
167            Err(AnalyzeError::DataTypeMismatch {
168                left_type: lhs_datatype.to_string(),
169                right_type: rhs_datatype.to_string(),
170            })
171        }
172    }
173
174    /// Create a new cast expression
175    pub fn try_new_cast(from_column: DynProofExpr, to_datatype: ColumnType) -> AnalyzeResult<Self> {
176        let from_datatype = from_column.data_type();
177        try_cast_types(from_datatype, to_datatype)
178            .map(|()| Self::Cast(CastExpr::new(Box::new(from_column), to_datatype)))
179            .map_err(|_| AnalyzeError::DataTypeMismatch {
180                left_type: from_datatype.to_string(),
181                right_type: to_datatype.to_string(),
182            })
183    }
184
185    /// Check that the plan has the correct data type
186    fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
187        if self.data_type() == data_type {
188            Ok(())
189        } else {
190            Err(AnalyzeError::InvalidDataType {
191                actual: self.data_type(),
192                expected: data_type,
193            })
194        }
195    }
196}