proof_of_sql_planner/
conversion.rs1use crate::{
2 logical_plan_to_proof_plan, logical_plan_to_proof_plan_with_postprocessing, PlannerResult,
3 PoSqlContextProvider, ProofPlanWithPostprocessing,
4};
5use alloc::{sync::Arc, vec::Vec};
6use datafusion::{
7 config::ConfigOptions,
8 logical_expr::LogicalPlan,
9 optimizer::{Analyzer, Optimizer, OptimizerContext, OptimizerRule},
10 sql::planner::{ParserOptions, SqlToRel},
11};
12use indexmap::IndexSet;
13use proof_of_sql::{
14 base::database::{ParseError, SchemaAccessor, TableRef},
15 sql::proof_plans::DynProofPlan,
16};
17use sqlparser::ast::{visit_relations, Statement};
18use std::ops::ControlFlow;
19
20pub fn optimizer() -> Optimizer {
26 let recommended_rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> = Optimizer::new().rules;
28
29 let filtered_rules = recommended_rules
31 .into_iter()
32 .filter(|rule| rule.name() != "common_sub_expression_eliminate")
33 .collect::<Vec<_>>();
34
35 Optimizer::with_rules(filtered_rules)
37}
38
39fn sql_to_posql_plans<T, F, A>(
48 statements: &[Statement],
49 schemas: &A,
50 config: &ConfigOptions,
51 planner_converter: F,
52) -> PlannerResult<Vec<T>>
53where
54 F: Fn(&LogicalPlan, &A) -> PlannerResult<T>,
55 A: SchemaAccessor + Clone,
56{
57 let context_provider = PoSqlContextProvider::new(schemas.clone());
58 statements
60 .iter()
61 .map(|ast| -> PlannerResult<T> {
62 let raw_logical_plan = SqlToRel::new_with_options(
64 &context_provider,
65 ParserOptions {
66 parse_float_as_decimal: config.sql_parser.parse_float_as_decimal,
67 enable_ident_normalization: config.sql_parser.enable_ident_normalization,
68 },
69 )
70 .sql_statement_to_plan(ast.clone())?;
71 let analyzer = Analyzer::new();
73 let analyzed_logical_plan =
74 analyzer.execute_and_check(raw_logical_plan, config, |_, _| {})?;
75 let optimizer = optimizer();
77 let optimizer_context = OptimizerContext::default();
78 let optimized_logical_plan =
79 optimizer.optimize(analyzed_logical_plan, &optimizer_context, |_, _| {})?;
80 planner_converter(&optimized_logical_plan, schemas)
82 })
83 .collect::<PlannerResult<Vec<_>>>()
84}
85
86pub fn sql_to_proof_plans<A: SchemaAccessor + Clone>(
90 statements: &[Statement],
91 schemas: &A,
92 config: &ConfigOptions,
93) -> PlannerResult<Vec<DynProofPlan>> {
94 sql_to_posql_plans(statements, schemas, config, logical_plan_to_proof_plan)
95}
96
97pub fn sql_to_proof_plans_with_postprocessing<A: SchemaAccessor + Clone>(
101 statements: &[Statement],
102 schemas: &A,
103 config: &ConfigOptions,
104) -> PlannerResult<Vec<ProofPlanWithPostprocessing>> {
105 sql_to_posql_plans(
106 statements,
107 schemas,
108 config,
109 logical_plan_to_proof_plan_with_postprocessing,
110 )
111}
112
113pub fn get_table_refs_from_statement(
115 statement: &Statement,
116) -> Result<IndexSet<TableRef>, ParseError> {
117 let mut table_refs: IndexSet<TableRef> = IndexSet::<TableRef>::new();
118 visit_relations(statement, |object_name| {
119 match object_name.to_string().as_str().try_into() {
120 Ok(table_ref) => {
121 table_refs.insert(table_ref);
122 ControlFlow::Continue(())
123 }
124 e => ControlFlow::Break(e),
125 }
126 })
127 .break_value()
128 .transpose()?;
129 Ok(table_refs)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::get_table_refs_from_statement;
135 use indexmap::IndexSet;
136 use proof_of_sql::base::database::TableRef;
137 use sqlparser::{dialect::GenericDialect, parser::Parser};
138
139 #[test]
140 fn we_can_get_table_references() {
141 let statement = Parser::parse_sql(
142 &GenericDialect {},
143 "SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
144FROM employees e
145JOIN departments d ON e.department_id = d.department_id
146JOIN management.projects p ON e.employee_id = p.employee_id
147JOIN internal.salaries s ON e.employee_id = s.employee_id
148WHERE e.department_id IN (
149 SELECT department_id
150 FROM departments
151 WHERE department_name = 'Sales'
152)
153AND p.project_id IN (
154 SELECT project_id
155 FROM project_assignments
156 WHERE employee_id = e.employee_id
157)
158AND s.salary > (
159 SELECT AVG(salary)
160 FROM internal.salaries
161 WHERE department_id = e.department_id
162);
163",
164 )
165 .unwrap()[0]
166 .clone();
167 let table_refs = get_table_refs_from_statement(&statement).unwrap();
168 let expected_table_refs: IndexSet<TableRef> = [
169 ("", "departments"),
170 ("", "employees"),
171 ("management", "projects"),
172 ("", "project_assignments"),
173 ("internal", "salaries"),
174 ]
175 .map(|(s, t)| TableRef::new(s, t))
176 .into_iter()
177 .collect();
178 assert_eq!(table_refs, expected_table_refs);
179 }
180}