use super::ProofExpr;
use crate::{
base::{
database::{Column, ColumnField, ColumnRef, ColumnType, LiteralValue, Table},
map::{IndexMap, IndexSet},
proof::{PlaceholderResult, ProofError},
scalar::Scalar,
},
sql::proof::{FinalRoundBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use sqlparser::ast::Ident;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ColumnExpr {
column_ref: ColumnRef,
}
impl ColumnExpr {
#[must_use]
pub fn new(column_ref: ColumnRef) -> Self {
Self { column_ref }
}
#[must_use]
pub fn get_column_reference(&self) -> ColumnRef {
self.column_ref.clone()
}
#[must_use]
pub fn column_ref(&self) -> &ColumnRef {
&self.column_ref
}
#[must_use]
pub fn get_column_field(&self) -> ColumnField {
ColumnField::new(self.column_ref.column_id(), *self.column_ref.column_type())
}
#[must_use]
pub fn column_id(&self) -> Ident {
self.column_ref.column_id()
}
#[must_use]
pub fn fetch_column<'a, S: Scalar>(&self, table: &Table<'a, S>) -> Column<'a, S> {
*table
.inner_table()
.get(&self.column_ref.column_id())
.expect("Column not found")
}
}
impl ProofExpr for ColumnExpr {
fn data_type(&self) -> ColumnType {
*self.get_column_reference().column_type()
}
fn first_round_evaluate<'a, S: Scalar>(
&self,
_alloc: &'a Bump,
table: &Table<'a, S>,
_params: &[LiteralValue],
) -> PlaceholderResult<Column<'a, S>> {
Ok(self.fetch_column(table))
}
fn final_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FinalRoundBuilder<'a, S>,
_alloc: &'a Bump,
table: &Table<'a, S>,
_params: &[LiteralValue],
) -> PlaceholderResult<Column<'a, S>> {
Ok(self.fetch_column(table))
}
fn verifier_evaluate<S: Scalar>(
&self,
_builder: &mut impl VerificationBuilder<S>,
accessor: &IndexMap<Ident, S>,
_chi_eval: S,
_params: &[LiteralValue],
) -> Result<S, ProofError> {
Ok(*accessor
.get(&self.column_ref.column_id())
.ok_or(ProofError::VerificationError {
error: "Column Not Found",
})?)
}
fn get_column_references(&self, columns: &mut IndexSet<ColumnRef>) {
columns.insert(self.column_ref.clone());
}
}