proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs

1use super::{
2    AddSubtractExpr, AggregateExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr,
3    MultiplyExpr, NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6    base::{
7        database::{Column, ColumnRef, ColumnType, LiteralValue, Table},
8        map::{IndexMap, IndexSet},
9        proof::ProofError,
10        scalar::Scalar,
11    },
12    sql::{
13        proof::{FinalRoundBuilder, VerificationBuilder},
14        util::type_check_binary_operation,
15        AnalyzeError, AnalyzeResult,
16    },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use proof_of_sql_parser::intermediate_ast::AggregationOperator;
22use serde::{Deserialize, Serialize};
23use sqlparser::ast::BinaryOperator;
24
25/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
27#[enum_dispatch::enum_dispatch]
28pub enum DynProofExpr {
29    /// Column
30    Column(ColumnExpr),
31    /// Provable logical AND expression
32    And(AndExpr),
33    /// Provable logical OR expression
34    Or(OrExpr),
35    /// Provable logical NOT expression
36    Not(NotExpr),
37    /// Provable CONST expression
38    Literal(LiteralExpr),
39    /// Provable AST expression for an equals expression
40    Equals(EqualsExpr),
41    /// Provable AST expression for an inequality expression
42    Inequality(InequalityExpr),
43    /// Provable numeric `+` / `-` expression
44    AddSubtract(AddSubtractExpr),
45    /// Provable numeric `*` expression
46    Multiply(MultiplyExpr),
47    /// Provable aggregate expression
48    Aggregate(AggregateExpr),
49}
50impl DynProofExpr {
51    /// Create column expression
52    #[must_use]
53    pub fn new_column(column_ref: ColumnRef) -> Self {
54        Self::Column(ColumnExpr::new(column_ref))
55    }
56    /// Create logical AND expression
57    pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
58        lhs.check_data_type(ColumnType::Boolean)?;
59        rhs.check_data_type(ColumnType::Boolean)?;
60        Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
61    }
62    /// Create logical OR expression
63    pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
64        lhs.check_data_type(ColumnType::Boolean)?;
65        rhs.check_data_type(ColumnType::Boolean)?;
66        Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
67    }
68    /// Create logical NOT expression
69    pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
70        expr.check_data_type(ColumnType::Boolean)?;
71        Ok(Self::Not(NotExpr::new(Box::new(expr))))
72    }
73    /// Create CONST expression
74    #[must_use]
75    pub fn new_literal(value: LiteralValue) -> Self {
76        Self::Literal(LiteralExpr::new(value))
77    }
78    /// Create a new equals expression
79    pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
80        let lhs_datatype = lhs.data_type();
81        let rhs_datatype = rhs.data_type();
82        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
83            Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
84        } else {
85            Err(AnalyzeError::DataTypeMismatch {
86                left_type: lhs_datatype.to_string(),
87                right_type: rhs_datatype.to_string(),
88            })
89        }
90    }
91    /// Create a new inequality expression
92    pub fn try_new_inequality(
93        lhs: DynProofExpr,
94        rhs: DynProofExpr,
95        is_lt: bool,
96    ) -> AnalyzeResult<Self> {
97        let lhs_datatype = lhs.data_type();
98        let rhs_datatype = rhs.data_type();
99        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
100            Ok(Self::Inequality(InequalityExpr::new(
101                Box::new(lhs),
102                Box::new(rhs),
103                is_lt,
104            )))
105        } else {
106            Err(AnalyzeError::DataTypeMismatch {
107                left_type: lhs_datatype.to_string(),
108                right_type: rhs_datatype.to_string(),
109            })
110        }
111    }
112
113    /// Create a new add expression
114    pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
115        let lhs_datatype = lhs.data_type();
116        let rhs_datatype = rhs.data_type();
117        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
118            Ok(Self::AddSubtract(AddSubtractExpr::new(
119                Box::new(lhs),
120                Box::new(rhs),
121                false,
122            )))
123        } else {
124            Err(AnalyzeError::DataTypeMismatch {
125                left_type: lhs_datatype.to_string(),
126                right_type: rhs_datatype.to_string(),
127            })
128        }
129    }
130
131    /// Create a new subtract expression
132    pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
133        let lhs_datatype = lhs.data_type();
134        let rhs_datatype = rhs.data_type();
135        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
136            Ok(Self::AddSubtract(AddSubtractExpr::new(
137                Box::new(lhs),
138                Box::new(rhs),
139                true,
140            )))
141        } else {
142            Err(AnalyzeError::DataTypeMismatch {
143                left_type: lhs_datatype.to_string(),
144                right_type: rhs_datatype.to_string(),
145            })
146        }
147    }
148
149    /// Create a new multiply expression
150    pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
151        let lhs_datatype = lhs.data_type();
152        let rhs_datatype = rhs.data_type();
153        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
154            Ok(Self::Multiply(MultiplyExpr::new(
155                Box::new(lhs),
156                Box::new(rhs),
157            )))
158        } else {
159            Err(AnalyzeError::DataTypeMismatch {
160                left_type: lhs_datatype.to_string(),
161                right_type: rhs_datatype.to_string(),
162            })
163        }
164    }
165
166    /// Create a new aggregate expression
167    #[must_use]
168    pub fn new_aggregate(op: AggregationOperator, expr: DynProofExpr) -> Self {
169        Self::Aggregate(AggregateExpr::new(op, Box::new(expr)))
170    }
171
172    /// Check that the plan has the correct data type
173    fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
174        if self.data_type() == data_type {
175            Ok(())
176        } else {
177            Err(AnalyzeError::InvalidDataType {
178                actual: self.data_type(),
179                expected: data_type,
180            })
181        }
182    }
183}