proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs

1use super::{
2    AddExpr, AndExpr, CastExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
3    NotExpr, OrExpr, PlaceholderExpr, ProofExpr, ScalingCastExpr, SubtractExpr,
4};
5use crate::{
6    base::{
7        database::{try_cast_types, Column, ColumnRef, ColumnType, LiteralValue, Table},
8        map::{IndexMap, IndexSet},
9        proof::{PlaceholderResult, ProofError},
10        scalar::Scalar,
11    },
12    sql::{
13        proof::{FinalRoundBuilder, VerificationBuilder},
14        util::try_binary_operation_type,
15        AnalyzeError, AnalyzeResult,
16    },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
25#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28    /// Column
29    Column(ColumnExpr),
30    /// Provable logical AND expression
31    And(AndExpr),
32    /// Provable logical OR expression
33    Or(OrExpr),
34    /// Provable logical NOT expression
35    Not(NotExpr),
36    /// Provable CONST expression
37    Literal(LiteralExpr),
38    /// Provable placeholder expression
39    Placeholder(PlaceholderExpr),
40    /// Provable AST expression for an equals expression
41    Equals(EqualsExpr),
42    /// Provable AST expression for an inequality expression
43    Inequality(InequalityExpr),
44    /// Provable numeric `+` expression
45    Add(AddExpr),
46    /// Provable numeric `-` expression
47    Subtract(SubtractExpr),
48    /// Provable numeric `*` expression
49    Multiply(MultiplyExpr),
50    /// Provable CAST expression
51    Cast(CastExpr),
52    /// Provable expression for casting numeric expressions to decimal expressions
53    ScalingCast(ScalingCastExpr),
54}
55impl DynProofExpr {
56    /// Create column expression
57    #[must_use]
58    pub fn new_column(column_ref: ColumnRef) -> Self {
59        Self::Column(ColumnExpr::new(column_ref))
60    }
61    /// Create logical AND expression
62    pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
63        lhs.check_data_type(ColumnType::Boolean)?;
64        rhs.check_data_type(ColumnType::Boolean)?;
65        Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
66    }
67    /// Create logical OR expression
68    pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
69        lhs.check_data_type(ColumnType::Boolean)?;
70        rhs.check_data_type(ColumnType::Boolean)?;
71        Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
72    }
73    /// Create logical NOT expression
74    pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
75        expr.check_data_type(ColumnType::Boolean)?;
76        Ok(Self::Not(NotExpr::new(Box::new(expr))))
77    }
78    /// Create CONST expression
79    #[must_use]
80    pub fn new_literal(value: LiteralValue) -> Self {
81        Self::Literal(LiteralExpr::new(value))
82    }
83    /// Create placeholder expression
84    pub fn try_new_placeholder(id: usize, column_type: ColumnType) -> AnalyzeResult<Self> {
85        Ok(Self::Placeholder(PlaceholderExpr::try_new(
86            id,
87            column_type,
88        )?))
89    }
90    /// Create a new equals expression
91    pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
92        let lhs_datatype = lhs.data_type();
93        let rhs_datatype = rhs.data_type();
94        if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Eq).is_some() {
95            Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
96        } else {
97            Err(AnalyzeError::DataTypeMismatch {
98                left_type: lhs_datatype.to_string(),
99                right_type: rhs_datatype.to_string(),
100            })
101        }
102    }
103    /// Create a new inequality expression
104    pub fn try_new_inequality(
105        lhs: DynProofExpr,
106        rhs: DynProofExpr,
107        is_lt: bool,
108    ) -> AnalyzeResult<Self> {
109        let lhs_datatype = lhs.data_type();
110        let rhs_datatype = rhs.data_type();
111        if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Lt).is_some() {
112            Ok(Self::Inequality(InequalityExpr::new(
113                Box::new(lhs),
114                Box::new(rhs),
115                is_lt,
116            )))
117        } else {
118            Err(AnalyzeError::DataTypeMismatch {
119                left_type: lhs_datatype.to_string(),
120                right_type: rhs_datatype.to_string(),
121            })
122        }
123    }
124
125    /// Create a new add expression
126    pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
127        let lhs_datatype = lhs.data_type();
128        let rhs_datatype = rhs.data_type();
129        if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Plus).is_some() {
130            Ok(Self::Add(AddExpr::new(Box::new(lhs), Box::new(rhs))))
131        } else {
132            Err(AnalyzeError::DataTypeMismatch {
133                left_type: lhs_datatype.to_string(),
134                right_type: rhs_datatype.to_string(),
135            })
136        }
137    }
138
139    /// Create a new subtract expression
140    pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
141        let lhs_datatype = lhs.data_type();
142        let rhs_datatype = rhs.data_type();
143        if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Minus).is_some() {
144            Ok(Self::Subtract(SubtractExpr::new(
145                Box::new(lhs),
146                Box::new(rhs),
147            )))
148        } else {
149            Err(AnalyzeError::DataTypeMismatch {
150                left_type: lhs_datatype.to_string(),
151                right_type: rhs_datatype.to_string(),
152            })
153        }
154    }
155
156    /// Create a new multiply expression
157    pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
158        let lhs_datatype = lhs.data_type();
159        let rhs_datatype = rhs.data_type();
160        if try_binary_operation_type(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply)
161            .is_some()
162        {
163            Ok(Self::Multiply(MultiplyExpr::new(
164                Box::new(lhs),
165                Box::new(rhs),
166            )))
167        } else {
168            Err(AnalyzeError::DataTypeMismatch {
169                left_type: lhs_datatype.to_string(),
170                right_type: rhs_datatype.to_string(),
171            })
172        }
173    }
174
175    /// Create a new cast expression
176    pub fn try_new_cast(from_column: DynProofExpr, to_datatype: ColumnType) -> AnalyzeResult<Self> {
177        let from_datatype = from_column.data_type();
178        try_cast_types(from_datatype, to_datatype)
179            .map(|()| Self::Cast(CastExpr::new(Box::new(from_column), to_datatype)))
180            .map_err(|_| AnalyzeError::DataTypeMismatch {
181                left_type: from_datatype.to_string(),
182                right_type: to_datatype.to_string(),
183            })
184    }
185
186    /// Create a new decimal scale cast expression
187    pub fn try_new_scaling_cast(
188        from_expr: DynProofExpr,
189        to_datatype: ColumnType,
190    ) -> AnalyzeResult<Self> {
191        let from_datatype = from_expr.data_type();
192        ScalingCastExpr::try_new(Box::new(from_expr), to_datatype)
193            .map(DynProofExpr::ScalingCast)
194            .map_err(|_| AnalyzeError::DataTypeMismatch {
195                left_type: from_datatype.to_string(),
196                right_type: to_datatype.to_string(),
197            })
198    }
199
200    /// Check that the plan has the correct data type
201    fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
202        if self.data_type() == data_type {
203            Ok(())
204        } else {
205            Err(AnalyzeError::InvalidDataType {
206                actual: self.data_type(),
207                expected: data_type,
208            })
209        }
210    }
211}