use super::DynProofPlan;
use crate::{
base::{
database::{
filter_util::filter_columns, ColumnField, ColumnRef, LiteralValue, Table,
TableEvaluation, TableOptions, TableRef,
},
map::{IndexMap, IndexSet},
proof::{PlaceholderResult, ProofError},
scalar::Scalar,
},
sql::{
proof::{
FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, VerificationBuilder,
},
proof_gadgets::{final_round_evaluate_filter, verify_evaluate_filter},
proof_plans::fold_vals,
},
utils::log,
};
use alloc::{boxed::Box, vec::Vec};
use bumpalo::Bump;
use core::iter::repeat;
use itertools::repeat_n;
use serde::{Deserialize, Serialize};
use sqlparser::ast::Ident;
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
pub struct SliceExec {
pub(super) input: Box<DynProofPlan>,
pub(super) skip: usize,
pub(super) fetch: Option<usize>,
}
fn get_slice_select(num_rows: usize, skip: usize, fetch: Option<usize>) -> Vec<bool> {
repeat_n(false, skip)
.chain(repeat_n(true, fetch.unwrap_or(num_rows)))
.chain(repeat(false))
.take(num_rows)
.collect()
}
impl SliceExec {
pub fn new(input: Box<DynProofPlan>, skip: usize, fetch: Option<usize>) -> Self {
Self { input, skip, fetch }
}
pub fn input(&self) -> &DynProofPlan {
&self.input
}
pub fn skip(&self) -> usize {
self.skip
}
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
}
impl ProofPlan for SliceExec
where
SliceExec: ProverEvaluate,
{
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut impl VerificationBuilder<S>,
accessor: &IndexMap<TableRef, IndexMap<Ident, S>>,
chi_eval_map: &IndexMap<TableRef, (S, usize)>,
params: &[LiteralValue],
) -> Result<TableEvaluation<S>, ProofError> {
let input_table_eval =
self.input
.verifier_evaluate(builder, accessor, chi_eval_map, params)?;
let (input_eval, input_length) = input_table_eval.chi();
let (output_chi_eval, output_length) = builder.try_consume_chi_evaluation()?;
let columns_evals = input_table_eval.column_evals();
let (offset_chi_eval, offset) = builder.try_consume_chi_evaluation()?;
let (max_chi_eval, max) = builder.try_consume_chi_evaluation()?;
if output_length != max - offset {
return Err(ProofError::VerificationError {
error: "output length does not match selection length",
});
}
if self.skip.min(input_length) != offset {
return Err(ProofError::VerificationError {
error: "offset length does not match plan value",
});
}
if max
!= self
.fetch
.map_or(input_length, |f| (f + self.skip).min(input_length))
{
return Err(ProofError::VerificationError {
error: "max length does not match expected value",
});
}
let selection_eval = max_chi_eval - offset_chi_eval;
let filtered_columns_evals =
builder.try_consume_first_round_mle_evaluations(columns_evals.len())?;
let alpha = builder.try_consume_post_result_challenge()?;
let beta = builder.try_consume_post_result_challenge()?;
let c_fold_eval = alpha * fold_vals(beta, columns_evals);
let d_fold_eval = alpha * fold_vals(beta, &filtered_columns_evals);
verify_evaluate_filter(
builder,
c_fold_eval,
d_fold_eval,
input_eval,
output_chi_eval,
selection_eval,
)?;
Ok(TableEvaluation::new(
filtered_columns_evals,
(output_chi_eval, output_length),
))
}
fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.input.get_column_result_fields()
}
fn get_column_references(&self) -> IndexSet<ColumnRef> {
self.input.get_column_references()
}
fn get_table_references(&self) -> IndexSet<TableRef> {
self.input.get_table_references()
}
}
impl ProverEvaluate for SliceExec {
#[tracing::instrument(name = "SliceExec::first_round_evaluate", level = "debug", skip_all)]
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
params: &[LiteralValue],
) -> PlaceholderResult<Table<'a, S>> {
log::log_memory_usage("Start");
let input = self
.input
.first_round_evaluate(builder, alloc, table_map, params)?;
let input_length = input.num_rows();
let columns = input.columns().copied().collect::<Vec<_>>();
let select = get_slice_select(input_length, self.skip, self.fetch);
let offset_index = self.skip.min(input_length);
let max_index = if let Some(fetch) = self.fetch {
(self.skip + fetch).min(input_length)
} else {
input_length
};
let output_length = max_index - offset_index;
let (filtered_columns, _) = filter_columns(alloc, &columns, &select);
filtered_columns.iter().copied().for_each(|column| {
builder.produce_intermediate_mle(column);
});
let res = Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
builder.request_post_result_challenges(2);
builder.produce_chi_evaluation_length(output_length);
builder.produce_chi_evaluation_length(offset_index);
builder.produce_chi_evaluation_length(max_index);
log::log_memory_usage("End");
Ok(res)
}
#[tracing::instrument(name = "SliceExec::final_round_evaluate", level = "debug", skip_all)]
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
params: &[LiteralValue],
) -> PlaceholderResult<Table<'a, S>> {
log::log_memory_usage("Start");
let input = self
.input
.final_round_evaluate(builder, alloc, table_map, params)?;
let columns = input.columns().copied().collect::<Vec<_>>();
let select = get_slice_select(input.num_rows(), self.skip, self.fetch);
let select_ref: &'a [_] = alloc.alloc_slice_copy(&select);
let output_length = select.iter().filter(|b| **b).count();
let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select);
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();
final_round_evaluate_filter::<S>(
builder,
alloc,
alpha,
beta,
&columns,
select_ref,
&filtered_columns,
input.num_rows(),
result_len,
);
let res = Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
log::log_memory_usage("End");
Ok(res)
}
}