proof_of_sql_planner/
conversion.rs1use crate::{logical_plan_to_proof_plan, PlannerResult, PoSqlContextProvider};
2use alloc::{sync::Arc, vec::Vec};
3use datafusion::{
4 config::ConfigOptions,
5 logical_expr::LogicalPlan,
6 optimizer::{Analyzer, Optimizer, OptimizerContext, OptimizerRule},
7 sql::planner::{ParserOptions, SqlToRel},
8};
9use indexmap::IndexSet;
10use proof_of_sql::{
11 base::database::{ParseError, SchemaAccessor, TableRef},
12 sql::proof_plans::DynProofPlan,
13};
14use sqlparser::ast::{visit_relations, Statement};
15use std::ops::ControlFlow;
16
17pub fn optimizer() -> Optimizer {
23 let recommended_rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> = Optimizer::new().rules;
25
26 let filtered_rules = recommended_rules
28 .into_iter()
29 .filter(|rule| rule.name() != "common_sub_expression_eliminate")
30 .collect::<Vec<_>>();
31
32 Optimizer::with_rules(filtered_rules)
34}
35
36fn sql_to_posql_plans<T, F, A>(
45 statements: &[Statement],
46 schemas: &A,
47 config: &ConfigOptions,
48 planner_converter: F,
49) -> PlannerResult<Vec<T>>
50where
51 F: Fn(&LogicalPlan, &A) -> PlannerResult<T>,
52 A: SchemaAccessor + Clone,
53{
54 let context_provider = PoSqlContextProvider::new(schemas.clone());
55 statements
57 .iter()
58 .map(|ast| -> PlannerResult<T> {
59 let raw_logical_plan = SqlToRel::new_with_options(
61 &context_provider,
62 ParserOptions {
63 parse_float_as_decimal: config.sql_parser.parse_float_as_decimal,
64 enable_ident_normalization: config.sql_parser.enable_ident_normalization,
65 },
66 )
67 .sql_statement_to_plan(ast.clone())?;
68 let analyzer = Analyzer::new();
70 let analyzed_logical_plan =
71 analyzer.execute_and_check(raw_logical_plan, config, |_, _| {})?;
72 let optimizer = optimizer();
74 let optimizer_context = OptimizerContext::default();
75 let optimized_logical_plan =
76 optimizer.optimize(analyzed_logical_plan, &optimizer_context, |_, _| {})?;
77 planner_converter(&optimized_logical_plan, schemas)
79 })
80 .collect::<PlannerResult<Vec<_>>>()
81}
82
83pub fn sql_to_proof_plans<A: SchemaAccessor + Clone>(
87 statements: &[Statement],
88 schemas: &A,
89 config: &ConfigOptions,
90) -> PlannerResult<Vec<DynProofPlan>> {
91 sql_to_posql_plans(statements, schemas, config, logical_plan_to_proof_plan)
92}
93
94pub fn get_table_refs_from_statement(
96 statement: &Statement,
97) -> Result<IndexSet<TableRef>, ParseError> {
98 let mut table_refs: IndexSet<TableRef> = IndexSet::<TableRef>::new();
99 visit_relations(statement, |object_name| {
100 match object_name.to_string().as_str().try_into() {
101 Ok(table_ref) => {
102 table_refs.insert(table_ref);
103 ControlFlow::Continue(())
104 }
105 e => ControlFlow::Break(e),
106 }
107 })
108 .break_value()
109 .transpose()?;
110 Ok(table_refs)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::get_table_refs_from_statement;
116 use indexmap::IndexSet;
117 use proof_of_sql::base::database::TableRef;
118 use sqlparser::{dialect::GenericDialect, parser::Parser};
119
120 #[test]
121 fn we_can_get_table_references() {
122 let statement = Parser::parse_sql(
123 &GenericDialect {},
124 "SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
125FROM employees e
126JOIN departments d ON e.department_id = d.department_id
127JOIN management.projects p ON e.employee_id = p.employee_id
128JOIN internal.salaries s ON e.employee_id = s.employee_id
129WHERE e.department_id IN (
130 SELECT department_id
131 FROM departments
132 WHERE department_name = 'Sales'
133)
134AND p.project_id IN (
135 SELECT project_id
136 FROM project_assignments
137 WHERE employee_id = e.employee_id
138)
139AND s.salary > (
140 SELECT AVG(salary)
141 FROM internal.salaries
142 WHERE department_id = e.department_id
143);
144",
145 )
146 .unwrap()[0]
147 .clone();
148 let table_refs = get_table_refs_from_statement(&statement).unwrap();
149 let expected_table_refs: IndexSet<TableRef> = [
150 ("", "departments"),
151 ("", "employees"),
152 ("management", "projects"),
153 ("", "project_assignments"),
154 ("internal", "salaries"),
155 ]
156 .map(|(s, t)| TableRef::new(s, t))
157 .into_iter()
158 .collect();
159 assert_eq!(table_refs, expected_table_refs);
160 }
161}