proof_of_sql_planner/
conversion.rs

1use crate::{
2    logical_plan_to_proof_plan, logical_plan_to_proof_plan_with_postprocessing, PlannerResult,
3    PoSqlContextProvider, ProofPlanWithPostprocessing,
4};
5use alloc::{sync::Arc, vec::Vec};
6use datafusion::{
7    config::ConfigOptions,
8    logical_expr::LogicalPlan,
9    optimizer::{Analyzer, Optimizer, OptimizerContext, OptimizerRule},
10    sql::planner::{ParserOptions, SqlToRel},
11};
12use indexmap::IndexSet;
13use proof_of_sql::{
14    base::database::{ParseError, SchemaAccessor, TableRef},
15    sql::proof_plans::DynProofPlan,
16};
17use sqlparser::ast::{visit_relations, Statement};
18use std::ops::ControlFlow;
19
20/// Get [`Optimizer`]
21///
22/// In order to support queries such as `select $1::varchar;` we have to temporarily disable
23/// [`CommonSubexprEliminate`] rule in the optimizer in `DataFusion` 38. Once we upgrade to
24/// `DataFusion` 46 we can remove this function and use `Optimizer::new()` directly.
25pub fn optimizer() -> Optimizer {
26    // Step 1: Grab the recommended set
27    let recommended_rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> = Optimizer::new().rules;
28
29    // Step 2: Filter out [`CommonSubexprEliminate`]
30    let filtered_rules = recommended_rules
31        .into_iter()
32        .filter(|rule| rule.name() != "common_sub_expression_eliminate")
33        .collect::<Vec<_>>();
34
35    // Step 3: Build an optimizer with the new list
36    Optimizer::with_rules(filtered_rules)
37}
38
39/// Convert a SQL query to a Proof of SQL plan using schema from provided tables
40///
41/// This function does the following
42/// 1. Parse the SQL query into AST using sqlparser
43/// 2. Convert the AST into a `LogicalPlan` using `SqlToRel`
44/// 3. Analyze the `LogicalPlan` using `Analyzer`
45/// 4. Optimize the `LogicalPlan` using `Optimizer`
46/// 5. Convert the optimized `LogicalPlan` into a Proof of SQL plan
47fn sql_to_posql_plans<T, F, A>(
48    statements: &[Statement],
49    schemas: &A,
50    config: &ConfigOptions,
51    planner_converter: F,
52) -> PlannerResult<Vec<T>>
53where
54    F: Fn(&LogicalPlan, &A) -> PlannerResult<T>,
55    A: SchemaAccessor + Clone,
56{
57    let context_provider = PoSqlContextProvider::new(schemas.clone());
58    // 1. Parse the SQL query into AST using sqlparser
59    statements
60        .iter()
61        .map(|ast| -> PlannerResult<T> {
62            // 2. Convert the AST into a `LogicalPlan` using `SqlToRel`
63            let raw_logical_plan = SqlToRel::new_with_options(
64                &context_provider,
65                ParserOptions {
66                    parse_float_as_decimal: config.sql_parser.parse_float_as_decimal,
67                    enable_ident_normalization: config.sql_parser.enable_ident_normalization,
68                },
69            )
70            .sql_statement_to_plan(ast.clone())?;
71            // 3. Analyze the `LogicalPlan` using `Analyzer`
72            let analyzer = Analyzer::new();
73            let analyzed_logical_plan =
74                analyzer.execute_and_check(raw_logical_plan, config, |_, _| {})?;
75            // 4. Optimize the `LogicalPlan` using `Optimizer`
76            let optimizer = optimizer();
77            let optimizer_context = OptimizerContext::default();
78            let optimized_logical_plan =
79                optimizer.optimize(analyzed_logical_plan, &optimizer_context, |_, _| {})?;
80            // 5. Convert the optimized `LogicalPlan` into a Proof of SQL plan
81            planner_converter(&optimized_logical_plan, schemas)
82        })
83        .collect::<PlannerResult<Vec<_>>>()
84}
85
86/// Convert a SQL query to a `DynProofPlan` using schema from provided tables
87///
88/// See `sql_to_posql_plans` for more details
89pub fn sql_to_proof_plans<A: SchemaAccessor + Clone>(
90    statements: &[Statement],
91    schemas: &A,
92    config: &ConfigOptions,
93) -> PlannerResult<Vec<DynProofPlan>> {
94    sql_to_posql_plans(statements, schemas, config, logical_plan_to_proof_plan)
95}
96
97/// Convert a SQL query to a `ProofPlanWithPostprocessing` using schema from provided tables
98///
99/// See `sql_to_posql_plans` for more details
100pub fn sql_to_proof_plans_with_postprocessing<A: SchemaAccessor + Clone>(
101    statements: &[Statement],
102    schemas: &A,
103    config: &ConfigOptions,
104) -> PlannerResult<Vec<ProofPlanWithPostprocessing>> {
105    sql_to_posql_plans(
106        statements,
107        schemas,
108        config,
109        logical_plan_to_proof_plan_with_postprocessing,
110    )
111}
112
113/// Given a `Statement` retrieves all unique tables in the query
114pub fn get_table_refs_from_statement(
115    statement: &Statement,
116) -> Result<IndexSet<TableRef>, ParseError> {
117    let mut table_refs: IndexSet<TableRef> = IndexSet::<TableRef>::new();
118    visit_relations(statement, |object_name| {
119        match object_name.to_string().as_str().try_into() {
120            Ok(table_ref) => {
121                table_refs.insert(table_ref);
122                ControlFlow::Continue(())
123            }
124            e => ControlFlow::Break(e),
125        }
126    })
127    .break_value()
128    .transpose()?;
129    Ok(table_refs)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::get_table_refs_from_statement;
135    use indexmap::IndexSet;
136    use proof_of_sql::base::database::TableRef;
137    use sqlparser::{dialect::GenericDialect, parser::Parser};
138
139    #[test]
140    fn we_can_get_table_references() {
141        let statement = Parser::parse_sql(
142            &GenericDialect {},
143            "SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
144FROM employees e
145JOIN departments d ON e.department_id = d.department_id
146JOIN management.projects p ON e.employee_id = p.employee_id
147JOIN internal.salaries s ON e.employee_id = s.employee_id
148WHERE e.department_id IN (
149    SELECT department_id
150    FROM departments
151    WHERE department_name = 'Sales'
152)
153AND p.project_id IN (
154    SELECT project_id
155    FROM project_assignments
156    WHERE employee_id = e.employee_id
157)
158AND s.salary > (
159    SELECT AVG(salary)
160    FROM internal.salaries
161    WHERE department_id = e.department_id
162);
163",
164        )
165        .unwrap()[0]
166            .clone();
167        let table_refs = get_table_refs_from_statement(&statement).unwrap();
168        let expected_table_refs: IndexSet<TableRef> = [
169            ("", "departments"),
170            ("", "employees"),
171            ("management", "projects"),
172            ("", "project_assignments"),
173            ("internal", "salaries"),
174        ]
175        .map(|(s, t)| TableRef::new(s, t))
176        .into_iter()
177        .collect();
178        assert_eq!(table_refs, expected_table_refs);
179    }
180}