1use crate::cli_context::CliSessionContext;
21use crate::helper::split_from_semicolon;
22use crate::print_format::PrintFormat;
23use crate::{
24 command::{Command, OutputFormat},
25 helper::CliHelper,
26 object_storage::get_object_store,
27 print_options::{MaxRows, PrintOptions},
28};
29use datafusion::common::instant::Instant;
30use datafusion::common::{plan_datafusion_err, plan_err};
31use datafusion::config::ConfigFileType;
32use datafusion::datasource::listing::ListingTableUrl;
33use datafusion::error::{DataFusionError, Result};
34use datafusion::execution::memory_pool::MemoryConsumer;
35use datafusion::logical_expr::{DdlStatement, LogicalPlan};
36use datafusion::physical_plan::execution_plan::EmissionType;
37use datafusion::physical_plan::spill::get_record_batch_memory_size;
38use datafusion::physical_plan::{ExecutionPlanProperties, execute_stream};
39use datafusion::sql::parser::{DFParser, Statement};
40use datafusion::sql::sqlparser;
41use datafusion::sql::sqlparser::dialect::dialect_from_str;
42use futures::StreamExt;
43use log::warn;
44use object_store::Error::Generic;
45use rustyline::Editor;
46use rustyline::error::ReadlineError;
47use std::collections::HashMap;
48use std::fs::File;
49use std::io::BufReader;
50use std::io::prelude::*;
51use tokio::signal;
52
53pub async fn exec_from_commands(
55 ctx: &dyn CliSessionContext,
56 commands: Vec<String>,
57 print_options: &PrintOptions,
58) -> Result<()> {
59 for sql in commands {
60 exec_and_print(ctx, print_options, sql).await?;
61 }
62
63 Ok(())
64}
65
66pub async fn exec_from_lines(
68 ctx: &dyn CliSessionContext,
69 reader: &mut BufReader<File>,
70 print_options: &PrintOptions,
71) -> Result<()> {
72 let mut query = "".to_owned();
73
74 for line in reader.lines() {
75 match line {
76 Ok(line) if line.starts_with("#!") => {
77 continue;
78 }
79 Ok(line) if line.starts_with("--") => {
80 continue;
81 }
82 Ok(line) => {
83 let line = line.trim_end();
84 query.push_str(line);
85 if line.ends_with(';') {
86 match exec_and_print(ctx, print_options, query).await {
87 Ok(_) => {}
88 Err(err) => eprintln!("{err}"),
89 }
90 query = "".to_string();
91 } else {
92 query.push('\n');
93 }
94 }
95 _ => {
96 break;
97 }
98 }
99 }
100
101 if query.contains(|c| c != '\n') {
104 exec_and_print(ctx, print_options, query).await?;
105 }
106
107 Ok(())
108}
109
110pub async fn exec_from_files(
111 ctx: &dyn CliSessionContext,
112 files: Vec<String>,
113 print_options: &PrintOptions,
114) -> Result<()> {
115 let files = files
116 .into_iter()
117 .map(|file_path| File::open(file_path).unwrap())
118 .collect::<Vec<_>>();
119
120 for file in files {
121 let mut reader = BufReader::new(file);
122 exec_from_lines(ctx, &mut reader, print_options).await?;
123 }
124
125 Ok(())
126}
127
128pub async fn exec_from_repl(
130 ctx: &dyn CliSessionContext,
131 print_options: &mut PrintOptions,
132) -> rustyline::Result<()> {
133 let mut rl = Editor::new()?;
134 rl.set_helper(Some(CliHelper::new(
135 &ctx.task_ctx().session_config().options().sql_parser.dialect,
136 print_options.color,
137 )));
138 rl.load_history(".history").ok();
139
140 loop {
141 match rl.readline("> ") {
142 Ok(line) if line.starts_with('\\') => {
143 rl.add_history_entry(line.trim_end())?;
144 let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
145 if let Ok(cmd) = &command[1..].parse::<Command>() {
146 match cmd {
147 Command::Quit => break,
148 Command::OutputFormat(subcommand) => {
149 if let Some(subcommand) = subcommand {
150 if let Ok(command) = subcommand.parse::<OutputFormat>() {
151 if let Err(e) = command.execute(print_options).await {
152 eprintln!("{e}")
153 }
154 } else {
155 eprintln!(
156 "'\\{}' is not a valid command, you can use '\\?' to see all commands",
157 &line[1..]
158 );
159 }
160 } else {
161 println!("Output format is {:?}.", print_options.format);
162 }
163 }
164 _ => {
165 if let Err(e) = cmd.execute(ctx, print_options).await {
166 eprintln!("{e}")
167 }
168 }
169 }
170 } else {
171 eprintln!(
172 "'\\{}' is not a valid command, you can use '\\?' to see all commands",
173 &line[1..]
174 );
175 }
176 }
177 Ok(line) => {
178 let lines = split_from_semicolon(&line);
179 for line in lines {
180 rl.add_history_entry(line.trim_end())?;
181 tokio::select! {
182 res = exec_and_print(ctx, print_options, line) => match res {
183 Ok(_) => {}
184 Err(err) => eprintln!("{err}"),
185 },
186 _ = signal::ctrl_c() => {
187 println!("^C");
188 continue
189 },
190 }
191 rl.helper_mut().unwrap().set_dialect(
193 &ctx.task_ctx().session_config().options().sql_parser.dialect,
194 );
195 }
196 }
197 Err(ReadlineError::Interrupted) => {
198 println!("^C");
199 rl.helper().unwrap().reset_hint();
200 continue;
201 }
202 Err(ReadlineError::Eof) => {
203 println!("\\q");
204 break;
205 }
206 Err(err) => {
207 eprintln!("Unknown error happened {err:?}");
208 break;
209 }
210 }
211 }
212
213 rl.save_history(".history")
214}
215
216pub(super) async fn exec_and_print(
217 ctx: &dyn CliSessionContext,
218 print_options: &PrintOptions,
219 sql: String,
220) -> Result<()> {
221 let task_ctx = ctx.task_ctx();
222 let options = task_ctx.session_config().options();
223 let dialect = &options.sql_parser.dialect;
224 let dialect = dialect_from_str(dialect).ok_or_else(|| {
225 plan_datafusion_err!(
226 "Unsupported SQL dialect: {dialect}. Available dialects: \
227 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
228 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
229 )
230 })?;
231
232 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
233 for statement in statements {
234 StatementExecutor::new(statement)
235 .execute(ctx, print_options)
236 .await?;
237 }
238
239 Ok(())
240}
241
242struct StatementExecutor {
244 statement: Statement,
245 statement_for_retry: Option<Statement>,
246}
247
248impl StatementExecutor {
249 fn new(statement: Statement) -> Self {
250 let statement_for_retry = matches!(statement, Statement::CreateExternalTable(_))
251 .then(|| statement.clone());
252
253 Self {
254 statement,
255 statement_for_retry,
256 }
257 }
258
259 async fn execute(
260 self,
261 ctx: &dyn CliSessionContext,
262 print_options: &PrintOptions,
263 ) -> Result<()> {
264 let now = Instant::now();
265 let (df, adjusted) = self
266 .create_and_execute_logical_plan(ctx, print_options)
267 .await?;
268 let physical_plan = df.create_physical_plan().await?;
269 let task_ctx = ctx.task_ctx();
270 let options = task_ctx.session_config().options();
271
272 let reservation =
274 MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool());
275
276 if physical_plan.boundedness().is_unbounded() {
277 if physical_plan.pipeline_behavior() == EmissionType::Final {
278 return plan_err!(
279 "The given query can generate a valid result only once \
280 the source finishes, but the source is unbounded"
281 );
282 }
283 let stream = execute_stream(physical_plan, task_ctx.clone())?;
286 print_options
287 .print_stream(stream, now, &options.format)
288 .await?;
289 } else {
290 let schema = physical_plan.schema();
292 let mut stream = execute_stream(physical_plan, task_ctx.clone())?;
293 let mut results = vec![];
294 let mut row_count = 0_usize;
295 let max_rows = match print_options.maxrows {
296 MaxRows::Unlimited => usize::MAX,
297 MaxRows::Limited(n) => n,
298 };
299 while let Some(batch) = stream.next().await {
300 let batch = batch?;
301 let curr_num_rows = batch.num_rows();
302 if row_count < max_rows.saturating_add(curr_num_rows) {
305 reservation.try_grow(get_record_batch_memory_size(&batch))?;
307 results.push(batch);
308 }
309 row_count += curr_num_rows;
310 }
311 adjusted.into_inner().print_batches(
312 schema,
313 &results,
314 now,
315 row_count,
316 &options.format,
317 )?;
318 reservation.free();
319 }
320
321 Ok(())
322 }
323
324 async fn create_and_execute_logical_plan(
325 mut self,
326 ctx: &dyn CliSessionContext,
327 print_options: &PrintOptions,
328 ) -> Result<(datafusion::dataframe::DataFrame, AdjustedPrintOptions)> {
329 let adjusted = AdjustedPrintOptions::new(print_options.clone())
330 .with_statement(&self.statement);
331
332 let plan = create_plan(ctx, self.statement, false).await?;
333 let adjusted = adjusted.with_plan(&plan);
334
335 let df = match ctx.execute_logical_plan(plan).await {
336 Ok(df) => Ok(df),
337 Err(DataFusionError::ObjectStore(err))
338 if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store))
339 && self.statement_for_retry.is_some() =>
340 {
341 warn!(
342 "S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."
343 );
344 let plan =
345 create_plan(ctx, self.statement_for_retry.take().unwrap(), true)
346 .await?;
347 ctx.execute_logical_plan(plan).await
348 }
349 Err(e) => Err(e),
350 }?;
351
352 Ok((df, adjusted))
353 }
354}
355
356#[derive(Debug)]
358struct AdjustedPrintOptions {
359 inner: PrintOptions,
360}
361
362impl AdjustedPrintOptions {
363 fn new(inner: PrintOptions) -> Self {
364 Self { inner }
365 }
366 fn with_statement(mut self, statement: &Statement) -> Self {
368 if let Statement::Statement(sql_stmt) = statement {
369 if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
371 self.inner.maxrows = MaxRows::Unlimited
372 }
373 }
374 self
375 }
376
377 fn with_plan(mut self, plan: &LogicalPlan) -> Self {
379 if matches!(
382 plan,
383 LogicalPlan::Explain(_)
384 | LogicalPlan::DescribeTable(_)
385 | LogicalPlan::Analyze(_)
386 ) {
387 self.inner.maxrows = MaxRows::Unlimited;
388 }
389 self
390 }
391
392 fn into_inner(mut self) -> PrintOptions {
394 if self.inner.format == PrintFormat::Automatic {
395 self.inner.format = PrintFormat::Table;
396 }
397
398 self.inner
399 }
400}
401
402fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
403 match ext.to_lowercase().as_str() {
404 "csv" => Some(ConfigFileType::CSV),
405 "json" => Some(ConfigFileType::JSON),
406 "parquet" => Some(ConfigFileType::PARQUET),
407 _ => None,
408 }
409}
410
411async fn create_plan(
412 ctx: &dyn CliSessionContext,
413 statement: Statement,
414 resolve_region: bool,
415) -> Result<LogicalPlan, DataFusionError> {
416 let mut plan = ctx.session_state().statement_to_plan(statement).await?;
417
418 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
422 let format = config_file_type_from_str(&cmd.file_type);
424 register_object_store_and_config_extensions(
425 ctx,
426 &cmd.location,
427 &cmd.options,
428 format,
429 resolve_region,
430 )
431 .await?;
432 }
433
434 if let LogicalPlan::Copy(copy_to) = &mut plan {
435 let format = config_file_type_from_str(©_to.file_type.get_ext());
436
437 register_object_store_and_config_extensions(
438 ctx,
439 ©_to.output_url,
440 ©_to.options,
441 format,
442 false,
443 )
444 .await?;
445 }
446 Ok(plan)
447}
448
449pub(crate) async fn register_object_store_and_config_extensions(
477 ctx: &dyn CliSessionContext,
478 location: &String,
479 options: &HashMap<String, String>,
480 format: Option<ConfigFileType>,
481 resolve_region: bool,
482) -> Result<()> {
483 let table_path = ListingTableUrl::parse(location)?;
485
486 let scheme = table_path.scheme();
488
489 let url = table_path.as_ref();
491
492 ctx.register_table_options_extension_from_scheme(scheme);
494
495 let mut table_options = ctx.session_state().default_table_options();
497 if let Some(format) = format {
498 table_options.set_config_format(format);
499 }
500 table_options.alter_with_string_hash_map(options)?;
501
502 let store = get_object_store(
504 &ctx.session_state(),
505 scheme,
506 url,
507 &table_options,
508 resolve_region,
509 )
510 .await?;
511
512 ctx.register_object_store(url, store);
514
515 Ok(())
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 use datafusion::common::plan_err;
523
524 use datafusion::prelude::SessionContext;
525 use datafusion_common::assert_contains;
526 use url::Url;
527
528 async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
529 let ctx = SessionContext::new();
530 let plan = ctx.state().create_logical_plan(sql).await?;
531
532 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
533 let format = config_file_type_from_str(&cmd.file_type);
534 register_object_store_and_config_extensions(
535 &ctx,
536 &cmd.location,
537 &cmd.options,
538 format,
539 false,
540 )
541 .await?;
542 } else {
543 return plan_err!("LogicalPlan is not a CreateExternalTable");
544 }
545
546 ctx.runtime_env()
548 .object_store(ListingTableUrl::parse(location)?)?;
549
550 Ok(())
551 }
552
553 async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
554 let ctx = SessionContext::new();
555 let plan = ctx.state().create_logical_plan(sql).await?;
558
559 if let LogicalPlan::Copy(cmd) = &plan {
560 let format = config_file_type_from_str(&cmd.file_type.get_ext());
561 register_object_store_and_config_extensions(
562 &ctx,
563 &cmd.output_url,
564 &cmd.options,
565 format,
566 false,
567 )
568 .await?;
569 } else {
570 return plan_err!("LogicalPlan is not a CreateExternalTable");
571 }
572
573 ctx.runtime_env()
575 .object_store(ListingTableUrl::parse(location)?)?;
576
577 Ok(())
578 }
579
580 #[tokio::test]
581 async fn create_object_store_table_http() -> Result<()> {
582 let location = "http://example.com/file.parquet";
584 let sql =
585 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
586 create_external_table_test(location, &sql).await?;
587
588 Ok(())
589 }
590 #[tokio::test]
591 async fn copy_to_external_object_store_test() -> Result<()> {
592 let aws_envs = vec![
593 "AWS_ENDPOINT",
594 "AWS_ACCESS_KEY_ID",
595 "AWS_SECRET_ACCESS_KEY",
596 "AWS_ALLOW_HTTP",
597 ];
598 for aws_env in aws_envs {
599 if std::env::var(aws_env).is_err() {
600 eprint!("aws envs not set, skipping s3 test");
601 return Ok(());
602 }
603 }
604
605 let locations = vec![
606 "s3://bucket/path/file.parquet",
607 "oss://bucket/path/file.parquet",
608 "cos://bucket/path/file.parquet",
609 "gcs://bucket/path/file.parquet",
610 ];
611 let ctx = SessionContext::new();
612 let task_ctx = ctx.task_ctx();
613 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
614 let dialect = dialect_from_str(dialect).ok_or_else(|| {
615 plan_datafusion_err!(
616 "Unsupported SQL dialect: {dialect}. Available dialects: \
617 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
618 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
619 )
620 })?;
621 for location in locations {
622 let sql = format!("copy (values (1,2)) to '{location}' STORED AS PARQUET;");
623 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
624 for statement in statements {
625 let mut plan = create_plan(&ctx, statement, false).await?;
627 if let LogicalPlan::Copy(copy_to) = &mut plan {
628 assert_eq!(copy_to.output_url, location);
629 assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
630 ctx.runtime_env()
631 .object_store_registry
632 .get_store(&Url::parse(©_to.output_url).unwrap())?;
633 } else {
634 return plan_err!("LogicalPlan is not a CopyTo");
635 }
636 }
637 }
638 Ok(())
639 }
640
641 #[tokio::test]
642 async fn copy_to_object_store_table_s3() -> Result<()> {
643 let access_key_id = "fake_access_key_id";
644 let secret_access_key = "fake_secret_access_key";
645 let location = "s3://bucket/path/file.parquet";
646
647 let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
649 OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
650 copy_to_table_test(location, &sql).await?;
651
652 Ok(())
653 }
654
655 #[tokio::test]
656 async fn create_object_store_table_s3() -> Result<()> {
657 let access_key_id = "fake_access_key_id";
658 let secret_access_key = "fake_secret_access_key";
659 let region = "fake_us-east-2";
660 let session_token = "fake_session_token";
661 let location = "s3://bucket/path/file.parquet";
662
663 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
665 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
666 create_external_table_test(location, &sql).await?;
667
668 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
670 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
671 create_external_table_test(location, &sql).await?;
672
673 Ok(())
674 }
675
676 #[tokio::test]
677 async fn create_object_store_table_oss() -> Result<()> {
678 let access_key_id = "fake_access_key_id";
679 let secret_access_key = "fake_secret_access_key";
680 let endpoint = "fake_endpoint";
681 let location = "oss://bucket/path/file.parquet";
682
683 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
685 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
686 create_external_table_test(location, &sql).await?;
687
688 Ok(())
689 }
690
691 #[tokio::test]
692 async fn create_object_store_table_cos() -> Result<()> {
693 let access_key_id = "fake_access_key_id";
694 let secret_access_key = "fake_secret_access_key";
695 let endpoint = "fake_endpoint";
696 let location = "cos://bucket/path/file.parquet";
697
698 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
700 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
701 create_external_table_test(location, &sql).await?;
702
703 Ok(())
704 }
705
706 #[tokio::test]
707 async fn create_object_store_table_gcs() -> Result<()> {
708 let service_account_path = "fake_service_account_path";
709 let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
710 let application_credentials_path = "fake_application_credentials_path";
711 let location = "gcs://bucket/path/file.parquet";
712
713 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
715 OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
716 let err = create_external_table_test(location, &sql)
717 .await
718 .unwrap_err();
719 assert_contains!(err.to_string(), "os error 2");
720
721 let sql = format!(
723 "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'"
724 );
725 let err = create_external_table_test(location, &sql)
726 .await
727 .unwrap_err();
728 assert_contains!(err.to_string(), "Error reading pem file: no items found");
729
730 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
732 OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
733 let err = create_external_table_test(location, &sql)
734 .await
735 .unwrap_err();
736 assert_contains!(err.to_string(), "os error 2");
737
738 Ok(())
739 }
740
741 #[tokio::test]
742 async fn create_external_table_local_file() -> Result<()> {
743 let location = "path/to/file.parquet";
744
745 let sql =
747 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
748 create_external_table_test(location, &sql).await.unwrap();
749
750 Ok(())
751 }
752
753 #[tokio::test]
754 async fn create_external_table_format_option() -> Result<()> {
755 let location = "path/to/file.cvs";
756
757 let sql = format!(
759 "CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"
760 );
761 create_external_table_test(location, &sql).await.unwrap();
762
763 Ok(())
764 }
765}