1use std::sync::Arc;
34
35use datafusion::arrow::array::StringArray;
36use datafusion::arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
37use datafusion::arrow::record_batch::RecordBatch;
38use datafusion::error::{DataFusionError, Result as DFResult};
39use datafusion::prelude::{DataFrame, SessionContext};
40use datafusion::sql::sqlparser::ast::{
41 AlterTableOperation, ColumnDef, CreateTable, CreateTableOptions, HiveDistributionStyle, Merge,
42 ObjectName, RenameTableNameKind, SqlOption, Statement, TableFactor, Update,
43};
44use datafusion::sql::sqlparser::dialect::GenericDialect;
45use datafusion::sql::sqlparser::parser::Parser;
46use paimon::catalog::{Catalog, Identifier};
47use paimon::spec::SchemaChange;
48
49use crate::error::to_datafusion_error;
50use paimon::arrow::arrow_to_paimon_type;
51
52pub struct PaimonSqlHandler {
63 ctx: SessionContext,
64 catalog: Arc<dyn Catalog>,
65 catalog_name: String,
67}
68
69impl PaimonSqlHandler {
70 pub fn new(
71 ctx: SessionContext,
72 catalog: Arc<dyn Catalog>,
73 catalog_name: impl Into<String>,
74 ) -> Self {
75 Self {
76 ctx,
77 catalog,
78 catalog_name: catalog_name.into(),
79 }
80 }
81
82 pub fn ctx(&self) -> &SessionContext {
84 &self.ctx
85 }
86
87 pub async fn sql(&self, sql: &str) -> DFResult<DataFrame> {
90 let dialect = GenericDialect {};
91 let statements = Parser::parse_sql(&dialect, sql)
92 .map_err(|e| DataFusionError::Plan(format!("SQL parse error: {e}")))?;
93
94 if statements.len() != 1 {
95 return Err(DataFusionError::Plan(
96 "Expected exactly one SQL statement".to_string(),
97 ));
98 }
99
100 match &statements[0] {
101 Statement::CreateTable(create_table) => self.handle_create_table(create_table).await,
102 Statement::AlterTable(alter_table) => {
103 self.handle_alter_table(
104 &alter_table.name,
105 &alter_table.operations,
106 alter_table.if_exists,
107 )
108 .await
109 }
110 Statement::Merge(merge) => self.handle_merge_into(merge).await,
111 Statement::Update(update) => self.handle_update(update).await,
112 _ => self.ctx.sql(sql).await,
113 }
114 }
115
116 async fn handle_create_table(&self, ct: &CreateTable) -> DFResult<DataFrame> {
117 if ct.external {
118 return Err(DataFusionError::Plan(
119 "CREATE EXTERNAL TABLE is not supported. Use CREATE TABLE instead.".to_string(),
120 ));
121 }
122 if ct.location.is_some() {
123 return Err(DataFusionError::Plan(
124 "LOCATION is not supported for Paimon tables. Table path is determined by the catalog warehouse.".to_string(),
125 ));
126 }
127 if ct.query.is_some() {
128 return Err(DataFusionError::Plan(
129 "CREATE TABLE AS SELECT is not yet supported for Paimon tables.".to_string(),
130 ));
131 }
132
133 let identifier = self.resolve_table_name(&ct.name)?;
134
135 let mut builder = paimon::spec::Schema::builder();
136
137 for col in &ct.columns {
139 let arrow_type = sql_data_type_to_arrow(&col.data_type)?;
140 let nullable = !col.options.iter().any(|opt| {
141 matches!(
142 opt.option,
143 datafusion::sql::sqlparser::ast::ColumnOption::NotNull
144 )
145 });
146 let paimon_type =
147 arrow_to_paimon_type(&arrow_type, nullable).map_err(to_datafusion_error)?;
148 builder = builder.column(col.name.value.clone(), paimon_type);
149 }
150
151 for constraint in &ct.constraints {
153 if let datafusion::sql::sqlparser::ast::TableConstraint::PrimaryKey(pk) = constraint {
154 let pk_cols: Vec<String> = pk
155 .columns
156 .iter()
157 .map(|c| c.column.expr.to_string())
158 .collect();
159 builder = builder.primary_key(pk_cols);
160 }
161 }
162
163 if let HiveDistributionStyle::PARTITIONED { columns } = &ct.hive_distribution {
165 let partition_keys: Vec<String> =
166 columns.iter().map(|c| c.name.value.clone()).collect();
167 builder = builder.partition_keys(partition_keys);
168 }
169
170 for (k, v) in extract_options(&ct.table_options)? {
172 builder = builder.option(k, v);
173 }
174
175 let schema = builder.build().map_err(to_datafusion_error)?;
176
177 self.catalog
178 .create_table(&identifier, schema, ct.if_not_exists)
179 .await
180 .map_err(to_datafusion_error)?;
181
182 ok_result(&self.ctx)
183 }
184
185 async fn handle_alter_table(
186 &self,
187 name: &ObjectName,
188 operations: &[AlterTableOperation],
189 if_exists: bool,
190 ) -> DFResult<DataFrame> {
191 let identifier = self.resolve_table_name(name)?;
192
193 let mut changes = Vec::new();
194 let mut rename_to: Option<Identifier> = None;
195
196 for op in operations {
197 match op {
198 AlterTableOperation::AddColumn { column_def, .. } => {
199 let change = column_def_to_add_column(column_def)?;
200 changes.push(change);
201 }
202 AlterTableOperation::DropColumn {
203 column_names,
204 if_exists: _,
205 ..
206 } => {
207 for col in column_names {
208 changes.push(SchemaChange::drop_column(col.value.clone()));
209 }
210 }
211 AlterTableOperation::RenameColumn {
212 old_column_name,
213 new_column_name,
214 } => {
215 changes.push(SchemaChange::rename_column(
216 old_column_name.value.clone(),
217 new_column_name.value.clone(),
218 ));
219 }
220 AlterTableOperation::RenameTable { table_name } => {
221 let new_name = match table_name {
222 RenameTableNameKind::To(name) | RenameTableNameKind::As(name) => {
223 object_name_to_string(name)
224 }
225 };
226 rename_to = Some(Identifier::new(identifier.database().to_string(), new_name));
227 }
228 other => {
229 return Err(DataFusionError::Plan(format!(
230 "Unsupported ALTER TABLE operation: {other}"
231 )));
232 }
233 }
234 }
235
236 if let Some(new_identifier) = rename_to {
237 self.catalog
238 .rename_table(&identifier, &new_identifier, if_exists)
239 .await
240 .map_err(to_datafusion_error)?;
241 }
242
243 if !changes.is_empty() {
244 self.catalog
245 .alter_table(&identifier, changes, if_exists)
246 .await
247 .map_err(to_datafusion_error)?;
248 }
249
250 ok_result(&self.ctx)
251 }
252
253 async fn handle_merge_into(&self, merge: &Merge) -> DFResult<DataFrame> {
254 let table_name = match &merge.table {
256 TableFactor::Table { name, .. } => name.clone(),
257 other => {
258 return Err(DataFusionError::Plan(format!(
259 "Unsupported target table in MERGE INTO: {other}"
260 )))
261 }
262 };
263 let identifier = self.resolve_table_name(&table_name)?;
264
265 let table = self
267 .catalog
268 .get_table(&identifier)
269 .await
270 .map_err(to_datafusion_error)?;
271
272 crate::merge_into::execute_merge_into(&self.ctx, merge, table).await
273 }
274
275 async fn handle_update(&self, update: &Update) -> DFResult<DataFrame> {
276 let table_name = match &update.table.relation {
277 TableFactor::Table { name, .. } => name.clone(),
278 other => {
279 return Err(DataFusionError::Plan(format!(
280 "Unsupported target table in UPDATE: {other}"
281 )))
282 }
283 };
284 let identifier = self.resolve_table_name(&table_name)?;
285
286 let table = self
287 .catalog
288 .get_table(&identifier)
289 .await
290 .map_err(to_datafusion_error)?;
291
292 crate::update::execute_update(&self.ctx, update, table).await
293 }
294
295 fn resolve_table_name(&self, name: &ObjectName) -> DFResult<Identifier> {
297 let parts: Vec<String> = name
298 .0
299 .iter()
300 .filter_map(|p| p.as_ident().map(|id| id.value.clone()))
301 .collect();
302 match parts.len() {
303 3 => {
304 if parts[0] != self.catalog_name {
306 return Err(DataFusionError::Plan(format!(
307 "Unknown catalog '{}', expected '{}'",
308 parts[0], self.catalog_name
309 )));
310 }
311 Ok(Identifier::new(parts[1].clone(), parts[2].clone()))
312 }
313 2 => Ok(Identifier::new(parts[0].clone(), parts[1].clone())),
314 1 => Err(DataFusionError::Plan(format!(
315 "ALTER TABLE requires at least database.table, got: {}",
316 parts[0]
317 ))),
318 _ => Err(DataFusionError::Plan(format!(
319 "Invalid table reference: {name}"
320 ))),
321 }
322 }
323}
324
325fn column_def_to_add_column(col: &ColumnDef) -> DFResult<SchemaChange> {
327 let arrow_type = sql_data_type_to_arrow(&col.data_type)?;
328 let nullable = !col.options.iter().any(|opt| {
329 matches!(
330 opt.option,
331 datafusion::sql::sqlparser::ast::ColumnOption::NotNull
332 )
333 });
334 let paimon_type = arrow_to_paimon_type(&arrow_type, nullable).map_err(to_datafusion_error)?;
335 Ok(SchemaChange::add_column(
336 col.name.value.clone(),
337 paimon_type,
338 ))
339}
340
341fn sql_data_type_to_arrow(
343 sql_type: &datafusion::sql::sqlparser::ast::DataType,
344) -> DFResult<ArrowDataType> {
345 use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
346 match sql_type {
347 SqlType::Boolean => Ok(ArrowDataType::Boolean),
348 SqlType::TinyInt(_) => Ok(ArrowDataType::Int8),
349 SqlType::SmallInt(_) => Ok(ArrowDataType::Int16),
350 SqlType::Int(_) | SqlType::Integer(_) => Ok(ArrowDataType::Int32),
351 SqlType::BigInt(_) => Ok(ArrowDataType::Int64),
352 SqlType::Float(_) => Ok(ArrowDataType::Float32),
353 SqlType::Real => Ok(ArrowDataType::Float32),
354 SqlType::Double(_) | SqlType::DoublePrecision => Ok(ArrowDataType::Float64),
355 SqlType::Varchar(_) | SqlType::CharVarying(_) | SqlType::Text | SqlType::String(_) => {
356 Ok(ArrowDataType::Utf8)
357 }
358 SqlType::Char(_) | SqlType::Character(_) => Ok(ArrowDataType::Utf8),
359 SqlType::Binary(_) | SqlType::Varbinary(_) | SqlType::Blob(_) | SqlType::Bytea => {
360 Ok(ArrowDataType::Binary)
361 }
362 SqlType::Date => Ok(ArrowDataType::Date32),
363 SqlType::Timestamp(precision, tz_info) => {
364 use datafusion::sql::sqlparser::ast::TimezoneInfo;
365 let unit = match precision {
366 Some(0) => datafusion::arrow::datatypes::TimeUnit::Second,
367 Some(1..=3) | None => datafusion::arrow::datatypes::TimeUnit::Millisecond,
368 Some(4..=6) => datafusion::arrow::datatypes::TimeUnit::Microsecond,
369 _ => datafusion::arrow::datatypes::TimeUnit::Nanosecond,
370 };
371 let tz = match tz_info {
372 TimezoneInfo::None | TimezoneInfo::WithoutTimeZone => None,
373 _ => Some("UTC".into()),
374 };
375 Ok(ArrowDataType::Timestamp(unit, tz))
376 }
377 SqlType::Decimal(info) => {
378 use datafusion::sql::sqlparser::ast::ExactNumberInfo;
379 let (p, s) = match info {
380 ExactNumberInfo::PrecisionAndScale(p, s) => (*p as u8, *s as i8),
381 ExactNumberInfo::Precision(p) => (*p as u8, 0),
382 ExactNumberInfo::None => (10, 0),
383 };
384 Ok(ArrowDataType::Decimal128(p, s))
385 }
386 SqlType::Array(elem_def) => {
387 let elem_type = match elem_def {
388 ArrayElemTypeDef::AngleBracket(t)
389 | ArrayElemTypeDef::SquareBracket(t, _)
390 | ArrayElemTypeDef::Parenthesis(t) => sql_data_type_to_arrow(t)?,
391 ArrayElemTypeDef::None => {
392 return Err(DataFusionError::Plan(
393 "ARRAY type requires an element type".to_string(),
394 ));
395 }
396 };
397 Ok(ArrowDataType::List(Arc::new(Field::new(
398 "element", elem_type, true,
399 ))))
400 }
401 SqlType::Map(key_type, value_type) => {
402 let key = sql_data_type_to_arrow(key_type)?;
403 let value = sql_data_type_to_arrow(value_type)?;
404 let entries = Field::new(
405 "entries",
406 ArrowDataType::Struct(
407 vec![
408 Field::new("key", key, false),
409 Field::new("value", value, true),
410 ]
411 .into(),
412 ),
413 false,
414 );
415 Ok(ArrowDataType::Map(Arc::new(entries), false))
416 }
417 SqlType::Struct(fields, _) => {
418 let arrow_fields: Vec<Field> = fields
419 .iter()
420 .map(|f| {
421 let name = f
422 .field_name
423 .as_ref()
424 .map(|n| n.value.clone())
425 .unwrap_or_default();
426 let dt = sql_data_type_to_arrow(&f.field_type)?;
427 Ok(Field::new(name, dt, true))
428 })
429 .collect::<DFResult<_>>()?;
430 Ok(ArrowDataType::Struct(arrow_fields.into()))
431 }
432 _ => Err(DataFusionError::Plan(format!(
433 "Unsupported SQL data type: {sql_type}"
434 ))),
435 }
436}
437
438fn object_name_to_string(name: &ObjectName) -> String {
439 name.0
440 .iter()
441 .filter_map(|p| p.as_ident().map(|id| id.value.clone()))
442 .collect::<Vec<_>>()
443 .join(".")
444}
445
446fn extract_options(opts: &CreateTableOptions) -> DFResult<Vec<(String, String)>> {
448 let sql_options = match opts {
449 CreateTableOptions::With(options)
450 | CreateTableOptions::Options(options)
451 | CreateTableOptions::TableProperties(options)
452 | CreateTableOptions::Plain(options) => options,
453 CreateTableOptions::None => return Ok(Vec::new()),
454 };
455 sql_options
456 .iter()
457 .map(|opt| match opt {
458 SqlOption::KeyValue { key, value } => {
459 let v = value.to_string();
460 let v = v
462 .strip_prefix('\'')
463 .and_then(|s| s.strip_suffix('\''))
464 .unwrap_or(&v)
465 .to_string();
466 Ok((key.value.clone(), v))
467 }
468 other => Err(DataFusionError::Plan(format!(
469 "Unsupported table option: {other}"
470 ))),
471 })
472 .collect()
473}
474
475fn ok_result(ctx: &SessionContext) -> DFResult<DataFrame> {
477 let schema = Arc::new(Schema::new(vec![Field::new(
478 "result",
479 ArrowDataType::Utf8,
480 false,
481 )]));
482 let batch = RecordBatch::try_new(
483 schema.clone(),
484 vec![Arc::new(StringArray::from(vec!["OK"]))],
485 )?;
486 let df = ctx.read_batch(batch)?;
487 Ok(df)
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use std::collections::HashMap;
494 use std::sync::Mutex;
495
496 use async_trait::async_trait;
497 use datafusion::arrow::datatypes::TimeUnit;
498 use paimon::catalog::Database;
499 use paimon::spec::Schema as PaimonSchema;
500 use paimon::table::Table;
501
502 #[allow(clippy::enum_variant_names)]
505 #[derive(Debug)]
506 enum CatalogCall {
507 CreateTable {
508 identifier: Identifier,
509 schema: PaimonSchema,
510 ignore_if_exists: bool,
511 },
512 AlterTable {
513 identifier: Identifier,
514 changes: Vec<SchemaChange>,
515 ignore_if_not_exists: bool,
516 },
517 RenameTable {
518 from: Identifier,
519 to: Identifier,
520 ignore_if_not_exists: bool,
521 },
522 }
523
524 struct MockCatalog {
525 calls: Mutex<Vec<CatalogCall>>,
526 }
527
528 impl MockCatalog {
529 fn new() -> Self {
530 Self {
531 calls: Mutex::new(Vec::new()),
532 }
533 }
534
535 fn take_calls(&self) -> Vec<CatalogCall> {
536 std::mem::take(&mut *self.calls.lock().unwrap())
537 }
538 }
539
540 #[async_trait]
541 impl Catalog for MockCatalog {
542 async fn list_databases(&self) -> paimon::Result<Vec<String>> {
543 Ok(vec![])
544 }
545 async fn create_database(
546 &self,
547 _name: &str,
548 _ignore_if_exists: bool,
549 _properties: HashMap<String, String>,
550 ) -> paimon::Result<()> {
551 Ok(())
552 }
553 async fn get_database(&self, _name: &str) -> paimon::Result<Database> {
554 unimplemented!()
555 }
556 async fn drop_database(
557 &self,
558 _name: &str,
559 _ignore_if_not_exists: bool,
560 _cascade: bool,
561 ) -> paimon::Result<()> {
562 Ok(())
563 }
564 async fn get_table(&self, _identifier: &Identifier) -> paimon::Result<Table> {
565 unimplemented!()
566 }
567 async fn list_tables(&self, _database_name: &str) -> paimon::Result<Vec<String>> {
568 Ok(vec![])
569 }
570 async fn create_table(
571 &self,
572 identifier: &Identifier,
573 creation: PaimonSchema,
574 ignore_if_exists: bool,
575 ) -> paimon::Result<()> {
576 self.calls.lock().unwrap().push(CatalogCall::CreateTable {
577 identifier: identifier.clone(),
578 schema: creation,
579 ignore_if_exists,
580 });
581 Ok(())
582 }
583 async fn drop_table(
584 &self,
585 _identifier: &Identifier,
586 _ignore_if_not_exists: bool,
587 ) -> paimon::Result<()> {
588 Ok(())
589 }
590 async fn rename_table(
591 &self,
592 from: &Identifier,
593 to: &Identifier,
594 ignore_if_not_exists: bool,
595 ) -> paimon::Result<()> {
596 self.calls.lock().unwrap().push(CatalogCall::RenameTable {
597 from: from.clone(),
598 to: to.clone(),
599 ignore_if_not_exists,
600 });
601 Ok(())
602 }
603 async fn alter_table(
604 &self,
605 identifier: &Identifier,
606 changes: Vec<SchemaChange>,
607 ignore_if_not_exists: bool,
608 ) -> paimon::Result<()> {
609 self.calls.lock().unwrap().push(CatalogCall::AlterTable {
610 identifier: identifier.clone(),
611 changes,
612 ignore_if_not_exists,
613 });
614 Ok(())
615 }
616 }
617
618 fn make_handler(catalog: Arc<MockCatalog>) -> PaimonSqlHandler {
619 PaimonSqlHandler::new(SessionContext::new(), catalog, "paimon")
620 }
621
622 #[test]
625 fn test_sql_type_boolean() {
626 use datafusion::sql::sqlparser::ast::DataType as SqlType;
627 assert_eq!(
628 sql_data_type_to_arrow(&SqlType::Boolean).unwrap(),
629 ArrowDataType::Boolean
630 );
631 }
632
633 #[test]
634 fn test_sql_type_integers() {
635 use datafusion::sql::sqlparser::ast::DataType as SqlType;
636 assert_eq!(
637 sql_data_type_to_arrow(&SqlType::TinyInt(None)).unwrap(),
638 ArrowDataType::Int8
639 );
640 assert_eq!(
641 sql_data_type_to_arrow(&SqlType::SmallInt(None)).unwrap(),
642 ArrowDataType::Int16
643 );
644 assert_eq!(
645 sql_data_type_to_arrow(&SqlType::Int(None)).unwrap(),
646 ArrowDataType::Int32
647 );
648 assert_eq!(
649 sql_data_type_to_arrow(&SqlType::Integer(None)).unwrap(),
650 ArrowDataType::Int32
651 );
652 assert_eq!(
653 sql_data_type_to_arrow(&SqlType::BigInt(None)).unwrap(),
654 ArrowDataType::Int64
655 );
656 }
657
658 #[test]
659 fn test_sql_type_floats() {
660 use datafusion::sql::sqlparser::ast::{DataType as SqlType, ExactNumberInfo};
661 assert_eq!(
662 sql_data_type_to_arrow(&SqlType::Float(ExactNumberInfo::None)).unwrap(),
663 ArrowDataType::Float32
664 );
665 assert_eq!(
666 sql_data_type_to_arrow(&SqlType::Real).unwrap(),
667 ArrowDataType::Float32
668 );
669 assert_eq!(
670 sql_data_type_to_arrow(&SqlType::DoublePrecision).unwrap(),
671 ArrowDataType::Float64
672 );
673 }
674
675 #[test]
676 fn test_sql_type_string_variants() {
677 use datafusion::sql::sqlparser::ast::DataType as SqlType;
678 for sql_type in [SqlType::Varchar(None), SqlType::Text, SqlType::String(None)] {
679 assert_eq!(
680 sql_data_type_to_arrow(&sql_type).unwrap(),
681 ArrowDataType::Utf8,
682 "failed for {sql_type:?}"
683 );
684 }
685 }
686
687 #[test]
688 fn test_sql_type_binary() {
689 use datafusion::sql::sqlparser::ast::DataType as SqlType;
690 assert_eq!(
691 sql_data_type_to_arrow(&SqlType::Bytea).unwrap(),
692 ArrowDataType::Binary
693 );
694 }
695
696 #[test]
697 fn test_sql_type_date() {
698 use datafusion::sql::sqlparser::ast::DataType as SqlType;
699 assert_eq!(
700 sql_data_type_to_arrow(&SqlType::Date).unwrap(),
701 ArrowDataType::Date32
702 );
703 }
704
705 #[test]
706 fn test_sql_type_timestamp_default() {
707 use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
708 let result = sql_data_type_to_arrow(&SqlType::Timestamp(None, TimezoneInfo::None)).unwrap();
709 assert_eq!(
710 result,
711 ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
712 );
713 }
714
715 #[test]
716 fn test_sql_type_timestamp_with_precision() {
717 use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
718 assert_eq!(
720 sql_data_type_to_arrow(&SqlType::Timestamp(Some(0), TimezoneInfo::None)).unwrap(),
721 ArrowDataType::Timestamp(TimeUnit::Second, None)
722 );
723 assert_eq!(
725 sql_data_type_to_arrow(&SqlType::Timestamp(Some(3), TimezoneInfo::None)).unwrap(),
726 ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
727 );
728 assert_eq!(
730 sql_data_type_to_arrow(&SqlType::Timestamp(Some(6), TimezoneInfo::None)).unwrap(),
731 ArrowDataType::Timestamp(TimeUnit::Microsecond, None)
732 );
733 assert_eq!(
735 sql_data_type_to_arrow(&SqlType::Timestamp(Some(9), TimezoneInfo::None)).unwrap(),
736 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
737 );
738 }
739
740 #[test]
741 fn test_sql_type_timestamp_with_tz() {
742 use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
743 let result =
744 sql_data_type_to_arrow(&SqlType::Timestamp(None, TimezoneInfo::WithTimeZone)).unwrap();
745 assert_eq!(
746 result,
747 ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()))
748 );
749 }
750
751 #[test]
752 fn test_sql_type_decimal() {
753 use datafusion::sql::sqlparser::ast::{DataType as SqlType, ExactNumberInfo};
754 assert_eq!(
755 sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::PrecisionAndScale(18, 2)))
756 .unwrap(),
757 ArrowDataType::Decimal128(18, 2)
758 );
759 assert_eq!(
760 sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::Precision(10))).unwrap(),
761 ArrowDataType::Decimal128(10, 0)
762 );
763 assert_eq!(
764 sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::None)).unwrap(),
765 ArrowDataType::Decimal128(10, 0)
766 );
767 }
768
769 #[test]
770 fn test_sql_type_unsupported() {
771 use datafusion::sql::sqlparser::ast::DataType as SqlType;
772 assert!(sql_data_type_to_arrow(&SqlType::Regclass).is_err());
773 }
774
775 #[test]
776 fn test_sql_type_array() {
777 use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
778 let result = sql_data_type_to_arrow(&SqlType::Array(ArrayElemTypeDef::AngleBracket(
779 Box::new(SqlType::Int(None)),
780 )))
781 .unwrap();
782 assert_eq!(
783 result,
784 ArrowDataType::List(Arc::new(Field::new("element", ArrowDataType::Int32, true)))
785 );
786 }
787
788 #[test]
789 fn test_sql_type_array_no_element() {
790 use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
791 assert!(sql_data_type_to_arrow(&SqlType::Array(ArrayElemTypeDef::None)).is_err());
792 }
793
794 #[test]
795 fn test_sql_type_map() {
796 use datafusion::sql::sqlparser::ast::DataType as SqlType;
797 let result = sql_data_type_to_arrow(&SqlType::Map(
798 Box::new(SqlType::Varchar(None)),
799 Box::new(SqlType::Int(None)),
800 ))
801 .unwrap();
802 let expected = ArrowDataType::Map(
803 Arc::new(Field::new(
804 "entries",
805 ArrowDataType::Struct(
806 vec![
807 Field::new("key", ArrowDataType::Utf8, false),
808 Field::new("value", ArrowDataType::Int32, true),
809 ]
810 .into(),
811 ),
812 false,
813 )),
814 false,
815 );
816 assert_eq!(result, expected);
817 }
818
819 #[test]
820 fn test_sql_type_struct() {
821 use datafusion::sql::sqlparser::ast::{
822 DataType as SqlType, Ident, StructBracketKind, StructField,
823 };
824 let result = sql_data_type_to_arrow(&SqlType::Struct(
825 vec![
826 StructField {
827 field_name: Some(Ident::new("name")),
828 field_type: SqlType::Varchar(None),
829 options: None,
830 },
831 StructField {
832 field_name: Some(Ident::new("age")),
833 field_type: SqlType::Int(None),
834 options: None,
835 },
836 ],
837 StructBracketKind::AngleBrackets,
838 ))
839 .unwrap();
840 assert_eq!(
841 result,
842 ArrowDataType::Struct(
843 vec![
844 Field::new("name", ArrowDataType::Utf8, true),
845 Field::new("age", ArrowDataType::Int32, true),
846 ]
847 .into()
848 )
849 );
850 }
851
852 #[test]
855 fn test_resolve_three_part_name() {
856 let catalog = Arc::new(MockCatalog::new());
857 let handler = make_handler(catalog);
858 let dialect = GenericDialect {};
859 let stmts = Parser::parse_sql(&dialect, "SELECT * FROM paimon.mydb.mytable").unwrap();
860 if let Statement::Query(q) = &stmts[0] {
861 if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
862 if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
863 &sel.from[0].relation
864 {
865 let id = handler.resolve_table_name(name).unwrap();
866 assert_eq!(id.database(), "mydb");
867 assert_eq!(id.object(), "mytable");
868 }
869 }
870 }
871 }
872
873 #[test]
874 fn test_resolve_two_part_name() {
875 let catalog = Arc::new(MockCatalog::new());
876 let handler = make_handler(catalog);
877 let dialect = GenericDialect {};
878 let stmts = Parser::parse_sql(&dialect, "SELECT * FROM mydb.mytable").unwrap();
879 if let Statement::Query(q) = &stmts[0] {
880 if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
881 if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
882 &sel.from[0].relation
883 {
884 let id = handler.resolve_table_name(name).unwrap();
885 assert_eq!(id.database(), "mydb");
886 assert_eq!(id.object(), "mytable");
887 }
888 }
889 }
890 }
891
892 #[test]
893 fn test_resolve_wrong_catalog_name() {
894 let catalog = Arc::new(MockCatalog::new());
895 let handler = make_handler(catalog);
896 let dialect = GenericDialect {};
897 let stmts = Parser::parse_sql(&dialect, "SELECT * FROM other.mydb.mytable").unwrap();
898 if let Statement::Query(q) = &stmts[0] {
899 if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
900 if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
901 &sel.from[0].relation
902 {
903 let err = handler.resolve_table_name(name).unwrap_err();
904 assert!(err.to_string().contains("Unknown catalog"));
905 }
906 }
907 }
908 }
909
910 #[test]
911 fn test_resolve_single_part_name_error() {
912 let catalog = Arc::new(MockCatalog::new());
913 let handler = make_handler(catalog);
914 let dialect = GenericDialect {};
915 let stmts = Parser::parse_sql(&dialect, "SELECT * FROM mytable").unwrap();
916 if let Statement::Query(q) = &stmts[0] {
917 if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
918 if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
919 &sel.from[0].relation
920 {
921 let err = handler.resolve_table_name(name).unwrap_err();
922 assert!(err.to_string().contains("at least database.table"));
923 }
924 }
925 }
926 }
927
928 #[test]
931 fn test_extract_options_none() {
932 let opts = extract_options(&CreateTableOptions::None).unwrap();
933 assert!(opts.is_empty());
934 }
935
936 #[test]
937 fn test_extract_options_with_kv() {
938 let dialect = GenericDialect {};
940 let stmts =
941 Parser::parse_sql(&dialect, "CREATE TABLE t (id INT) WITH ('bucket' = '4')").unwrap();
942 if let Statement::CreateTable(ct) = &stmts[0] {
943 let opts = extract_options(&ct.table_options).unwrap();
944 assert_eq!(opts.len(), 1);
945 assert_eq!(opts[0].0, "bucket");
946 assert_eq!(opts[0].1, "4");
947 } else {
948 panic!("expected CreateTable");
949 }
950 }
951
952 #[tokio::test]
955 async fn test_create_table_basic() {
956 let catalog = Arc::new(MockCatalog::new());
957 let handler = make_handler(catalog.clone());
958
959 handler
960 .sql("CREATE TABLE mydb.t1 (id INT NOT NULL, name VARCHAR, PRIMARY KEY (id))")
961 .await
962 .unwrap();
963
964 let calls = catalog.take_calls();
965 assert_eq!(calls.len(), 1);
966 if let CatalogCall::CreateTable {
967 identifier,
968 schema,
969 ignore_if_exists,
970 } = &calls[0]
971 {
972 assert_eq!(identifier.database(), "mydb");
973 assert_eq!(identifier.object(), "t1");
974 assert!(!ignore_if_exists);
975 assert_eq!(schema.primary_keys(), &["id"]);
976 } else {
977 panic!("expected CreateTable call");
978 }
979 }
980
981 #[tokio::test]
982 async fn test_create_table_if_not_exists() {
983 let catalog = Arc::new(MockCatalog::new());
984 let handler = make_handler(catalog.clone());
985
986 handler
987 .sql("CREATE TABLE IF NOT EXISTS mydb.t1 (id INT)")
988 .await
989 .unwrap();
990
991 let calls = catalog.take_calls();
992 assert_eq!(calls.len(), 1);
993 if let CatalogCall::CreateTable {
994 ignore_if_exists, ..
995 } = &calls[0]
996 {
997 assert!(ignore_if_exists);
998 } else {
999 panic!("expected CreateTable call");
1000 }
1001 }
1002
1003 #[tokio::test]
1004 async fn test_create_table_with_options() {
1005 let catalog = Arc::new(MockCatalog::new());
1006 let handler = make_handler(catalog.clone());
1007
1008 handler
1009 .sql("CREATE TABLE mydb.t1 (id INT) WITH ('bucket' = '4', 'file.format' = 'parquet')")
1010 .await
1011 .unwrap();
1012
1013 let calls = catalog.take_calls();
1014 assert_eq!(calls.len(), 1);
1015 if let CatalogCall::CreateTable { schema, .. } = &calls[0] {
1016 let opts = schema.options();
1017 assert_eq!(opts.get("bucket").unwrap(), "4");
1018 assert_eq!(opts.get("file.format").unwrap(), "parquet");
1019 } else {
1020 panic!("expected CreateTable call");
1021 }
1022 }
1023
1024 #[tokio::test]
1025 async fn test_create_table_three_part_name() {
1026 let catalog = Arc::new(MockCatalog::new());
1027 let handler = make_handler(catalog.clone());
1028
1029 handler
1030 .sql("CREATE TABLE paimon.mydb.t1 (id INT)")
1031 .await
1032 .unwrap();
1033
1034 let calls = catalog.take_calls();
1035 if let CatalogCall::CreateTable { identifier, .. } = &calls[0] {
1036 assert_eq!(identifier.database(), "mydb");
1037 assert_eq!(identifier.object(), "t1");
1038 } else {
1039 panic!("expected CreateTable call");
1040 }
1041 }
1042
1043 #[tokio::test]
1044 async fn test_alter_table_add_column() {
1045 let catalog = Arc::new(MockCatalog::new());
1046 let handler = make_handler(catalog.clone());
1047
1048 handler
1049 .sql("ALTER TABLE mydb.t1 ADD COLUMN age INT")
1050 .await
1051 .unwrap();
1052
1053 let calls = catalog.take_calls();
1054 assert_eq!(calls.len(), 1);
1055 if let CatalogCall::AlterTable {
1056 identifier,
1057 changes,
1058 ..
1059 } = &calls[0]
1060 {
1061 assert_eq!(identifier.database(), "mydb");
1062 assert_eq!(identifier.object(), "t1");
1063 assert_eq!(changes.len(), 1);
1064 assert!(
1065 matches!(&changes[0], SchemaChange::AddColumn { field_name, .. } if field_name == "age")
1066 );
1067 } else {
1068 panic!("expected AlterTable call");
1069 }
1070 }
1071
1072 #[tokio::test]
1073 async fn test_alter_table_drop_column() {
1074 let catalog = Arc::new(MockCatalog::new());
1075 let handler = make_handler(catalog.clone());
1076
1077 handler
1078 .sql("ALTER TABLE mydb.t1 DROP COLUMN age")
1079 .await
1080 .unwrap();
1081
1082 let calls = catalog.take_calls();
1083 assert_eq!(calls.len(), 1);
1084 if let CatalogCall::AlterTable { changes, .. } = &calls[0] {
1085 assert_eq!(changes.len(), 1);
1086 assert!(
1087 matches!(&changes[0], SchemaChange::DropColumn { field_name } if field_name == "age")
1088 );
1089 } else {
1090 panic!("expected AlterTable call");
1091 }
1092 }
1093
1094 #[tokio::test]
1095 async fn test_alter_table_rename_column() {
1096 let catalog = Arc::new(MockCatalog::new());
1097 let handler = make_handler(catalog.clone());
1098
1099 handler
1100 .sql("ALTER TABLE mydb.t1 RENAME COLUMN old_name TO new_name")
1101 .await
1102 .unwrap();
1103
1104 let calls = catalog.take_calls();
1105 assert_eq!(calls.len(), 1);
1106 if let CatalogCall::AlterTable { changes, .. } = &calls[0] {
1107 assert_eq!(changes.len(), 1);
1108 assert!(matches!(
1109 &changes[0],
1110 SchemaChange::RenameColumn { field_name, new_name }
1111 if field_name == "old_name" && new_name == "new_name"
1112 ));
1113 } else {
1114 panic!("expected AlterTable call");
1115 }
1116 }
1117
1118 #[tokio::test]
1119 async fn test_alter_table_rename_table() {
1120 let catalog = Arc::new(MockCatalog::new());
1121 let handler = make_handler(catalog.clone());
1122
1123 handler
1124 .sql("ALTER TABLE mydb.t1 RENAME TO t2")
1125 .await
1126 .unwrap();
1127
1128 let calls = catalog.take_calls();
1129 assert_eq!(calls.len(), 1);
1130 if let CatalogCall::RenameTable { from, to, .. } = &calls[0] {
1131 assert_eq!(from.database(), "mydb");
1132 assert_eq!(from.object(), "t1");
1133 assert_eq!(to.database(), "mydb");
1134 assert_eq!(to.object(), "t2");
1135 } else {
1136 panic!("expected RenameTable call");
1137 }
1138 }
1139
1140 #[tokio::test]
1141 async fn test_alter_table_if_exists_add_column() {
1142 let catalog = Arc::new(MockCatalog::new());
1143 let handler = make_handler(catalog.clone());
1144
1145 handler
1146 .sql("ALTER TABLE IF EXISTS mydb.t1 ADD COLUMN age INT")
1147 .await
1148 .unwrap();
1149
1150 let calls = catalog.take_calls();
1151 assert_eq!(calls.len(), 1);
1152 if let CatalogCall::AlterTable {
1153 ignore_if_not_exists,
1154 ..
1155 } = &calls[0]
1156 {
1157 assert!(ignore_if_not_exists);
1158 } else {
1159 panic!("expected AlterTable call");
1160 }
1161 }
1162
1163 #[tokio::test]
1164 async fn test_alter_table_without_if_exists() {
1165 let catalog = Arc::new(MockCatalog::new());
1166 let handler = make_handler(catalog.clone());
1167
1168 handler
1169 .sql("ALTER TABLE mydb.t1 ADD COLUMN age INT")
1170 .await
1171 .unwrap();
1172
1173 let calls = catalog.take_calls();
1174 if let CatalogCall::AlterTable {
1175 ignore_if_not_exists,
1176 ..
1177 } = &calls[0]
1178 {
1179 assert!(!ignore_if_not_exists);
1180 } else {
1181 panic!("expected AlterTable call");
1182 }
1183 }
1184
1185 #[tokio::test]
1186 async fn test_alter_table_if_exists_rename() {
1187 let catalog = Arc::new(MockCatalog::new());
1188 let handler = make_handler(catalog.clone());
1189
1190 handler
1191 .sql("ALTER TABLE IF EXISTS mydb.t1 RENAME TO t2")
1192 .await
1193 .unwrap();
1194
1195 let calls = catalog.take_calls();
1196 assert_eq!(calls.len(), 1);
1197 if let CatalogCall::RenameTable {
1198 from,
1199 to,
1200 ignore_if_not_exists,
1201 } = &calls[0]
1202 {
1203 assert!(ignore_if_not_exists);
1204 assert_eq!(from.object(), "t1");
1205 assert_eq!(to.object(), "t2");
1206 } else {
1207 panic!("expected RenameTable call");
1208 }
1209 }
1210
1211 #[tokio::test]
1212 async fn test_alter_table_rename_three_part_name() {
1213 let catalog = Arc::new(MockCatalog::new());
1214 let handler = make_handler(catalog.clone());
1215
1216 handler
1217 .sql("ALTER TABLE paimon.mydb.t1 RENAME TO t2")
1218 .await
1219 .unwrap();
1220
1221 let calls = catalog.take_calls();
1222 assert_eq!(calls.len(), 1);
1223 if let CatalogCall::RenameTable { from, to, .. } = &calls[0] {
1224 assert_eq!(from.database(), "mydb");
1225 assert_eq!(from.object(), "t1");
1226 assert_eq!(to.database(), "mydb");
1227 assert_eq!(to.object(), "t2");
1228 } else {
1229 panic!("expected RenameTable call");
1230 }
1231 }
1232
1233 #[tokio::test]
1234 async fn test_sql_parse_error() {
1235 let catalog = Arc::new(MockCatalog::new());
1236 let handler = make_handler(catalog);
1237 let result = handler.sql("NOT VALID SQL !!!").await;
1238 assert!(result.is_err());
1239 assert!(result.unwrap_err().to_string().contains("SQL parse error"));
1240 }
1241
1242 #[tokio::test]
1243 async fn test_multiple_statements_error() {
1244 let catalog = Arc::new(MockCatalog::new());
1245 let handler = make_handler(catalog);
1246 let result = handler.sql("SELECT 1; SELECT 2").await;
1247 assert!(result.is_err());
1248 assert!(result
1249 .unwrap_err()
1250 .to_string()
1251 .contains("exactly one SQL statement"));
1252 }
1253
1254 #[tokio::test]
1255 async fn test_create_external_table_rejected() {
1256 let catalog = Arc::new(MockCatalog::new());
1257 let handler = make_handler(catalog);
1258 let result = handler
1259 .sql("CREATE EXTERNAL TABLE mydb.t1 (id INT) STORED AS PARQUET")
1260 .await;
1261 assert!(result.is_err());
1262 assert!(result
1263 .unwrap_err()
1264 .to_string()
1265 .contains("CREATE EXTERNAL TABLE is not supported"));
1266 }
1267
1268 #[tokio::test]
1269 async fn test_non_ddl_delegates_to_datafusion() {
1270 let catalog = Arc::new(MockCatalog::new());
1271 let handler = make_handler(catalog.clone());
1272 let df = handler.sql("SELECT 1 AS x").await.unwrap();
1274 let batches = df.collect().await.unwrap();
1275 assert_eq!(batches.len(), 1);
1276 assert_eq!(batches[0].num_rows(), 1);
1277 assert!(catalog.take_calls().is_empty());
1279 }
1280}