1use std::collections::HashMap;
20use std::sync::Arc;
21use std::vec;
22
23use arrow::datatypes::*;
24use datafusion_common::config::SqlParserOptions;
25use datafusion_common::error::add_possible_columns_to_diag;
26use datafusion_common::{
27 field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic,
28 SchemaError,
29};
30use sqlparser::ast::TimezoneInfo;
31use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
32use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
33use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias};
34
35use datafusion_common::TableReference;
36use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result};
37use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder};
38use datafusion_expr::utils::find_column_exprs;
39use datafusion_expr::{col, Expr};
40
41use crate::utils::make_decimal_type;
42pub use datafusion_expr::planner::ContextProvider;
43
44#[derive(Debug, Clone, Copy)]
46pub struct ParserOptions {
47 pub parse_float_as_decimal: bool,
49 pub enable_ident_normalization: bool,
51 pub support_varchar_with_length: bool,
53 pub enable_options_value_normalization: bool,
55 pub collect_spans: bool,
57}
58
59impl ParserOptions {
60 pub fn new() -> Self {
71 Self {
72 parse_float_as_decimal: false,
73 enable_ident_normalization: true,
74 support_varchar_with_length: true,
75 enable_options_value_normalization: false,
76 collect_spans: false,
77 }
78 }
79
80 pub fn with_parse_float_as_decimal(mut self, value: bool) -> Self {
90 self.parse_float_as_decimal = value;
91 self
92 }
93
94 pub fn with_enable_ident_normalization(mut self, value: bool) -> Self {
104 self.enable_ident_normalization = value;
105 self
106 }
107
108 pub fn with_support_varchar_with_length(mut self, value: bool) -> Self {
110 self.support_varchar_with_length = value;
111 self
112 }
113
114 pub fn with_enable_options_value_normalization(mut self, value: bool) -> Self {
116 self.enable_options_value_normalization = value;
117 self
118 }
119
120 pub fn with_collect_spans(mut self, value: bool) -> Self {
122 self.collect_spans = value;
123 self
124 }
125}
126
127impl Default for ParserOptions {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl From<&SqlParserOptions> for ParserOptions {
134 fn from(options: &SqlParserOptions) -> Self {
135 Self {
136 parse_float_as_decimal: options.parse_float_as_decimal,
137 enable_ident_normalization: options.enable_ident_normalization,
138 support_varchar_with_length: options.support_varchar_with_length,
139 enable_options_value_normalization: options
140 .enable_options_value_normalization,
141 collect_spans: options.collect_spans,
142 }
143 }
144}
145
146#[derive(Debug)]
148pub struct IdentNormalizer {
149 normalize: bool,
150}
151
152impl Default for IdentNormalizer {
153 fn default() -> Self {
154 Self { normalize: true }
155 }
156}
157
158impl IdentNormalizer {
159 pub fn new(normalize: bool) -> Self {
160 Self { normalize }
161 }
162
163 pub fn normalize(&self, ident: Ident) -> String {
164 if self.normalize {
165 crate::utils::normalize_ident(ident)
166 } else {
167 ident.value
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
185pub struct PlannerContext {
186 prepare_param_data_types: Arc<Vec<DataType>>,
189 ctes: HashMap<String, Arc<LogicalPlan>>,
192 outer_query_schema: Option<DFSchemaRef>,
194 outer_from_schema: Option<DFSchemaRef>,
197 create_table_schema: Option<DFSchemaRef>,
199}
200
201impl Default for PlannerContext {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl PlannerContext {
208 pub fn new() -> Self {
210 Self {
211 prepare_param_data_types: Arc::new(vec![]),
212 ctes: HashMap::new(),
213 outer_query_schema: None,
214 outer_from_schema: None,
215 create_table_schema: None,
216 }
217 }
218
219 pub fn with_prepare_param_data_types(
221 mut self,
222 prepare_param_data_types: Vec<DataType>,
223 ) -> Self {
224 self.prepare_param_data_types = prepare_param_data_types.into();
225 self
226 }
227
228 pub fn outer_query_schema(&self) -> Option<&DFSchema> {
230 self.outer_query_schema.as_ref().map(|s| s.as_ref())
231 }
232
233 pub fn set_outer_query_schema(
236 &mut self,
237 mut schema: Option<DFSchemaRef>,
238 ) -> Option<DFSchemaRef> {
239 std::mem::swap(&mut self.outer_query_schema, &mut schema);
240 schema
241 }
242
243 pub fn set_table_schema(
244 &mut self,
245 mut schema: Option<DFSchemaRef>,
246 ) -> Option<DFSchemaRef> {
247 std::mem::swap(&mut self.create_table_schema, &mut schema);
248 schema
249 }
250
251 pub fn table_schema(&self) -> Option<DFSchemaRef> {
252 self.create_table_schema.clone()
253 }
254
255 pub fn outer_from_schema(&self) -> Option<Arc<DFSchema>> {
257 self.outer_from_schema.clone()
258 }
259
260 pub fn set_outer_from_schema(
262 &mut self,
263 mut schema: Option<DFSchemaRef>,
264 ) -> Option<DFSchemaRef> {
265 std::mem::swap(&mut self.outer_from_schema, &mut schema);
266 schema
267 }
268
269 pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
271 match self.outer_from_schema.as_mut() {
272 Some(from_schema) => Arc::make_mut(from_schema).merge(schema),
273 None => self.outer_from_schema = Some(Arc::clone(schema)),
274 };
275 Ok(())
276 }
277
278 pub fn prepare_param_data_types(&self) -> &[DataType] {
280 &self.prepare_param_data_types
281 }
282
283 pub fn contains_cte(&self, cte_name: &str) -> bool {
286 self.ctes.contains_key(cte_name)
287 }
288
289 pub fn insert_cte(&mut self, cte_name: impl Into<String>, plan: LogicalPlan) {
292 let cte_name = cte_name.into();
293 self.ctes.insert(cte_name, Arc::new(plan));
294 }
295
296 pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> {
299 self.ctes.get(cte_name).map(|cte| cte.as_ref())
300 }
301
302 pub(super) fn remove_cte(&mut self, cte_name: &str) {
304 self.ctes.remove(cte_name);
305 }
306}
307
308pub struct SqlToRel<'a, S: ContextProvider> {
327 pub(crate) context_provider: &'a S,
328 pub(crate) options: ParserOptions,
329 pub(crate) ident_normalizer: IdentNormalizer,
330}
331
332impl<'a, S: ContextProvider> SqlToRel<'a, S> {
333 pub fn new(context_provider: &'a S) -> Self {
337 let parser_options = ParserOptions::from(&context_provider.options().sql_parser);
338 Self::new_with_options(context_provider, parser_options)
339 }
340
341 pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
346 let ident_normalize = options.enable_ident_normalization;
347
348 SqlToRel {
349 context_provider,
350 options,
351 ident_normalizer: IdentNormalizer::new(ident_normalize),
352 }
353 }
354
355 pub fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
356 let mut fields = Vec::with_capacity(columns.len());
357
358 for column in columns {
359 let data_type = self.convert_data_type(&column.data_type)?;
360 let not_nullable = column
361 .options
362 .iter()
363 .any(|x| x.option == ColumnOption::NotNull);
364 fields.push(Field::new(
365 self.ident_normalizer.normalize(column.name),
366 data_type,
367 !not_nullable,
368 ));
369 }
370
371 Ok(Schema::new(fields))
372 }
373
374 pub(super) fn build_column_defaults(
376 &self,
377 columns: &Vec<SQLColumnDef>,
378 planner_context: &mut PlannerContext,
379 ) -> Result<Vec<(String, Expr)>> {
380 let mut column_defaults = vec![];
381 let empty_schema = DFSchema::empty();
383 let error_desc = |e: DataFusionError| match e {
384 DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => {
385 plan_datafusion_err!(
386 "Column reference is not allowed in the DEFAULT expression : {}",
387 e
388 )
389 }
390 _ => e,
391 };
392
393 for column in columns {
394 if let Some(default_sql_expr) =
395 column.options.iter().find_map(|o| match &o.option {
396 ColumnOption::Default(expr) => Some(expr),
397 _ => None,
398 })
399 {
400 let default_expr = self
401 .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context)
402 .map_err(error_desc)?;
403 column_defaults.push((
404 self.ident_normalizer.normalize(column.name.clone()),
405 default_expr,
406 ));
407 }
408 }
409 Ok(column_defaults)
410 }
411
412 pub(crate) fn apply_table_alias(
414 &self,
415 plan: LogicalPlan,
416 alias: TableAlias,
417 ) -> Result<LogicalPlan> {
418 let idents = alias.columns.into_iter().map(|c| c.name).collect();
419 let plan = self.apply_expr_alias(plan, idents)?;
420
421 LogicalPlanBuilder::from(plan)
422 .alias(TableReference::bare(
423 self.ident_normalizer.normalize(alias.name),
424 ))?
425 .build()
426 }
427
428 pub(crate) fn apply_expr_alias(
429 &self,
430 plan: LogicalPlan,
431 idents: Vec<Ident>,
432 ) -> Result<LogicalPlan> {
433 if idents.is_empty() {
434 Ok(plan)
435 } else if idents.len() != plan.schema().fields().len() {
436 plan_err!(
437 "Source table contains {} columns but only {} names given as column alias",
438 plan.schema().fields().len(),
439 idents.len()
440 )
441 } else {
442 let fields = plan.schema().fields().clone();
443 LogicalPlanBuilder::from(plan)
444 .project(fields.iter().zip(idents.into_iter()).map(|(field, ident)| {
445 col(field.name()).alias(self.ident_normalizer.normalize(ident))
446 }))?
447 .build()
448 }
449 }
450
451 pub(crate) fn validate_schema_satisfies_exprs(
453 &self,
454 schema: &DFSchema,
455 exprs: &[Expr],
456 ) -> Result<()> {
457 find_column_exprs(exprs)
458 .iter()
459 .try_for_each(|col| match col {
460 Expr::Column(col) => match &col.relation {
461 Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()),
462 None => {
463 if !schema.fields_with_unqualified_name(&col.name).is_empty() {
464 Ok(())
465 } else {
466 Err(field_not_found(
467 col.relation.clone(),
468 col.name.as_str(),
469 schema,
470 ))
471 }
472 }
473 }
474 .map_err(|err: DataFusionError| match &err {
475 DataFusionError::SchemaError(
476 SchemaError::FieldNotFound {
477 field,
478 valid_fields,
479 },
480 _,
481 ) => {
482 let mut diagnostic = if let Some(relation) = &col.relation {
483 Diagnostic::new_error(
484 format!(
485 "column '{}' not found in '{}'",
486 &col.name, relation
487 ),
488 col.spans().first(),
489 )
490 } else {
491 Diagnostic::new_error(
492 format!("column '{}' not found", &col.name),
493 col.spans().first(),
494 )
495 };
496 add_possible_columns_to_diag(
497 &mut diagnostic,
498 field,
499 valid_fields,
500 );
501 err.with_diagnostic(diagnostic)
502 }
503 _ => err,
504 }),
505 _ => internal_err!("Not a column"),
506 })
507 }
508
509 pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
510 if let Some(type_planner) = self.context_provider.get_type_planner() {
512 if let Some(data_type) = type_planner.plan_type(sql_type)? {
513 return Ok(data_type);
514 }
515 }
516
517 match sql_type {
519 SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => {
520 let inner_data_type = self.convert_data_type(inner_sql_type)?;
522 Ok(DataType::new_list(inner_data_type, true))
523 }
524 SQLDataType::Array(ArrayElemTypeDef::SquareBracket(
525 inner_sql_type,
526 maybe_array_size,
527 )) => {
528 let inner_data_type = self.convert_data_type(inner_sql_type)?;
529 if let Some(array_size) = maybe_array_size {
530 Ok(DataType::new_fixed_size_list(
531 inner_data_type,
532 *array_size as i32,
533 true,
534 ))
535 } else {
536 Ok(DataType::new_list(inner_data_type, true))
537 }
538 }
539 SQLDataType::Array(ArrayElemTypeDef::None) => {
540 not_impl_err!("Arrays with unspecified type is not supported")
541 }
542 other => self.convert_simple_data_type(other),
543 }
544 }
545
546 fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
547 match sql_type {
548 SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean),
549 SQLDataType::TinyInt(_) => Ok(DataType::Int8),
550 SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16),
551 SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32),
552 SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64),
553 SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8),
554 SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16),
555 SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => {
556 Ok(DataType::UInt32)
557 }
558 SQLDataType::Varchar(length) => {
559 match (length, self.options.support_varchar_with_length) {
560 (Some(_), false) => plan_err!("does not support Varchar with length, please set `support_varchar_with_length` to be true"),
561 _ => Ok(DataType::Utf8),
562 }
563 }
564 SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64),
565 SQLDataType::Float(_) => Ok(DataType::Float32),
566 SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32),
567 SQLDataType::Double(ExactNumberInfo::None) | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64),
568 SQLDataType::Double(ExactNumberInfo::Precision(_)|ExactNumberInfo::PrecisionAndScale(_, _)) => {
569 not_impl_err!("Unsupported SQL type (precision/scale not supported) {sql_type}")
570 }
571 SQLDataType::Char(_)
572 | SQLDataType::Text
573 | SQLDataType::String(_) => Ok(DataType::Utf8),
574 SQLDataType::Timestamp(precision, tz_info)
575 if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => {
576 let tz = if matches!(tz_info, TimezoneInfo::Tz)
577 || matches!(tz_info, TimezoneInfo::WithTimeZone)
578 {
579 self.context_provider.options().execution.time_zone.clone()
583 } else {
584 None
586 };
587 let precision = match precision {
588 Some(0) => TimeUnit::Second,
589 Some(3) => TimeUnit::Millisecond,
590 Some(6) => TimeUnit::Microsecond,
591 None | Some(9) => TimeUnit::Nanosecond,
592 _ => unreachable!(),
593 };
594 Ok(DataType::Timestamp(precision, tz.map(Into::into)))
595 }
596 SQLDataType::Date => Ok(DataType::Date32),
597 SQLDataType::Time(None, tz_info) => {
598 if matches!(tz_info, TimezoneInfo::None)
599 || matches!(tz_info, TimezoneInfo::WithoutTimeZone)
600 {
601 Ok(DataType::Time64(TimeUnit::Nanosecond))
602 } else {
603 not_impl_err!(
605 "Unsupported SQL type {sql_type:?}"
606 )
607 }
608 }
609 SQLDataType::Numeric(exact_number_info)
610 | SQLDataType::Decimal(exact_number_info) => {
611 let (precision, scale) = match *exact_number_info {
612 ExactNumberInfo::None => (None, None),
613 ExactNumberInfo::Precision(precision) => (Some(precision), None),
614 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
615 (Some(precision), Some(scale))
616 }
617 };
618 make_decimal_type(precision, scale)
619 }
620 SQLDataType::Bytea => Ok(DataType::Binary),
621 SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
622 SQLDataType::Struct(fields, _) => {
623 let fields = fields
624 .iter()
625 .enumerate()
626 .map(|(idx, field)| {
627 let data_type = self.convert_data_type(&field.field_type)?;
628 let field_name = match &field.field_name{
629 Some(ident) => ident.clone(),
630 None => Ident::new(format!("c{idx}"))
631 };
632 Ok(Arc::new(Field::new(
633 self.ident_normalizer.normalize(field_name),
634 data_type,
635 true,
636 )))
637 })
638 .collect::<Result<Vec<_>>>()?;
639 Ok(DataType::Struct(Fields::from(fields)))
640 }
641 SQLDataType::Nvarchar(_)
645 | SQLDataType::JSON
646 | SQLDataType::Uuid
647 | SQLDataType::Binary(_)
648 | SQLDataType::Varbinary(_)
649 | SQLDataType::Blob(_)
650 | SQLDataType::Datetime(_)
651 | SQLDataType::Regclass
652 | SQLDataType::Custom(_, _)
653 | SQLDataType::Array(_)
654 | SQLDataType::Enum(_, _)
655 | SQLDataType::Set(_)
656 | SQLDataType::MediumInt(_)
657 | SQLDataType::UnsignedMediumInt(_)
658 | SQLDataType::Character(_)
659 | SQLDataType::CharacterVarying(_)
660 | SQLDataType::CharVarying(_)
661 | SQLDataType::CharacterLargeObject(_)
662 | SQLDataType::CharLargeObject(_)
663 | SQLDataType::Timestamp(_, _)
665 | SQLDataType::Time(Some(_), _)
667 | SQLDataType::Dec(_)
668 | SQLDataType::BigNumeric(_)
669 | SQLDataType::BigDecimal(_)
670 | SQLDataType::Clob(_)
671 | SQLDataType::Bytes(_)
672 | SQLDataType::Int64
673 | SQLDataType::Float64
674 | SQLDataType::JSONB
675 | SQLDataType::Unspecified
676 | SQLDataType::Int16
678 | SQLDataType::Int32
679 | SQLDataType::Int128
680 | SQLDataType::Int256
681 | SQLDataType::UInt8
682 | SQLDataType::UInt16
683 | SQLDataType::UInt32
684 | SQLDataType::UInt64
685 | SQLDataType::UInt128
686 | SQLDataType::UInt256
687 | SQLDataType::Float32
688 | SQLDataType::Date32
689 | SQLDataType::Datetime64(_, _)
690 | SQLDataType::FixedString(_)
691 | SQLDataType::Map(_, _)
692 | SQLDataType::Tuple(_)
693 | SQLDataType::Nested(_)
694 | SQLDataType::Union(_)
695 | SQLDataType::Nullable(_)
696 | SQLDataType::LowCardinality(_)
697 | SQLDataType::Trigger
698 | SQLDataType::TinyBlob
700 | SQLDataType::MediumBlob
701 | SQLDataType::LongBlob
702 | SQLDataType::TinyText
703 | SQLDataType::MediumText
704 | SQLDataType::LongText
705 | SQLDataType::Bit(_)
706 | SQLDataType::BitVarying(_)
707 | SQLDataType::AnyType
709 => not_impl_err!(
710 "Unsupported SQL type {sql_type:?}"
711 ),
712 }
713 }
714
715 pub(crate) fn object_name_to_table_reference(
716 &self,
717 object_name: ObjectName,
718 ) -> Result<TableReference> {
719 object_name_to_table_reference(
720 object_name,
721 self.options.enable_ident_normalization,
722 )
723 }
724}
725
726pub fn object_name_to_table_reference(
737 object_name: ObjectName,
738 enable_normalization: bool,
739) -> Result<TableReference> {
740 let ObjectName(idents) = object_name;
742 idents_to_table_reference(idents, enable_normalization)
743}
744
745struct IdentTaker {
746 normalizer: IdentNormalizer,
747 idents: Vec<Ident>,
748}
749
750impl IdentTaker {
753 fn new(idents: Vec<Ident>, enable_normalization: bool) -> Self {
754 Self {
755 normalizer: IdentNormalizer::new(enable_normalization),
756 idents,
757 }
758 }
759
760 fn take(&mut self) -> String {
761 let ident = self.idents.pop().expect("no more identifiers");
762 self.normalizer.normalize(ident)
763 }
764
765 fn len(&self) -> usize {
767 self.idents.len()
768 }
769}
770
771impl std::fmt::Display for IdentTaker {
773 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
774 let mut first = true;
775 for ident in self.idents.iter() {
776 if !first {
777 write!(f, ".")?;
778 }
779 write!(f, "{}", ident)?;
780 first = false;
781 }
782
783 Ok(())
784 }
785}
786
787pub(crate) fn idents_to_table_reference(
789 idents: Vec<Ident>,
790 enable_normalization: bool,
791) -> Result<TableReference> {
792 let mut taker = IdentTaker::new(idents, enable_normalization);
793
794 match taker.len() {
795 1 => {
796 let table = taker.take();
797 Ok(TableReference::bare(table))
798 }
799 2 => {
800 let table = taker.take();
801 let schema = taker.take();
802 Ok(TableReference::partial(schema, table))
803 }
804 3 => {
805 let table = taker.take();
806 let schema = taker.take();
807 let catalog = taker.take();
808 Ok(TableReference::full(catalog, schema, table))
809 }
810 _ => plan_err!(
811 "Unsupported compound identifier '{}'. Expected 1, 2 or 3 parts, got {}",
812 taker,
813 taker.len()
814 ),
815 }
816}
817
818pub fn object_name_to_qualifier(
821 sql_table_name: &ObjectName,
822 enable_normalization: bool,
823) -> String {
824 let columns = vec!["table_name", "table_schema", "table_catalog"].into_iter();
825 let normalizer = IdentNormalizer::new(enable_normalization);
826 sql_table_name
827 .0
828 .iter()
829 .rev()
830 .zip(columns)
831 .map(|(ident, column_name)| {
832 format!(
833 r#"{} = '{}'"#,
834 column_name,
835 normalizer.normalize(ident.clone())
836 )
837 })
838 .collect::<Vec<_>>()
839 .join(" AND ")
840}