proof_of_sql_planner/
conversion.rs

1use crate::{logical_plan_to_proof_plan, PlannerResult, PoSqlContextProvider};
2use alloc::{sync::Arc, vec::Vec};
3use datafusion::{
4    config::ConfigOptions,
5    logical_expr::LogicalPlan,
6    optimizer::{Analyzer, Optimizer, OptimizerContext, OptimizerRule},
7    sql::planner::{ParserOptions, SqlToRel},
8};
9use indexmap::IndexSet;
10use proof_of_sql::{
11    base::database::{ParseError, SchemaAccessor, TableRef},
12    sql::proof_plans::DynProofPlan,
13};
14use sqlparser::ast::{visit_relations, Statement};
15use std::ops::ControlFlow;
16
17/// Get [`Optimizer`]
18///
19/// In order to support queries such as `select $1::varchar;` we have to temporarily disable
20/// [`CommonSubexprEliminate`] rule in the optimizer in `DataFusion` 38. Once we upgrade to
21/// `DataFusion` 46 we can remove this function and use `Optimizer::new()` directly.
22pub fn optimizer() -> Optimizer {
23    // Step 1: Grab the recommended set
24    let recommended_rules: Vec<Arc<dyn OptimizerRule + Send + Sync>> = Optimizer::new().rules;
25
26    // Step 2: Filter out [`CommonSubexprEliminate`]
27    let filtered_rules = recommended_rules
28        .into_iter()
29        .filter(|rule| rule.name() != "common_sub_expression_eliminate")
30        .collect::<Vec<_>>();
31
32    // Step 3: Build an optimizer with the new list
33    Optimizer::with_rules(filtered_rules)
34}
35
36/// Convert a SQL query to a Proof of SQL plan using schema from provided tables
37///
38/// This function does the following
39/// 1. Parse the SQL query into AST using sqlparser
40/// 2. Convert the AST into a `LogicalPlan` using `SqlToRel`
41/// 3. Analyze the `LogicalPlan` using `Analyzer`
42/// 4. Optimize the `LogicalPlan` using `Optimizer`
43/// 5. Convert the optimized `LogicalPlan` into a Proof of SQL plan
44fn sql_to_posql_plans<T, F, A>(
45    statements: &[Statement],
46    schemas: &A,
47    config: &ConfigOptions,
48    planner_converter: F,
49) -> PlannerResult<Vec<T>>
50where
51    F: Fn(&LogicalPlan, &A) -> PlannerResult<T>,
52    A: SchemaAccessor + Clone,
53{
54    let context_provider = PoSqlContextProvider::new(schemas.clone());
55    // 1. Parse the SQL query into AST using sqlparser
56    statements
57        .iter()
58        .map(|ast| -> PlannerResult<T> {
59            // 2. Convert the AST into a `LogicalPlan` using `SqlToRel`
60            let raw_logical_plan = SqlToRel::new_with_options(
61                &context_provider,
62                ParserOptions {
63                    parse_float_as_decimal: config.sql_parser.parse_float_as_decimal,
64                    enable_ident_normalization: config.sql_parser.enable_ident_normalization,
65                },
66            )
67            .sql_statement_to_plan(ast.clone())?;
68            // 3. Analyze the `LogicalPlan` using `Analyzer`
69            let analyzer = Analyzer::new();
70            let analyzed_logical_plan =
71                analyzer.execute_and_check(raw_logical_plan, config, |_, _| {})?;
72            // 4. Optimize the `LogicalPlan` using `Optimizer`
73            let optimizer = optimizer();
74            let optimizer_context = OptimizerContext::default();
75            let optimized_logical_plan =
76                optimizer.optimize(analyzed_logical_plan, &optimizer_context, |_, _| {})?;
77            // 5. Convert the optimized `LogicalPlan` into a Proof of SQL plan
78            planner_converter(&optimized_logical_plan, schemas)
79        })
80        .collect::<PlannerResult<Vec<_>>>()
81}
82
83/// Convert a SQL query to a `DynProofPlan` using schema from provided tables
84///
85/// See `sql_to_posql_plans` for more details
86pub fn sql_to_proof_plans<A: SchemaAccessor + Clone>(
87    statements: &[Statement],
88    schemas: &A,
89    config: &ConfigOptions,
90) -> PlannerResult<Vec<DynProofPlan>> {
91    sql_to_posql_plans(statements, schemas, config, logical_plan_to_proof_plan)
92}
93
94/// Given a `Statement` retrieves all unique tables in the query
95pub fn get_table_refs_from_statement(
96    statement: &Statement,
97) -> Result<IndexSet<TableRef>, ParseError> {
98    let mut table_refs: IndexSet<TableRef> = IndexSet::<TableRef>::new();
99    visit_relations(statement, |object_name| {
100        match object_name.to_string().as_str().try_into() {
101            Ok(table_ref) => {
102                table_refs.insert(table_ref);
103                ControlFlow::Continue(())
104            }
105            e => ControlFlow::Break(e),
106        }
107    })
108    .break_value()
109    .transpose()?;
110    Ok(table_refs)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::get_table_refs_from_statement;
116    use indexmap::IndexSet;
117    use proof_of_sql::base::database::TableRef;
118    use sqlparser::{dialect::GenericDialect, parser::Parser};
119
120    #[test]
121    fn we_can_get_table_references() {
122        let statement = Parser::parse_sql(
123            &GenericDialect {},
124            "SELECT e.employee_id, e.employee_name, d.department_name, p.project_name, s.salary
125FROM employees e
126JOIN departments d ON e.department_id = d.department_id
127JOIN management.projects p ON e.employee_id = p.employee_id
128JOIN internal.salaries s ON e.employee_id = s.employee_id
129WHERE e.department_id IN (
130    SELECT department_id
131    FROM departments
132    WHERE department_name = 'Sales'
133)
134AND p.project_id IN (
135    SELECT project_id
136    FROM project_assignments
137    WHERE employee_id = e.employee_id
138)
139AND s.salary > (
140    SELECT AVG(salary)
141    FROM internal.salaries
142    WHERE department_id = e.department_id
143);
144",
145        )
146        .unwrap()[0]
147            .clone();
148        let table_refs = get_table_refs_from_statement(&statement).unwrap();
149        let expected_table_refs: IndexSet<TableRef> = [
150            ("", "departments"),
151            ("", "employees"),
152            ("management", "projects"),
153            ("", "project_assignments"),
154            ("internal", "salaries"),
155        ]
156        .map(|(s, t)| TableRef::new(s, t))
157        .into_iter()
158        .collect();
159        assert_eq!(table_refs, expected_table_refs);
160    }
161}