1use std::io::Write;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24use color_eyre::eyre::eyre;
25use datafusion::logical_expr::LogicalPlan;
26use futures::TryFutureExt;
27use log::{debug, error, info};
28
29use crate::catalog::create_app_catalog;
30use crate::config::ExecutionConfig;
31use crate::{ExecOptions, ExecResult};
32use color_eyre::eyre::{self, Result};
33use datafusion::common::Result as DFResult;
34use datafusion::execution::{SendableRecordBatchStream, SessionState};
35use datafusion::physical_plan::{execute_stream, ExecutionPlan};
36use datafusion::prelude::*;
37use datafusion::sql::parser::{DFParser, Statement};
38use tokio_stream::StreamExt;
39
40use super::executor::dedicated::DedicatedExecutor;
41use super::local_benchmarks::{BenchmarkMode, BenchmarkProgressReporter, LocalBenchmarkStats};
42use super::stats::{ExecutionDurationStats, ExecutionStats};
43#[cfg(feature = "udfs-wasm")]
44use super::wasm::create_wasm_udfs;
45#[cfg(feature = "observability")]
46use {crate::config::ObservabilityConfig, crate::observability::ObservabilityContext};
47
48#[derive(Clone)]
62pub struct ExecutionContext {
63 config: ExecutionConfig,
64 session_ctx: SessionContext,
66 ddl_path: Option<PathBuf>,
68 executor: Option<DedicatedExecutor>,
70 #[cfg(feature = "observability")]
72 observability: ObservabilityContext,
73}
74
75impl std::fmt::Debug for ExecutionContext {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("ExecutionContext").finish()
78 }
79}
80
81impl ExecutionContext {
82 pub fn try_new(
84 config: &ExecutionConfig,
85 session_state: SessionState,
86 app_name: &str,
87 app_version: &str,
88 ) -> Result<Self> {
89 let mut executor = None;
90 if config.dedicated_executor_enabled {
91 let runtime_builder = tokio::runtime::Builder::new_multi_thread();
95 let dedicated_executor =
96 DedicatedExecutor::new("cpu_runtime", config.clone(), runtime_builder);
97 executor = Some(dedicated_executor)
98 }
99
100 let mut session_ctx = SessionContext::new_with_state(session_state);
101 session_ctx = session_ctx.enable_url_table();
102
103 #[cfg(feature = "functions-json")]
104 datafusion_functions_json::register_all(&mut session_ctx)?;
105
106 #[cfg(feature = "udfs-wasm")]
107 {
108 let wasm_udfs = create_wasm_udfs(&config.wasm_udf)?;
109 for wasm_udf in wasm_udfs {
110 session_ctx.register_udf(wasm_udf);
111 }
112 }
113
114 session_ctx.register_udtf(
115 "parquet_metadata",
116 Arc::new(datafusion_functions_parquet::ParquetMetadataFunc {}),
117 );
118
119 let catalog = create_app_catalog(config, app_name, app_version)?;
120 session_ctx.register_catalog(&config.catalog.name, catalog);
121
122 let ctx = {
123 #[cfg(feature = "observability")]
124 {
125 let observability =
126 ObservabilityContext::try_new(config.observability.clone(), app_name)?;
127 if let Some(cat) = session_ctx.catalog(&config.catalog.name) {
128 match cat
129 .register_schema(&config.observability.schema_name, observability.schema())
130 {
131 Ok(_) => {
132 info!("Registered observability schema")
133 }
134 Err(e) => {
135 error!("Error registering observability schema: {}", e)
136 }
137 }
138 } else {
139 error!("Missing catalog to register observability schema")
140 }
141 Self {
142 config: config.clone(),
143 session_ctx,
144 ddl_path: config.ddl_path.as_ref().map(PathBuf::from),
145 executor,
146 observability,
147 }
148 }
149 #[cfg(not(feature = "observability"))]
150 {
151 Self {
152 config: config.clone(),
153 session_ctx,
154 ddl_path: config.ddl_path.as_ref().map(PathBuf::from),
155 executor,
156 }
157 }
158 };
159
160 Ok(ctx)
161 }
162
163 pub fn test() -> Self {
165 let cfg = SessionConfig::new().with_information_schema(true);
166 let session_ctx = SessionContext::new_with_config(cfg);
167 let exec_cfg = ExecutionConfig::default();
168 let app_catalog = create_app_catalog(&exec_cfg, "test", ".0.1.0").unwrap();
170 session_ctx.register_catalog("test", app_catalog);
171 #[cfg(feature = "observability")]
172 let observability =
173 ObservabilityContext::try_new(ObservabilityConfig::default(), "test").unwrap();
174 Self {
175 config: ExecutionConfig::default(),
176 session_ctx,
177 ddl_path: None,
178 executor: None,
179 #[cfg(feature = "observability")]
180 observability,
181 }
182 }
183
184 pub fn config(&self) -> &ExecutionConfig {
185 &self.config
186 }
187
188 pub fn create_tables(&mut self) -> Result<()> {
189 Ok(())
190 }
191
192 pub fn session_ctx(&self) -> &SessionContext {
194 &self.session_ctx
195 }
196
197 pub fn executor(&self) -> &Option<DedicatedExecutor> {
199 &self.executor
200 }
201
202 #[cfg(feature = "observability")]
204 pub fn observability(&self) -> &ObservabilityContext {
205 &self.observability
206 }
207
208 pub async fn statement_to_logical_plan(&self, statement: Statement) -> Result<LogicalPlan> {
210 let ctx = self.session_ctx.clone();
211 let task = async move { ctx.state().statement_to_plan(statement).await };
212 if let Some(executor) = &self.executor {
213 let job = executor.spawn(task).map_err(|e| eyre::eyre!(e));
214 let job_res = job.await?;
215 job_res.map_err(|e| eyre!(e))
216 } else {
217 task.await.map_err(|e| eyre!(e))
218 }
219 }
220
221 pub async fn execute_logical_plan(
223 &self,
224 logical_plan: LogicalPlan,
225 ) -> Result<SendableRecordBatchStream> {
226 let ctx = self.session_ctx.clone();
227 let task = async move {
228 let df = ctx.execute_logical_plan(logical_plan).await?;
229 df.execute_stream().await
230 };
231 if let Some(executor) = &self.executor {
232 let job = executor.spawn(task).map_err(|e| eyre!(e));
233 let job_res = job.await?;
234 job_res.map_err(|e| eyre!(e))
235 } else {
236 task.await.map_err(|e| eyre!(e))
237 }
238 }
239
240 pub async fn execute_sql_and_discard_results(
242 &self,
243 sql: &str,
244 ) -> datafusion::error::Result<()> {
245 let mut stream = self.execute_sql(sql).await?;
246 while let Some(maybe_batch) = stream.next().await {
248 maybe_batch?; }
250 Ok(())
251 }
252
253 pub async fn create_physical_plan(
256 &self,
257 sql: &str,
258 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
259 let df = self.session_ctx.sql(sql).await?;
260 df.create_physical_plan().await
261 }
262
263 pub async fn execute_sql(
266 &self,
267 sql: &str,
268 ) -> datafusion::error::Result<SendableRecordBatchStream> {
269 self.session_ctx.sql(sql).await?.execute_stream().await
270 }
271
272 pub async fn execute_statement(
275 &self,
276 statement: Statement,
277 ) -> datafusion::error::Result<SendableRecordBatchStream> {
278 let plan = self
279 .session_ctx
280 .state()
281 .statement_to_plan(statement)
282 .await?;
283 self.session_ctx
284 .execute_logical_plan(plan)
285 .await?
286 .execute_stream()
287 .await
288 }
289
290 pub fn load_ddl(&self) -> Option<String> {
292 info!("Loading DDL from: {:?}", &self.ddl_path);
293 if let Some(ddl_path) = &self.ddl_path {
294 if ddl_path.exists() {
295 let maybe_ddl = std::fs::read_to_string(ddl_path);
296 match maybe_ddl {
297 Ok(ddl) => Some(ddl),
298 Err(err) => {
299 error!("Error reading DDL: {:?}", err);
300 None
301 }
302 }
303 } else {
304 info!("DDL path ({:?}) does not exist", ddl_path);
305 None
306 }
307 } else {
308 info!("No DDL file configured");
309 None
310 }
311 }
312
313 pub fn save_ddl(&self, ddl: String) {
315 info!("Loading DDL from: {:?}", &self.ddl_path);
316 if let Some(ddl_path) = &self.ddl_path {
317 match std::fs::File::create(ddl_path) {
318 Ok(mut f) => match f.write_all(ddl.as_bytes()) {
319 Ok(_) => {
320 info!("Saved DDL file")
321 }
322 Err(e) => {
323 error!("Error writing DDL file: {e}")
324 }
325 },
326 Err(e) => {
327 error!("Error creating or opening DDL file: {e}")
328 }
329 }
330 } else {
331 info!("No DDL file configured");
332 }
333 }
334
335 pub async fn execute_ddl(&self) {
337 match self.load_ddl() {
338 Some(ddl) => {
339 let ddl_statements = ddl.split(';').collect::<Vec<&str>>();
340 for statement in ddl_statements {
341 if statement.trim().is_empty() {
342 continue;
343 }
344 if statement.trim().starts_with("--") {
345 continue;
346 }
347
348 debug!("Executing DDL statement: {:?}", statement);
349 match self.execute_sql_and_discard_results(statement).await {
350 Ok(_) => {
351 info!("DDL statement executed");
352 }
353 Err(e) => {
354 error!("Error executing DDL statement: {e}");
355 }
356 }
357 }
358 }
359 None => {
360 info!("No DDL to execute");
361 }
362 }
363 }
364
365 async fn benchmark_single_iteration(
367 &self,
368 statement: Statement,
369 ) -> Result<(
370 usize,
371 std::time::Duration,
372 std::time::Duration,
373 std::time::Duration,
374 std::time::Duration,
375 )> {
376 let start = std::time::Instant::now();
377 let logical_plan = self
378 .session_ctx()
379 .state()
380 .statement_to_plan(statement)
381 .await?;
382 let logical_planning_duration = start.elapsed();
383 let physical_plan = self
384 .session_ctx()
385 .state()
386 .create_physical_plan(&logical_plan)
387 .await?;
388 let physical_planning_duration = start.elapsed();
389 let task_ctx = self.session_ctx().task_ctx();
390 let mut stream = execute_stream(physical_plan, task_ctx)?;
391 let mut rows = 0;
392 while let Some(b) = stream.next().await {
393 rows += b?.num_rows();
394 }
395 let execution_duration = start.elapsed();
396 let total_duration = start.elapsed();
397 Ok((
398 rows,
399 logical_planning_duration,
400 physical_planning_duration - logical_planning_duration,
401 execution_duration - physical_planning_duration,
402 total_duration,
403 ))
404 }
405
406 pub async fn benchmark_query(
407 &self,
408 query: &str,
409 cli_iterations: Option<usize>,
410 concurrent: bool,
411 progress_reporter: Option<Arc<dyn BenchmarkProgressReporter>>,
412 ) -> Result<LocalBenchmarkStats> {
413 let iterations = cli_iterations.unwrap_or(self.config.benchmark_iterations);
414 let dialect = datafusion::sql::sqlparser::dialect::GenericDialect {};
415 let statements = DFParser::parse_sql_with_dialect(query, &dialect)?;
416
417 if statements.len() != 1 {
418 return Err(eyre::eyre!("Only a single statement can be benchmarked"));
419 }
420
421 let statement = statements[0].clone();
422 let concurrency = if concurrent {
423 std::cmp::min(iterations, num_cpus::get())
424 } else {
425 1
426 };
427 let mode = if concurrent {
428 BenchmarkMode::Concurrent(concurrency)
429 } else {
430 BenchmarkMode::Serial
431 };
432
433 info!(
434 "Benchmarking query with {} iterations (concurrency: {})",
435 iterations, concurrency
436 );
437
438 let mut rows_returned = Vec::with_capacity(iterations);
439 let mut logical_planning_durations = Vec::with_capacity(iterations);
440 let mut physical_planning_durations = Vec::with_capacity(iterations);
441 let mut execution_durations = Vec::with_capacity(iterations);
442 let mut total_durations = Vec::with_capacity(iterations);
443
444 if !concurrent {
445 for i in 0..iterations {
447 let (rows, lp_dur, pp_dur, exec_dur, total_dur) =
448 self.benchmark_single_iteration(statement.clone()).await?;
449 rows_returned.push(rows);
450 logical_planning_durations.push(lp_dur);
451 physical_planning_durations.push(pp_dur);
452 execution_durations.push(exec_dur);
453 total_durations.push(total_dur);
454
455 if let Some(ref reporter) = progress_reporter {
456 reporter.on_iteration_complete(i + 1, iterations, total_dur);
457 }
458 }
459 } else {
460 let mut completed = 0;
462
463 while completed < iterations {
464 let batch_size = std::cmp::min(concurrency, iterations - completed);
465 let mut join_set = tokio::task::JoinSet::new();
466
467 for _ in 0..batch_size {
468 let self_clone = self.clone();
469 let statement_clone = statement.clone();
470 join_set.spawn(async move {
471 self_clone.benchmark_single_iteration(statement_clone).await
472 });
473 }
474
475 while let Some(result) = join_set.join_next().await {
476 let (rows, lp_dur, pp_dur, exec_dur, total_dur) = result??;
477 rows_returned.push(rows);
478 logical_planning_durations.push(lp_dur);
479 physical_planning_durations.push(pp_dur);
480 execution_durations.push(exec_dur);
481 total_durations.push(total_dur);
482
483 completed += 1;
484 if let Some(ref reporter) = progress_reporter {
485 reporter.on_iteration_complete(completed, iterations, total_dur);
486 }
487 }
488 }
489 }
490
491 if let Some(ref reporter) = progress_reporter {
492 reporter.finish();
493 }
494
495 Ok(LocalBenchmarkStats::new(
496 query.to_string(),
497 rows_returned,
498 mode,
499 logical_planning_durations,
500 physical_planning_durations,
501 execution_durations,
502 total_durations,
503 ))
504 }
505
506 pub async fn analyze_query(&self, query: &str) -> Result<ExecutionStats> {
507 let dialect = datafusion::sql::sqlparser::dialect::GenericDialect {};
508 let start = std::time::Instant::now();
509 let statements = DFParser::parse_sql_with_dialect(query, &dialect)?;
510 let parsing_duration = start.elapsed();
511 if statements.len() == 1 {
512 let statement = statements[0].clone();
513 let logical_plan = self
514 .session_ctx()
515 .state()
516 .statement_to_plan(statement.clone())
517 .await?;
518 let logical_planning_duration = start.elapsed();
519 let physical_plan = self
520 .session_ctx()
521 .state()
522 .create_physical_plan(&logical_plan)
523 .await?;
524 let physical_planning_duration = start.elapsed();
525 let task_ctx = self.session_ctx().task_ctx();
526 let mut stream = execute_stream(Arc::clone(&physical_plan), task_ctx)?;
527 let mut rows = 0;
528 let mut batches = 0;
529 let mut bytes = 0;
530 while let Some(b) = stream.next().await {
531 let batch = b?;
532 rows += batch.num_rows();
533 batches += 1;
534 bytes += batch.get_array_memory_size();
535 }
536 let execution_duration = start.elapsed();
537 let durations = ExecutionDurationStats::new(
538 parsing_duration,
539 logical_planning_duration - parsing_duration,
540 physical_planning_duration - logical_planning_duration,
541 execution_duration - physical_planning_duration,
542 start.elapsed(),
543 );
544 ExecutionStats::try_new(
545 query.to_string(),
546 durations,
547 rows,
548 batches,
549 bytes,
550 physical_plan,
551 )
552 } else {
553 Err(eyre::eyre!("Only a single statement can be benchmarked"))
554 }
555 }
556
557 pub async fn execute_sql_with_opts(
558 &self,
559 sql: &str,
560 opts: ExecOptions,
561 ) -> DFResult<ExecResult> {
562 let df = self.session_ctx.sql(sql).await?;
563 let df = if let Some(limit) = opts.limit {
564 df.limit(0, Some(limit))?
565 } else {
566 df
567 };
568 Ok(ExecResult::RecordBatchStream(df.execute_stream().await?))
569 }
570}