1use std::collections::HashMap;
21use std::fs::File;
22use std::io::prelude::*;
23use std::io::BufReader;
24
25use crate::cli_context::CliSessionContext;
26use crate::helper::split_from_semicolon;
27use crate::print_format::PrintFormat;
28use crate::{
29 command::{Command, OutputFormat},
30 helper::{unescape_input, CliHelper},
31 iceberg::transform_iceberg_input,
32 object_storage::get_object_store,
33 print_options::{MaxRows, PrintOptions},
34};
35
36use datafusion::common::instant::Instant;
37use datafusion::common::plan_datafusion_err;
38use datafusion::config::ConfigFileType;
39use datafusion::datasource::listing::ListingTableUrl;
40use datafusion::error::{DataFusionError, Result};
41use datafusion::logical_expr::{DdlStatement, LogicalPlan};
42use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties};
43use datafusion::sql::parser::{DFParser, Statement};
44use datafusion::sql::sqlparser::dialect::dialect_from_str;
45
46use datafusion::sql::sqlparser;
47use rustyline::error::ReadlineError;
48use rustyline::Editor;
49use tokio::signal;
50
51pub async fn exec_from_commands(
53 ctx: &dyn CliSessionContext,
54 commands: Vec<String>,
55 print_options: &PrintOptions,
56) -> Result<()> {
57 for sql in commands {
58 exec_and_print(ctx, print_options, sql).await?;
59 }
60
61 Ok(())
62}
63
64pub async fn exec_from_lines(
66 ctx: &dyn CliSessionContext,
67 reader: &mut BufReader<File>,
68 print_options: &PrintOptions,
69) -> Result<()> {
70 let mut query = "".to_owned();
71
72 for line in reader.lines() {
73 match line {
74 Ok(line) if line.starts_with("#!") => {
75 continue;
76 }
77 Ok(line) if line.starts_with("--") => {
78 continue;
79 }
80 Ok(line) => {
81 let line = line.trim_end();
82 query.push_str(line);
83 if line.ends_with(';') {
84 match exec_and_print(ctx, print_options, query).await {
85 Ok(_) => {}
86 Err(err) => eprintln!("{err}"),
87 }
88 query = "".to_string();
89 } else {
90 query.push('\n');
91 }
92 }
93 _ => {
94 break;
95 }
96 }
97 }
98
99 if query.contains(|c| c != '\n') {
102 exec_and_print(ctx, print_options, query).await?;
103 }
104
105 Ok(())
106}
107
108pub async fn exec_from_files(
109 ctx: &dyn CliSessionContext,
110 files: Vec<String>,
111 print_options: &PrintOptions,
112) -> Result<()> {
113 let files = files
114 .into_iter()
115 .map(|file_path| File::open(file_path).unwrap())
116 .collect::<Vec<_>>();
117
118 for file in files {
119 let mut reader = BufReader::new(file);
120 exec_from_lines(ctx, &mut reader, print_options).await?;
121 }
122
123 Ok(())
124}
125
126pub async fn exec_from_repl(
128 ctx: &dyn CliSessionContext,
129 print_options: &mut PrintOptions,
130) -> rustyline::Result<()> {
131 let mut rl = Editor::new()?;
132 rl.set_helper(Some(CliHelper::new(
133 &ctx.task_ctx().session_config().options().sql_parser.dialect,
134 print_options.color,
135 )));
136 rl.load_history(".history").ok();
137
138 loop {
139 match rl.readline("> ") {
140 Ok(line) if line.starts_with('\\') => {
141 rl.add_history_entry(line.trim_end())?;
142 let command = line.split_whitespace().collect::<Vec<_>>().join(" ");
143 if let Ok(cmd) = &command[1..].parse::<Command>() {
144 match cmd {
145 Command::Quit => break,
146 Command::OutputFormat(subcommand) => {
147 if let Some(subcommand) = subcommand {
148 if let Ok(command) = subcommand.parse::<OutputFormat>() {
149 if let Err(e) = command.execute(print_options).await {
150 eprintln!("{e}")
151 }
152 } else {
153 eprintln!(
154 "'\\{}' is not a valid command",
155 &line[1..]
156 );
157 }
158 } else {
159 println!("Output format is {:?}.", print_options.format);
160 }
161 }
162 _ => {
163 if let Err(e) = cmd.execute(ctx, print_options).await {
164 eprintln!("{e}")
165 }
166 }
167 }
168 } else {
169 eprintln!("'\\{}' is not a valid command", &line[1..]);
170 }
171 }
172 Ok(line) => {
173 let lines = split_from_semicolon(line);
174 for line in lines {
175 rl.add_history_entry(line.trim_end())?;
176 tokio::select! {
177 res = exec_and_print(ctx, print_options, line) => match res {
178 Ok(_) => {}
179 Err(err) => eprintln!("{err}"),
180 },
181 _ = signal::ctrl_c() => {
182 println!("^C");
183 continue
184 },
185 }
186 rl.helper_mut().unwrap().set_dialect(
188 &ctx.task_ctx().session_config().options().sql_parser.dialect,
189 );
190 }
191 }
192 Err(ReadlineError::Interrupted) => {
193 println!("^C");
194 continue;
195 }
196 Err(ReadlineError::Eof) => {
197 println!("\\q");
198 break;
199 }
200 Err(err) => {
201 eprintln!("Unknown error happened {:?}", err);
202 break;
203 }
204 }
205 }
206
207 rl.save_history(".history")
208}
209
210pub(super) async fn exec_and_print(
211 ctx: &dyn CliSessionContext,
212 print_options: &PrintOptions,
213 sql: String,
214) -> Result<()> {
215 let now = Instant::now();
216 let sql = unescape_input(&sql)?;
217 let task_ctx = ctx.task_ctx();
218 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
219 let dialect = dialect_from_str(dialect).ok_or_else(|| {
220 plan_datafusion_err!(
221 "Unsupported SQL dialect: {dialect}. Available dialects: \
222 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
223 MsSQL, ClickHouse, BigQuery, Ansi."
224 )
225 })?;
226
227 let statements = DFParser::parse_sql_with_dialect(
228 &transform_iceberg_input(&sql),
229 dialect.as_ref(),
230 )?;
231 for statement in statements {
232 let adjusted =
233 AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement);
234
235 let plan = create_plan(ctx, statement).await?;
236 let adjusted = adjusted.with_plan(&plan);
237
238 let df = ctx.execute_logical_plan(plan).await?;
239 let physical_plan = df.create_physical_plan().await?;
240
241 if physical_plan.execution_mode().is_unbounded() {
242 let stream = execute_stream(physical_plan, task_ctx.clone())?;
243 print_options.print_stream(stream, now).await?;
244 } else {
245 let schema = physical_plan.schema();
246 let results = collect(physical_plan, task_ctx.clone()).await?;
247 adjusted.into_inner().print_batches(schema, &results, now)?;
248 }
249 }
250
251 Ok(())
252}
253
254#[derive(Debug)]
256struct AdjustedPrintOptions {
257 inner: PrintOptions,
258}
259
260impl AdjustedPrintOptions {
261 fn new(inner: PrintOptions) -> Self {
262 Self { inner }
263 }
264 fn with_statement(mut self, statement: &Statement) -> Self {
266 if let Statement::Statement(sql_stmt) = statement {
267 if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() {
269 self.inner.maxrows = MaxRows::Unlimited
270 }
271 }
272 self
273 }
274
275 fn with_plan(mut self, plan: &LogicalPlan) -> Self {
277 if matches!(
280 plan,
281 LogicalPlan::Explain(_)
282 | LogicalPlan::DescribeTable(_)
283 | LogicalPlan::Analyze(_)
284 ) {
285 self.inner.maxrows = MaxRows::Unlimited;
286 }
287 self
288 }
289
290 fn into_inner(mut self) -> PrintOptions {
292 if self.inner.format == PrintFormat::Automatic {
293 self.inner.format = PrintFormat::Table;
294 }
295
296 self.inner
297 }
298}
299
300fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
301 match ext.to_lowercase().as_str() {
302 "csv" => Some(ConfigFileType::CSV),
303 "json" => Some(ConfigFileType::JSON),
304 "parquet" => Some(ConfigFileType::PARQUET),
305 _ => None,
306 }
307}
308
309async fn create_plan(
310 ctx: &dyn CliSessionContext,
311 statement: Statement,
312) -> Result<LogicalPlan, DataFusionError> {
313 let mut plan = ctx.session_state().statement_to_plan(statement).await?;
314
315 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
319 if !(cmd.file_type.to_lowercase() == "iceberg") {
320 let format = config_file_type_from_str(&cmd.file_type);
322 register_object_store_and_config_extensions(
323 ctx,
324 &cmd.location,
325 &cmd.options,
326 format,
327 )
328 .await?;
329 }
330 }
331
332 if let LogicalPlan::Copy(copy_to) = &mut plan {
333 let format = config_file_type_from_str(©_to.file_type.get_ext());
334
335 register_object_store_and_config_extensions(
336 ctx,
337 ©_to.output_url,
338 ©_to.options,
339 format,
340 )
341 .await?;
342 }
343 Ok(plan)
344}
345
346pub(crate) async fn register_object_store_and_config_extensions(
374 ctx: &dyn CliSessionContext,
375 location: &String,
376 options: &HashMap<String, String>,
377 format: Option<ConfigFileType>,
378) -> Result<()> {
379 let table_path = ListingTableUrl::parse(location)?;
381
382 let scheme = table_path.scheme();
384
385 let url = table_path.as_ref();
387
388 ctx.register_table_options_extension_from_scheme(scheme);
390
391 let mut table_options = ctx.session_state().default_table_options();
393 if let Some(format) = format {
394 table_options.set_config_format(format);
395 }
396 table_options.alter_with_string_hash_map(options)?;
397
398 let store =
400 get_object_store(&ctx.session_state(), scheme, url, &table_options).await?;
401
402 ctx.register_object_store(url, store);
404
405 Ok(())
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 use datafusion::common::plan_err;
413
414 use datafusion::prelude::SessionContext;
415 use url::Url;
416
417 async fn create_external_table_test(location: &str, sql: &str) -> Result<()> {
418 let ctx = SessionContext::new();
419 let plan = ctx.state().create_logical_plan(sql).await?;
420
421 if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan {
422 let format = config_file_type_from_str(&cmd.file_type);
423 register_object_store_and_config_extensions(
424 &ctx,
425 &cmd.location,
426 &cmd.options,
427 format,
428 )
429 .await?;
430 } else {
431 return plan_err!("LogicalPlan is not a CreateExternalTable");
432 }
433
434 ctx.runtime_env()
436 .object_store(ListingTableUrl::parse(location)?)?;
437
438 Ok(())
439 }
440
441 async fn copy_to_table_test(location: &str, sql: &str) -> Result<()> {
442 let ctx = SessionContext::new();
443 let plan = ctx.state().create_logical_plan(sql).await?;
446
447 if let LogicalPlan::Copy(cmd) = &plan {
448 let format = config_file_type_from_str(&cmd.file_type.get_ext());
449 register_object_store_and_config_extensions(
450 &ctx,
451 &cmd.output_url,
452 &cmd.options,
453 format,
454 )
455 .await?;
456 } else {
457 return plan_err!("LogicalPlan is not a CreateExternalTable");
458 }
459
460 ctx.runtime_env()
462 .object_store(ListingTableUrl::parse(location)?)?;
463
464 Ok(())
465 }
466
467 #[tokio::test]
468 async fn create_object_store_table_http() -> Result<()> {
469 let location = "http://example.com/file.parquet";
471 let sql =
472 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
473 create_external_table_test(location, &sql).await?;
474
475 Ok(())
476 }
477 #[tokio::test]
478 async fn copy_to_external_object_store_test() -> Result<()> {
479 let locations = vec![
480 "s3://bucket/path/file.parquet",
481 "oss://bucket/path/file.parquet",
482 "cos://bucket/path/file.parquet",
483 "gcs://bucket/path/file.parquet",
484 ];
485 let ctx = SessionContext::new();
486 let task_ctx = ctx.task_ctx();
487 let dialect = &task_ctx.session_config().options().sql_parser.dialect;
488 let dialect = dialect_from_str(dialect).ok_or_else(|| {
489 plan_datafusion_err!(
490 "Unsupported SQL dialect: {dialect}. Available dialects: \
491 Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
492 MsSQL, ClickHouse, BigQuery, Ansi."
493 )
494 })?;
495 for location in locations {
496 let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location);
497 let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
498 for statement in statements {
499 let mut plan = create_plan(&ctx, statement).await?;
501 if let LogicalPlan::Copy(copy_to) = &mut plan {
502 assert_eq!(copy_to.output_url, location);
503 assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
504 ctx.runtime_env()
505 .object_store_registry
506 .get_store(&Url::parse(©_to.output_url).unwrap())?;
507 } else {
508 return plan_err!("LogicalPlan is not a CopyTo");
509 }
510 }
511 }
512 Ok(())
513 }
514
515 #[tokio::test]
516 async fn copy_to_object_store_table_s3() -> Result<()> {
517 let access_key_id = "fake_access_key_id";
518 let secret_access_key = "fake_secret_access_key";
519 let location = "s3://bucket/path/file.parquet";
520
521 let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET
523 OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')");
524 copy_to_table_test(location, &sql).await?;
525
526 Ok(())
527 }
528
529 #[tokio::test]
530 async fn create_object_store_table_s3() -> Result<()> {
531 let access_key_id = "fake_access_key_id";
532 let secret_access_key = "fake_secret_access_key";
533 let region = "fake_us-east-2";
534 let session_token = "fake_session_token";
535 let location = "s3://bucket/path/file.parquet";
536
537 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
539 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}') LOCATION '{location}'");
540 create_external_table_test(location, &sql).await?;
541
542 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
544 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.region' '{region}', 'aws.session_token' '{session_token}') LOCATION '{location}'");
545 create_external_table_test(location, &sql).await?;
546
547 Ok(())
548 }
549
550 #[tokio::test]
551 async fn create_object_store_table_oss() -> Result<()> {
552 let access_key_id = "fake_access_key_id";
553 let secret_access_key = "fake_secret_access_key";
554 let endpoint = "fake_endpoint";
555 let location = "oss://bucket/path/file.parquet";
556
557 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
559 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'");
560 create_external_table_test(location, &sql).await?;
561
562 Ok(())
563 }
564
565 #[tokio::test]
566 async fn create_object_store_table_cos() -> Result<()> {
567 let access_key_id = "fake_access_key_id";
568 let secret_access_key = "fake_secret_access_key";
569 let endpoint = "fake_endpoint";
570 let location = "cos://bucket/path/file.parquet";
571
572 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
574 OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'");
575 create_external_table_test(location, &sql).await?;
576
577 Ok(())
578 }
579
580 #[tokio::test]
581 async fn create_object_store_table_gcs() -> Result<()> {
582 let service_account_path = "fake_service_account_path";
583 let service_account_key =
584 "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}";
585 let application_credentials_path = "fake_application_credentials_path";
586 let location = "gcs://bucket/path/file.parquet";
587
588 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
590 OPTIONS('gcp.service_account_path' '{service_account_path}') LOCATION '{location}'");
591 let err = create_external_table_test(location, &sql)
592 .await
593 .unwrap_err();
594 assert!(err.to_string().contains("os error 2"));
595
596 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'");
598 let err = create_external_table_test(location, &sql)
599 .await
600 .unwrap_err()
601 .to_string();
602 assert!(err.contains("No RSA key found in pem file"), "{err}");
603
604 let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET
606 OPTIONS('gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'");
607 let err = create_external_table_test(location, &sql)
608 .await
609 .unwrap_err();
610 assert!(err.to_string().contains("os error 2"));
611
612 Ok(())
613 }
614
615 #[tokio::test]
616 async fn create_external_table_local_file() -> Result<()> {
617 let location = "path/to/file.parquet";
618
619 let sql =
621 format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'");
622 create_external_table_test(location, &sql).await.unwrap();
623
624 Ok(())
625 }
626
627 #[tokio::test]
628 async fn create_external_table_format_option() -> Result<()> {
629 let location = "path/to/file.cvs";
630
631 let sql =
633 format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')");
634 create_external_table_test(location, &sql).await.unwrap();
635
636 Ok(())
637 }
638}