optd_sqlplannertest/
lib.rs

1// Copyright (c) 2023-2024 CMU Database Group
2//
3// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
4// https://opensource.org/licenses/MIT.
5
6use std::collections::HashSet;
7use std::sync::Arc;
8
9use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
10use datafusion::catalog::CatalogList;
11use datafusion::execution::context::{SessionConfig, SessionState};
12use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
13use datafusion::prelude::SessionContext;
14use datafusion::sql::parser::DFParser;
15use datafusion::sql::sqlparser::dialect::GenericDialect;
16use datafusion_optd_cli::helper::unescape_input;
17use itertools::Itertools;
18use lazy_static::lazy_static;
19use mimalloc::MiMalloc;
20use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
21use optd_datafusion_repr::DatafusionOptimizer;
22use optd_datafusion_repr_adv_cost::adv_stats::stats::DataFusionBaseTableStats;
23use optd_datafusion_repr_adv_cost::new_physical_adv_cost;
24use regex::Regex;
25
26#[global_allocator]
27static GLOBAL: MiMalloc = MiMalloc;
28
29use anyhow::{bail, Result};
30use async_trait::async_trait;
31
32#[derive(Default)]
33pub struct DatafusionDBMS {
34    ctx: SessionContext,
35    /// Context enabling datafusion's logical optimizer.
36    use_df_logical_ctx: SessionContext,
37    /// Shared optd optimizer (for tweaking config)
38    optd_optimizer: Option<Arc<OptdQueryPlanner>>,
39}
40
41impl DatafusionDBMS {
42    pub async fn new() -> Result<Self> {
43        let (ctx, optd_optimizer) = DatafusionDBMS::new_session_ctx(false, None, false).await?;
44        let (use_df_logical_ctx, _) =
45            Self::new_session_ctx(true, Some(ctx.state().catalog_list().clone()), false).await?;
46        Ok(Self {
47            ctx,
48            use_df_logical_ctx,
49            optd_optimizer: Some(optd_optimizer),
50        })
51    }
52
53    pub async fn new_advanced_cost() -> Result<Self> {
54        let (ctx, optd_optimizer) = DatafusionDBMS::new_session_ctx(false, None, true).await?;
55        let (use_df_logical_ctx, _) =
56            Self::new_session_ctx(true, Some(ctx.state().catalog_list().clone()), true).await?;
57        Ok(Self {
58            ctx,
59            use_df_logical_ctx,
60            optd_optimizer: Some(optd_optimizer),
61        })
62    }
63
64    /// Creates a new session context. If the `use_df_logical` flag is set, datafusion's logical
65    /// optimizer will be used.
66    async fn new_session_ctx(
67        use_df_logical: bool,
68        catalog: Option<Arc<dyn CatalogList>>,
69        with_advanced_cost: bool,
70    ) -> Result<(SessionContext, Arc<OptdQueryPlanner>)> {
71        let mut session_config = SessionConfig::from_env()?.with_information_schema(true);
72        if !use_df_logical {
73            session_config.options_mut().optimizer.max_passes = 0;
74        }
75
76        let rn_config = RuntimeConfig::new();
77        let runtime_env = RuntimeEnv::new(rn_config.clone())?;
78        let optd_optimizer;
79
80        let ctx = {
81            let mut state = if let Some(catalog) = catalog {
82                SessionState::new_with_config_rt_and_catalog_list(
83                    session_config.clone(),
84                    Arc::new(runtime_env),
85                    catalog,
86                )
87            } else {
88                SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env))
89            };
90            let optimizer = if with_advanced_cost {
91                new_physical_adv_cost(
92                    Arc::new(DatafusionCatalog::new(state.catalog_list())),
93                    DataFusionBaseTableStats::default(),
94                    false,
95                )
96            } else {
97                DatafusionOptimizer::new_physical(
98                    Arc::new(DatafusionCatalog::new(state.catalog_list())),
99                    false,
100                )
101            };
102            if !use_df_logical {
103                // clean up optimizer rules so that we can plug in our own optimizer
104                state = state.with_optimizer_rules(vec![]);
105            }
106            state = state.with_physical_optimizer_rules(vec![]);
107            // use optd-bridge query planner
108            optd_optimizer = Arc::new(OptdQueryPlanner::new(optimizer));
109            state = state.with_query_planner(optd_optimizer.clone());
110            SessionContext::new_with_state(state)
111        };
112        ctx.refresh_catalogs().await?;
113        Ok((ctx, optd_optimizer))
114    }
115
116    pub(crate) async fn execute(&self, sql: &str, flags: &TestFlags) -> Result<Vec<Vec<String>>> {
117        {
118            let mut guard = self
119                .optd_optimizer
120                .as_ref()
121                .unwrap()
122                .optimizer
123                .lock()
124                .unwrap();
125            let optimizer = guard.as_mut().unwrap().optd_optimizer_mut();
126            if flags.panic_on_budget {
127                optimizer.panic_on_explore_limit(true);
128            } else {
129                optimizer.panic_on_explore_limit(false);
130            }
131            if flags.disable_pruning {
132                optimizer.disable_pruning(true);
133            } else {
134                optimizer.disable_pruning(false);
135            }
136            let rules = optimizer.rules();
137            if flags.enable_logical_rules.is_empty() {
138                for r in 0..rules.len() {
139                    optimizer.enable_rule(r);
140                }
141                guard.as_mut().unwrap().enable_heuristic(true);
142            } else {
143                for (rule_id, rule) in rules.as_ref().iter().enumerate() {
144                    if rule.is_impl_rule() {
145                        optimizer.enable_rule(rule_id);
146                    } else {
147                        optimizer.disable_rule(rule_id);
148                    }
149                }
150                let mut rules_to_enable = flags
151                    .enable_logical_rules
152                    .iter()
153                    .map(|x| x.as_str())
154                    .collect::<HashSet<_>>();
155                for (rule_id, rule) in rules.as_ref().iter().enumerate() {
156                    if rules_to_enable.remove(rule.name()) {
157                        optimizer.enable_rule(rule_id);
158                    }
159                }
160                if !rules_to_enable.is_empty() {
161                    bail!("Unknown logical rule: {:?}", rules_to_enable);
162                }
163                guard.as_mut().unwrap().enable_heuristic(false);
164            }
165        }
166        let sql = unescape_input(sql)?;
167        let dialect = Box::new(GenericDialect);
168        let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
169        let mut result = Vec::new();
170        for statement in statements {
171            let df = if flags.enable_df_logical {
172                let plan = self
173                    .use_df_logical_ctx
174                    .state()
175                    .statement_to_plan(statement)
176                    .await?;
177                self.use_df_logical_ctx.execute_logical_plan(plan).await?
178            } else {
179                let plan = self.ctx.state().statement_to_plan(statement).await?;
180
181                self.ctx.execute_logical_plan(plan).await?
182            };
183
184            let batches = df.collect().await?;
185
186            let options = FormatOptions::default();
187
188            for batch in batches {
189                let converters = batch
190                    .columns()
191                    .iter()
192                    .map(|a| ArrayFormatter::try_new(a.as_ref(), &options))
193                    .collect::<Result<Vec<_>, _>>()?;
194                for row_idx in 0..batch.num_rows() {
195                    let mut row = Vec::with_capacity(batch.num_columns());
196                    for converter in converters.iter() {
197                        let mut buffer = String::with_capacity(8);
198                        converter.value(row_idx).write(&mut buffer)?;
199                        row.push(buffer);
200                    }
201                    result.push(row);
202                }
203            }
204        }
205        if flags.dump_memo_table {
206            let mut guard = self
207                .optd_optimizer
208                .as_ref()
209                .unwrap()
210                .optimizer
211                .lock()
212                .unwrap();
213            let optimizer = guard.as_mut().unwrap().optd_optimizer_mut();
214            optimizer.dump();
215        }
216        Ok(result)
217    }
218
219    /// Executes the `execute` task.
220    async fn task_execute(&mut self, r: &mut String, sql: &str, flags: &TestFlags) -> Result<()> {
221        use std::fmt::Write;
222        if flags.verbose {
223            bail!("Verbose flag is not supported for execute task");
224        }
225        let result = self.execute(sql, flags).await?;
226        writeln!(r, "{}", result.into_iter().map(|x| x.join(" ")).join("\n"))?;
227        writeln!(r)?;
228        Ok(())
229    }
230
231    /// Executes the `explain` task.
232    async fn task_explain(
233        &mut self,
234        r: &mut String,
235        sql: &str,
236        task: &str,
237        flags: &TestFlags,
238    ) -> Result<()> {
239        use std::fmt::Write;
240
241        let verbose = flags.verbose;
242        let explain_sql = if verbose {
243            format!("explain verbose {}", &sql)
244        } else {
245            format!("explain {}", &sql)
246        };
247        let result = self.execute(&explain_sql, flags).await?;
248        let subtask_start_pos = task.rfind(':').unwrap() + 1;
249        for subtask in task[subtask_start_pos..].split(',') {
250            let subtask = subtask.trim();
251            if subtask == "logical_datafusion" {
252                writeln!(
253                    r,
254                    "{}",
255                    result
256                        .iter()
257                        .find(|x| x[0] == "logical_plan after datafusion")
258                        .map(|x| &x[1])
259                        .unwrap()
260                )?;
261            } else if subtask == "logical_optd_heuristic" || subtask == "optimized_logical_optd" {
262                writeln!(
263                    r,
264                    "{}",
265                    result
266                        .iter()
267                        .find(|x| x[0] == "logical_plan after optd-heuristic")
268                        .map(|x| &x[1])
269                        .unwrap()
270                )?;
271            } else if subtask == "logical_optd" {
272                writeln!(
273                    r,
274                    "{}",
275                    result
276                        .iter()
277                        .find(|x| x[0] == "logical_plan after optd")
278                        .map(|x| &x[1])
279                        .unwrap()
280                )?;
281            } else if subtask == "physical_optd" {
282                writeln!(
283                    r,
284                    "{}",
285                    result
286                        .iter()
287                        .find(|x| x[0] == "physical_plan after optd")
288                        .map(|x| &x[1])
289                        .unwrap()
290                )?;
291            } else if subtask == "logical_join_orders" {
292                writeln!(
293                    r,
294                    "{}",
295                    result
296                        .iter()
297                        .find(|x| x[0] == "physical_plan after optd-all-logical-join-orders")
298                        .map(|x| &x[1])
299                        .unwrap()
300                )?;
301                writeln!(r)?;
302            } else if subtask == "physical_datafusion" {
303                writeln!(
304                    r,
305                    "{}",
306                    result
307                        .iter()
308                        .find(|x| x[0] == "physical_plan")
309                        .map(|x| &x[1])
310                        .unwrap()
311                )?;
312            } else {
313                bail!("Unknown subtask: {}", subtask);
314            }
315        }
316
317        Ok(())
318    }
319}
320
321#[async_trait]
322impl sqlplannertest::PlannerTestRunner for DatafusionDBMS {
323    async fn run(&mut self, test_case: &sqlplannertest::ParsedTestCase) -> Result<String> {
324        if !test_case.before_sql.is_empty() {
325            panic!("before is not supported in optd-sqlplannertest, always specify the task type to run");
326        }
327
328        let mut result = String::new();
329        let r = &mut result;
330        for task in &test_case.tasks {
331            let flags = extract_flags(task)?;
332            if task.starts_with("execute") {
333                self.task_execute(r, &test_case.sql, &flags).await?;
334            } else if task.starts_with("explain") {
335                self.task_explain(r, &test_case.sql, task, &flags).await?;
336            }
337        }
338        Ok(result)
339    }
340}
341
342lazy_static! {
343    static ref FLAGS_REGEX: Regex = Regex::new(r"\[(.*)\]").unwrap();
344}
345
346#[derive(Default, Debug)]
347struct TestFlags {
348    verbose: bool,
349    enable_df_logical: bool,
350    enable_logical_rules: Vec<String>,
351    panic_on_budget: bool,
352    dump_memo_table: bool,
353    disable_pruning: bool,
354}
355
356/// Extract the flags from a task. The flags are specified in square brackets.
357/// For example, the flags for the task `explain[use_df_logical, verbose]` are `["use_df_logical",
358/// "verbose"]`.
359fn extract_flags(task: &str) -> Result<TestFlags> {
360    if let Some(captures) = FLAGS_REGEX.captures(task) {
361        let flags = captures
362            .get(1)
363            .unwrap()
364            .as_str()
365            .split(',')
366            .map(|x| x.trim().to_string())
367            .collect_vec();
368        let mut options = TestFlags::default();
369        for flag in flags {
370            if flag == "verbose" {
371                options.verbose = true;
372            } else if flag == "use_df_logical" {
373                options.enable_df_logical = true;
374            } else if flag.starts_with("logical_rules") {
375                if let Some((_, flag)) = flag.split_once(':') {
376                    options.enable_logical_rules = flag.split('+').map(|x| x.to_string()).collect();
377                } else {
378                    bail!("Failed to parse logical_rules flag: {}", flag);
379                }
380            } else if flag == "panic_on_budget" {
381                options.panic_on_budget = true;
382            } else if flag == "dump_memo_table" {
383                options.dump_memo_table = true;
384            } else if flag == "disable_pruning" {
385                options.disable_pruning = true;
386            } else {
387                bail!("Unknown flag: {}", flag);
388            }
389        }
390        Ok(options)
391    } else {
392        Ok(TestFlags::default())
393    }
394}