1use crate::cli_context::CliSessionContext;
21use crate::helper::split_from_semicolon;
22use crate::print_format::PrintFormat;
23use crate::{
24 command::{Command, OutputFormat},
25 helper::CliHelper,
26 object_storage::get_object_store,
27 print_options::{MaxRows, PrintOptions},
28};
29use futures::StreamExt;
30use std::collections::HashMap;
31use std::fs::File;
32use std::io::prelude::*;
33use std::io::BufReader;
34
35use datafusion::common::instant::Instant;
36use datafusion::common::{plan_datafusion_err, plan_err};
37use datafusion::config::ConfigFileType;
38use datafusion::datasource::listing::ListingTableUrl;
39use datafusion::error::{DataFusionError, Result};
40use datafusion::logical_expr::{DdlStatement, LogicalPlan};
41use datafusion::physical_plan::execution_plan::EmissionType;
42use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties};
43use datafusion::sql::parser::{DFParser, Statement};
44use datafusion::sql::sqlparser::dialect::dialect_from_str;
45
46use datafusion::execution::memory_pool::MemoryConsumer;
47use datafusion::physical_plan::spill::get_record_batch_memory_size;
48use datafusion::sql::sqlparser;
49use rustyline::error::ReadlineError;
50use rustyline::Editor;
51use tokio::signal;
52
53pub async fn exec_from_commands(
55 ctx: &dyn CliSessionContext,
56 commands: Vec<String>,
57 print_options: &PrintOptions,
58) -> Result<()> {
59 for sql in commands {
60 exec_and_print(ctx, print_options, sql).await?;
61 }
62
63 Ok(())
64}
65
66pub async fn exec_from_lines(
68 ctx: &dyn CliSessionContext,
69 reader: &mut BufReader<File>,
70 print_options: &PrintOptions,
71) -> Result<()> {
72 let mut query = "".to_owned();
73
74 for line in reader.lines() {
75 match line {
76 Ok(line) if line.starts_with("#!") => {
77 continue;
78 }
79 Ok(line) if line.starts_with("--") => {
80 continue;
81 }
82 Ok(line) => {
83 let line = line.trim_end();
84 query.push_str(line);
85 if line.ends_with(';') {
86 match exec_and_print(ctx, print_options, query).await {
87 Ok(_) => {}
88 Err(err) => eprintln!("{err}"),
89 }
90 query = "".to_string();
91 } else {
92 query.push('\n');
93 }
94 }
95 _ => {
96 break;
97 }
98 }
99 }
100
101 if query.contains(|c| c != '\n') {
104 exec_and_print(ctx, print_options, query).await?;
105 }
106
107 Ok(())
108}
109
110pub async fn exec_from_files(
111 ctx: &dyn CliSessionContext,
112 files: Vec<String>,
113 print_options: &PrintOptions,
114) -> Result<()> {
115 let files = files
116 .into_iter()
117 .map(|file_path| File::open(file_path).unwrap())
118 .collect::<Vec<_>>();
119
120 for file in files {
121 let mut reader = BufReader::new(file);
122 exec_from_lines(ctx, &mut reader, print_options).await?;
123 }
124
125 Ok(())
126}
127
128pub async fn exec_from_repl(
130 ctx: &dyn CliSessionContext,
131 print_options: &mut PrintOptions,
132) -> rustyline::Result<()> {
133 let mut rl = Editor::new()?;
134 rl.set_helper(Some(CliHelper::new(
135 &ctx.task_ctx().session_config().options().sql_parser.dialect,
136 print_options.color,
137 )));
138 rl.load_history(".history").ok();
139
140 loop {
141 match rl.readline("> ") {
142 Ok(line) if line.starts_with('\\') => {
143 rl.add_history_entry(line.trim_end())?;
144 let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
145 if let Ok(cmd) = &command[1..].parse::<Command>() {
146 match cmd {
147 Command::Quit => break,
148 Command::OutputFormat(subcommand) => {
149 if let Some(subcommand) = subcommand {
150 if let Ok(command) = subcommand.parse::<OutputFormat>() {
151 if let Err(e) = command.execute(print_options).await {
152 eprintln!("{e}")
153 }
154 } else {
155 eprintln!(
156 "'\\{}' is not a valid command",
157 &line[1..]
158 );
159 }
160 } else {
161 println!("Output format is {:?}.", print_options.format);
162 }
163 }
164 _ => {
165 if let Err(e) = cmd.execute(ctx, print_options).await {
166 eprintln!("{e}")
167 }
168 }
169 }
170 } else {
171 eprintln!("'\\{}' is not a valid command", &line[1..]);
172 }
173 }
174 Ok(line) => {
175 let lines = split_from_semicolon(&line);
176 for line in lines {
177 rl.add_history_entry(line.trim_end())?;
178 tokio::select! {
179 res = exec_and_print(ctx, print_options, line) => match res {
180 Ok(_) => {}
181 Err(err) => eprintln!("{err}"),
182 },
183 _ = signal::ctrl_c() => {
184 println!("^C");
185 continue
186 },
187 }
188 rl.helper_mut().unwrap().set_dialect(
190 &ctx.task_ctx().session_config().options().sql_parser.dialect,
191 );
192 }
193 }
194 Err(ReadlineError::Interrupted) => {
195 println!("^C");
196 continue;
197 }
198 Err(ReadlineError::Eof) => {
199 println!("\\q");
200 break;
201 }
202 Err(err) => {
203 eprintln!("Unknown error happened {:?}", err);
204 break;
205 }
206 }
207 }
208
209 rl.save_history(".history")
210}
211
212pub(super) async fn exec_and_print(
213 ctx: &dyn CliSessionContext,
214 print_options: &PrintOptions,
215 sql: String,
216) -> Result<()> {
217 let now = Instant::now();
218 let task_ctx = ctx.task_ctx();
219 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
220 let dialect = dialect_from_str(dialect).ok_or_else(|| {
221 plan_datafusion_err!(
222 "Unsupported SQL dialect: {dialect}. Available dialects: \
223 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
224 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
225 )
226 })?;
227
228 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
229 for statement in statements {
230 let adjusted =
231 AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement);
232
233 let plan = create_plan(ctx, statement).await?;
234 let adjusted = adjusted.with_plan(&plan);
235
236 let df = ctx.execute_logical_plan(plan).await?;
237 let physical_plan = df.create_physical_plan().await?;
238
239 let mut reservation =
241 MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool());
242
243 if physical_plan.boundedness().is_unbounded() {
244 if physical_plan.pipeline_behavior() == EmissionType::Final {
245 return plan_err!(
246 "The given query can generate a valid result only once \
247 the source finishes, but the source is unbounded"
248 );
249 }
250 let stream = execute_stream(physical_plan, task_ctx.clone())?;
253 print_options.print_stream(stream, now).await?;
254 } else {
255 let schema = physical_plan.schema();
257 let mut stream = execute_stream(physical_plan, task_ctx.clone())?;
258 let mut results = vec![];
259 let mut row_count = 0_usize;
260 let max_rows = match print_options.maxrows {
261 MaxRows::Unlimited => usize::MAX,
262 MaxRows::Limited(n) => n,
263 };
264 while let Some(batch) = stream.next().await {
265 let batch = batch?;
266 let curr_num_rows = batch.num_rows();
267 if row_count < max_rows + curr_num_rows {
270 reservation.try_grow(get_record_batch_memory_size(&batch))?;
272 results.push(batch);
273 }
274 row_count += curr_num_rows;
275 }
276 adjusted
277 .into_inner()
278 .print_batches(schema, &results, now, row_count)?;
279 reservation.free();
280 }
281 }
282
283 Ok(())
284}
285
286#[derive(Debug)]
288struct AdjustedPrintOptions {
289 inner: PrintOptions,
290}
291
292impl AdjustedPrintOptions {
293 fn new(inner: PrintOptions) -> Self {
294 Self { inner }
295 }
296 fn with_statement(mut self, statement: &Statement) -> Self {
298 if let Statement::Statement(sql_stmt) = statement {
299 if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
301 self.inner.maxrows = MaxRows::Unlimited
302 }
303 }
304 self
305 }
306
307 fn with_plan(mut self, plan: &LogicalPlan) -> Self {
309 if matches!(
312 plan,
313 LogicalPlan::Explain(_)
314 | LogicalPlan::DescribeTable(_)
315 | LogicalPlan::Analyze(_)
316 ) {
317 self.inner.maxrows = MaxRows::Unlimited;
318 }
319 self
320 }
321
322 fn into_inner(mut self) -> PrintOptions {
324 if self.inner.format == PrintFormat::Automatic {
325 self.inner.format = PrintFormat::Table;
326 }
327
328 self.inner
329 }
330}
331
332fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
333 match ext.to_lowercase().as_str() {
334 "csv" => Some(ConfigFileType::CSV),
335 "json" => Some(ConfigFileType::JSON),
336 "parquet" => Some(ConfigFileType::PARQUET),
337 _ => None,
338 }
339}
340
341async fn create_plan(
342 ctx: &dyn CliSessionContext,
343 statement: Statement,
344) -> Result<LogicalPlan, DataFusionError> {
345 let mut plan = ctx.session_state().statement_to_plan(statement).await?;
346
347 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
351 let format = config_file_type_from_str(&cmd.file_type);
353 register_object_store_and_config_extensions(
354 ctx,
355 &cmd.location,
356 &cmd.options,
357 format,
358 )
359 .await?;
360 }
361
362 if let LogicalPlan::Copy(copy_to) = &mut plan {
363 let format = config_file_type_from_str(©_to.file_type.get_ext());
364
365 register_object_store_and_config_extensions(
366 ctx,
367 ©_to.output_url,
368 ©_to.options,
369 format,
370 )
371 .await?;
372 }
373 Ok(plan)
374}
375
376pub(crate) async fn register_object_store_and_config_extensions(
404 ctx: &dyn CliSessionContext,
405 location: &String,
406 options: &HashMap<String, String>,
407 format: Option<ConfigFileType>,
408) -> Result<()> {
409 let table_path = ListingTableUrl::parse(location)?;
411
412 let scheme = table_path.scheme();
414
415 let url = table_path.as_ref();
417
418 ctx.register_table_options_extension_from_scheme(scheme);
420
421 let mut table_options = ctx.session_state().default_table_options();
423 if let Some(format) = format {
424 table_options.set_config_format(format);
425 }
426 table_options.alter_with_string_hash_map(options)?;
427
428 let store =
430 get_object_store(&ctx.session_state(), scheme, url, &table_options).await?;
431
432 ctx.register_object_store(url, store);
434
435 Ok(())
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 use datafusion::common::plan_err;
443
444 use datafusion::prelude::SessionContext;
445 use url::Url;
446
447 async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
448 let ctx = SessionContext::new();
449 let plan = ctx.state().create_logical_plan(sql).await?;
450
451 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
452 let format = config_file_type_from_str(&cmd.file_type);
453 register_object_store_and_config_extensions(
454 &ctx,
455 &cmd.location,
456 &cmd.options,
457 format,
458 )
459 .await?;
460 } else {
461 return plan_err!("LogicalPlan is not a CreateExternalTable");
462 }
463
464 ctx.runtime_env()
466 .object_store(ListingTableUrl::parse(location)?)?;
467
468 Ok(())
469 }
470
471 async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
472 let ctx = SessionContext::new();
473 let plan = ctx.state().create_logical_plan(sql).await?;
476
477 if let LogicalPlan::Copy(cmd) = &plan {
478 let format = config_file_type_from_str(&cmd.file_type.get_ext());
479 register_object_store_and_config_extensions(
480 &ctx,
481 &cmd.output_url,
482 &cmd.options,
483 format,
484 )
485 .await?;
486 } else {
487 return plan_err!("LogicalPlan is not a CreateExternalTable");
488 }
489
490 ctx.runtime_env()
492 .object_store(ListingTableUrl::parse(location)?)?;
493
494 Ok(())
495 }
496
497 #[tokio::test]
498 async fn create_object_store_table_http() -> Result<()> {
499 let location = "http://example.com/file.parquet";
501 let sql =
502 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
503 create_external_table_test(location, &sql).await?;
504
505 Ok(())
506 }
507 #[tokio::test]
508 async fn copy_to_external_object_store_test() -> Result<()> {
509 let locations = vec![
510 "s3://bucket/path/file.parquet",
511 "oss://bucket/path/file.parquet",
512 "cos://bucket/path/file.parquet",
513 "gcs://bucket/path/file.parquet",
514 ];
515 let ctx = SessionContext::new();
516 let task_ctx = ctx.task_ctx();
517 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
518 let dialect = dialect_from_str(dialect).ok_or_else(|| {
519 plan_datafusion_err!(
520 "Unsupported SQL dialect: {dialect}. Available dialects: \
521 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
522 MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
523 )
524 })?;
525 for location in locations {
526 let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location);
527 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
528 for statement in statements {
529 let mut plan = create_plan(&ctx, statement).await?;
531 if let LogicalPlan::Copy(copy_to) = &mut plan {
532 assert_eq!(copy_to.output_url, location);
533 assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
534 ctx.runtime_env()
535 .object_store_registry
536 .get_store(&Url::parse(©_to.output_url).unwrap())?;
537 } else {
538 return plan_err!("LogicalPlan is not a CopyTo");
539 }
540 }
541 }
542 Ok(())
543 }
544
545 #[tokio::test]
546 async fn copy_to_object_store_table_s3() -> Result<()> {
547 let access_key_id = "fake_access_key_id";
548 let secret_access_key = "fake_secret_access_key";
549 let location = "s3://bucket/path/file.parquet";
550
551 let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
553 OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
554 copy_to_table_test(location, &sql).await?;
555
556 Ok(())
557 }
558
559 #[tokio::test]
560 async fn create_object_store_table_s3() -> Result<()> {
561 let access_key_id = "fake_access_key_id";
562 let secret_access_key = "fake_secret_access_key";
563 let region = "fake_us-east-2";
564 let session_token = "fake_session_token";
565 let location = "s3://bucket/path/file.parquet";
566
567 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
569 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
570 create_external_table_test(location, &sql).await?;
571
572 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
574 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
575 create_external_table_test(location, &sql).await?;
576
577 Ok(())
578 }
579
580 #[tokio::test]
581 async fn create_object_store_table_oss() -> Result<()> {
582 let access_key_id = "fake_access_key_id";
583 let secret_access_key = "fake_secret_access_key";
584 let endpoint = "fake_endpoint";
585 let location = "oss://bucket/path/file.parquet";
586
587 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
589 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
590 create_external_table_test(location, &sql).await?;
591
592 Ok(())
593 }
594
595 #[tokio::test]
596 async fn create_object_store_table_cos() -> Result<()> {
597 let access_key_id = "fake_access_key_id";
598 let secret_access_key = "fake_secret_access_key";
599 let endpoint = "fake_endpoint";
600 let location = "cos://bucket/path/file.parquet";
601
602 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
604 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
605 create_external_table_test(location, &sql).await?;
606
607 Ok(())
608 }
609
610 #[tokio::test]
611 async fn create_object_store_table_gcs() -> Result<()> {
612 let service_account_path = "fake_service_account_path";
613 let service_account_key =
614 "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
615 let application_credentials_path = "fake_application_credentials_path";
616 let location = "gcs://bucket/path/file.parquet";
617
618 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
620 OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
621 let err = create_external_table_test(location, &sql)
622 .await
623 .unwrap_err();
624 assert!(err.to_string().contains("os error 2"));
625
626 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'");
628 let err = create_external_table_test(location, &sql)
629 .await
630 .unwrap_err()
631 .to_string();
632 assert!(err.contains("No RSA key found in pem file"), "{err}");
633
634 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
636 OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
637 let err = create_external_table_test(location, &sql)
638 .await
639 .unwrap_err();
640 assert!(err.to_string().contains("os error 2"));
641
642 Ok(())
643 }
644
645 #[tokio::test]
646 async fn create_external_table_local_file() -> Result<()> {
647 let location = "path/to/file.parquet";
648
649 let sql =
651 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
652 create_external_table_test(location, &sql).await.unwrap();
653
654 Ok(())
655 }
656
657 #[tokio::test]
658 async fn create_external_table_format_option() -> Result<()> {
659 let location = "path/to/file.cvs";
660
661 let sql =
663 format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')");
664 create_external_table_test(location, &sql).await.unwrap();
665
666 Ok(())
667 }
668}