frostbow_cli/
exec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Execution functions
19
20use std::collections::HashMap;
21use std::fs::File;
22use std::io::prelude::*;
23use std::io::BufReader;
24
25use crate::cli_context::CliSessionContext;
26use crate::helper::split_from_semicolon;
27use crate::print_format::PrintFormat;
28use crate::{
29    command::{Command, OutputFormat},
30    helper::{unescape_input, CliHelper},
31    iceberg::transform_iceberg_input,
32    object_storage::get_object_store,
33    print_options::{MaxRows, PrintOptions},
34};
35
36use datafusion::common::instant::Instant;
37use datafusion::common::plan_datafusion_err;
38use datafusion::config::ConfigFileType;
39use datafusion::datasource::listing::ListingTableUrl;
40use datafusion::error::{DataFusionError, Result};
41use datafusion::logical_expr::{DdlStatement, LogicalPlan};
42use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties};
43use datafusion::sql::parser::{DFParser, Statement};
44use datafusion::sql::sqlparser::dialect::dialect_from_str;
45
46use datafusion::sql::sqlparser;
47use rustyline::error::ReadlineError;
48use rustyline::Editor;
49use tokio::signal;
50
51/// run and execute SQL statements and commands, against a context with the given print options
52pub async fn exec_from_commands(
53    ctx: &dyn CliSessionContext,
54    commands: Vec<String>,
55    print_options: &PrintOptions,
56) -> Result<()> {
57    for sql in commands {
58        exec_and_print(ctx, print_options, sql).await?;
59    }
60
61    Ok(())
62}
63
64/// run and execute SQL statements and commands from a file, against a context with the given print options
65pub async fn exec_from_lines(
66    ctx: &dyn CliSessionContext,
67    reader: &mut BufReader<File>,
68    print_options: &PrintOptions,
69) -> Result<()> {
70    let mut query = "".to_owned();
71
72    for line in reader.lines() {
73        match line {
74            Ok(line) if line.starts_with("#!") => {
75                continue;
76            }
77            Ok(line) if line.starts_with("--") => {
78                continue;
79            }
80            Ok(line) => {
81                let line = line.trim_end();
82                query.push_str(line);
83                if line.ends_with(';') {
84                    match exec_and_print(ctx, print_options, query).await {
85                        Ok(_) => {}
86                        Err(err) => eprintln!("{err}"),
87                    }
88                    query = "".to_string();
89                } else {
90                    query.push('\n');
91                }
92            }
93            _ => {
94                break;
95            }
96        }
97    }
98
99    // run the left over query if the last statement doesn't contain ‘;’
100    // ignore if it only consists of '\n'
101    if query.contains(|c| c != '\n') {
102        exec_and_print(ctx, print_options, query).await?;
103    }
104
105    Ok(())
106}
107
108pub async fn exec_from_files(
109    ctx: &dyn CliSessionContext,
110    files: Vec<String>,
111    print_options: &PrintOptions,
112) -> Result<()> {
113    let files = files
114        .into_iter()
115        .map(|file_path| File::open(file_path).unwrap())
116        .collect::<Vec<_>>();
117
118    for file in files {
119        let mut reader = BufReader::new(file);
120        exec_from_lines(ctx, &mut reader, print_options).await?;
121    }
122
123    Ok(())
124}
125
126/// run and execute SQL statements and commands against a context with the given print options
127pub async fn exec_from_repl(
128    ctx: &dyn CliSessionContext,
129    print_options: &mut PrintOptions,
130) -> rustyline::Result<()> {
131    let mut rl = Editor::new()?;
132    rl.set_helper(Some(CliHelper::new(
133        &ctx.task_ctx().session_config().options().sql_parser.dialect,
134        print_options.color,
135    )));
136    rl.load_history(".history").ok();
137
138    loop {
139        match rl.readline("> ") {
140            Ok(line) if line.starts_with('\\') => {
141                rl.add_history_entry(line.trim_end())?;
142                let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
143                if let Ok(cmd) = &command[1..].parse::<Command>() {
144                    match cmd {
145                        Command::Quit => break,
146                        Command::OutputFormat(subcommand) => {
147                            if let Some(subcommand) = subcommand {
148                                if let Ok(command) = subcommand.parse::<OutputFormat>() {
149                                    if let Err(e) = command.execute(print_options).await {
150                                        eprintln!("{e}")
151                                    }
152                                } else {
153                                    eprintln!(
154                                        "'\\{}' is not a valid command",
155                                        &line[1..]
156                                    );
157                                }
158                            } else {
159                                println!("Output format is {:?}.", print_options.format);
160                            }
161                        }
162                        _ => {
163                            if let Err(e) = cmd.execute(ctx, print_options).await {
164                                eprintln!("{e}")
165                            }
166                        }
167                    }
168                } else {
169                    eprintln!("'\\{}' is not a valid command", &line[1..]);
170                }
171            }
172            Ok(line) => {
173                let lines = split_from_semicolon(line);
174                for line in lines {
175                    rl.add_history_entry(line.trim_end())?;
176                    tokio::select! {
177                        res = exec_and_print(ctx, print_options, line) => match res {
178                            Ok(_) => {}
179                            Err(err) => eprintln!("{err}"),
180                        },
181                        _ = signal::ctrl_c() => {
182                            println!("^C");
183                            continue
184                        },
185                    }
186                    // dialect might have changed
187                    rl.helper_mut().unwrap().set_dialect(
188                        &ctx.task_ctx().session_config().options().sql_parser.dialect,
189                    );
190                }
191            }
192            Err(ReadlineError::Interrupted) => {
193                println!("^C");
194                continue;
195            }
196            Err(ReadlineError::Eof) => {
197                println!("\\q");
198                break;
199            }
200            Err(err) => {
201                eprintln!("Unknown error happened {:?}", err);
202                break;
203            }
204        }
205    }
206
207    rl.save_history(".history")
208}
209
210pub(super) async fn exec_and_print(
211    ctx: &dyn CliSessionContext,
212    print_options: &PrintOptions,
213    sql: String,
214) -> Result<()> {
215    let now = Instant::now();
216    let sql = unescape_input(&sql)?;
217    let task_ctx = ctx.task_ctx();
218    let dialect = &task_ctx.session_config().options().sql_parser.dialect;
219    let dialect = dialect_from_str(dialect).ok_or_else(|| {
220        plan_datafusion_err!(
221            "Unsupported SQL dialect: {dialect}. Available dialects: \
222                 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
223                 MsSQL, ClickHouse, BigQuery, Ansi."
224        )
225    })?;
226
227    let statements = DFParser::parse_sql_with_dialect(
228        &transform_iceberg_input(&sql),
229        dialect.as_ref(),
230    )?;
231    for statement in statements {
232        let adjusted =
233            AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement);
234
235        let plan = create_plan(ctx, statement).await?;
236        let adjusted = adjusted.with_plan(&plan);
237
238        let df = ctx.execute_logical_plan(plan).await?;
239        let physical_plan = df.create_physical_plan().await?;
240
241        if physical_plan.execution_mode().is_unbounded() {
242            let stream = execute_stream(physical_plan, task_ctx.clone())?;
243            print_options.print_stream(stream, now).await?;
244        } else {
245            let schema = physical_plan.schema();
246            let results = collect(physical_plan, task_ctx.clone()).await?;
247            adjusted.into_inner().print_batches(schema, &results, now)?;
248        }
249    }
250
251    Ok(())
252}
253
254/// Track adjustments to the print options based on the plan / statement being executed
255#[derive(Debug)]
256struct AdjustedPrintOptions {
257    inner: PrintOptions,
258}
259
260impl AdjustedPrintOptions {
261    fn new(inner: PrintOptions) -> Self {
262        Self { inner }
263    }
264    /// Adjust print options based on any statement specific requirements
265    fn with_statement(mut self, statement: &Statement) -> Self {
266        if let Statement::Statement(sql_stmt) = statement {
267            // SHOW / SHOW ALL
268            if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
269                self.inner.maxrows = MaxRows::Unlimited
270            }
271        }
272        self
273    }
274
275    /// Adjust print options based on any plan specific requirements
276    fn with_plan(mut self, plan: &LogicalPlan) -> Self {
277        // For plans like `Explain` ignore `MaxRows` option and always display
278        // all rows
279        if matches!(
280            plan,
281            LogicalPlan::Explain(_)
282                | LogicalPlan::DescribeTable(_)
283                | LogicalPlan::Analyze(_)
284        ) {
285            self.inner.maxrows = MaxRows::Unlimited;
286        }
287        self
288    }
289
290    /// Finalize and return the inner `PrintOptions`
291    fn into_inner(mut self) -> PrintOptions {
292        if self.inner.format == PrintFormat::Automatic {
293            self.inner.format = PrintFormat::Table;
294        }
295
296        self.inner
297    }
298}
299
300fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
301    match ext.to_lowercase().as_str() {
302        "csv" => Some(ConfigFileType::CSV),
303        "json" => Some(ConfigFileType::JSON),
304        "parquet" => Some(ConfigFileType::PARQUET),
305        _ => None,
306    }
307}
308
309async fn create_plan(
310    ctx: &dyn CliSessionContext,
311    statement: Statement,
312) -> Result<LogicalPlan, DataFusionError> {
313    let mut plan = ctx.session_state().statement_to_plan(statement).await?;
314
315    // Note that cmd is a mutable reference so that create_external_table function can remove all
316    // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion
317    // will raise Configuration errors.
318    if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
319        if !(cmd.file_type.to_lowercase() == "iceberg") {
320            // To support custom formats, treat error as None
321            let format = config_file_type_from_str(&cmd.file_type);
322            register_object_store_and_config_extensions(
323                ctx,
324                &cmd.location,
325                &cmd.options,
326                format,
327            )
328            .await?;
329        }
330    }
331
332    if let LogicalPlan::Copy(copy_to) = &mut plan {
333        let format = config_file_type_from_str(&copy_to.file_type.get_ext());
334
335        register_object_store_and_config_extensions(
336            ctx,
337            &copy_to.output_url,
338            &copy_to.options,
339            format,
340        )
341        .await?;
342    }
343    Ok(plan)
344}
345
346/// Asynchronously registers an object store and its configuration extensions
347/// to the session context.
348///
349/// This function dynamically registers a cloud object store based on the given
350/// location and options. It first parses the location to determine the scheme
351/// and constructs the URL accordingly. Depending on the scheme, it also registers
352/// relevant options. The function then alters the default table options with the
353/// given custom options. Finally, it retrieves and registers the object store
354/// in the session context.
355///
356/// # Parameters
357///
358/// * `ctx`: A reference to the `SessionContext` for registering the object store.
359/// * `location`: A string reference representing the location of the object store.
360/// * `options`: A reference to a hash map containing configuration options for
361///   the object store.
362///
363/// # Returns
364///
365/// A `Result<()>` which is an Ok value indicating successful registration, or
366/// an error upon failure.
367///
368/// # Errors
369///
370/// This function can return an error if the location parsing fails, options
371/// alteration fails, or if the object store cannot be retrieved and registered
372/// successfully.
373pub(crate) async fn register_object_store_and_config_extensions(
374    ctx: &dyn CliSessionContext,
375    location: &String,
376    options: &HashMap<String, String>,
377    format: Option<ConfigFileType>,
378) -> Result<()> {
379    // Parse the location URL to extract the scheme and other components
380    let table_path = ListingTableUrl::parse(location)?;
381
382    // Extract the scheme (e.g., "s3", "gcs") from the parsed URL
383    let scheme = table_path.scheme();
384
385    // Obtain a reference to the URL
386    let url = table_path.as_ref();
387
388    // Register the options based on the scheme extracted from the location
389    ctx.register_table_options_extension_from_scheme(scheme);
390
391    // Clone and modify the default table options based on the provided options
392    let mut table_options = ctx.session_state().default_table_options();
393    if let Some(format) = format {
394        table_options.set_config_format(format);
395    }
396    table_options.alter_with_string_hash_map(options)?;
397
398    // Retrieve the appropriate object store based on the scheme, URL, and modified table options
399    let store =
400        get_object_store(&ctx.session_state(), scheme, url, &table_options).await?;
401
402    // Register the retrieved object store in the session context's runtime environment
403    ctx.register_object_store(url, store);
404
405    Ok(())
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    use datafusion::common::plan_err;
413
414    use datafusion::prelude::SessionContext;
415    use url::Url;
416
417    async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
418        let ctx = SessionContext::new();
419        let plan = ctx.state().create_logical_plan(sql).await?;
420
421        if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
422            let format = config_file_type_from_str(&cmd.file_type);
423            register_object_store_and_config_extensions(
424                &ctx,
425                &cmd.location,
426                &cmd.options,
427                format,
428            )
429            .await?;
430        } else {
431            return plan_err!("LogicalPlan is not a CreateExternalTable");
432        }
433
434        // Ensure the URL is supported by the object store
435        ctx.runtime_env()
436            .object_store(ListingTableUrl::parse(location)?)?;
437
438        Ok(())
439    }
440
441    async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
442        let ctx = SessionContext::new();
443        // AWS CONFIG register.
444
445        let plan = ctx.state().create_logical_plan(sql).await?;
446
447        if let LogicalPlan::Copy(cmd) = &plan {
448            let format = config_file_type_from_str(&cmd.file_type.get_ext());
449            register_object_store_and_config_extensions(
450                &ctx,
451                &cmd.output_url,
452                &cmd.options,
453                format,
454            )
455            .await?;
456        } else {
457            return plan_err!("LogicalPlan is not a CreateExternalTable");
458        }
459
460        // Ensure the URL is supported by the object store
461        ctx.runtime_env()
462            .object_store(ListingTableUrl::parse(location)?)?;
463
464        Ok(())
465    }
466
467    #[tokio::test]
468    async fn create_object_store_table_http() -> Result<()> {
469        // Should be OK
470        let location = "http://example.com/file.parquet";
471        let sql =
472            format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
473        create_external_table_test(location, &sql).await?;
474
475        Ok(())
476    }
477    #[tokio::test]
478    async fn copy_to_external_object_store_test() -> Result<()> {
479        let locations = vec![
480            "s3://bucket/path/file.parquet",
481            "oss://bucket/path/file.parquet",
482            "cos://bucket/path/file.parquet",
483            "gcs://bucket/path/file.parquet",
484        ];
485        let ctx = SessionContext::new();
486        let task_ctx = ctx.task_ctx();
487        let dialect = &task_ctx.session_config().options().sql_parser.dialect;
488        let dialect = dialect_from_str(dialect).ok_or_else(|| {
489            plan_datafusion_err!(
490                "Unsupported SQL dialect: {dialect}. Available dialects: \
491                 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
492                 MsSQL, ClickHouse, BigQuery, Ansi."
493            )
494        })?;
495        for location in locations {
496            let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location);
497            let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
498            for statement in statements {
499                //Should not fail
500                let mut plan = create_plan(&ctx, statement).await?;
501                if let LogicalPlan::Copy(copy_to) = &mut plan {
502                    assert_eq!(copy_to.output_url, location);
503                    assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
504                    ctx.runtime_env()
505                        .object_store_registry
506                        .get_store(&Url::parse(&copy_to.output_url).unwrap())?;
507                } else {
508                    return plan_err!("LogicalPlan is not a CopyTo");
509                }
510            }
511        }
512        Ok(())
513    }
514
515    #[tokio::test]
516    async fn copy_to_object_store_table_s3() -> Result<()> {
517        let access_key_id = "fake_access_key_id";
518        let secret_access_key = "fake_secret_access_key";
519        let location = "s3://bucket/path/file.parquet";
520
521        // Missing region, use object_store defaults
522        let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
523            OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
524        copy_to_table_test(location, &sql).await?;
525
526        Ok(())
527    }
528
529    #[tokio::test]
530    async fn create_object_store_table_s3() -> Result<()> {
531        let access_key_id = "fake_access_key_id";
532        let secret_access_key = "fake_secret_access_key";
533        let region = "fake_us-east-2";
534        let session_token = "fake_session_token";
535        let location = "s3://bucket/path/file.parquet";
536
537        // Missing region, use object_store defaults
538        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
539            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
540        create_external_table_test(location, &sql).await?;
541
542        // Should be OK
543        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
544            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
545        create_external_table_test(location, &sql).await?;
546
547        Ok(())
548    }
549
550    #[tokio::test]
551    async fn create_object_store_table_oss() -> Result<()> {
552        let access_key_id = "fake_access_key_id";
553        let secret_access_key = "fake_secret_access_key";
554        let endpoint = "fake_endpoint";
555        let location = "oss://bucket/path/file.parquet";
556
557        // Should be OK
558        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
559            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
560        create_external_table_test(location, &sql).await?;
561
562        Ok(())
563    }
564
565    #[tokio::test]
566    async fn create_object_store_table_cos() -> Result<()> {
567        let access_key_id = "fake_access_key_id";
568        let secret_access_key = "fake_secret_access_key";
569        let endpoint = "fake_endpoint";
570        let location = "cos://bucket/path/file.parquet";
571
572        // Should be OK
573        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
574            OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
575        create_external_table_test(location, &sql).await?;
576
577        Ok(())
578    }
579
580    #[tokio::test]
581    async fn create_object_store_table_gcs() -> Result<()> {
582        let service_account_path = "fake_service_account_path";
583        let service_account_key =
584            "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
585        let application_credentials_path = "fake_application_credentials_path";
586        let location = "gcs://bucket/path/file.parquet";
587
588        // for service_account_path
589        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
590            OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
591        let err = create_external_table_test(location, &sql)
592            .await
593            .unwrap_err();
594        assert!(err.to_string().contains("os error 2"));
595
596        // for service_account_key
597        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'");
598        let err = create_external_table_test(location, &sql)
599            .await
600            .unwrap_err()
601            .to_string();
602        assert!(err.contains("No RSA key found in pem file"), "{err}");
603
604        // for application_credentials_path
605        let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
606            OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
607        let err = create_external_table_test(location, &sql)
608            .await
609            .unwrap_err();
610        assert!(err.to_string().contains("os error 2"));
611
612        Ok(())
613    }
614
615    #[tokio::test]
616    async fn create_external_table_local_file() -> Result<()> {
617        let location = "path/to/file.parquet";
618
619        // Ensure that local files are also registered
620        let sql =
621            format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
622        create_external_table_test(location, &sql).await.unwrap();
623
624        Ok(())
625    }
626
627    #[tokio::test]
628    async fn create_external_table_format_option() -> Result<()> {
629        let location = "path/to/file.cvs";
630
631        // Test with format options
632        let sql =
633            format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')");
634        create_external_table_test(location, &sql).await.unwrap();
635
636        Ok(())
637    }
638}