1use crate::cli_context::CliSessionContext;
21use crate::helper::split_from_semicolon;
22use crate::print_format::PrintFormat;
23use crate::{
24 command::{Command, OutputFormat},
25 helper::CliHelper,
26 object_storage::get_object_store,
27 print_options::{MaxRows, PrintOptions},
28};
29use datafusion::common::instant::Instant;
30use datafusion::common::{plan_datafusion_err, plan_err};
31use datafusion::config::ConfigFileType;
32use datafusion::datasource::listing::ListingTableUrl;
33use datafusion::error::{DataFusionError, Result};
34use datafusion::execution::memory_pool::MemoryConsumer;
35use datafusion::logical_expr::{DdlStatement, LogicalPlan};
36use datafusion::physical_plan::execution_plan::EmissionType;
37use datafusion::physical_plan::spill::get_record_batch_memory_size;
38use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties};
39use datafusion::sql::parser::{DFParser, Statement};
40use datafusion::sql::sqlparser;
41use datafusion::sql::sqlparser::dialect::dialect_from_str;
42use futures::StreamExt;
43use log::warn;
44use object_store::Error::Generic;
45use rustyline::error::ReadlineError;
46use rustyline::Editor;
47use std::collections::HashMap;
48use std::fs::File;
49use std::io::prelude::*;
50use std::io::BufReader;
51use tokio::signal;
52
53pub async fn exec_from_commands(
55 ctx: &dyn CliSessionContext,
56 commands: Vec<String>,
57 print_options: &PrintOptions,
58) -> Result<()> {
59 for sql in commands {
60 exec_and_print(ctx, print_options, sql).await?;
61 }
62
63 Ok(())
64}
65
66pub async fn exec_from_lines(
68 ctx: &dyn CliSessionContext,
69 reader: &mut BufReader<File>,
70 print_options: &PrintOptions,
71) -> Result<()> {
72 let mut query = "".to_owned();
73
74 for line in reader.lines() {
75 match line {
76 Ok(line) if line.starts_with("#!") => {
77 continue;
78 }
79 Ok(line) if line.starts_with("--") => {
80 continue;
81 }
82 Ok(line) => {
83 let line = line.trim_end();
84 query.push_str(line);
85 if line.ends_with(';') {
86 match exec_and_print(ctx, print_options, query).await {
87 Ok(_) => {}
88 Err(err) => eprintln!("{err}"),
89 }
90 query = "".to_string();
91 } else {
92 query.push('\n');
93 }
94 }
95 _ => {
96 break;
97 }
98 }
99 }
100
101 if query.contains(|c| c != '\n') {
104 exec_and_print(ctx, print_options, query).await?;
105 }
106
107 Ok(())
108}
109
110pub async fn exec_from_files(
111 ctx: &dyn CliSessionContext,
112 files: Vec<String>,
113 print_options: &PrintOptions,
114) -> Result<()> {
115 let files = files
116 .into_iter()
117 .map(|file_path| File::open(file_path).unwrap())
118 .collect::<Vec<_>>();
119
120 for file in files {
121 let mut reader = BufReader::new(file);
122 exec_from_lines(ctx, &mut reader, print_options).await?;
123 }
124
125 Ok(())
126}
127
128pub async fn exec_from_repl(
130 ctx: &dyn CliSessionContext,
131 print_options: &mut PrintOptions,
132) -> rustyline::Result<()> {
133 let mut rl = Editor::new()?;
134 rl.set_helper(Some(CliHelper::new(
135 &ctx.task_ctx().session_config().options().sql_parser.dialect,
136 print_options.color,
137 )));
138 rl.load_history(".history").ok();
139
140 loop {
141 match rl.readline("> ") {
142 Ok(line) if line.starts_with('\\') => {
143 rl.add_history_entry(line.trim_end())?;
144 let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
145 if let Ok(cmd) = &command[1..].parse::<Command>() {
146 match cmd {
147 Command::Quit => break,
148 Command::OutputFormat(subcommand) => {
149 if let Some(subcommand) = subcommand {
150 if let Ok(command) = subcommand.parse::<OutputFormat>() {
151 if let Err(e) = command.execute(print_options).await {
152 eprintln!("{e}")
153 }
154 } else {
155 eprintln!(
156 "'\\{}' is not a valid command",
157 &line[1..]
158 );
159 }
160 } else {
161 println!("Output format is {:?}.", print_options.format);
162 }
163 }
164 _ => {
165 if let Err(e) = cmd.execute(ctx, print_options).await {
166 eprintln!("{e}")
167 }
168 }
169 }
170 } else {
171 eprintln!("'\\{}' is not a valid command", &line[1..]);
172 }
173 }
174 Ok(line) => {
175 let lines = split_from_semicolon(&line);
176 for line in lines {
177 rl.add_history_entry(line.trim_end())?;
178 tokio::select! {
179 res = exec_and_print(ctx, print_options, line) => match res {
180 Ok(_) => {}
181 Err(err) => eprintln!("{err}"),
182 },
183 _ = signal::ctrl_c() => {
184 println!("^C");
185 continue
186 },
187 }
188 rl.helper_mut().unwrap().set_dialect(
190 &ctx.task_ctx().session_config().options().sql_parser.dialect,
191 );
192 }
193 }
194 Err(ReadlineError::Interrupted) => {
195 println!("^C");
196 continue;
197 }
198 Err(ReadlineError::Eof) => {
199 println!("\\q");
200 break;
201 }
202 Err(err) => {
203 eprintln!("Unknown error happened {err:?}");
204 break;
205 }
206 }
207 }
208
209 rl.save_history(".history")
210}
211
212pub(super) async fn exec_and_print(
213 ctx: &dyn CliSessionContext,
214 print_options: &PrintOptions,
215 sql: String,
216) -> Result<()> {
217 let task_ctx = ctx.task_ctx();
218 let options = task_ctx.session_config().options();
219 let dialect = &options.sql_parser.dialect;
220 let dialect = dialect_from_str(dialect).ok_or_else(|| {
221 plan_datafusion_err!(
222 "Unsupported SQL dialect: {dialect}. Available dialects: \
223 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
224 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
225 )
226 })?;
227
228 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
229 for statement in statements {
230 StatementExecutor::new(statement)
231 .execute(ctx, print_options)
232 .await?;
233 }
234
235 Ok(())
236}
237
238struct StatementExecutor {
240 statement: Statement,
241 statement_for_retry: Option<Statement>,
242}
243
244impl StatementExecutor {
245 fn new(statement: Statement) -> Self {
246 let statement_for_retry = matches!(statement, Statement::CreateExternalTable(_))
247 .then(|| statement.clone());
248
249 Self {
250 statement,
251 statement_for_retry,
252 }
253 }
254
255 async fn execute(
256 self,
257 ctx: &dyn CliSessionContext,
258 print_options: &PrintOptions,
259 ) -> Result<()> {
260 let now = Instant::now();
261 let (df, adjusted) = self
262 .create_and_execute_logical_plan(ctx, print_options)
263 .await?;
264 let physical_plan = df.create_physical_plan().await?;
265 let task_ctx = ctx.task_ctx();
266 let options = task_ctx.session_config().options();
267
268 let mut reservation =
270 MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool());
271
272 if physical_plan.boundedness().is_unbounded() {
273 if physical_plan.pipeline_behavior() == EmissionType::Final {
274 return plan_err!(
275 "The given query can generate a valid result only once \
276 the source finishes, but the source is unbounded"
277 );
278 }
279 let stream = execute_stream(physical_plan, task_ctx.clone())?;
282 print_options
283 .print_stream(stream, now, &options.format)
284 .await?;
285 } else {
286 let schema = physical_plan.schema();
288 let mut stream = execute_stream(physical_plan, task_ctx.clone())?;
289 let mut results = vec![];
290 let mut row_count = 0_usize;
291 let max_rows = match print_options.maxrows {
292 MaxRows::Unlimited => usize::MAX,
293 MaxRows::Limited(n) => n,
294 };
295 while let Some(batch) = stream.next().await {
296 let batch = batch?;
297 let curr_num_rows = batch.num_rows();
298 if row_count < max_rows + curr_num_rows {
301 reservation.try_grow(get_record_batch_memory_size(&batch))?;
303 results.push(batch);
304 }
305 row_count += curr_num_rows;
306 }
307 adjusted.into_inner().print_batches(
308 schema,
309 &results,
310 now,
311 row_count,
312 &options.format,
313 )?;
314 reservation.free();
315 }
316
317 Ok(())
318 }
319
320 async fn create_and_execute_logical_plan(
321 mut self,
322 ctx: &dyn CliSessionContext,
323 print_options: &PrintOptions,
324 ) -> Result<(datafusion::dataframe::DataFrame, AdjustedPrintOptions)> {
325 let adjusted = AdjustedPrintOptions::new(print_options.clone())
326 .with_statement(&self.statement);
327
328 let plan = create_plan(ctx, self.statement, false).await?;
329 let adjusted = adjusted.with_plan(&plan);
330
331 let df = match ctx.execute_logical_plan(plan).await {
332 Ok(df) => Ok(df),
333 Err(DataFusionError::ObjectStore(err))
334 if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store))
335 && self.statement_for_retry.is_some() =>
336 {
337 warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration.");
338 let plan =
339 create_plan(ctx, self.statement_for_retry.take().unwrap(), true)
340 .await?;
341 ctx.execute_logical_plan(plan).await
342 }
343 Err(e) => Err(e),
344 }?;
345
346 Ok((df, adjusted))
347 }
348}
349
350#[derive(Debug)]
352struct AdjustedPrintOptions {
353 inner: PrintOptions,
354}
355
356impl AdjustedPrintOptions {
357 fn new(inner: PrintOptions) -> Self {
358 Self { inner }
359 }
360 fn with_statement(mut self, statement: &Statement) -> Self {
362 if let Statement::Statement(sql_stmt) = statement {
363 if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
365 self.inner.maxrows = MaxRows::Unlimited
366 }
367 }
368 self
369 }
370
371 fn with_plan(mut self, plan: &LogicalPlan) -> Self {
373 if matches!(
376 plan,
377 LogicalPlan::Explain(_)
378 | LogicalPlan::DescribeTable(_)
379 | LogicalPlan::Analyze(_)
380 ) {
381 self.inner.maxrows = MaxRows::Unlimited;
382 }
383 self
384 }
385
386 fn into_inner(mut self) -> PrintOptions {
388 if self.inner.format == PrintFormat::Automatic {
389 self.inner.format = PrintFormat::Table;
390 }
391
392 self.inner
393 }
394}
395
396fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
397 match ext.to_lowercase().as_str() {
398 "csv" => Some(ConfigFileType::CSV),
399 "json" => Some(ConfigFileType::JSON),
400 "parquet" => Some(ConfigFileType::PARQUET),
401 _ => None,
402 }
403}
404
405async fn create_plan(
406 ctx: &dyn CliSessionContext,
407 statement: Statement,
408 resolve_region: bool,
409) -> Result<LogicalPlan, DataFusionError> {
410 let mut plan = ctx.session_state().statement_to_plan(statement).await?;
411
412 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
416 let format = config_file_type_from_str(&cmd.file_type);
418 register_object_store_and_config_extensions(
419 ctx,
420 &cmd.location,
421 &cmd.options,
422 format,
423 resolve_region,
424 )
425 .await?;
426 }
427
428 if let LogicalPlan::Copy(copy_to) = &mut plan {
429 let format = config_file_type_from_str(©_to.file_type.get_ext());
430
431 register_object_store_and_config_extensions(
432 ctx,
433 ©_to.output_url,
434 ©_to.options,
435 format,
436 false,
437 )
438 .await?;
439 }
440 Ok(plan)
441}
442
443pub(crate) async fn register_object_store_and_config_extensions(
471 ctx: &dyn CliSessionContext,
472 location: &String,
473 options: &HashMap<String, String>,
474 format: Option<ConfigFileType>,
475 resolve_region: bool,
476) -> Result<()> {
477 let table_path = ListingTableUrl::parse(location)?;
479
480 let scheme = table_path.scheme();
482
483 let url = table_path.as_ref();
485
486 ctx.register_table_options_extension_from_scheme(scheme);
488
489 let mut table_options = ctx.session_state().default_table_options();
491 if let Some(format) = format {
492 table_options.set_config_format(format);
493 }
494 table_options.alter_with_string_hash_map(options)?;
495
496 let store = get_object_store(
498 &ctx.session_state(),
499 scheme,
500 url,
501 &table_options,
502 resolve_region,
503 )
504 .await?;
505
506 ctx.register_object_store(url, store);
508
509 Ok(())
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 use datafusion::common::plan_err;
517
518 use datafusion::prelude::SessionContext;
519 use url::Url;
520
521 async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
522 let ctx = SessionContext::new();
523 let plan = ctx.state().create_logical_plan(sql).await?;
524
525 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
526 let format = config_file_type_from_str(&cmd.file_type);
527 register_object_store_and_config_extensions(
528 &ctx,
529 &cmd.location,
530 &cmd.options,
531 format,
532 false,
533 )
534 .await?;
535 } else {
536 return plan_err!("LogicalPlan is not a CreateExternalTable");
537 }
538
539 ctx.runtime_env()
541 .object_store(ListingTableUrl::parse(location)?)?;
542
543 Ok(())
544 }
545
546 async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
547 let ctx = SessionContext::new();
548 let plan = ctx.state().create_logical_plan(sql).await?;
551
552 if let LogicalPlan::Copy(cmd) = &plan {
553 let format = config_file_type_from_str(&cmd.file_type.get_ext());
554 register_object_store_and_config_extensions(
555 &ctx,
556 &cmd.output_url,
557 &cmd.options,
558 format,
559 false,
560 )
561 .await?;
562 } else {
563 return plan_err!("LogicalPlan is not a CreateExternalTable");
564 }
565
566 ctx.runtime_env()
568 .object_store(ListingTableUrl::parse(location)?)?;
569
570 Ok(())
571 }
572
573 #[tokio::test]
574 async fn create_object_store_table_http() -> Result<()> {
575 let location = "http://example.com/file.parquet";
577 let sql =
578 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
579 create_external_table_test(location, &sql).await?;
580
581 Ok(())
582 }
583 #[tokio::test]
584 async fn copy_to_external_object_store_test() -> Result<()> {
585 let aws_envs = vec![
586 "AWS_ENDPOINT",
587 "AWS_ACCESS_KEY_ID",
588 "AWS_SECRET_ACCESS_KEY",
589 "AWS_ALLOW_HTTP",
590 ];
591 for aws_env in aws_envs {
592 if std::env::var(aws_env).is_err() {
593 eprint!("aws envs not set, skipping s3 test");
594 return Ok(());
595 }
596 }
597
598 let locations = vec![
599 "s3://bucket/path/file.parquet",
600 "oss://bucket/path/file.parquet",
601 "cos://bucket/path/file.parquet",
602 "gcs://bucket/path/file.parquet",
603 ];
604 let ctx = SessionContext::new();
605 let task_ctx = ctx.task_ctx();
606 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
607 let dialect = dialect_from_str(dialect).ok_or_else(|| {
608 plan_datafusion_err!(
609 "Unsupported SQL dialect: {dialect}. Available dialects: \
610 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
611 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
612 )
613 })?;
614 for location in locations {
615 let sql = format!("copy (values (1,2)) to '{location}' STORED AS PARQUET;");
616 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
617 for statement in statements {
618 let mut plan = create_plan(&ctx, statement, false).await?;
620 if let LogicalPlan::Copy(copy_to) = &mut plan {
621 assert_eq!(copy_to.output_url, location);
622 assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
623 ctx.runtime_env()
624 .object_store_registry
625 .get_store(&Url::parse(©_to.output_url).unwrap())?;
626 } else {
627 return plan_err!("LogicalPlan is not a CopyTo");
628 }
629 }
630 }
631 Ok(())
632 }
633
634 #[tokio::test]
635 async fn copy_to_object_store_table_s3() -> Result<()> {
636 let access_key_id = "fake_access_key_id";
637 let secret_access_key = "fake_secret_access_key";
638 let location = "s3://bucket/path/file.parquet";
639
640 let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
642 OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
643 copy_to_table_test(location, &sql).await?;
644
645 Ok(())
646 }
647
648 #[tokio::test]
649 async fn create_object_store_table_s3() -> Result<()> {
650 let access_key_id = "fake_access_key_id";
651 let secret_access_key = "fake_secret_access_key";
652 let region = "fake_us-east-2";
653 let session_token = "fake_session_token";
654 let location = "s3://bucket/path/file.parquet";
655
656 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
658 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
659 create_external_table_test(location, &sql).await?;
660
661 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
663 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
664 create_external_table_test(location, &sql).await?;
665
666 Ok(())
667 }
668
669 #[tokio::test]
670 async fn create_object_store_table_oss() -> Result<()> {
671 let access_key_id = "fake_access_key_id";
672 let secret_access_key = "fake_secret_access_key";
673 let endpoint = "fake_endpoint";
674 let location = "oss://bucket/path/file.parquet";
675
676 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
678 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
679 create_external_table_test(location, &sql).await?;
680
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn create_object_store_table_cos() -> Result<()> {
686 let access_key_id = "fake_access_key_id";
687 let secret_access_key = "fake_secret_access_key";
688 let endpoint = "fake_endpoint";
689 let location = "cos://bucket/path/file.parquet";
690
691 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
693 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
694 create_external_table_test(location, &sql).await?;
695
696 Ok(())
697 }
698
699 #[tokio::test]
700 async fn create_object_store_table_gcs() -> Result<()> {
701 let service_account_path = "fake_service_account_path";
702 let service_account_key =
703 "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
704 let application_credentials_path = "fake_application_credentials_path";
705 let location = "gcs://bucket/path/file.parquet";
706
707 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
709 OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
710 let err = create_external_table_test(location, &sql)
711 .await
712 .unwrap_err();
713 assert!(err.to_string().contains("os error 2"));
714
715 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'");
717 let err = create_external_table_test(location, &sql)
718 .await
719 .unwrap_err()
720 .to_string();
721 assert!(err.contains("No RSA key found in pem file"), "{err}");
722
723 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
725 OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
726 let err = create_external_table_test(location, &sql)
727 .await
728 .unwrap_err();
729 assert!(err.to_string().contains("os error 2"));
730
731 Ok(())
732 }
733
734 #[tokio::test]
735 async fn create_external_table_local_file() -> Result<()> {
736 let location = "path/to/file.parquet";
737
738 let sql =
740 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
741 create_external_table_test(location, &sql).await.unwrap();
742
743 Ok(())
744 }
745
746 #[tokio::test]
747 async fn create_external_table_format_option() -> Result<()> {
748 let location = "path/to/file.cvs";
749
750 let sql =
752 format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')");
753 create_external_table_test(location, &sql).await.unwrap();
754
755 Ok(())
756 }
757}