Skip to main content

datafusion_dft/cli/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17//! [`CliApp`]: Command Line User Interface
18
19mod progress;
20
21use crate::config::AppConfig;
22use crate::db::register_db;
23use crate::{args::DftArgs, execution::AppExecution};
24use color_eyre::eyre::eyre;
25use color_eyre::Result;
26use datafusion::arrow::array::{RecordBatch, RecordBatchWriter};
27use datafusion::arrow::datatypes::SchemaRef;
28use datafusion::arrow::util::pretty::pretty_format_batches;
29use datafusion::arrow::{csv, json};
30use datafusion::sql::parser::DFParser;
31use datafusion_app::config::merge_configs;
32use datafusion_app::extensions::DftSessionStateBuilder;
33use datafusion_app::local::ExecutionContext;
34use datafusion_app::local_benchmarks::LocalBenchmarkStats;
35use futures::{Stream, StreamExt};
36use log::info;
37use parquet::{arrow::ArrowWriter, file::properties::WriterProperties};
38use std::error::Error;
39use std::fs::File;
40use std::io::Write;
41use std::path::{Path, PathBuf};
42#[cfg(feature = "flightsql")]
43use {
44    crate::args::{parse_headers_file, Command, FlightSqlCommand},
45    datafusion_app::{
46        config::{AuthConfig, FlightSQLConfig},
47        flightsql::FlightSQLContext,
48        flightsql_benchmarks::FlightSQLBenchmarkStats,
49    },
50    tonic::IntoRequest,
51};
52#[cfg(feature = "vortex")]
53use {
54    vortex::array::{arrow::FromArrowArray, ArrayRef},
55    vortex_file::VortexWriteOptions,
56    vortex_session::VortexSession,
57};
58
59const LOCAL_BENCHMARK_HEADER_ROW: &str =
60    "query,runs,logical_planning_min,logical_planning_max,logical_planning_mean,logical_planning_median,logical_planning_percent_of_total,physical_planning_min,physical_planning_max,physical_planning,mean,physical_planning_median,physical_planning_percent_of_total,execution_min,execution_max,execution_execution_mean,execution_median,execution_percent_of_total,total_min,total_max,total_mean,total_median,total_percent_of_total,concurrency_mode";
61
62#[cfg(feature = "flightsql")]
63const FLIGHTSQL_BENCHMARK_HEADER_ROW: &str =
64    "query,runs,get_flight_info_min,get_flight_info_max,get_flight_info_mean,get_flight_info_median,get_flight_info_percent_of_total,ttfb_min,ttfb_max,ttfb,mean,ttfb_median,ttfb_percent_of_total,do_get_min,do_get_max,do_get_mean,do_get_median,do_get_percent_of_total,total_min,total_max,total_mean,total_median,total_percent_of_total,concurrency_mode";
65
66/// Encapsulates the command line interface
67pub struct CliApp {
68    /// Execution context for running queries
69    app_execution: AppExecution,
70    args: DftArgs,
71}
72
73impl CliApp {
74    pub fn new(app_execution: AppExecution, args: DftArgs) -> Self {
75        Self {
76            app_execution,
77            args,
78        }
79    }
80
81    fn validate_args(&self) -> color_eyre::Result<()> {
82        let more_than_one_command_or_file = (self.args.commands.len() > 1
83            || self.args.files.len() > 1)
84            && self.args.output.is_some();
85        if more_than_one_command_or_file {
86            return Err(eyre!(
87                "Output can only be saved for a single file or command"
88            ));
89        }
90
91        Ok(())
92    }
93
94    #[cfg(feature = "flightsql")]
95    async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> {
96        use futures::stream;
97
98        match command {
99            FlightSqlCommand::StatementQuery { sql } => self.exec_from_flightsql(sql, 0).await,
100            FlightSqlCommand::GetCatalogs => {
101                let flight_info = self
102                    .app_execution
103                    .flightsql_ctx()
104                    .get_catalogs_flight_info()
105                    .await?;
106                let streams = self
107                    .app_execution
108                    .flightsql_ctx()
109                    .do_get(flight_info)
110                    .await?;
111                let flight_batch_stream = stream::select_all(streams);
112                self.print_any_stream(flight_batch_stream).await;
113                Ok(())
114            }
115            FlightSqlCommand::GetDbSchemas {
116                catalog,
117                db_schema_filter_pattern,
118            } => {
119                let flight_info = self
120                    .app_execution
121                    .flightsql_ctx()
122                    .get_db_schemas_flight_info(catalog, db_schema_filter_pattern)
123                    .await?;
124                let streams = self
125                    .app_execution
126                    .flightsql_ctx()
127                    .do_get(flight_info)
128                    .await?;
129                let flight_batch_stream = stream::select_all(streams);
130                self.print_any_stream(flight_batch_stream).await;
131                Ok(())
132            }
133
134            FlightSqlCommand::GetTables {
135                catalog,
136                db_schema_filter_pattern,
137                table_name_filter_pattern,
138                table_types,
139            } => {
140                let flight_info = self
141                    .app_execution
142                    .flightsql_ctx()
143                    .get_tables_flight_info(
144                        catalog,
145                        db_schema_filter_pattern,
146                        table_name_filter_pattern,
147                        table_types.unwrap_or_default(),
148                        false,
149                    )
150                    .await?;
151                let streams = self
152                    .app_execution
153                    .flightsql_ctx()
154                    .do_get(flight_info)
155                    .await?;
156                let flight_batch_stream = stream::select_all(streams);
157                self.print_any_stream(flight_batch_stream).await;
158                Ok(())
159            }
160            FlightSqlCommand::GetTableTypes => {
161                let flight_info = self
162                    .app_execution
163                    .flightsql_ctx()
164                    .get_table_types_flight_info()
165                    .await?;
166                let streams = self
167                    .app_execution
168                    .flightsql_ctx()
169                    .do_get(flight_info)
170                    .await?;
171                let flight_batch_stream = stream::select_all(streams);
172                self.print_any_stream(flight_batch_stream).await;
173                Ok(())
174            }
175            FlightSqlCommand::GetSqlInfo { info } => {
176                let flight_info = self
177                    .app_execution
178                    .flightsql_ctx()
179                    .get_sql_info_flight_info(info)
180                    .await?;
181                let streams = self
182                    .app_execution
183                    .flightsql_ctx()
184                    .do_get(flight_info)
185                    .await?;
186                let flight_batch_stream = stream::select_all(streams);
187                self.print_any_stream(flight_batch_stream).await;
188                Ok(())
189            }
190            FlightSqlCommand::GetXdbcTypeInfo { data_type } => {
191                let flight_info = self
192                    .app_execution
193                    .flightsql_ctx()
194                    .get_xdbc_type_info_flight_info(data_type)
195                    .await?;
196                let streams = self
197                    .app_execution
198                    .flightsql_ctx()
199                    .do_get(flight_info)
200                    .await?;
201                let flight_batch_stream = stream::select_all(streams);
202                self.print_any_stream(flight_batch_stream).await;
203                Ok(())
204            }
205        }
206    }
207
208    /// Execute the provided sql, which was passed as an argument from CLI.
209    ///
210    /// Optionally, use the FlightSQL client for execution.
211    pub async fn execute_files_or_commands(&self) -> color_eyre::Result<()> {
212        if self.args.run_ddl {
213            self.app_execution.execution_ctx().execute_ddl().await;
214        }
215
216        self.validate_args()?;
217
218        #[cfg(feature = "flightsql")]
219        if let Some(Command::FlightSql { command }) = &self.args.command {
220            return self.handle_flightsql_command(command.clone()).await;
221        };
222
223        #[cfg(not(feature = "flightsql"))]
224        match (
225            self.args.files.is_empty(),
226            self.args.commands.is_empty(),
227            self.args.flightsql,
228            self.args.bench,
229            self.args.analyze,
230        ) {
231            // Error cases
232            (_, _, true, _, _) => Err(eyre!(
233                "FLightSQL feature isn't enabled. Reinstall `dft` with `--features=flightsql`"
234            )),
235            (false, false, false, true, _) => {
236                Err(eyre!("Cannot benchmark without a command or file"))
237            }
238            (true, true, _, _, _) => Err(eyre!("No files or commands provided to execute")),
239            (false, false, _, false, _) => Err(eyre!(
240                "Cannot execute both files and commands at the same time"
241            )),
242            (_, _, false, true, true) => Err(eyre!(
243                "The `benchmark` and `analyze` flags are mutually exclusive"
244            )),
245
246            // Execution cases
247            (false, true, _, false, false) => self.execute_files(&self.args.files).await,
248            (true, false, _, false, false) => self.execute_commands(&self.args.commands).await,
249
250            // Benchmark cases
251            (false, true, _, true, false) => self.benchmark_files(&self.args.files).await,
252            (true, false, _, true, false) => self.benchmark_commands(&self.args.commands).await,
253
254            // Analyze cases
255            (false, true, _, false, true) => self.analyze_files(&self.args.files).await,
256            (true, false, _, false, true) => self.analyze_commands(&self.args.commands).await,
257        }
258        #[cfg(feature = "flightsql")]
259        match (
260            self.args.files.is_empty(),
261            self.args.commands.is_empty(),
262            self.args.flightsql,
263            self.args.bench,
264            self.args.analyze,
265        ) {
266            // Error cases
267            (true, true, _, _, _) => Err(eyre!("No files or commands provided to execute")),
268            (false, false, false, true, _) => {
269                Err(eyre!("Cannot benchmark without a command or file"))
270            }
271            (false, false, _, _, _) => Err(eyre!(
272                "Cannot execute both files and commands at the same time"
273            )),
274            (_, _, _, true, true) => Err(eyre!(
275                "The `benchmark` and `analyze` flags are mutually exclusive"
276            )),
277            (_, _, true, false, true) => Err(eyre!(
278                "The `analyze` flag is not currently supported with FlightSQL"
279            )),
280
281            // Execution cases
282            (true, false, false, false, false) => self.execute_commands(&self.args.commands).await,
283            (false, true, false, false, false) => self.execute_files(&self.args.files).await,
284
285            // FlightSQL execution cases
286            (false, true, true, false, false) => {
287                self.flightsql_execute_files(&self.args.files).await
288            }
289            (true, false, true, false, false) => {
290                self.flightsql_execute_commands(&self.args.commands).await
291            }
292
293            // Benchmark cases
294            (false, true, false, true, false) => self.benchmark_files(&self.args.files).await,
295            (false, true, true, true, false) => {
296                self.flightsql_benchmark_files(&self.args.files).await
297            }
298            (true, false, true, true, false) => {
299                self.flightsql_benchmark_commands(&self.args.commands).await
300            }
301            (true, false, false, true, false) => self.benchmark_commands(&self.args.commands).await,
302
303            // Analyze cases
304            (true, false, false, false, true) => self.analyze_commands(&self.args.commands).await,
305            (false, true, false, false, true) => self.analyze_files(&self.args.files).await,
306        }
307    }
308
309    async fn execute_files(&self, files: &[PathBuf]) -> Result<()> {
310        info!("Executing files: {:?}", files);
311        for file in files {
312            self.exec_from_file(file).await?
313        }
314
315        Ok(())
316    }
317
318    async fn benchmark_files(&self, files: &[PathBuf]) -> Result<()> {
319        if let Some(run_before_query) = &self.args.run_before {
320            self.app_execution
321                .execution_ctx()
322                .execute_sql_and_discard_results(run_before_query)
323                .await?;
324        }
325        info!("Benchmarking files: {:?}", files);
326        for file in files {
327            let query = std::fs::read_to_string(file)?;
328            let stats = self.benchmark_from_string(&query).await?;
329            println!("{}", stats);
330        }
331        Ok(())
332    }
333
334    async fn analyze_files(&self, files: &[PathBuf]) -> Result<()> {
335        info!("Analyzing files: {:?}", files);
336        for file in files {
337            let query = std::fs::read_to_string(file)?;
338            self.analyze_from_string(&query).await?;
339        }
340        Ok(())
341    }
342
343    #[cfg(feature = "flightsql")]
344    async fn flightsql_execute_files(&self, files: &[PathBuf]) -> color_eyre::Result<()> {
345        info!("Executing FlightSQL files: {:?}", files);
346        for (i, file) in files.iter().enumerate() {
347            let file = std::fs::read_to_string(file)?;
348            self.exec_from_flightsql(file, i).await?;
349        }
350
351        Ok(())
352    }
353
354    #[cfg(feature = "flightsql")]
355    async fn flightsql_benchmark_files(&self, files: &[PathBuf]) -> Result<()> {
356        info!("Benchmarking FlightSQL files: {:?}", files);
357
358        let mut open_opts = std::fs::OpenOptions::new();
359        let mut results_file = if let Some(p) = &self.args.save {
360            if !p.exists() {
361                if let Some(parent) = p.parent() {
362                    std::fs::DirBuilder::new().recursive(true).create(parent)?;
363                }
364            };
365            if self.args.append && p.exists() {
366                open_opts.append(true).create(true);
367                Some(open_opts.open(p)?)
368            } else {
369                open_opts.write(true).create(true).truncate(true);
370                let mut file = open_opts.open(p)?;
371                writeln!(file, "{}", FLIGHTSQL_BENCHMARK_HEADER_ROW)?;
372                Some(file)
373            }
374        } else {
375            None
376        };
377
378        for file in files {
379            let query = std::fs::read_to_string(file)?;
380            let stats = self.flightsql_benchmark_from_string(&query).await?;
381            println!("{}", stats);
382            if let Some(ref mut results_file) = &mut results_file {
383                writeln!(results_file, "{}", stats.to_summary_csv_row())?
384            }
385        }
386
387        Ok(())
388    }
389
390    #[cfg(feature = "flightsql")]
391    async fn exec_from_flightsql(&self, sql: String, i: usize) -> color_eyre::Result<()> {
392        let client = self.app_execution.flightsql_client();
393        let mut guard = client.lock().await;
394        if let Some(client) = guard.as_mut() {
395            let start = if self.args.time {
396                Some(std::time::Instant::now())
397            } else {
398                None
399            };
400            let flight_info = client.execute(sql, None).await?;
401            for endpoint in flight_info.endpoint {
402                if let Some(ticket) = endpoint.ticket {
403                    let stream = client.do_get(ticket.into_request()).await?;
404                    if let Some(output_path) = &self.args.output {
405                        self.output_stream(stream, output_path).await?
406                    } else if let Some(start) = start {
407                        self.exec_stream(stream).await;
408                        let elapsed = start.elapsed();
409                        println!("Query {i} executed in {:?}", elapsed);
410                    } else {
411                        self.print_any_stream(stream).await;
412                    }
413                }
414            }
415        } else {
416            println!("No FlightSQL client configured.  Add one in `~/.config/dft/config.toml`");
417        }
418
419        Ok(())
420    }
421
422    async fn execute_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
423        info!("Executing commands: {:?}", commands);
424        if let Some(run_before_query) = &self.args.run_before {
425            self.app_execution
426                .execution_ctx()
427                .execute_sql_and_discard_results(run_before_query)
428                .await?;
429        }
430
431        for command in commands {
432            self.exec_from_string(command).await?
433        }
434
435        Ok(())
436    }
437
438    async fn benchmark_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
439        if let Some(run_before_query) = &self.args.run_before {
440            self.app_execution
441                .execution_ctx()
442                .execute_sql_and_discard_results(run_before_query)
443                .await?;
444        }
445        info!("Benchmarking commands: {:?}", commands);
446        let mut open_opts = std::fs::OpenOptions::new();
447        let mut file = if let Some(p) = &self.args.save {
448            if !p.exists() {
449                if let Some(parent) = p.parent() {
450                    std::fs::DirBuilder::new().recursive(true).create(parent)?;
451                }
452            };
453            if self.args.append && p.exists() {
454                open_opts.append(true).create(true);
455                Some(open_opts.open(p)?)
456            } else {
457                open_opts.write(true).create(true).truncate(true);
458                let mut file = open_opts.open(p)?;
459                writeln!(file, "{}", LOCAL_BENCHMARK_HEADER_ROW)?;
460                Some(file)
461            }
462        } else {
463            None
464        };
465
466        for command in commands {
467            let stats = self.benchmark_from_string(command).await?;
468            println!("{}", stats);
469            if let Some(ref mut file) = &mut file {
470                writeln!(file, "{}", stats.to_summary_csv_row())?;
471            }
472        }
473        Ok(())
474    }
475
476    async fn analyze_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
477        info!("Analyzing commands: {:?}", commands);
478        for command in commands {
479            self.analyze_from_string(command).await?;
480        }
481
482        Ok(())
483    }
484
485    #[cfg(feature = "flightsql")]
486    async fn flightsql_execute_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
487        info!("Executing FlightSQL commands: {:?}", commands);
488        for (i, command) in commands.iter().enumerate() {
489            self.exec_from_flightsql(command.to_string(), i).await?
490        }
491
492        Ok(())
493    }
494
495    #[cfg(feature = "flightsql")]
496    async fn flightsql_benchmark_commands(&self, commands: &[String]) -> color_eyre::Result<()> {
497        info!("Benchmark FlightSQL commands: {:?}", commands);
498
499        let mut open_opts = std::fs::OpenOptions::new();
500        let mut file = if let Some(p) = &self.args.save {
501            if !p.exists() {
502                if let Some(parent) = p.parent() {
503                    std::fs::DirBuilder::new().recursive(true).create(parent)?;
504                }
505            };
506            if self.args.append && p.exists() {
507                open_opts.append(true).create(true);
508                Some(open_opts.open(p)?)
509            } else {
510                open_opts.write(true).create(true).truncate(true);
511                let mut file = open_opts.open(p)?;
512                writeln!(file, "{}", FLIGHTSQL_BENCHMARK_HEADER_ROW)?;
513                Some(file)
514            }
515        } else {
516            None
517        };
518
519        for command in commands {
520            let stats = self.flightsql_benchmark_from_string(command).await?;
521            println!("{}", stats);
522            if let Some(ref mut file) = &mut file {
523                writeln!(file, "{}", stats.to_summary_csv_row())?
524            }
525        }
526
527        Ok(())
528    }
529
530    async fn exec_from_string(&self, sql: &str) -> Result<()> {
531        let dialect = datafusion::sql::sqlparser::dialect::GenericDialect {};
532        let statements = DFParser::parse_sql_with_dialect(sql, &dialect)?;
533        let start = if self.args.time {
534            Some(std::time::Instant::now())
535        } else {
536            None
537        };
538        for (i, statement) in statements.into_iter().enumerate() {
539            let stream = self
540                .app_execution
541                .execution_ctx()
542                .execute_statement(statement)
543                .await?;
544            if let Some(output_path) = &self.args.output {
545                self.output_stream(stream, output_path).await?;
546            } else if let Some(start) = start {
547                self.exec_stream(stream).await;
548                let elapsed = start.elapsed();
549                println!("Query {i} executed in {:?}", elapsed);
550            } else {
551                self.print_any_stream(stream).await;
552            }
553        }
554        Ok(())
555    }
556
557    async fn benchmark_from_string(&self, sql: &str) -> Result<LocalBenchmarkStats> {
558        use std::sync::Arc;
559
560        // Calculate iterations and concurrency
561        let iterations = self.args.benchmark_iterations.unwrap_or(
562            self.app_execution
563                .execution_ctx()
564                .config()
565                .benchmark_iterations,
566        );
567        let concurrency = if self.args.concurrent {
568            let parallelism = std::thread::available_parallelism()
569                .map(|n| n.get())
570                .unwrap_or(1);
571            std::cmp::min(iterations, parallelism)
572        } else {
573            1
574        };
575
576        // Create progress reporter
577        let progress_reporter = Some(Arc::new(progress::IndicatifProgressReporter::new(
578            sql,
579            iterations,
580            self.args.concurrent,
581            concurrency,
582        ))
583            as Arc<dyn datafusion_app::local_benchmarks::BenchmarkProgressReporter>);
584
585        // Call benchmark with reporter
586        let stats = self
587            .app_execution
588            .execution_ctx()
589            .benchmark_query(
590                sql,
591                self.args.benchmark_iterations,
592                self.args.concurrent,
593                progress_reporter,
594            )
595            .await?;
596        Ok(stats)
597    }
598
599    async fn analyze_from_string(&self, sql: &str) -> Result<()> {
600        let mut stats = self
601            .app_execution
602            .execution_ctx()
603            .analyze_query(sql)
604            .await?;
605        stats.collect_stats();
606        println!("{}", stats);
607        Ok(())
608    }
609
610    #[cfg(feature = "flightsql")]
611    async fn flightsql_benchmark_from_string(&self, sql: &str) -> Result<FlightSQLBenchmarkStats> {
612        use std::sync::Arc;
613
614        // Calculate iterations and concurrency
615        // Use a default of 10 if not specified (matches default in FlightSQLConfig)
616        let iterations = self.args.benchmark_iterations.unwrap_or(10);
617        let concurrency = if self.args.concurrent {
618            let parallelism = std::thread::available_parallelism()
619                .map(|n| n.get())
620                .unwrap_or(1);
621            std::cmp::min(iterations, parallelism)
622        } else {
623            1
624        };
625
626        // Create progress reporter
627        let progress_reporter = Some(Arc::new(progress::IndicatifProgressReporter::new(
628            sql,
629            iterations,
630            self.args.concurrent,
631            concurrency,
632        ))
633            as Arc<dyn datafusion_app::local_benchmarks::BenchmarkProgressReporter>);
634
635        // Call benchmark with reporter
636        let stats = self
637            .app_execution
638            .flightsql_ctx()
639            .benchmark_query(
640                sql,
641                self.args.benchmark_iterations,
642                self.args.concurrent,
643                progress_reporter,
644            )
645            .await?;
646        Ok(stats)
647    }
648
649    /// run and execute SQL statements and commands from a file, against a context
650    /// with the given print options
651    pub async fn exec_from_file(&self, file: &Path) -> color_eyre::Result<()> {
652        let string = std::fs::read_to_string(file)?;
653
654        self.exec_from_string(&string).await?;
655
656        Ok(())
657    }
658
659    /// executes a sql statement and prints the result to stdout
660    pub async fn execute_and_print_sql(&self, sql: &str) -> color_eyre::Result<()> {
661        let stream = self.app_execution.execution_ctx().execute_sql(sql).await?;
662        self.print_any_stream(stream).await;
663        Ok(())
664    }
665
666    async fn exec_stream<S, E>(&self, mut stream: S)
667    where
668        S: Stream<Item = Result<RecordBatch, E>> + Unpin,
669        E: Error,
670    {
671        while let Some(maybe_batch) = stream.next().await {
672            match maybe_batch {
673                Ok(_) => {}
674                Err(e) => {
675                    println!("Error executing SQL: {e}");
676                    break;
677                }
678            }
679        }
680    }
681
682    async fn print_any_stream<S, E>(&self, mut stream: S)
683    where
684        S: Stream<Item = Result<RecordBatch, E>> + Unpin,
685        E: Error,
686    {
687        while let Some(maybe_batch) = stream.next().await {
688            match maybe_batch {
689                Ok(batch) => match pretty_format_batches(&[batch]) {
690                    Ok(d) => println!("{}", d),
691                    Err(e) => println!("Error formatting batch: {e}"),
692                },
693                Err(e) => println!("Error executing SQL: {e}"),
694            }
695        }
696    }
697
698    async fn output_stream<S, E>(&self, mut stream: S, path: &Path) -> Result<()>
699    where
700        S: Stream<Item = Result<RecordBatch, E>> + Unpin,
701        E: Error,
702    {
703        // We get the schema from the first batch and use that for creating the writer
704        if let Some(Ok(first_batch)) = stream.next().await {
705            let schema = first_batch.schema();
706            let mut writer = path_to_writer(path, schema)?;
707            writer.write(&first_batch)?;
708
709            while let Some(maybe_batch) = stream.next().await {
710                match maybe_batch {
711                    Ok(batch) => writer.write(&batch)?,
712                    Err(e) => return Err(eyre!("Error executing SQL: {e}")),
713                }
714            }
715            writer.close().await?;
716        }
717
718        Ok(())
719    }
720}
721
722/// Wrapper for Vortex writer to handle Arrow RecordBatch conversion
723#[cfg(feature = "vortex")]
724struct VortexFileWriter {
725    path: PathBuf,
726    batches: Vec<RecordBatch>,
727}
728
729#[cfg(feature = "vortex")]
730impl VortexFileWriter {
731    fn new(file: File, _schema: SchemaRef, path: &Path) -> Result<Self> {
732        // We need to drop the std::fs::File and use tokio::fs::File later
733        drop(file);
734        Ok(Self {
735            path: path.to_path_buf(),
736            batches: Vec::new(),
737        })
738    }
739
740    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
741        // Collect batches to write them all at once when closing
742        self.batches.push(batch.clone());
743        Ok(())
744    }
745
746    async fn close(self) -> Result<()> {
747        if self.batches.is_empty() {
748            return Ok(());
749        }
750
751        // Open file using tokio::fs::File which implements VortexWrite
752        let file = tokio::fs::File::create(&self.path).await?;
753
754        // Concatenate all batches into a single batch
755        let schema = self.batches[0].schema();
756        let concatenated = datafusion::arrow::compute::concat_batches(&schema, &self.batches)?;
757
758        // Convert to Vortex array
759        let vortex_array = ArrayRef::from_arrow(concatenated, false);
760
761        // Convert to array stream
762        let stream = vortex_array.to_array_stream();
763
764        // Write using async API
765        let session = VortexSession::empty();
766        VortexWriteOptions::new(session)
767            .write(file, stream)
768            .await
769            .map_err(|e| eyre!("Failed to write Vortex file: {}", e))?;
770
771        Ok(())
772    }
773}
774
775/// We use an Enum for this because of limitations with using trait objects and the `close` method
776/// on a writer taking `self` as an argument which requires a size for the trait object which is
777/// not known at compile time.
778#[allow(clippy::large_enum_variant)]
779enum AnyWriter {
780    Csv(csv::writer::Writer<File>),
781    Json(json::writer::LineDelimitedWriter<File>),
782    Parquet(ArrowWriter<File>),
783    #[cfg(feature = "vortex")]
784    Vortex(VortexFileWriter),
785}
786
787impl AnyWriter {
788    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
789        match self {
790            AnyWriter::Csv(w) => Ok(w.write(batch)?),
791            AnyWriter::Json(w) => Ok(w.write(batch)?),
792            AnyWriter::Parquet(w) => Ok(w.write(batch)?),
793            #[cfg(feature = "vortex")]
794            AnyWriter::Vortex(w) => Ok(w.write(batch)?),
795        }
796    }
797
798    async fn close(self) -> Result<()> {
799        match self {
800            AnyWriter::Csv(w) => Ok(w.close()?),
801            AnyWriter::Json(w) => Ok(w.close()?),
802            AnyWriter::Parquet(w) => {
803                w.close()?;
804                Ok(())
805            }
806            #[cfg(feature = "vortex")]
807            AnyWriter::Vortex(w) => w.close().await,
808        }
809    }
810}
811
812fn path_to_writer(path: &Path, schema: SchemaRef) -> Result<AnyWriter> {
813    if let Some(extension) = path.extension() {
814        if let Some(e) = extension.to_ascii_lowercase().to_str() {
815            let file = std::fs::File::create(path)?;
816            return match e {
817                "csv" => Ok(AnyWriter::Csv(csv::writer::Writer::new(file))),
818                "json" => Ok(AnyWriter::Json(json::writer::LineDelimitedWriter::new(
819                    file,
820                ))),
821                "parquet" => {
822                    let props = WriterProperties::default();
823                    let writer = ArrowWriter::try_new(file, schema, Some(props))?;
824                    Ok(AnyWriter::Parquet(writer))
825                }
826                #[cfg(feature = "vortex")]
827                "vortex" => Ok(AnyWriter::Vortex(VortexFileWriter::new(
828                    file, schema, path,
829                )?)),
830                _ => {
831                    #[cfg(feature = "vortex")]
832                    return Err(eyre!(
833                        "Only 'csv', 'parquet', 'json', and 'vortex' file types can be output"
834                    ));
835                    #[cfg(not(feature = "vortex"))]
836                    return Err(eyre!(
837                        "Only 'csv', 'parquet', and 'json' file types can be output"
838                    ));
839                }
840            };
841        }
842    }
843    Err(eyre!("Unable to parse extension"))
844}
845
846pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
847    let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
848    let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
849        .with_extensions()
850        .await?;
851
852    // CLI mode: executing commands from files or CLI arguments
853    let session_state = session_state_builder.build()?;
854    let execution_ctx = ExecutionContext::try_new(
855        &merged_exec_config,
856        session_state,
857        crate::APP_NAME,
858        env!("CARGO_PKG_VERSION"),
859    )?;
860    #[allow(unused_mut)]
861    let mut app_execution = AppExecution::new(execution_ctx);
862    #[cfg(feature = "flightsql")]
863    {
864        if cli.flightsql || matches!(cli.command, Some(Command::FlightSql { .. })) {
865            let auth = AuthConfig {
866                basic_auth: config.flightsql_client.auth.basic_auth,
867                bearer_token: config.flightsql_client.auth.bearer_token,
868            };
869            let flightsql_cfg = FlightSQLConfig::new(
870                config.flightsql_client.connection_url,
871                config.flightsql_client.benchmark_iterations,
872                auth,
873                config.flightsql_client.headers.clone(),
874            );
875            let flightsql_ctx = FlightSQLContext::new(flightsql_cfg);
876
877            // Three-way header merge: config < file < CLI
878            let mut all_headers = config.flightsql_client.headers.clone();
879
880            // Merge headers from file (if specified in config or CLI)
881            let headers_file = cli
882                .headers_file
883                .as_ref()
884                .or(config.flightsql_client.headers_file.as_ref());
885
886            if let Some(file_path) = headers_file {
887                match parse_headers_file(file_path) {
888                    Ok(file_headers) => {
889                        all_headers.extend(file_headers);
890                    }
891                    Err(e) => {
892                        return Err(eyre!("Error reading headers file: {}", e));
893                    }
894                }
895            }
896
897            // Merge CLI headers (highest precedence)
898            if let Some(cli_headers) = &cli.header {
899                all_headers.extend(cli_headers.iter().cloned());
900            }
901
902            let headers = if all_headers.is_empty() {
903                None
904            } else {
905                Some(all_headers)
906            };
907
908            flightsql_ctx
909                .create_client(cli.host.clone(), headers)
910                .await?;
911            app_execution.with_flightsql_ctx(flightsql_ctx);
912        }
913    }
914    register_db(app_execution.session_ctx(), &config.db).await?;
915    let app = CliApp::new(app_execution, cli.clone());
916    app.execute_files_or_commands().await?;
917    Ok(())
918}