1use std::collections::HashSet;
7use std::sync::Arc;
8
9use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
10use datafusion::catalog::CatalogList;
11use datafusion::execution::context::{SessionConfig, SessionState};
12use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
13use datafusion::prelude::SessionContext;
14use datafusion::sql::parser::DFParser;
15use datafusion::sql::sqlparser::dialect::GenericDialect;
16use datafusion_optd_cli::helper::unescape_input;
17use itertools::Itertools;
18use lazy_static::lazy_static;
19use mimalloc::MiMalloc;
20use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
21use optd_datafusion_repr::DatafusionOptimizer;
22use optd_datafusion_repr_adv_cost::adv_stats::stats::DataFusionBaseTableStats;
23use optd_datafusion_repr_adv_cost::new_physical_adv_cost;
24use regex::Regex;
25
26#[global_allocator]
27static GLOBAL: MiMalloc = MiMalloc;
28
29use anyhow::{bail, Result};
30use async_trait::async_trait;
31
32#[derive(Default)]
33pub struct DatafusionDBMS {
34 ctx: SessionContext,
35 use_df_logical_ctx: SessionContext,
37 optd_optimizer: Option<Arc<OptdQueryPlanner>>,
39}
40
41impl DatafusionDBMS {
42 pub async fn new() -> Result<Self> {
43 let (ctx, optd_optimizer) = DatafusionDBMS::new_session_ctx(false, None, false).await?;
44 let (use_df_logical_ctx, _) =
45 Self::new_session_ctx(true, Some(ctx.state().catalog_list().clone()), false).await?;
46 Ok(Self {
47 ctx,
48 use_df_logical_ctx,
49 optd_optimizer: Some(optd_optimizer),
50 })
51 }
52
53 pub async fn new_advanced_cost() -> Result<Self> {
54 let (ctx, optd_optimizer) = DatafusionDBMS::new_session_ctx(false, None, true).await?;
55 let (use_df_logical_ctx, _) =
56 Self::new_session_ctx(true, Some(ctx.state().catalog_list().clone()), true).await?;
57 Ok(Self {
58 ctx,
59 use_df_logical_ctx,
60 optd_optimizer: Some(optd_optimizer),
61 })
62 }
63
64 async fn new_session_ctx(
67 use_df_logical: bool,
68 catalog: Option<Arc<dyn CatalogList>>,
69 with_advanced_cost: bool,
70 ) -> Result<(SessionContext, Arc<OptdQueryPlanner>)> {
71 let mut session_config = SessionConfig::from_env()?.with_information_schema(true);
72 if !use_df_logical {
73 session_config.options_mut().optimizer.max_passes = 0;
74 }
75
76 let rn_config = RuntimeConfig::new();
77 let runtime_env = RuntimeEnv::new(rn_config.clone())?;
78 let optd_optimizer;
79
80 let ctx = {
81 let mut state = if let Some(catalog) = catalog {
82 SessionState::new_with_config_rt_and_catalog_list(
83 session_config.clone(),
84 Arc::new(runtime_env),
85 catalog,
86 )
87 } else {
88 SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env))
89 };
90 let optimizer = if with_advanced_cost {
91 new_physical_adv_cost(
92 Arc::new(DatafusionCatalog::new(state.catalog_list())),
93 DataFusionBaseTableStats::default(),
94 false,
95 )
96 } else {
97 DatafusionOptimizer::new_physical(
98 Arc::new(DatafusionCatalog::new(state.catalog_list())),
99 false,
100 )
101 };
102 if !use_df_logical {
103 state = state.with_optimizer_rules(vec![]);
105 }
106 state = state.with_physical_optimizer_rules(vec![]);
107 optd_optimizer = Arc::new(OptdQueryPlanner::new(optimizer));
109 state = state.with_query_planner(optd_optimizer.clone());
110 SessionContext::new_with_state(state)
111 };
112 ctx.refresh_catalogs().await?;
113 Ok((ctx, optd_optimizer))
114 }
115
116 pub(crate) async fn execute(&self, sql: &str, flags: &TestFlags) -> Result<Vec<Vec<String>>> {
117 {
118 let mut guard = self
119 .optd_optimizer
120 .as_ref()
121 .unwrap()
122 .optimizer
123 .lock()
124 .unwrap();
125 let optimizer = guard.as_mut().unwrap().optd_optimizer_mut();
126 if flags.panic_on_budget {
127 optimizer.panic_on_explore_limit(true);
128 } else {
129 optimizer.panic_on_explore_limit(false);
130 }
131 if flags.disable_pruning {
132 optimizer.disable_pruning(true);
133 } else {
134 optimizer.disable_pruning(false);
135 }
136 let rules = optimizer.rules();
137 if flags.enable_logical_rules.is_empty() {
138 for r in 0..rules.len() {
139 optimizer.enable_rule(r);
140 }
141 guard.as_mut().unwrap().enable_heuristic(true);
142 } else {
143 for (rule_id, rule) in rules.as_ref().iter().enumerate() {
144 if rule.is_impl_rule() {
145 optimizer.enable_rule(rule_id);
146 } else {
147 optimizer.disable_rule(rule_id);
148 }
149 }
150 let mut rules_to_enable = flags
151 .enable_logical_rules
152 .iter()
153 .map(|x| x.as_str())
154 .collect::<HashSet<_>>();
155 for (rule_id, rule) in rules.as_ref().iter().enumerate() {
156 if rules_to_enable.remove(rule.name()) {
157 optimizer.enable_rule(rule_id);
158 }
159 }
160 if !rules_to_enable.is_empty() {
161 bail!("Unknown logical rule: {:?}", rules_to_enable);
162 }
163 guard.as_mut().unwrap().enable_heuristic(false);
164 }
165 }
166 let sql = unescape_input(sql)?;
167 let dialect = Box::new(GenericDialect);
168 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
169 let mut result = Vec::new();
170 for statement in statements {
171 let df = if flags.enable_df_logical {
172 let plan = self
173 .use_df_logical_ctx
174 .state()
175 .statement_to_plan(statement)
176 .await?;
177 self.use_df_logical_ctx.execute_logical_plan(plan).await?
178 } else {
179 let plan = self.ctx.state().statement_to_plan(statement).await?;
180
181 self.ctx.execute_logical_plan(plan).await?
182 };
183
184 let batches = df.collect().await?;
185
186 let options = FormatOptions::default();
187
188 for batch in batches {
189 let converters = batch
190 .columns()
191 .iter()
192 .map(|a| ArrayFormatter::try_new(a.as_ref(), &options))
193 .collect::<Result<Vec<_>, _>>()?;
194 for row_idx in 0..batch.num_rows() {
195 let mut row = Vec::with_capacity(batch.num_columns());
196 for converter in converters.iter() {
197 let mut buffer = String::with_capacity(8);
198 converter.value(row_idx).write(&mut buffer)?;
199 row.push(buffer);
200 }
201 result.push(row);
202 }
203 }
204 }
205 if flags.dump_memo_table {
206 let mut guard = self
207 .optd_optimizer
208 .as_ref()
209 .unwrap()
210 .optimizer
211 .lock()
212 .unwrap();
213 let optimizer = guard.as_mut().unwrap().optd_optimizer_mut();
214 optimizer.dump();
215 }
216 Ok(result)
217 }
218
219 async fn task_execute(&mut self, r: &mut String, sql: &str, flags: &TestFlags) -> Result<()> {
221 use std::fmt::Write;
222 if flags.verbose {
223 bail!("Verbose flag is not supported for execute task");
224 }
225 let result = self.execute(sql, flags).await?;
226 writeln!(r, "{}", result.into_iter().map(|x| x.join(" ")).join("\n"))?;
227 writeln!(r)?;
228 Ok(())
229 }
230
231 async fn task_explain(
233 &mut self,
234 r: &mut String,
235 sql: &str,
236 task: &str,
237 flags: &TestFlags,
238 ) -> Result<()> {
239 use std::fmt::Write;
240
241 let verbose = flags.verbose;
242 let explain_sql = if verbose {
243 format!("explain verbose {}", &sql)
244 } else {
245 format!("explain {}", &sql)
246 };
247 let result = self.execute(&explain_sql, flags).await?;
248 let subtask_start_pos = task.rfind(':').unwrap() + 1;
249 for subtask in task[subtask_start_pos..].split(',') {
250 let subtask = subtask.trim();
251 if subtask == "logical_datafusion" {
252 writeln!(
253 r,
254 "{}",
255 result
256 .iter()
257 .find(|x| x[0] == "logical_plan after datafusion")
258 .map(|x| &x[1])
259 .unwrap()
260 )?;
261 } else if subtask == "logical_optd_heuristic" || subtask == "optimized_logical_optd" {
262 writeln!(
263 r,
264 "{}",
265 result
266 .iter()
267 .find(|x| x[0] == "logical_plan after optd-heuristic")
268 .map(|x| &x[1])
269 .unwrap()
270 )?;
271 } else if subtask == "logical_optd" {
272 writeln!(
273 r,
274 "{}",
275 result
276 .iter()
277 .find(|x| x[0] == "logical_plan after optd")
278 .map(|x| &x[1])
279 .unwrap()
280 )?;
281 } else if subtask == "physical_optd" {
282 writeln!(
283 r,
284 "{}",
285 result
286 .iter()
287 .find(|x| x[0] == "physical_plan after optd")
288 .map(|x| &x[1])
289 .unwrap()
290 )?;
291 } else if subtask == "logical_join_orders" {
292 writeln!(
293 r,
294 "{}",
295 result
296 .iter()
297 .find(|x| x[0] == "physical_plan after optd-all-logical-join-orders")
298 .map(|x| &x[1])
299 .unwrap()
300 )?;
301 writeln!(r)?;
302 } else if subtask == "physical_datafusion" {
303 writeln!(
304 r,
305 "{}",
306 result
307 .iter()
308 .find(|x| x[0] == "physical_plan")
309 .map(|x| &x[1])
310 .unwrap()
311 )?;
312 } else {
313 bail!("Unknown subtask: {}", subtask);
314 }
315 }
316
317 Ok(())
318 }
319}
320
321#[async_trait]
322impl sqlplannertest::PlannerTestRunner for DatafusionDBMS {
323 async fn run(&mut self, test_case: &sqlplannertest::ParsedTestCase) -> Result<String> {
324 if !test_case.before_sql.is_empty() {
325 panic!("before is not supported in optd-sqlplannertest, always specify the task type to run");
326 }
327
328 let mut result = String::new();
329 let r = &mut result;
330 for task in &test_case.tasks {
331 let flags = extract_flags(task)?;
332 if task.starts_with("execute") {
333 self.task_execute(r, &test_case.sql, &flags).await?;
334 } else if task.starts_with("explain") {
335 self.task_explain(r, &test_case.sql, task, &flags).await?;
336 }
337 }
338 Ok(result)
339 }
340}
341
342lazy_static! {
343 static ref FLAGS_REGEX: Regex = Regex::new(r"\[(.*)\]").unwrap();
344}
345
346#[derive(Default, Debug)]
347struct TestFlags {
348 verbose: bool,
349 enable_df_logical: bool,
350 enable_logical_rules: Vec<String>,
351 panic_on_budget: bool,
352 dump_memo_table: bool,
353 disable_pruning: bool,
354}
355
356fn extract_flags(task: &str) -> Result<TestFlags> {
360 if let Some(captures) = FLAGS_REGEX.captures(task) {
361 let flags = captures
362 .get(1)
363 .unwrap()
364 .as_str()
365 .split(',')
366 .map(|x| x.trim().to_string())
367 .collect_vec();
368 let mut options = TestFlags::default();
369 for flag in flags {
370 if flag == "verbose" {
371 options.verbose = true;
372 } else if flag == "use_df_logical" {
373 options.enable_df_logical = true;
374 } else if flag.starts_with("logical_rules") {
375 if let Some((_, flag)) = flag.split_once(':') {
376 options.enable_logical_rules = flag.split('+').map(|x| x.to_string()).collect();
377 } else {
378 bail!("Failed to parse logical_rules flag: {}", flag);
379 }
380 } else if flag == "panic_on_budget" {
381 options.panic_on_budget = true;
382 } else if flag == "dump_memo_table" {
383 options.dump_memo_table = true;
384 } else if flag == "disable_pruning" {
385 options.disable_pruning = true;
386 } else {
387 bail!("Unknown flag: {}", flag);
388 }
389 }
390 Ok(options)
391 } else {
392 Ok(TestFlags::default())
393 }
394}