datafusion_cli/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Execution functions
19
20use crate::cli_context::CliSessionContext;
21use crate::helper::split_from_semicolon;
22use crate::print_format::PrintFormat;
23use crate::{
24    command::{Command, OutputFormat},
25    helper::CliHelper,
26    object_storage::get_object_store,
27    print_options::{MaxRows, PrintOptions},
28};
29use futures::StreamExt;
30use std::collections::HashMap;
31use std::fs::File;
32use std::io::prelude::*;
33use std::io::BufReader;
34
35use datafusion::common::instant::Instant;
36use datafusion::common::{plan_datafusion_err, plan_err};
37use datafusion::config::ConfigFileType;
38use datafusion::datasource::listing::ListingTableUrl;
39use datafusion::error::{DataFusionError, Result};
40use datafusion::logical_expr::{DdlStatement, LogicalPlan};
41use datafusion::physical_plan::execution_plan::EmissionType;
42use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties};
43use datafusion::sql::parser::{DFParser, Statement};
44use datafusion::sql::sqlparser::dialect::dialect_from_str;
45
46use datafusion::execution::memory_pool::MemoryConsumer;
47use datafusion::physical_plan::spill::get_record_batch_memory_size;
48use datafusion::sql::sqlparser;
49use rustyline::error::ReadlineError;
50use rustyline::Editor;
51use tokio::signal;
52
53/// run and execute SQL statements and commands, against a context with the given print options
54pub async fn exec_from_commands(
55    ctx: &dyn CliSessionContext,
56    commands: Vec<String>,
57    print_options: &PrintOptions,
58) -> Result<()> {
59    for sql in commands {
60        exec_and_print(ctx, print_options, sql).await?;
61    }
62
63    Ok(())
64}
65
66/// run and execute SQL statements and commands from a file, against a context with the given print options
67pub async fn exec_from_lines(
68    ctx: &dyn CliSessionContext,
69    reader: &mut BufReader<File>,
70    print_options: &PrintOptions,
71) -> Result<()> {
72    let mut query = "".to_owned();
73
74    for line in reader.lines() {
75        match line {
76            Ok(line) if line.starts_with("#!") => {
77                continue;
78            }
79            Ok(line) if line.starts_with("--") => {
80                continue;
81            }
82            Ok(line) => {
83                let line = line.trim_end();
84                query.push_str(line);
85                if line.ends_with(';') {
86                    match exec_and_print(ctx, print_options, query).await {
87                        Ok(_) => {}
88                        Err(err) => eprintln!("{err}"),
89                    }
90                    query = "".to_string();
91                } else {
92                    query.push('\n');
93                }
94            }
95            _ => {
96                break;
97            }
98        }
99    }
100
101    // run the left over query if the last statement doesn't contain ‘;’
102    // ignore if it only consists of '\n'
103    if query.contains(|c| c != '\n') {
104        exec_and_print(ctx, print_options, query).await?;
105    }
106
107    Ok(())
108}
109
110pub async fn exec_from_files(
111    ctx: &dyn CliSessionContext,
112    files: Vec<String>,
113    print_options: &PrintOptions,
114) -> Result<()> {
115    let files = files
116        .into_iter()
117        .map(|file_path| File::open(file_path).unwrap())
118        .collect::<Vec<_>>();
119
120    for file in files {
121        let mut reader = BufReader::new(file);
122        exec_from_lines(ctx, &mut reader, print_options).await?;
123    }
124
125    Ok(())
126}
127
128/// run and execute SQL statements and commands against a context with the given print options
129pub async fn exec_from_repl(
130    ctx: &dyn CliSessionContext,
131    print_options: &mut PrintOptions,
132) -> rustyline::Result<()> {
133    let mut rl = Editor::new()?;
134    rl.set_helper(Some(CliHelper::new(
135        &ctx.task_ctx().session_config().options().sql_parser.dialect,
136        print_options.color,
137    )));
138    rl.load_history(".history").ok();
139
140    loop {
141        match rl.readline("> ") {
142            Ok(line) if line.starts_with('\\') => {
143                rl.add_history_entry(line.trim_end())?;
144                let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
145                if let Ok(cmd) = &command[1..].parse::<Command>() {
146                    match cmd {
147                        Command::Quit => break,
148                        Command::OutputFormat(subcommand) => {
149                            if let Some(subcommand) = subcommand {
150                                if let Ok(command) = subcommand.parse::<OutputFormat>() {
151                                    if let Err(e) = command.execute(print_options).await {
152                                        eprintln!("{e}")
153                                    }
154                                } else {
155                                    eprintln!(
156                                        "'\\{}' is not a valid command",
157                                        &line[1..]
158                                    );
159                                }
160                            } else {
161                                println!("Output format is {:?}.", print_options.format);
162                            }
163                        }
164                        _ => {
165                            if let Err(e) = cmd.execute(ctx, print_options).await {
166                                eprintln!("{e}")
167                            }
168                        }
169                    }
170                } else {
171                    eprintln!("'\\{}' is not a valid command", &line[1..]);
172                }
173            }
174            Ok(line) => {
175                let lines = split_from_semicolon(&line);
176                for line in lines {
177                    rl.add_history_entry(line.trim_end())?;
178                    tokio::select! {
179                        res = exec_and_print(ctx, print_options, line) => match res {
180                            Ok(_) => {}
181                            Err(err) => eprintln!("{err}"),
182                        },
183                        _ = signal::ctrl_c() => {
184                            println!("^C");
185                            continue
186                        },
187                    }
188                    // dialect might have changed
189                    rl.helper_mut().unwrap().set_dialect(
190                        &ctx.task_ctx().session_config().options().sql_parser.dialect,
191                    );
192                }
193            }
194            Err(ReadlineError::Interrupted) => {
195                println!("^C");
196                continue;
197            }
198            Err(ReadlineError::Eof) => {
199                println!("\\q");
200                break;
201            }
202            Err(err) => {
203                eprintln!("Unknown error happened {:?}", err);
204                break;
205            }
206        }
207    }
208
209    rl.save_history(".history")
210}
211
212pub(super) async fn exec_and_print(
213    ctx: &dyn CliSessionContext,
214    print_options: &PrintOptions,
215    sql: String,
216) -> Result<()> {
217    let now = Instant::now();
218    let task_ctx = ctx.task_ctx();
219    let dialect = &task_ctx.session_config().options().sql_parser.dialect;
220    let dialect = dialect_from_str(dialect).ok_or_else(|| {
221        plan_datafusion_err!(
222            "Unsupported SQL dialect: {dialect}. Available dialects: \
223                 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
224                 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
225        )
226    })?;
227
228    let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
229    for statement in statements {
230        let adjusted =
231            AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement);
232
233        let plan = create_plan(ctx, statement).await?;
234        let adjusted = adjusted.with_plan(&plan);
235
236        let df = ctx.execute_logical_plan(plan).await?;
237        let physical_plan = df.create_physical_plan().await?;
238
239        // Track memory usage for the query result if it's bounded
240        let mut reservation =
241            MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool());
242
243        if physical_plan.boundedness().is_unbounded() {
244            if physical_plan.pipeline_behavior() == EmissionType::Final {
245                return plan_err!(
246                    "The given query can generate a valid result only once \
247                    the source finishes, but the source is unbounded"
248                );
249            }
250            // As the input stream comes, we can generate results.
251            // However, memory safety is not guaranteed.
252            let stream = execute_stream(physical_plan, task_ctx.clone())?;
253            print_options.print_stream(stream, now).await?;
254        } else {
255            // Bounded stream; collected results size is limited by the maxrows option
256            let schema = physical_plan.schema();
257            let mut stream = execute_stream(physical_plan, task_ctx.clone())?;
258            let mut results = vec![];
259            let mut row_count = 0_usize;
260            let max_rows = match print_options.maxrows {
261                MaxRows::Unlimited => usize::MAX,
262                MaxRows::Limited(n) => n,
263            };
264            while let Some(batch) = stream.next().await {
265                let batch = batch?;
266                let curr_num_rows = batch.num_rows();
267                // Stop collecting results if the number of rows exceeds the limit
268                // results batch should include the last batch that exceeds the limit
269                if row_count < max_rows + curr_num_rows {
270                    // Try to grow the reservation to accommodate the batch in memory
271                    reservation.try_grow(get_record_batch_memory_size(&batch))?;
272                    results.push(batch);
273                }
274                row_count += curr_num_rows;
275            }
276            adjusted
277                .into_inner()
278                .print_batches(schema, &results, now, row_count)?;
279            reservation.free();
280        }
281    }
282
283    Ok(())
284}
285
286/// Track adjustments to the print options based on the plan / statement being executed
287#[derive(Debug)]
288struct AdjustedPrintOptions {
289    inner: PrintOptions,
290}
291
292impl AdjustedPrintOptions {
293    fn new(inner: PrintOptions) -> Self {
294        Self { inner }
295    }
296    /// Adjust print options based on any statement specific requirements
297    fn with_statement(mut self, statement: &Statement) -> Self {
298        if let Statement::Statement(sql_stmt) = statement {
299            // SHOW / SHOW ALL
300            if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
301                self.inner.maxrows = MaxRows::Unlimited
302            }
303        }
304        self
305    }
306
307    /// Adjust print options based on any plan specific requirements
308    fn with_plan(mut self, plan: &LogicalPlan) -> Self {
309        // For plans like `Explain` ignore `MaxRows` option and always display
310        // all rows
311        if matches!(
312            plan,
313            LogicalPlan::Explain(_)
314                | LogicalPlan::DescribeTable(_)
315                | LogicalPlan::Analyze(_)
316        ) {
317            self.inner.maxrows = MaxRows::Unlimited;
318        }
319        self
320    }
321
322    /// Finalize and return the inner `PrintOptions`
323    fn into_inner(mut self) -> PrintOptions {
324        if self.inner.format == PrintFormat::Automatic {
325            self.inner.format = PrintFormat::Table;
326        }
327
328        self.inner
329    }
330}
331
332fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
333    match ext.to_lowercase().as_str() {
334        "csv" => Some(ConfigFileType::CSV),
335        "json" => Some(ConfigFileType::JSON),
336        "parquet" => Some(ConfigFileType::PARQUET),
337        _ => None,
338    }
339}
340
341async fn create_plan(
342    ctx: &dyn CliSessionContext,
343    statement: Statement,
344) -> Result<LogicalPlan, DataFusionError> {
345    let mut plan = ctx.session_state().statement_to_plan(statement).await?;
346
347    // Note that cmd is a mutable reference so that create_external_table function can remove all
348    // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion
349    // will raise Configuration errors.
350    if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
351        // To support custom formats, treat error as None
352        let format = config_file_type_from_str(&cmd.file_type);
353        register_object_store_and_config_extensions(
354            ctx,
355            &cmd.location,
356            &cmd.options,
357            format,
358        )
359        .await?;
360    }
361
362    if let LogicalPlan::Copy(copy_to) = &mut plan {
363        let format = config_file_type_from_str(&copy_to.file_type.get_ext());
364
365        register_object_store_and_config_extensions(
366            ctx,
367            &copy_to.output_url,
368            &copy_to.options,
369            format,
370        )
371        .await?;
372    }
373    Ok(plan)
374}
375
376/// Asynchronously registers an object store and its configuration extensions
377/// to the session context.
378///
379/// This function dynamically registers a cloud object store based on the given
380/// location and options. It first parses the location to determine the scheme
381/// and constructs the URL accordingly. Depending on the scheme, it also registers
382/// relevant options. The function then alters the default table options with the
383/// given custom options. Finally, it retrieves and registers the object store
384/// in the session context.
385///
386/// # Parameters
387///
388/// * `ctx`: A reference to the `SessionContext` for registering the object store.
389/// * `location`: A string reference representing the location of the object store.
390/// * `options`: A reference to a hash map containing configuration options for
391///   the object store.
392///
393/// # Returns
394///
395/// A `Result<()>` which is an Ok value indicating successful registration, or
396/// an error upon failure.
397///
398/// # Errors
399///
400/// This function can return an error if the location parsing fails, options
401/// alteration fails, or if the object store cannot be retrieved and registered
402/// successfully.
403pub(crate) async fn register_object_store_and_config_extensions(
404    ctx: &dyn CliSessionContext,
405    location: &String,
406    options: &HashMap<String, String>,
407    format: Option<ConfigFileType>,
408) -> Result<()> {
409    // Parse the location URL to extract the scheme and other components
410    let table_path = ListingTableUrl::parse(location)?;
411
412    // Extract the scheme (e.g., "s3", "gcs") from the parsed URL
413    let scheme = table_path.scheme();
414
415    // Obtain a reference to the URL
416    let url = table_path.as_ref();
417
418    // Register the options based on the scheme extracted from the location
419    ctx.register_table_options_extension_from_scheme(scheme);
420
421    // Clone and modify the default table options based on the provided options
422    let mut table_options = ctx.session_state().default_table_options();
423    if let Some(format) = format {
424        table_options.set_config_format(format);
425    }
426    table_options.alter_with_string_hash_map(options)?;
427
428    // Retrieve the appropriate object store based on the scheme, URL, and modified table options
429    let store =
430        get_object_store(&ctx.session_state(), scheme, url, &table_options).await?;
431
432    // Register the retrieved object store in the session context's runtime environment
433    ctx.register_object_store(url, store);
434
435    Ok(())
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    use datafusion::common::plan_err;
443
444    use datafusion::prelude::SessionContext;
445    use url::Url;
446
447    async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
448        let ctx = SessionContext::new();
449        let plan = ctx.state().create_logical_plan(sql).await?;
450
451        if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
452            let format = config_file_type_from_str(&cmd.file_type);
453            register_object_store_and_config_extensions(
454                &ctx,
455                &cmd.location,
456                &cmd.options,
457                format,
458            )
459            .await?;
460        } else {
461            return plan_err!("LogicalPlan is not a CreateExternalTable");
462        }
463
464        // Ensure the URL is supported by the object store
465        ctx.runtime_env()
466            .object_store(ListingTableUrl::parse(location)?)?;
467
468        Ok(())
469    }
470
471    async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
472        let ctx = SessionContext::new();
473        // AWS CONFIG register.
474
475        let plan = ctx.state().create_logical_plan(sql).await?;
476
477        if let LogicalPlan::Copy(cmd) = &plan {
478            let format = config_file_type_from_str(&cmd.file_type.get_ext());
479            register_object_store_and_config_extensions(
480                &ctx,
481                &cmd.output_url,
482                &cmd.options,
483                format,
484            )
485            .await?;
486        } else {
487            return plan_err!("LogicalPlan is not a CreateExternalTable");
488        }
489
490        // Ensure the URL is supported by the object store
491        ctx.runtime_env()
492            .object_store(ListingTableUrl::parse(location)?)?;
493
494        Ok(())
495    }
496
497    #[tokio::test]
498    async fn create_object_store_table_http() -> Result<()> {
499        // Should be OK
500        let location = "http://example.com/file.parquet";
501        let sql =
502            format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
503        create_external_table_test(location, &sql).await?;
504
505        Ok(())
506    }
507    #[tokio::test]
508    async fn copy_to_external_object_store_test() -> Result<()> {
509        let locations = vec![
510            "s3://bucket/path/file.parquet",
511            "oss://bucket/path/file.parquet",
512            "cos://bucket/path/file.parquet",
513            "gcs://bucket/path/file.parquet",
514        ];
515        let ctx = SessionContext::new();
516        let task_ctx = ctx.task_ctx();
517        let dialect = &task_ctx.session_config().options().sql_parser.dialect;
518        let dialect = dialect_from_str(dialect).ok_or_else(|| {
519            plan_datafusion_err!(
520                "Unsupported SQL dialect: {dialect}. Available dialects: \
521                 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
522                 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
523            )
524        })?;
525        for location in locations {
526            let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location);
527            let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
528            for statement in statements {
529                //Should not fail
530                let mut plan = create_plan(&ctx, statement).await?;
531                if let LogicalPlan::Copy(copy_to) = &mut plan {
532                    assert_eq!(copy_to.output_url, location);
533                    assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
534                    ctx.runtime_env()
535                        .object_store_registry
536                        .get_store(&Url::parse(&copy_to.output_url).unwrap())?;
537                } else {
538                    return plan_err!("LogicalPlan is not a CopyTo");
539                }
540            }
541        }
542        Ok(())
543    }
544
545    #[tokio::test]
546    async fn copy_to_object_store_table_s3() -> Result<()> {
547        let access_key_id = "fake_access_key_id";
548        let secret_access_key = "fake_secret_access_key";
549        let location = "s3://bucket/path/file.parquet";
550
551        // Missing region, use object_store defaults
552        let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
553            OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
554        copy_to_table_test(location, &sql).await?;
555
556        Ok(())
557    }
558
559    #[tokio::test]
560    async fn create_object_store_table_s3() -> Result<()> {
561        let access_key_id = "fake_access_key_id";
562        let secret_access_key = "fake_secret_access_key";
563        let region = "fake_us-east-2";
564        let session_token = "fake_session_token";
565        let location = "s3://bucket/path/file.parquet";
566
567        // Missing region, use object_store defaults
568        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
569            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
570        create_external_table_test(location, &sql).await?;
571
572        // Should be OK
573        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
574            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
575        create_external_table_test(location, &sql).await?;
576
577        Ok(())
578    }
579
580    #[tokio::test]
581    async fn create_object_store_table_oss() -> Result<()> {
582        let access_key_id = "fake_access_key_id";
583        let secret_access_key = "fake_secret_access_key";
584        let endpoint = "fake_endpoint";
585        let location = "oss://bucket/path/file.parquet";
586
587        // Should be OK
588        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
589            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
590        create_external_table_test(location, &sql).await?;
591
592        Ok(())
593    }
594
595    #[tokio::test]
596    async fn create_object_store_table_cos() -> Result<()> {
597        let access_key_id = "fake_access_key_id";
598        let secret_access_key = "fake_secret_access_key";
599        let endpoint = "fake_endpoint";
600        let location = "cos://bucket/path/file.parquet";
601
602        // Should be OK
603        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
604            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
605        create_external_table_test(location, &sql).await?;
606
607        Ok(())
608    }
609
610    #[tokio::test]
611    async fn create_object_store_table_gcs() -> Result<()> {
612        let service_account_path = "fake_service_account_path";
613        let service_account_key =
614            "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
615        let application_credentials_path = "fake_application_credentials_path";
616        let location = "gcs://bucket/path/file.parquet";
617
618        // for service_account_path
619        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
620            OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
621        let err = create_external_table_test(location, &sql)
622            .await
623            .unwrap_err();
624        assert!(err.to_string().contains("os error 2"));
625
626        // for service_account_key
627        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'");
628        let err = create_external_table_test(location, &sql)
629            .await
630            .unwrap_err()
631            .to_string();
632        assert!(err.contains("No RSA key found in pem file"), "{err}");
633
634        // for application_credentials_path
635        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
636            OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
637        let err = create_external_table_test(location, &sql)
638            .await
639            .unwrap_err();
640        assert!(err.to_string().contains("os error 2"));
641
642        Ok(())
643    }
644
645    #[tokio::test]
646    async fn create_external_table_local_file() -> Result<()> {
647        let location = "path/to/file.parquet";
648
649        // Ensure that local files are also registered
650        let sql =
651            format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
652        create_external_table_test(location, &sql).await.unwrap();
653
654        Ok(())
655    }
656
657    #[tokio::test]
658    async fn create_external_table_format_option() -> Result<()> {
659        let location = "path/to/file.cvs";
660
661        // Test with format options
662        let sql =
663            format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')");
664        create_external_table_test(location, &sql).await.unwrap();
665
666        Ok(())
667    }
668}