proof_of_sql/sql/proof_exprs/
dyn_proof_expr.rs

1use super::{
2    AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
3    NotExpr, OrExpr, ProofExpr,
4};
5use crate::{
6    base::{
7        database::{Column, ColumnRef, ColumnType, LiteralValue, Table},
8        map::{IndexMap, IndexSet},
9        proof::ProofError,
10        scalar::Scalar,
11    },
12    sql::{
13        proof::{FinalRoundBuilder, VerificationBuilder},
14        util::type_check_binary_operation,
15        AnalyzeError, AnalyzeResult,
16    },
17};
18use alloc::{boxed::Box, string::ToString};
19use bumpalo::Bump;
20use core::fmt::Debug;
21use serde::{Deserialize, Serialize};
22use sqlparser::ast::BinaryOperator;
23
24/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
25#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
26#[enum_dispatch::enum_dispatch]
27pub enum DynProofExpr {
28    /// Column
29    Column(ColumnExpr),
30    /// Provable logical AND expression
31    And(AndExpr),
32    /// Provable logical OR expression
33    Or(OrExpr),
34    /// Provable logical NOT expression
35    Not(NotExpr),
36    /// Provable CONST expression
37    Literal(LiteralExpr),
38    /// Provable AST expression for an equals expression
39    Equals(EqualsExpr),
40    /// Provable AST expression for an inequality expression
41    Inequality(InequalityExpr),
42    /// Provable numeric `+` / `-` expression
43    AddSubtract(AddSubtractExpr),
44    /// Provable numeric `*` expression
45    Multiply(MultiplyExpr),
46}
47impl DynProofExpr {
48    /// Create column expression
49    #[must_use]
50    pub fn new_column(column_ref: ColumnRef) -> Self {
51        Self::Column(ColumnExpr::new(column_ref))
52    }
53    /// Create logical AND expression
54    pub fn try_new_and(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
55        lhs.check_data_type(ColumnType::Boolean)?;
56        rhs.check_data_type(ColumnType::Boolean)?;
57        Ok(Self::And(AndExpr::new(Box::new(lhs), Box::new(rhs))))
58    }
59    /// Create logical OR expression
60    pub fn try_new_or(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
61        lhs.check_data_type(ColumnType::Boolean)?;
62        rhs.check_data_type(ColumnType::Boolean)?;
63        Ok(Self::Or(OrExpr::new(Box::new(lhs), Box::new(rhs))))
64    }
65    /// Create logical NOT expression
66    pub fn try_new_not(expr: DynProofExpr) -> AnalyzeResult<Self> {
67        expr.check_data_type(ColumnType::Boolean)?;
68        Ok(Self::Not(NotExpr::new(Box::new(expr))))
69    }
70    /// Create CONST expression
71    #[must_use]
72    pub fn new_literal(value: LiteralValue) -> Self {
73        Self::Literal(LiteralExpr::new(value))
74    }
75    /// Create a new equals expression
76    pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
77        let lhs_datatype = lhs.data_type();
78        let rhs_datatype = rhs.data_type();
79        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Eq) {
80            Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
81        } else {
82            Err(AnalyzeError::DataTypeMismatch {
83                left_type: lhs_datatype.to_string(),
84                right_type: rhs_datatype.to_string(),
85            })
86        }
87    }
88    /// Create a new inequality expression
89    pub fn try_new_inequality(
90        lhs: DynProofExpr,
91        rhs: DynProofExpr,
92        is_lt: bool,
93    ) -> AnalyzeResult<Self> {
94        let lhs_datatype = lhs.data_type();
95        let rhs_datatype = rhs.data_type();
96        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) {
97            Ok(Self::Inequality(InequalityExpr::new(
98                Box::new(lhs),
99                Box::new(rhs),
100                is_lt,
101            )))
102        } else {
103            Err(AnalyzeError::DataTypeMismatch {
104                left_type: lhs_datatype.to_string(),
105                right_type: rhs_datatype.to_string(),
106            })
107        }
108    }
109
110    /// Create a new add expression
111    pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
112        let lhs_datatype = lhs.data_type();
113        let rhs_datatype = rhs.data_type();
114        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Plus) {
115            Ok(Self::AddSubtract(AddSubtractExpr::new(
116                Box::new(lhs),
117                Box::new(rhs),
118                false,
119            )))
120        } else {
121            Err(AnalyzeError::DataTypeMismatch {
122                left_type: lhs_datatype.to_string(),
123                right_type: rhs_datatype.to_string(),
124            })
125        }
126    }
127
128    /// Create a new subtract expression
129    pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
130        let lhs_datatype = lhs.data_type();
131        let rhs_datatype = rhs.data_type();
132        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Minus) {
133            Ok(Self::AddSubtract(AddSubtractExpr::new(
134                Box::new(lhs),
135                Box::new(rhs),
136                true,
137            )))
138        } else {
139            Err(AnalyzeError::DataTypeMismatch {
140                left_type: lhs_datatype.to_string(),
141                right_type: rhs_datatype.to_string(),
142            })
143        }
144    }
145
146    /// Create a new multiply expression
147    pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> AnalyzeResult<Self> {
148        let lhs_datatype = lhs.data_type();
149        let rhs_datatype = rhs.data_type();
150        if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Multiply) {
151            Ok(Self::Multiply(MultiplyExpr::new(
152                Box::new(lhs),
153                Box::new(rhs),
154            )))
155        } else {
156            Err(AnalyzeError::DataTypeMismatch {
157                left_type: lhs_datatype.to_string(),
158                right_type: rhs_datatype.to_string(),
159            })
160        }
161    }
162
163    /// Check that the plan has the correct data type
164    fn check_data_type(&self, data_type: ColumnType) -> AnalyzeResult<()> {
165        if self.data_type() == data_type {
166            Ok(())
167        } else {
168            Err(AnalyzeError::InvalidDataType {
169                actual: self.data_type(),
170                expected: data_type,
171            })
172        }
173    }
174}