1mod progress;
20
21use crate::config::AppConfig;
22use crate::db::register_db;
23use crate::{args::DftArgs, execution::AppExecution};
24use color_eyre::eyre::eyre;
25use color_eyre::Result;
26use datafusion::arrow::array::{RecordBatch, RecordBatchWriter};
27use datafusion::arrow::datatypes::SchemaRef;
28use datafusion::arrow::util::pretty::pretty_format_batches;
29use datafusion::arrow::{csv, json};
30use datafusion::sql::parser::DFParser;
31use datafusion_app::config::merge_configs;
32use datafusion_app::extensions::DftSessionStateBuilder;
33use datafusion_app::local::ExecutionContext;
34use datafusion_app::local_benchmarks::LocalBenchmarkStats;
35use futures::{Stream, StreamExt};
36use log::info;
37use parquet::{arrow::ArrowWriter, file::properties::WriterProperties};
38use std::error::Error;
39use std::fs::File;
40use std::io::Write;
41use std::path::{Path, PathBuf};
42#[cfg(feature = "flightsql")]
43use {
44 crate::args::{parse_headers_file, Command, FlightSqlCommand},
45 datafusion_app::{
46 config::{AuthConfig, FlightSQLConfig},
47 flightsql::FlightSQLContext,
48 flightsql_benchmarks::FlightSQLBenchmarkStats,
49 },
50 tonic::IntoRequest,
51};
52#[cfg(feature = "vortex")]
53use {
54 vortex::array::{arrow::FromArrowArray, ArrayRef},
55 vortex_file::VortexWriteOptions,
56 vortex_session::VortexSession,
57};
58
59const LOCAL_BENCHMARK_HEADER_ROW: &str =
60 "query,runs,logical_planning_min,logical_planning_max,logical_planning_mean,logical_planning_median,logical_planning_percent_of_total,physical_planning_min,physical_planning_max,physical_planning,mean,physical_planning_median,physical_planning_percent_of_total,execution_min,execution_max,execution_execution_mean,execution_median,execution_percent_of_total,total_min,total_max,total_mean,total_median,total_percent_of_total,concurrency_mode";
61
62#[cfg(feature = "flightsql")]
63const FLIGHTSQL_BENCHMARK_HEADER_ROW: &str =
64 "query,runs,get_flight_info_min,get_flight_info_max,get_flight_info_mean,get_flight_info_median,get_flight_info_percent_of_total,ttfb_min,ttfb_max,ttfb,mean,ttfb_median,ttfb_percent_of_total,do_get_min,do_get_max,do_get_mean,do_get_median,do_get_percent_of_total,total_min,total_max,total_mean,total_median,total_percent_of_total,concurrency_mode";
65
66pub struct CliApp {
68 app_execution: AppExecution,
70 args: DftArgs,
71}
72
73impl CliApp {
74 pub fn new(app_execution: AppExecution, args: DftArgs) -> Self {
75 Self {
76 app_execution,
77 args,
78 }
79 }
80
81 fn validate_args(&self) -> color_eyre::Result<()> {
82 let more_than_one_command_or_file = (self.args.commands.len() > 1
83 || self.args.files.len() > 1)
84 && self.args.output.is_some();
85 if more_than_one_command_or_file {
86 return Err(eyre!(
87 "Output can only be saved for a single file or command"
88 ));
89 }
90
91 Ok(())
92 }
93
94 #[cfg(feature = "flightsql")]
95 async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> {
96 use futures::stream;
97
98 match command {
99 FlightSqlCommand::StatementQuery { sql } => self.exec_from_flightsql(sql, 0).await,
100 FlightSqlCommand::GetCatalogs => {
101 let flight_info = self
102 .app_execution
103 .flightsql_ctx()
104 .get_catalogs_flight_info()
105 .await?;
106 let streams = self
107 .app_execution
108 .flightsql_ctx()
109 .do_get(flight_info)
110 .await?;
111 let flight_batch_stream = stream::select_all(streams);
112 self.print_any_stream(flight_batch_stream).await;
113 Ok(())
114 }
115 FlightSqlCommand::GetDbSchemas {
116 catalog,
117 db_schema_filter_pattern,
118 } => {
119 let flight_info = self
120 .app_execution
121 .flightsql_ctx()
122 .get_db_schemas_flight_info(catalog, db_schema_filter_pattern)
123 .await?;
124 let streams = self
125 .app_execution
126 .flightsql_ctx()
127 .do_get(flight_info)
128 .await?;
129 let flight_batch_stream = stream::select_all(streams);
130 self.print_any_stream(flight_batch_stream).await;
131 Ok(())
132 }
133
134 FlightSqlCommand::GetTables {
135 catalog,
136 db_schema_filter_pattern,
137 table_name_filter_pattern,
138 table_types,
139 } => {
140 let flight_info = self
141 .app_execution
142 .flightsql_ctx()
143 .get_tables_flight_info(
144 catalog,
145 db_schema_filter_pattern,
146 table_name_filter_pattern,
147 table_types.unwrap_or_default(),
148 false,
149 )
150 .await?;
151 let streams = self
152 .app_execution
153 .flightsql_ctx()
154 .do_get(flight_info)
155 .await?;
156 let flight_batch_stream = stream::select_all(streams);
157 self.print_any_stream(flight_batch_stream).await;
158 Ok(())
159 }
160 FlightSqlCommand::GetTableTypes => {
161 let flight_info = self
162 .app_execution
163 .flightsql_ctx()
164 .get_table_types_flight_info()
165 .await?;
166 let streams = self
167 .app_execution
168 .flightsql_ctx()
169 .do_get(flight_info)
170 .await?;
171 let flight_batch_stream = stream::select_all(streams);
172 self.print_any_stream(flight_batch_stream).await;
173 Ok(())
174 }
175 FlightSqlCommand::GetSqlInfo { info } => {
176 let flight_info = self
177 .app_execution
178 .flightsql_ctx()
179 .get_sql_info_flight_info(info)
180 .await?;
181 let streams = self
182 .app_execution
183 .flightsql_ctx()
184 .do_get(flight_info)
185 .await?;
186 let flight_batch_stream = stream::select_all(streams);
187 self.print_any_stream(flight_batch_stream).await;
188 Ok(())
189 }
190 FlightSqlCommand::GetXdbcTypeInfo { data_type } => {
191 let flight_info = self
192 .app_execution
193 .flightsql_ctx()
194 .get_xdbc_type_info_flight_info(data_type)
195 .await?;
196 let streams = self
197 .app_execution
198 .flightsql_ctx()
199 .do_get(flight_info)
200 .await?;
201 let flight_batch_stream = stream::select_all(streams);
202 self.print_any_stream(flight_batch_stream).await;
203 Ok(())
204 }
205 }
206 }
207
208 pub async fn execute_files_or_commands(&self) -> color_eyre::Result<()> {
212 if self.args.run_ddl {
213 self.app_execution.execution_ctx().execute_ddl().await;
214 }
215
216 self.validate_args()?;
217
218 #[cfg(feature = "flightsql")]
219 if let Some(Command::FlightSql { command }) = &self.args.command {
220 return self.handle_flightsql_command(command.clone()).await;
221 };
222
223 #[cfg(not(feature = "flightsql"))]
224 match (
225 self.args.files.is_empty(),
226 self.args.commands.is_empty(),
227 self.args.flightsql,
228 self.args.bench,
229 self.args.analyze,
230 ) {
231 (_, _, true, _, _) => Err(eyre!(
233 "FLightSQL feature isn't enabled. Reinstall `dft` with `--features=flightsql`"
234 )),
235 (false, false, false, true, _) => {
236 Err(eyre!("Cannot benchmark without a command or file"))
237 }
238 (true, true, _, _, _) => Err(eyre!("No files or commands provided to execute")),
239 (false, false, _, false, _) => Err(eyre!(
240 "Cannot execute both files and commands at the same time"
241 )),
242 (_, _, false, true, true) => Err(eyre!(
243 "The `benchmark` and `analyze` flags are mutually exclusive"
244 )),
245
246 (false, true, _, false, false) => self.execute_files(&self.args.files).await,
248 (true, false, _, false, false) => self.execute_commands(&self.args.commands).await,
249
250 (false, true, _, true, false) => self.benchmark_files(&self.args.files).await,
252 (true, false, _, true, false) => self.benchmark_commands(&self.args.commands).await,
253
254 (false, true, _, false, true) => self.analyze_files(&self.args.files).await,
256 (true, false, _, false, true) => self.analyze_commands(&self.args.commands).await,
257 }
258 #[cfg(feature = "flightsql")]
259 match (
260 self.args.files.is_empty(),
261 self.args.commands.is_empty(),
262 self.args.flightsql,
263 self.args.bench,
264 self.args.analyze,
265 ) {
266 (true, true, _, _, _) => Err(eyre!("No files or commands provided to execute")),
268 (false, false, false, true, _) => {
269 Err(eyre!("Cannot benchmark without a command or file"))
270 }
271 (false, false, _, _, _) => Err(eyre!(
272 "Cannot execute both files and commands at the same time"
273 )),
274 (_, _, _, true, true) => Err(eyre!(
275 "The `benchmark` and `analyze` flags are mutually exclusive"
276 )),
277 (_, _, true, false, true) => Err(eyre!(
278 "The `analyze` flag is not currently supported with FlightSQL"
279 )),
280
281 (true, false, false, false, false) => self.execute_commands(&self.args.commands).await,
283 (false, true, false, false, false) => self.execute_files(&self.args.files).await,
284
285 (false, true, true, false, false) => {
287 self.flightsql_execute_files(&self.args.files).await
288 }
289 (true, false, true, false, false) => {
290 self.flightsql_execute_commands(&self.args.commands).await
291 }
292
293 (false, true, false, true, false) => self.benchmark_files(&self.args.files).await,
295 (false, true, true, true, false) => {
296 self.flightsql_benchmark_files(&self.args.files).await
297 }
298 (true, false, true, true, false) => {
299 self.flightsql_benchmark_commands(&self.args.commands).await
300 }
301 (true, false, false, true, false) => self.benchmark_commands(&self.args.commands).await,
302
303 (true, false, false, false, true) => self.analyze_commands(&self.args.commands).await,
305 (false, true, false, false, true) => self.analyze_files(&self.args.files).await,
306 }
307 }
308
309 async fn execute_files(&self, files: &[PathBuf]) -> Result<()> {
310 info!("Executing files: {:?}", files);
311 for file in files {
312 self.exec_from_file(file).await?
313 }
314
315 Ok(())
316 }
317
318 async fn benchmark_files(&self, files: &[PathBuf]) -> Result<()> {
319 if let Some(run_before_query) = &self.args.run_before {
320 self.app_execution
321 .execution_ctx()
322 .execute_sql_and_discard_results(run_before_query)
323 .await?;
324 }
325 info!("Benchmarking files: {:?}", files);
326 for file in files {
327 let query = std::fs::read_to_string(file)?;
328 let stats = self.benchmark_from_string(&query).await?;
329 println!("{}", stats);
330 }
331 Ok(())
332 }
333
334 async fn analyze_files(&self, files: &[PathBuf]) -> Result<()> {
335 info!("Analyzing files: {:?}", files);
336 for file in files {
337 let query = std::fs::read_to_string(file)?;
338 self.analyze_from_string(&query).await?;
339 }
340 Ok(())
341 }
342
343 #[cfg(feature = "flightsql")]
344 async fn flightsql_execute_files(&self, files: &[PathBuf]) -> color_eyre::Result<()> {
345 info!("Executing FlightSQL files: {:?}", files);
346 for (i, file) in files.iter().enumerate() {
347 let file = std::fs::read_to_string(file)?;
348 self.exec_from_flightsql(file, i).await?;
349 }
350
351 Ok(())
352 }
353
354 #[cfg(feature = "flightsql")]
355 async fn flightsql_benchmark_files(&self, files: &[PathBuf]) -> Result<()> {
356 info!("Benchmarking FlightSQL files: {:?}", files);
357
358 let mut open_opts = std::fs::OpenOptions::new();
359 let mut results_file = if let Some(p) = &self.args.save {
360 if !p.exists() {
361 if let Some(parent) = p.parent() {
362 std::fs::DirBuilder::new().recursive(true).create(parent)?;
363 }
364 };
365 if self.args.append && p.exists() {
366 open_opts.append(true).create(true);
367 Some(open_opts.open(p)?)
368 } else {
369 open_opts.write(true).create(true).truncate(true);
370 let mut file = open_opts.open(p)?;
371 writeln!(file, "{}", FLIGHTSQL_BENCHMARK_HEADER_ROW)?;
372 Some(file)
373 }
374 } else {
375 None
376 };
377
378 for file in files {
379 let query = std::fs::read_to_string(file)?;
380 let stats = self.flightsql_benchmark_from_string(&query).await?;
381 println!("{}", stats);
382 if let Some(ref mut results_file) = &mut results_file {
383 writeln!(results_file, "{}", stats.to_summary_csv_row())?
384 }
385 }
386
387 Ok(())
388 }
389
390 #[cfg(feature = "flightsql")]
391 async fn exec_from_flightsql(&self, sql: String, i: usize) -> color_eyre::Result<()> {
392 let client = self.app_execution.flightsql_client();
393 let mut guard = client.lock().await;
394 if let Some(client) = guard.as_mut() {
395 let start = if self.args.time {
396 Some(std::time::Instant::now())
397 } else {
398 None
399 };
400 let flight_info = client.execute(sql, None).await?;
401 for endpoint in flight_info.endpoint {
402 if let Some(ticket) = endpoint.ticket {
403 let stream = client.do_get(ticket.into_request()).await?;
404 if let Some(output_path) = &self.args.output {
405 self.output_stream(stream, output_path).await?
406 } else if let Some(start) = start {
407 self.exec_stream(stream).await;
408 let elapsed = start.elapsed();
409 println!("Query {i} executed in {:?}", elapsed);
410 } else {
411 self.print_any_stream(stream).await;
412 }
413 }
414 }
415 } else {
416 println!("No FlightSQL client configured. Add one in `~/.config/dft/config.toml`");
417 }
418
419 Ok(())
420 }
421
422 async fn execute_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
423 info!("Executing commands: {:?}", commands);
424 if let Some(run_before_query) = &self.args.run_before {
425 self.app_execution
426 .execution_ctx()
427 .execute_sql_and_discard_results(run_before_query)
428 .await?;
429 }
430
431 for command in commands {
432 self.exec_from_string(command).await?
433 }
434
435 Ok(())
436 }
437
438 async fn benchmark_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
439 if let Some(run_before_query) = &self.args.run_before {
440 self.app_execution
441 .execution_ctx()
442 .execute_sql_and_discard_results(run_before_query)
443 .await?;
444 }
445 info!("Benchmarking commands: {:?}", commands);
446 let mut open_opts = std::fs::OpenOptions::new();
447 let mut file = if let Some(p) = &self.args.save {
448 if !p.exists() {
449 if let Some(parent) = p.parent() {
450 std::fs::DirBuilder::new().recursive(true).create(parent)?;
451 }
452 };
453 if self.args.append && p.exists() {
454 open_opts.append(true).create(true);
455 Some(open_opts.open(p)?)
456 } else {
457 open_opts.write(true).create(true).truncate(true);
458 let mut file = open_opts.open(p)?;
459 writeln!(file, "{}", LOCAL_BENCHMARK_HEADER_ROW)?;
460 Some(file)
461 }
462 } else {
463 None
464 };
465
466 for command in commands {
467 let stats = self.benchmark_from_string(command).await?;
468 println!("{}", stats);
469 if let Some(ref mut file) = &mut file {
470 writeln!(file, "{}", stats.to_summary_csv_row())?;
471 }
472 }
473 Ok(())
474 }
475
476 async fn analyze_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
477 info!("Analyzing commands: {:?}", commands);
478 for command in commands {
479 self.analyze_from_string(command).await?;
480 }
481
482 Ok(())
483 }
484
485 #[cfg(feature = "flightsql")]
486 async fn flightsql_execute_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
487 info!("Executing FlightSQL commands: {:?}", commands);
488 for (i, command) in commands.iter().enumerate() {
489 self.exec_from_flightsql(command.to_string(), i).await?
490 }
491
492 Ok(())
493 }
494
495 #[cfg(feature = "flightsql")]
496 async fn flightsql_benchmark_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
497 info!("Benchmark FlightSQL commands: {:?}", commands);
498
499 let mut open_opts = std::fs::OpenOptions::new();
500 let mut file = if let Some(p) = &self.args.save {
501 if !p.exists() {
502 if let Some(parent) = p.parent() {
503 std::fs::DirBuilder::new().recursive(true).create(parent)?;
504 }
505 };
506 if self.args.append && p.exists() {
507 open_opts.append(true).create(true);
508 Some(open_opts.open(p)?)
509 } else {
510 open_opts.write(true).create(true).truncate(true);
511 let mut file = open_opts.open(p)?;
512 writeln!(file, "{}", FLIGHTSQL_BENCHMARK_HEADER_ROW)?;
513 Some(file)
514 }
515 } else {
516 None
517 };
518
519 for command in commands {
520 let stats = self.flightsql_benchmark_from_string(command).await?;
521 println!("{}", stats);
522 if let Some(ref mut file) = &mut file {
523 writeln!(file, "{}", stats.to_summary_csv_row())?
524 }
525 }
526
527 Ok(())
528 }
529
530 async fn exec_from_string(&self, sql: &str) -> Result<()> {
531 let dialect = datafusion::sql::sqlparser::dialect::GenericDialect {};
532 let statements = DFParser::parse_sql_with_dialect(sql, &dialect)?;
533 let start = if self.args.time {
534 Some(std::time::Instant::now())
535 } else {
536 None
537 };
538 for (i, statement) in statements.into_iter().enumerate() {
539 let stream = self
540 .app_execution
541 .execution_ctx()
542 .execute_statement(statement)
543 .await?;
544 if let Some(output_path) = &self.args.output {
545 self.output_stream(stream, output_path).await?;
546 } else if let Some(start) = start {
547 self.exec_stream(stream).await;
548 let elapsed = start.elapsed();
549 println!("Query {i} executed in {:?}", elapsed);
550 } else {
551 self.print_any_stream(stream).await;
552 }
553 }
554 Ok(())
555 }
556
557 async fn benchmark_from_string(&self, sql: &str) -> Result<LocalBenchmarkStats> {
558 use std::sync::Arc;
559
560 let iterations = self.args.benchmark_iterations.unwrap_or(
562 self.app_execution
563 .execution_ctx()
564 .config()
565 .benchmark_iterations,
566 );
567 let concurrency = if self.args.concurrent {
568 let parallelism = std::thread::available_parallelism()
569 .map(|n| n.get())
570 .unwrap_or(1);
571 std::cmp::min(iterations, parallelism)
572 } else {
573 1
574 };
575
576 let progress_reporter = Some(Arc::new(progress::IndicatifProgressReporter::new(
578 sql,
579 iterations,
580 self.args.concurrent,
581 concurrency,
582 ))
583 as Arc<dyn datafusion_app::local_benchmarks::BenchmarkProgressReporter>);
584
585 let stats = self
587 .app_execution
588 .execution_ctx()
589 .benchmark_query(
590 sql,
591 self.args.benchmark_iterations,
592 self.args.concurrent,
593 progress_reporter,
594 )
595 .await?;
596 Ok(stats)
597 }
598
599 async fn analyze_from_string(&self, sql: &str) -> Result<()> {
600 let mut stats = self
601 .app_execution
602 .execution_ctx()
603 .analyze_query(sql)
604 .await?;
605 stats.collect_stats();
606 println!("{}", stats);
607 Ok(())
608 }
609
610 #[cfg(feature = "flightsql")]
611 async fn flightsql_benchmark_from_string(&self, sql: &str) -> Result<FlightSQLBenchmarkStats> {
612 use std::sync::Arc;
613
614 let iterations = self.args.benchmark_iterations.unwrap_or(10);
617 let concurrency = if self.args.concurrent {
618 let parallelism = std::thread::available_parallelism()
619 .map(|n| n.get())
620 .unwrap_or(1);
621 std::cmp::min(iterations, parallelism)
622 } else {
623 1
624 };
625
626 let progress_reporter = Some(Arc::new(progress::IndicatifProgressReporter::new(
628 sql,
629 iterations,
630 self.args.concurrent,
631 concurrency,
632 ))
633 as Arc<dyn datafusion_app::local_benchmarks::BenchmarkProgressReporter>);
634
635 let stats = self
637 .app_execution
638 .flightsql_ctx()
639 .benchmark_query(
640 sql,
641 self.args.benchmark_iterations,
642 self.args.concurrent,
643 progress_reporter,
644 )
645 .await?;
646 Ok(stats)
647 }
648
649 pub async fn exec_from_file(&self, file: &Path) -> color_eyre::Result<()> {
652 let string = std::fs::read_to_string(file)?;
653
654 self.exec_from_string(&string).await?;
655
656 Ok(())
657 }
658
659 pub async fn execute_and_print_sql(&self, sql: &str) -> color_eyre::Result<()> {
661 let stream = self.app_execution.execution_ctx().execute_sql(sql).await?;
662 self.print_any_stream(stream).await;
663 Ok(())
664 }
665
666 async fn exec_stream<S, E>(&self, mut stream: S)
667 where
668 S: Stream<Item = Result<RecordBatch, E>> + Unpin,
669 E: Error,
670 {
671 while let Some(maybe_batch) = stream.next().await {
672 match maybe_batch {
673 Ok(_) => {}
674 Err(e) => {
675 println!("Error executing SQL: {e}");
676 break;
677 }
678 }
679 }
680 }
681
682 async fn print_any_stream<S, E>(&self, mut stream: S)
683 where
684 S: Stream<Item = Result<RecordBatch, E>> + Unpin,
685 E: Error,
686 {
687 while let Some(maybe_batch) = stream.next().await {
688 match maybe_batch {
689 Ok(batch) => match pretty_format_batches(&[batch]) {
690 Ok(d) => println!("{}", d),
691 Err(e) => println!("Error formatting batch: {e}"),
692 },
693 Err(e) => println!("Error executing SQL: {e}"),
694 }
695 }
696 }
697
698 async fn output_stream<S, E>(&self, mut stream: S, path: &Path) -> Result<()>
699 where
700 S: Stream<Item = Result<RecordBatch, E>> + Unpin,
701 E: Error,
702 {
703 if let Some(Ok(first_batch)) = stream.next().await {
705 let schema = first_batch.schema();
706 let mut writer = path_to_writer(path, schema)?;
707 writer.write(&first_batch)?;
708
709 while let Some(maybe_batch) = stream.next().await {
710 match maybe_batch {
711 Ok(batch) => writer.write(&batch)?,
712 Err(e) => return Err(eyre!("Error executing SQL: {e}")),
713 }
714 }
715 writer.close().await?;
716 }
717
718 Ok(())
719 }
720}
721
722#[cfg(feature = "vortex")]
724struct VortexFileWriter {
725 path: PathBuf,
726 batches: Vec<RecordBatch>,
727}
728
729#[cfg(feature = "vortex")]
730impl VortexFileWriter {
731 fn new(file: File, _schema: SchemaRef, path: &Path) -> Result<Self> {
732 drop(file);
734 Ok(Self {
735 path: path.to_path_buf(),
736 batches: Vec::new(),
737 })
738 }
739
740 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
741 self.batches.push(batch.clone());
743 Ok(())
744 }
745
746 async fn close(self) -> Result<()> {
747 if self.batches.is_empty() {
748 return Ok(());
749 }
750
751 let file = tokio::fs::File::create(&self.path).await?;
753
754 let schema = self.batches[0].schema();
756 let concatenated = datafusion::arrow::compute::concat_batches(&schema, &self.batches)?;
757
758 let vortex_array = ArrayRef::from_arrow(concatenated, false);
760
761 let stream = vortex_array.to_array_stream();
763
764 let session = VortexSession::empty();
766 VortexWriteOptions::new(session)
767 .write(file, stream)
768 .await
769 .map_err(|e| eyre!("Failed to write Vortex file: {}", e))?;
770
771 Ok(())
772 }
773}
774
775#[allow(clippy::large_enum_variant)]
779enum AnyWriter {
780 Csv(csv::writer::Writer<File>),
781 Json(json::writer::LineDelimitedWriter<File>),
782 Parquet(ArrowWriter<File>),
783 #[cfg(feature = "vortex")]
784 Vortex(VortexFileWriter),
785}
786
787impl AnyWriter {
788 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
789 match self {
790 AnyWriter::Csv(w) => Ok(w.write(batch)?),
791 AnyWriter::Json(w) => Ok(w.write(batch)?),
792 AnyWriter::Parquet(w) => Ok(w.write(batch)?),
793 #[cfg(feature = "vortex")]
794 AnyWriter::Vortex(w) => Ok(w.write(batch)?),
795 }
796 }
797
798 async fn close(self) -> Result<()> {
799 match self {
800 AnyWriter::Csv(w) => Ok(w.close()?),
801 AnyWriter::Json(w) => Ok(w.close()?),
802 AnyWriter::Parquet(w) => {
803 w.close()?;
804 Ok(())
805 }
806 #[cfg(feature = "vortex")]
807 AnyWriter::Vortex(w) => w.close().await,
808 }
809 }
810}
811
812fn path_to_writer(path: &Path, schema: SchemaRef) -> Result<AnyWriter> {
813 if let Some(extension) = path.extension() {
814 if let Some(e) = extension.to_ascii_lowercase().to_str() {
815 let file = std::fs::File::create(path)?;
816 return match e {
817 "csv" => Ok(AnyWriter::Csv(csv::writer::Writer::new(file))),
818 "json" => Ok(AnyWriter::Json(json::writer::LineDelimitedWriter::new(
819 file,
820 ))),
821 "parquet" => {
822 let props = WriterProperties::default();
823 let writer = ArrowWriter::try_new(file, schema, Some(props))?;
824 Ok(AnyWriter::Parquet(writer))
825 }
826 #[cfg(feature = "vortex")]
827 "vortex" => Ok(AnyWriter::Vortex(VortexFileWriter::new(
828 file, schema, path,
829 )?)),
830 _ => {
831 #[cfg(feature = "vortex")]
832 return Err(eyre!(
833 "Only 'csv', 'parquet', 'json', and 'vortex' file types can be output"
834 ));
835 #[cfg(not(feature = "vortex"))]
836 return Err(eyre!(
837 "Only 'csv', 'parquet', and 'json' file types can be output"
838 ));
839 }
840 };
841 }
842 }
843 Err(eyre!("Unable to parse extension"))
844}
845
846pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
847 let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
848 let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
849 .with_extensions()
850 .await?;
851
852 let session_state = session_state_builder.build()?;
854 let execution_ctx = ExecutionContext::try_new(
855 &merged_exec_config,
856 session_state,
857 crate::APP_NAME,
858 env!("CARGO_PKG_VERSION"),
859 )?;
860 #[allow(unused_mut)]
861 let mut app_execution = AppExecution::new(execution_ctx);
862 #[cfg(feature = "flightsql")]
863 {
864 if cli.flightsql || matches!(cli.command, Some(Command::FlightSql { .. })) {
865 let auth = AuthConfig {
866 basic_auth: config.flightsql_client.auth.basic_auth,
867 bearer_token: config.flightsql_client.auth.bearer_token,
868 };
869 let flightsql_cfg = FlightSQLConfig::new(
870 config.flightsql_client.connection_url,
871 config.flightsql_client.benchmark_iterations,
872 auth,
873 config.flightsql_client.headers.clone(),
874 );
875 let flightsql_ctx = FlightSQLContext::new(flightsql_cfg);
876
877 let mut all_headers = config.flightsql_client.headers.clone();
879
880 let headers_file = cli
882 .headers_file
883 .as_ref()
884 .or(config.flightsql_client.headers_file.as_ref());
885
886 if let Some(file_path) = headers_file {
887 match parse_headers_file(file_path) {
888 Ok(file_headers) => {
889 all_headers.extend(file_headers);
890 }
891 Err(e) => {
892 return Err(eyre!("Error reading headers file: {}", e));
893 }
894 }
895 }
896
897 if let Some(cli_headers) = &cli.header {
899 all_headers.extend(cli_headers.iter().cloned());
900 }
901
902 let headers = if all_headers.is_empty() {
903 None
904 } else {
905 Some(all_headers)
906 };
907
908 flightsql_ctx
909 .create_client(cli.host.clone(), headers)
910 .await?;
911 app_execution.with_flightsql_ctx(flightsql_ctx);
912 }
913 }
914 register_db(app_execution.session_ctx(), &config.db).await?;
915 let app = CliApp::new(app_execution, cli.clone());
916 app.execute_files_or_commands().await?;
917 Ok(())
918}