Skip to main content

paimon_datafusion/
sql_handler.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL support for Paimon tables.
19//!
20//! DataFusion does not natively support all SQL statements needed by Paimon.
21//! This module provides [`PaimonSqlHandler`] which intercepts CREATE TABLE,
22//! ALTER TABLE, MERGE INTO, UPDATE and other SQL, translates them to Paimon
23//! catalog operations, and delegates everything else (SELECT, CREATE/DROP
24//! SCHEMA, DROP TABLE, etc.) to the underlying [`SessionContext`].
25//!
26//! Supported DDL:
27//! - `CREATE TABLE db.t (col TYPE, ..., PRIMARY KEY (col, ...)) [PARTITIONED BY (col TYPE, ...)] [WITH ('key' = 'val')]`
28//! - `ALTER TABLE db.t ADD COLUMN col TYPE`
29//! - `ALTER TABLE db.t DROP COLUMN col`
30//! - `ALTER TABLE db.t RENAME COLUMN old TO new`
31//! - `ALTER TABLE db.t RENAME TO new_name`
32
33use std::sync::Arc;
34
35use datafusion::arrow::array::StringArray;
36use datafusion::arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
37use datafusion::arrow::record_batch::RecordBatch;
38use datafusion::error::{DataFusionError, Result as DFResult};
39use datafusion::prelude::{DataFrame, SessionContext};
40use datafusion::sql::sqlparser::ast::{
41    AlterTableOperation, ColumnDef, CreateTable, CreateTableOptions, HiveDistributionStyle, Merge,
42    ObjectName, RenameTableNameKind, SqlOption, Statement, TableFactor, Update,
43};
44use datafusion::sql::sqlparser::dialect::GenericDialect;
45use datafusion::sql::sqlparser::parser::Parser;
46use paimon::catalog::{Catalog, Identifier};
47use paimon::spec::SchemaChange;
48
49use crate::error::to_datafusion_error;
50use paimon::arrow::arrow_to_paimon_type;
51
52/// Wraps a [`SessionContext`] and a Paimon [`Catalog`] to handle DDL statements
53/// that DataFusion does not natively support (e.g. ALTER TABLE).
54///
55/// For all other SQL, it delegates to the inner `SessionContext`.
56///
57/// # Example
58/// ```ignore
59/// let handler = PaimonSqlHandler::new(ctx, catalog);
60/// let df = handler.sql("ALTER TABLE paimon.db.t ADD COLUMN age INT").await?;
61/// ```
62pub struct PaimonSqlHandler {
63    ctx: SessionContext,
64    catalog: Arc<dyn Catalog>,
65    /// The catalog name registered in the SessionContext (used to strip the catalog prefix).
66    catalog_name: String,
67}
68
69impl PaimonSqlHandler {
70    pub fn new(
71        ctx: SessionContext,
72        catalog: Arc<dyn Catalog>,
73        catalog_name: impl Into<String>,
74    ) -> Self {
75        Self {
76            ctx,
77            catalog,
78            catalog_name: catalog_name.into(),
79        }
80    }
81
82    /// Returns a reference to the inner [`SessionContext`].
83    pub fn ctx(&self) -> &SessionContext {
84        &self.ctx
85    }
86
87    /// Execute a SQL statement. ALTER TABLE is handled by Paimon directly;
88    /// everything else is delegated to DataFusion.
89    pub async fn sql(&self, sql: &str) -> DFResult<DataFrame> {
90        let dialect = GenericDialect {};
91        let statements = Parser::parse_sql(&dialect, sql)
92            .map_err(|e| DataFusionError::Plan(format!("SQL parse error: {e}")))?;
93
94        if statements.len() != 1 {
95            return Err(DataFusionError::Plan(
96                "Expected exactly one SQL statement".to_string(),
97            ));
98        }
99
100        match &statements[0] {
101            Statement::CreateTable(create_table) => self.handle_create_table(create_table).await,
102            Statement::AlterTable(alter_table) => {
103                self.handle_alter_table(
104                    &alter_table.name,
105                    &alter_table.operations,
106                    alter_table.if_exists,
107                )
108                .await
109            }
110            Statement::Merge(merge) => self.handle_merge_into(merge).await,
111            Statement::Update(update) => self.handle_update(update).await,
112            _ => self.ctx.sql(sql).await,
113        }
114    }
115
116    async fn handle_create_table(&self, ct: &CreateTable) -> DFResult<DataFrame> {
117        if ct.external {
118            return Err(DataFusionError::Plan(
119                "CREATE EXTERNAL TABLE is not supported. Use CREATE TABLE instead.".to_string(),
120            ));
121        }
122        if ct.location.is_some() {
123            return Err(DataFusionError::Plan(
124                "LOCATION is not supported for Paimon tables. Table path is determined by the catalog warehouse.".to_string(),
125            ));
126        }
127        if ct.query.is_some() {
128            return Err(DataFusionError::Plan(
129                "CREATE TABLE AS SELECT is not yet supported for Paimon tables.".to_string(),
130            ));
131        }
132
133        let identifier = self.resolve_table_name(&ct.name)?;
134
135        let mut builder = paimon::spec::Schema::builder();
136
137        // Columns
138        for col in &ct.columns {
139            let arrow_type = sql_data_type_to_arrow(&col.data_type)?;
140            let nullable = !col.options.iter().any(|opt| {
141                matches!(
142                    opt.option,
143                    datafusion::sql::sqlparser::ast::ColumnOption::NotNull
144                )
145            });
146            let paimon_type =
147                arrow_to_paimon_type(&arrow_type, nullable).map_err(to_datafusion_error)?;
148            builder = builder.column(col.name.value.clone(), paimon_type);
149        }
150
151        // Primary key from constraints: PRIMARY KEY (col, ...)
152        for constraint in &ct.constraints {
153            if let datafusion::sql::sqlparser::ast::TableConstraint::PrimaryKey(pk) = constraint {
154                let pk_cols: Vec<String> = pk
155                    .columns
156                    .iter()
157                    .map(|c| c.column.expr.to_string())
158                    .collect();
159                builder = builder.primary_key(pk_cols);
160            }
161        }
162
163        // Partition keys from PARTITIONED BY (col, ...)
164        if let HiveDistributionStyle::PARTITIONED { columns } = &ct.hive_distribution {
165            let partition_keys: Vec<String> =
166                columns.iter().map(|c| c.name.value.clone()).collect();
167            builder = builder.partition_keys(partition_keys);
168        }
169
170        // Table options from WITH ('key' = 'value', ...)
171        for (k, v) in extract_options(&ct.table_options)? {
172            builder = builder.option(k, v);
173        }
174
175        let schema = builder.build().map_err(to_datafusion_error)?;
176
177        self.catalog
178            .create_table(&identifier, schema, ct.if_not_exists)
179            .await
180            .map_err(to_datafusion_error)?;
181
182        ok_result(&self.ctx)
183    }
184
185    async fn handle_alter_table(
186        &self,
187        name: &ObjectName,
188        operations: &[AlterTableOperation],
189        if_exists: bool,
190    ) -> DFResult<DataFrame> {
191        let identifier = self.resolve_table_name(name)?;
192
193        let mut changes = Vec::new();
194        let mut rename_to: Option<Identifier> = None;
195
196        for op in operations {
197            match op {
198                AlterTableOperation::AddColumn { column_def, .. } => {
199                    let change = column_def_to_add_column(column_def)?;
200                    changes.push(change);
201                }
202                AlterTableOperation::DropColumn {
203                    column_names,
204                    if_exists: _,
205                    ..
206                } => {
207                    for col in column_names {
208                        changes.push(SchemaChange::drop_column(col.value.clone()));
209                    }
210                }
211                AlterTableOperation::RenameColumn {
212                    old_column_name,
213                    new_column_name,
214                } => {
215                    changes.push(SchemaChange::rename_column(
216                        old_column_name.value.clone(),
217                        new_column_name.value.clone(),
218                    ));
219                }
220                AlterTableOperation::RenameTable { table_name } => {
221                    let new_name = match table_name {
222                        RenameTableNameKind::To(name) | RenameTableNameKind::As(name) => {
223                            object_name_to_string(name)
224                        }
225                    };
226                    rename_to = Some(Identifier::new(identifier.database().to_string(), new_name));
227                }
228                other => {
229                    return Err(DataFusionError::Plan(format!(
230                        "Unsupported ALTER TABLE operation: {other}"
231                    )));
232                }
233            }
234        }
235
236        if let Some(new_identifier) = rename_to {
237            self.catalog
238                .rename_table(&identifier, &new_identifier, if_exists)
239                .await
240                .map_err(to_datafusion_error)?;
241        }
242
243        if !changes.is_empty() {
244            self.catalog
245                .alter_table(&identifier, changes, if_exists)
246                .await
247                .map_err(to_datafusion_error)?;
248        }
249
250        ok_result(&self.ctx)
251    }
252
253    async fn handle_merge_into(&self, merge: &Merge) -> DFResult<DataFrame> {
254        // Resolve the target table name from the MERGE INTO clause
255        let table_name = match &merge.table {
256            TableFactor::Table { name, .. } => name.clone(),
257            other => {
258                return Err(DataFusionError::Plan(format!(
259                    "Unsupported target table in MERGE INTO: {other}"
260                )))
261            }
262        };
263        let identifier = self.resolve_table_name(&table_name)?;
264
265        // Load the Paimon table from the catalog
266        let table = self
267            .catalog
268            .get_table(&identifier)
269            .await
270            .map_err(to_datafusion_error)?;
271
272        crate::merge_into::execute_merge_into(&self.ctx, merge, table).await
273    }
274
275    async fn handle_update(&self, update: &Update) -> DFResult<DataFrame> {
276        let table_name = match &update.table.relation {
277            TableFactor::Table { name, .. } => name.clone(),
278            other => {
279                return Err(DataFusionError::Plan(format!(
280                    "Unsupported target table in UPDATE: {other}"
281                )))
282            }
283        };
284        let identifier = self.resolve_table_name(&table_name)?;
285
286        let table = self
287            .catalog
288            .get_table(&identifier)
289            .await
290            .map_err(to_datafusion_error)?;
291
292        crate::update::execute_update(&self.ctx, update, table).await
293    }
294
295    /// Resolve an ObjectName like `paimon.db.table` or `db.table` to a Paimon Identifier.
296    fn resolve_table_name(&self, name: &ObjectName) -> DFResult<Identifier> {
297        let parts: Vec<String> = name
298            .0
299            .iter()
300            .filter_map(|p| p.as_ident().map(|id| id.value.clone()))
301            .collect();
302        match parts.len() {
303            3 => {
304                // catalog.database.table — strip catalog prefix
305                if parts[0] != self.catalog_name {
306                    return Err(DataFusionError::Plan(format!(
307                        "Unknown catalog '{}', expected '{}'",
308                        parts[0], self.catalog_name
309                    )));
310                }
311                Ok(Identifier::new(parts[1].clone(), parts[2].clone()))
312            }
313            2 => Ok(Identifier::new(parts[0].clone(), parts[1].clone())),
314            1 => Err(DataFusionError::Plan(format!(
315                "ALTER TABLE requires at least database.table, got: {}",
316                parts[0]
317            ))),
318            _ => Err(DataFusionError::Plan(format!(
319                "Invalid table reference: {name}"
320            ))),
321        }
322    }
323}
324
325/// Convert a sqlparser [`ColumnDef`] to a Paimon [`SchemaChange::AddColumn`].
326fn column_def_to_add_column(col: &ColumnDef) -> DFResult<SchemaChange> {
327    let arrow_type = sql_data_type_to_arrow(&col.data_type)?;
328    let nullable = !col.options.iter().any(|opt| {
329        matches!(
330            opt.option,
331            datafusion::sql::sqlparser::ast::ColumnOption::NotNull
332        )
333    });
334    let paimon_type = arrow_to_paimon_type(&arrow_type, nullable).map_err(to_datafusion_error)?;
335    Ok(SchemaChange::add_column(
336        col.name.value.clone(),
337        paimon_type,
338    ))
339}
340
341/// Convert a sqlparser SQL data type to an Arrow data type.
342fn sql_data_type_to_arrow(
343    sql_type: &datafusion::sql::sqlparser::ast::DataType,
344) -> DFResult<ArrowDataType> {
345    use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
346    match sql_type {
347        SqlType::Boolean => Ok(ArrowDataType::Boolean),
348        SqlType::TinyInt(_) => Ok(ArrowDataType::Int8),
349        SqlType::SmallInt(_) => Ok(ArrowDataType::Int16),
350        SqlType::Int(_) | SqlType::Integer(_) => Ok(ArrowDataType::Int32),
351        SqlType::BigInt(_) => Ok(ArrowDataType::Int64),
352        SqlType::Float(_) => Ok(ArrowDataType::Float32),
353        SqlType::Real => Ok(ArrowDataType::Float32),
354        SqlType::Double(_) | SqlType::DoublePrecision => Ok(ArrowDataType::Float64),
355        SqlType::Varchar(_) | SqlType::CharVarying(_) | SqlType::Text | SqlType::String(_) => {
356            Ok(ArrowDataType::Utf8)
357        }
358        SqlType::Char(_) | SqlType::Character(_) => Ok(ArrowDataType::Utf8),
359        SqlType::Binary(_) | SqlType::Varbinary(_) | SqlType::Blob(_) | SqlType::Bytea => {
360            Ok(ArrowDataType::Binary)
361        }
362        SqlType::Date => Ok(ArrowDataType::Date32),
363        SqlType::Timestamp(precision, tz_info) => {
364            use datafusion::sql::sqlparser::ast::TimezoneInfo;
365            let unit = match precision {
366                Some(0) => datafusion::arrow::datatypes::TimeUnit::Second,
367                Some(1..=3) | None => datafusion::arrow::datatypes::TimeUnit::Millisecond,
368                Some(4..=6) => datafusion::arrow::datatypes::TimeUnit::Microsecond,
369                _ => datafusion::arrow::datatypes::TimeUnit::Nanosecond,
370            };
371            let tz = match tz_info {
372                TimezoneInfo::None | TimezoneInfo::WithoutTimeZone => None,
373                _ => Some("UTC".into()),
374            };
375            Ok(ArrowDataType::Timestamp(unit, tz))
376        }
377        SqlType::Decimal(info) => {
378            use datafusion::sql::sqlparser::ast::ExactNumberInfo;
379            let (p, s) = match info {
380                ExactNumberInfo::PrecisionAndScale(p, s) => (*p as u8, *s as i8),
381                ExactNumberInfo::Precision(p) => (*p as u8, 0),
382                ExactNumberInfo::None => (10, 0),
383            };
384            Ok(ArrowDataType::Decimal128(p, s))
385        }
386        SqlType::Array(elem_def) => {
387            let elem_type = match elem_def {
388                ArrayElemTypeDef::AngleBracket(t)
389                | ArrayElemTypeDef::SquareBracket(t, _)
390                | ArrayElemTypeDef::Parenthesis(t) => sql_data_type_to_arrow(t)?,
391                ArrayElemTypeDef::None => {
392                    return Err(DataFusionError::Plan(
393                        "ARRAY type requires an element type".to_string(),
394                    ));
395                }
396            };
397            Ok(ArrowDataType::List(Arc::new(Field::new(
398                "element", elem_type, true,
399            ))))
400        }
401        SqlType::Map(key_type, value_type) => {
402            let key = sql_data_type_to_arrow(key_type)?;
403            let value = sql_data_type_to_arrow(value_type)?;
404            let entries = Field::new(
405                "entries",
406                ArrowDataType::Struct(
407                    vec![
408                        Field::new("key", key, false),
409                        Field::new("value", value, true),
410                    ]
411                    .into(),
412                ),
413                false,
414            );
415            Ok(ArrowDataType::Map(Arc::new(entries), false))
416        }
417        SqlType::Struct(fields, _) => {
418            let arrow_fields: Vec<Field> = fields
419                .iter()
420                .map(|f| {
421                    let name = f
422                        .field_name
423                        .as_ref()
424                        .map(|n| n.value.clone())
425                        .unwrap_or_default();
426                    let dt = sql_data_type_to_arrow(&f.field_type)?;
427                    Ok(Field::new(name, dt, true))
428                })
429                .collect::<DFResult<_>>()?;
430            Ok(ArrowDataType::Struct(arrow_fields.into()))
431        }
432        _ => Err(DataFusionError::Plan(format!(
433            "Unsupported SQL data type: {sql_type}"
434        ))),
435    }
436}
437
438fn object_name_to_string(name: &ObjectName) -> String {
439    name.0
440        .iter()
441        .filter_map(|p| p.as_ident().map(|id| id.value.clone()))
442        .collect::<Vec<_>>()
443        .join(".")
444}
445
446/// Extract key-value pairs from [`CreateTableOptions`].
447fn extract_options(opts: &CreateTableOptions) -> DFResult<Vec<(String, String)>> {
448    let sql_options = match opts {
449        CreateTableOptions::With(options)
450        | CreateTableOptions::Options(options)
451        | CreateTableOptions::TableProperties(options)
452        | CreateTableOptions::Plain(options) => options,
453        CreateTableOptions::None => return Ok(Vec::new()),
454    };
455    sql_options
456        .iter()
457        .map(|opt| match opt {
458            SqlOption::KeyValue { key, value } => {
459                let v = value.to_string();
460                // Strip surrounding quotes from the value if present.
461                let v = v
462                    .strip_prefix('\'')
463                    .and_then(|s| s.strip_suffix('\''))
464                    .unwrap_or(&v)
465                    .to_string();
466                Ok((key.value.clone(), v))
467            }
468            other => Err(DataFusionError::Plan(format!(
469                "Unsupported table option: {other}"
470            ))),
471        })
472        .collect()
473}
474
475/// Return an empty DataFrame with a single "result" column containing "OK".
476fn ok_result(ctx: &SessionContext) -> DFResult<DataFrame> {
477    let schema = Arc::new(Schema::new(vec![Field::new(
478        "result",
479        ArrowDataType::Utf8,
480        false,
481    )]));
482    let batch = RecordBatch::try_new(
483        schema.clone(),
484        vec![Arc::new(StringArray::from(vec!["OK"]))],
485    )?;
486    let df = ctx.read_batch(batch)?;
487    Ok(df)
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use std::collections::HashMap;
494    use std::sync::Mutex;
495
496    use async_trait::async_trait;
497    use datafusion::arrow::datatypes::TimeUnit;
498    use paimon::catalog::Database;
499    use paimon::spec::Schema as PaimonSchema;
500    use paimon::table::Table;
501
502    // ==================== Mock Catalog ====================
503
504    #[allow(clippy::enum_variant_names)]
505    #[derive(Debug)]
506    enum CatalogCall {
507        CreateTable {
508            identifier: Identifier,
509            schema: PaimonSchema,
510            ignore_if_exists: bool,
511        },
512        AlterTable {
513            identifier: Identifier,
514            changes: Vec<SchemaChange>,
515            ignore_if_not_exists: bool,
516        },
517        RenameTable {
518            from: Identifier,
519            to: Identifier,
520            ignore_if_not_exists: bool,
521        },
522    }
523
524    struct MockCatalog {
525        calls: Mutex<Vec<CatalogCall>>,
526    }
527
528    impl MockCatalog {
529        fn new() -> Self {
530            Self {
531                calls: Mutex::new(Vec::new()),
532            }
533        }
534
535        fn take_calls(&self) -> Vec<CatalogCall> {
536            std::mem::take(&mut *self.calls.lock().unwrap())
537        }
538    }
539
540    #[async_trait]
541    impl Catalog for MockCatalog {
542        async fn list_databases(&self) -> paimon::Result<Vec<String>> {
543            Ok(vec![])
544        }
545        async fn create_database(
546            &self,
547            _name: &str,
548            _ignore_if_exists: bool,
549            _properties: HashMap<String, String>,
550        ) -> paimon::Result<()> {
551            Ok(())
552        }
553        async fn get_database(&self, _name: &str) -> paimon::Result<Database> {
554            unimplemented!()
555        }
556        async fn drop_database(
557            &self,
558            _name: &str,
559            _ignore_if_not_exists: bool,
560            _cascade: bool,
561        ) -> paimon::Result<()> {
562            Ok(())
563        }
564        async fn get_table(&self, _identifier: &Identifier) -> paimon::Result<Table> {
565            unimplemented!()
566        }
567        async fn list_tables(&self, _database_name: &str) -> paimon::Result<Vec<String>> {
568            Ok(vec![])
569        }
570        async fn create_table(
571            &self,
572            identifier: &Identifier,
573            creation: PaimonSchema,
574            ignore_if_exists: bool,
575        ) -> paimon::Result<()> {
576            self.calls.lock().unwrap().push(CatalogCall::CreateTable {
577                identifier: identifier.clone(),
578                schema: creation,
579                ignore_if_exists,
580            });
581            Ok(())
582        }
583        async fn drop_table(
584            &self,
585            _identifier: &Identifier,
586            _ignore_if_not_exists: bool,
587        ) -> paimon::Result<()> {
588            Ok(())
589        }
590        async fn rename_table(
591            &self,
592            from: &Identifier,
593            to: &Identifier,
594            ignore_if_not_exists: bool,
595        ) -> paimon::Result<()> {
596            self.calls.lock().unwrap().push(CatalogCall::RenameTable {
597                from: from.clone(),
598                to: to.clone(),
599                ignore_if_not_exists,
600            });
601            Ok(())
602        }
603        async fn alter_table(
604            &self,
605            identifier: &Identifier,
606            changes: Vec<SchemaChange>,
607            ignore_if_not_exists: bool,
608        ) -> paimon::Result<()> {
609            self.calls.lock().unwrap().push(CatalogCall::AlterTable {
610                identifier: identifier.clone(),
611                changes,
612                ignore_if_not_exists,
613            });
614            Ok(())
615        }
616    }
617
618    fn make_handler(catalog: Arc<MockCatalog>) -> PaimonSqlHandler {
619        PaimonSqlHandler::new(SessionContext::new(), catalog, "paimon")
620    }
621
622    // ==================== sql_data_type_to_arrow tests ====================
623
624    #[test]
625    fn test_sql_type_boolean() {
626        use datafusion::sql::sqlparser::ast::DataType as SqlType;
627        assert_eq!(
628            sql_data_type_to_arrow(&SqlType::Boolean).unwrap(),
629            ArrowDataType::Boolean
630        );
631    }
632
633    #[test]
634    fn test_sql_type_integers() {
635        use datafusion::sql::sqlparser::ast::DataType as SqlType;
636        assert_eq!(
637            sql_data_type_to_arrow(&SqlType::TinyInt(None)).unwrap(),
638            ArrowDataType::Int8
639        );
640        assert_eq!(
641            sql_data_type_to_arrow(&SqlType::SmallInt(None)).unwrap(),
642            ArrowDataType::Int16
643        );
644        assert_eq!(
645            sql_data_type_to_arrow(&SqlType::Int(None)).unwrap(),
646            ArrowDataType::Int32
647        );
648        assert_eq!(
649            sql_data_type_to_arrow(&SqlType::Integer(None)).unwrap(),
650            ArrowDataType::Int32
651        );
652        assert_eq!(
653            sql_data_type_to_arrow(&SqlType::BigInt(None)).unwrap(),
654            ArrowDataType::Int64
655        );
656    }
657
658    #[test]
659    fn test_sql_type_floats() {
660        use datafusion::sql::sqlparser::ast::{DataType as SqlType, ExactNumberInfo};
661        assert_eq!(
662            sql_data_type_to_arrow(&SqlType::Float(ExactNumberInfo::None)).unwrap(),
663            ArrowDataType::Float32
664        );
665        assert_eq!(
666            sql_data_type_to_arrow(&SqlType::Real).unwrap(),
667            ArrowDataType::Float32
668        );
669        assert_eq!(
670            sql_data_type_to_arrow(&SqlType::DoublePrecision).unwrap(),
671            ArrowDataType::Float64
672        );
673    }
674
675    #[test]
676    fn test_sql_type_string_variants() {
677        use datafusion::sql::sqlparser::ast::DataType as SqlType;
678        for sql_type in [SqlType::Varchar(None), SqlType::Text, SqlType::String(None)] {
679            assert_eq!(
680                sql_data_type_to_arrow(&sql_type).unwrap(),
681                ArrowDataType::Utf8,
682                "failed for {sql_type:?}"
683            );
684        }
685    }
686
687    #[test]
688    fn test_sql_type_binary() {
689        use datafusion::sql::sqlparser::ast::DataType as SqlType;
690        assert_eq!(
691            sql_data_type_to_arrow(&SqlType::Bytea).unwrap(),
692            ArrowDataType::Binary
693        );
694    }
695
696    #[test]
697    fn test_sql_type_date() {
698        use datafusion::sql::sqlparser::ast::DataType as SqlType;
699        assert_eq!(
700            sql_data_type_to_arrow(&SqlType::Date).unwrap(),
701            ArrowDataType::Date32
702        );
703    }
704
705    #[test]
706    fn test_sql_type_timestamp_default() {
707        use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
708        let result = sql_data_type_to_arrow(&SqlType::Timestamp(None, TimezoneInfo::None)).unwrap();
709        assert_eq!(
710            result,
711            ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
712        );
713    }
714
715    #[test]
716    fn test_sql_type_timestamp_with_precision() {
717        use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
718        // precision 0 => Second
719        assert_eq!(
720            sql_data_type_to_arrow(&SqlType::Timestamp(Some(0), TimezoneInfo::None)).unwrap(),
721            ArrowDataType::Timestamp(TimeUnit::Second, None)
722        );
723        // precision 3 => Millisecond
724        assert_eq!(
725            sql_data_type_to_arrow(&SqlType::Timestamp(Some(3), TimezoneInfo::None)).unwrap(),
726            ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
727        );
728        // precision 6 => Microsecond
729        assert_eq!(
730            sql_data_type_to_arrow(&SqlType::Timestamp(Some(6), TimezoneInfo::None)).unwrap(),
731            ArrowDataType::Timestamp(TimeUnit::Microsecond, None)
732        );
733        // precision 9 => Nanosecond
734        assert_eq!(
735            sql_data_type_to_arrow(&SqlType::Timestamp(Some(9), TimezoneInfo::None)).unwrap(),
736            ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
737        );
738    }
739
740    #[test]
741    fn test_sql_type_timestamp_with_tz() {
742        use datafusion::sql::sqlparser::ast::{DataType as SqlType, TimezoneInfo};
743        let result =
744            sql_data_type_to_arrow(&SqlType::Timestamp(None, TimezoneInfo::WithTimeZone)).unwrap();
745        assert_eq!(
746            result,
747            ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()))
748        );
749    }
750
751    #[test]
752    fn test_sql_type_decimal() {
753        use datafusion::sql::sqlparser::ast::{DataType as SqlType, ExactNumberInfo};
754        assert_eq!(
755            sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::PrecisionAndScale(18, 2)))
756                .unwrap(),
757            ArrowDataType::Decimal128(18, 2)
758        );
759        assert_eq!(
760            sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::Precision(10))).unwrap(),
761            ArrowDataType::Decimal128(10, 0)
762        );
763        assert_eq!(
764            sql_data_type_to_arrow(&SqlType::Decimal(ExactNumberInfo::None)).unwrap(),
765            ArrowDataType::Decimal128(10, 0)
766        );
767    }
768
769    #[test]
770    fn test_sql_type_unsupported() {
771        use datafusion::sql::sqlparser::ast::DataType as SqlType;
772        assert!(sql_data_type_to_arrow(&SqlType::Regclass).is_err());
773    }
774
775    #[test]
776    fn test_sql_type_array() {
777        use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
778        let result = sql_data_type_to_arrow(&SqlType::Array(ArrayElemTypeDef::AngleBracket(
779            Box::new(SqlType::Int(None)),
780        )))
781        .unwrap();
782        assert_eq!(
783            result,
784            ArrowDataType::List(Arc::new(Field::new("element", ArrowDataType::Int32, true)))
785        );
786    }
787
788    #[test]
789    fn test_sql_type_array_no_element() {
790        use datafusion::sql::sqlparser::ast::{ArrayElemTypeDef, DataType as SqlType};
791        assert!(sql_data_type_to_arrow(&SqlType::Array(ArrayElemTypeDef::None)).is_err());
792    }
793
794    #[test]
795    fn test_sql_type_map() {
796        use datafusion::sql::sqlparser::ast::DataType as SqlType;
797        let result = sql_data_type_to_arrow(&SqlType::Map(
798            Box::new(SqlType::Varchar(None)),
799            Box::new(SqlType::Int(None)),
800        ))
801        .unwrap();
802        let expected = ArrowDataType::Map(
803            Arc::new(Field::new(
804                "entries",
805                ArrowDataType::Struct(
806                    vec![
807                        Field::new("key", ArrowDataType::Utf8, false),
808                        Field::new("value", ArrowDataType::Int32, true),
809                    ]
810                    .into(),
811                ),
812                false,
813            )),
814            false,
815        );
816        assert_eq!(result, expected);
817    }
818
819    #[test]
820    fn test_sql_type_struct() {
821        use datafusion::sql::sqlparser::ast::{
822            DataType as SqlType, Ident, StructBracketKind, StructField,
823        };
824        let result = sql_data_type_to_arrow(&SqlType::Struct(
825            vec![
826                StructField {
827                    field_name: Some(Ident::new("name")),
828                    field_type: SqlType::Varchar(None),
829                    options: None,
830                },
831                StructField {
832                    field_name: Some(Ident::new("age")),
833                    field_type: SqlType::Int(None),
834                    options: None,
835                },
836            ],
837            StructBracketKind::AngleBrackets,
838        ))
839        .unwrap();
840        assert_eq!(
841            result,
842            ArrowDataType::Struct(
843                vec![
844                    Field::new("name", ArrowDataType::Utf8, true),
845                    Field::new("age", ArrowDataType::Int32, true),
846                ]
847                .into()
848            )
849        );
850    }
851
852    // ==================== resolve_table_name tests ====================
853
854    #[test]
855    fn test_resolve_three_part_name() {
856        let catalog = Arc::new(MockCatalog::new());
857        let handler = make_handler(catalog);
858        let dialect = GenericDialect {};
859        let stmts = Parser::parse_sql(&dialect, "SELECT * FROM paimon.mydb.mytable").unwrap();
860        if let Statement::Query(q) = &stmts[0] {
861            if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
862                if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
863                    &sel.from[0].relation
864                {
865                    let id = handler.resolve_table_name(name).unwrap();
866                    assert_eq!(id.database(), "mydb");
867                    assert_eq!(id.object(), "mytable");
868                }
869            }
870        }
871    }
872
873    #[test]
874    fn test_resolve_two_part_name() {
875        let catalog = Arc::new(MockCatalog::new());
876        let handler = make_handler(catalog);
877        let dialect = GenericDialect {};
878        let stmts = Parser::parse_sql(&dialect, "SELECT * FROM mydb.mytable").unwrap();
879        if let Statement::Query(q) = &stmts[0] {
880            if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
881                if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
882                    &sel.from[0].relation
883                {
884                    let id = handler.resolve_table_name(name).unwrap();
885                    assert_eq!(id.database(), "mydb");
886                    assert_eq!(id.object(), "mytable");
887                }
888            }
889        }
890    }
891
892    #[test]
893    fn test_resolve_wrong_catalog_name() {
894        let catalog = Arc::new(MockCatalog::new());
895        let handler = make_handler(catalog);
896        let dialect = GenericDialect {};
897        let stmts = Parser::parse_sql(&dialect, "SELECT * FROM other.mydb.mytable").unwrap();
898        if let Statement::Query(q) = &stmts[0] {
899            if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
900                if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
901                    &sel.from[0].relation
902                {
903                    let err = handler.resolve_table_name(name).unwrap_err();
904                    assert!(err.to_string().contains("Unknown catalog"));
905                }
906            }
907        }
908    }
909
910    #[test]
911    fn test_resolve_single_part_name_error() {
912        let catalog = Arc::new(MockCatalog::new());
913        let handler = make_handler(catalog);
914        let dialect = GenericDialect {};
915        let stmts = Parser::parse_sql(&dialect, "SELECT * FROM mytable").unwrap();
916        if let Statement::Query(q) = &stmts[0] {
917            if let datafusion::sql::sqlparser::ast::SetExpr::Select(sel) = q.body.as_ref() {
918                if let datafusion::sql::sqlparser::ast::TableFactor::Table { name, .. } =
919                    &sel.from[0].relation
920                {
921                    let err = handler.resolve_table_name(name).unwrap_err();
922                    assert!(err.to_string().contains("at least database.table"));
923                }
924            }
925        }
926    }
927
928    // ==================== extract_options tests ====================
929
930    #[test]
931    fn test_extract_options_none() {
932        let opts = extract_options(&CreateTableOptions::None).unwrap();
933        assert!(opts.is_empty());
934    }
935
936    #[test]
937    fn test_extract_options_with_kv() {
938        // Parse a CREATE TABLE with WITH options to get a real CreateTableOptions
939        let dialect = GenericDialect {};
940        let stmts =
941            Parser::parse_sql(&dialect, "CREATE TABLE t (id INT) WITH ('bucket' = '4')").unwrap();
942        if let Statement::CreateTable(ct) = &stmts[0] {
943            let opts = extract_options(&ct.table_options).unwrap();
944            assert_eq!(opts.len(), 1);
945            assert_eq!(opts[0].0, "bucket");
946            assert_eq!(opts[0].1, "4");
947        } else {
948            panic!("expected CreateTable");
949        }
950    }
951
952    // ==================== PaimonSqlHandler::sql integration tests ====================
953
954    #[tokio::test]
955    async fn test_create_table_basic() {
956        let catalog = Arc::new(MockCatalog::new());
957        let handler = make_handler(catalog.clone());
958
959        handler
960            .sql("CREATE TABLE mydb.t1 (id INT NOT NULL, name VARCHAR, PRIMARY KEY (id))")
961            .await
962            .unwrap();
963
964        let calls = catalog.take_calls();
965        assert_eq!(calls.len(), 1);
966        if let CatalogCall::CreateTable {
967            identifier,
968            schema,
969            ignore_if_exists,
970        } = &calls[0]
971        {
972            assert_eq!(identifier.database(), "mydb");
973            assert_eq!(identifier.object(), "t1");
974            assert!(!ignore_if_exists);
975            assert_eq!(schema.primary_keys(), &["id"]);
976        } else {
977            panic!("expected CreateTable call");
978        }
979    }
980
981    #[tokio::test]
982    async fn test_create_table_if_not_exists() {
983        let catalog = Arc::new(MockCatalog::new());
984        let handler = make_handler(catalog.clone());
985
986        handler
987            .sql("CREATE TABLE IF NOT EXISTS mydb.t1 (id INT)")
988            .await
989            .unwrap();
990
991        let calls = catalog.take_calls();
992        assert_eq!(calls.len(), 1);
993        if let CatalogCall::CreateTable {
994            ignore_if_exists, ..
995        } = &calls[0]
996        {
997            assert!(ignore_if_exists);
998        } else {
999            panic!("expected CreateTable call");
1000        }
1001    }
1002
1003    #[tokio::test]
1004    async fn test_create_table_with_options() {
1005        let catalog = Arc::new(MockCatalog::new());
1006        let handler = make_handler(catalog.clone());
1007
1008        handler
1009            .sql("CREATE TABLE mydb.t1 (id INT) WITH ('bucket' = '4', 'file.format' = 'parquet')")
1010            .await
1011            .unwrap();
1012
1013        let calls = catalog.take_calls();
1014        assert_eq!(calls.len(), 1);
1015        if let CatalogCall::CreateTable { schema, .. } = &calls[0] {
1016            let opts = schema.options();
1017            assert_eq!(opts.get("bucket").unwrap(), "4");
1018            assert_eq!(opts.get("file.format").unwrap(), "parquet");
1019        } else {
1020            panic!("expected CreateTable call");
1021        }
1022    }
1023
1024    #[tokio::test]
1025    async fn test_create_table_three_part_name() {
1026        let catalog = Arc::new(MockCatalog::new());
1027        let handler = make_handler(catalog.clone());
1028
1029        handler
1030            .sql("CREATE TABLE paimon.mydb.t1 (id INT)")
1031            .await
1032            .unwrap();
1033
1034        let calls = catalog.take_calls();
1035        if let CatalogCall::CreateTable { identifier, .. } = &calls[0] {
1036            assert_eq!(identifier.database(), "mydb");
1037            assert_eq!(identifier.object(), "t1");
1038        } else {
1039            panic!("expected CreateTable call");
1040        }
1041    }
1042
1043    #[tokio::test]
1044    async fn test_alter_table_add_column() {
1045        let catalog = Arc::new(MockCatalog::new());
1046        let handler = make_handler(catalog.clone());
1047
1048        handler
1049            .sql("ALTER TABLE mydb.t1 ADD COLUMN age INT")
1050            .await
1051            .unwrap();
1052
1053        let calls = catalog.take_calls();
1054        assert_eq!(calls.len(), 1);
1055        if let CatalogCall::AlterTable {
1056            identifier,
1057            changes,
1058            ..
1059        } = &calls[0]
1060        {
1061            assert_eq!(identifier.database(), "mydb");
1062            assert_eq!(identifier.object(), "t1");
1063            assert_eq!(changes.len(), 1);
1064            assert!(
1065                matches!(&changes[0], SchemaChange::AddColumn { field_name, .. } if field_name == "age")
1066            );
1067        } else {
1068            panic!("expected AlterTable call");
1069        }
1070    }
1071
1072    #[tokio::test]
1073    async fn test_alter_table_drop_column() {
1074        let catalog = Arc::new(MockCatalog::new());
1075        let handler = make_handler(catalog.clone());
1076
1077        handler
1078            .sql("ALTER TABLE mydb.t1 DROP COLUMN age")
1079            .await
1080            .unwrap();
1081
1082        let calls = catalog.take_calls();
1083        assert_eq!(calls.len(), 1);
1084        if let CatalogCall::AlterTable { changes, .. } = &calls[0] {
1085            assert_eq!(changes.len(), 1);
1086            assert!(
1087                matches!(&changes[0], SchemaChange::DropColumn { field_name } if field_name == "age")
1088            );
1089        } else {
1090            panic!("expected AlterTable call");
1091        }
1092    }
1093
1094    #[tokio::test]
1095    async fn test_alter_table_rename_column() {
1096        let catalog = Arc::new(MockCatalog::new());
1097        let handler = make_handler(catalog.clone());
1098
1099        handler
1100            .sql("ALTER TABLE mydb.t1 RENAME COLUMN old_name TO new_name")
1101            .await
1102            .unwrap();
1103
1104        let calls = catalog.take_calls();
1105        assert_eq!(calls.len(), 1);
1106        if let CatalogCall::AlterTable { changes, .. } = &calls[0] {
1107            assert_eq!(changes.len(), 1);
1108            assert!(matches!(
1109                &changes[0],
1110                SchemaChange::RenameColumn { field_name, new_name }
1111                    if field_name == "old_name" && new_name == "new_name"
1112            ));
1113        } else {
1114            panic!("expected AlterTable call");
1115        }
1116    }
1117
1118    #[tokio::test]
1119    async fn test_alter_table_rename_table() {
1120        let catalog = Arc::new(MockCatalog::new());
1121        let handler = make_handler(catalog.clone());
1122
1123        handler
1124            .sql("ALTER TABLE mydb.t1 RENAME TO t2")
1125            .await
1126            .unwrap();
1127
1128        let calls = catalog.take_calls();
1129        assert_eq!(calls.len(), 1);
1130        if let CatalogCall::RenameTable { from, to, .. } = &calls[0] {
1131            assert_eq!(from.database(), "mydb");
1132            assert_eq!(from.object(), "t1");
1133            assert_eq!(to.database(), "mydb");
1134            assert_eq!(to.object(), "t2");
1135        } else {
1136            panic!("expected RenameTable call");
1137        }
1138    }
1139
1140    #[tokio::test]
1141    async fn test_alter_table_if_exists_add_column() {
1142        let catalog = Arc::new(MockCatalog::new());
1143        let handler = make_handler(catalog.clone());
1144
1145        handler
1146            .sql("ALTER TABLE IF EXISTS mydb.t1 ADD COLUMN age INT")
1147            .await
1148            .unwrap();
1149
1150        let calls = catalog.take_calls();
1151        assert_eq!(calls.len(), 1);
1152        if let CatalogCall::AlterTable {
1153            ignore_if_not_exists,
1154            ..
1155        } = &calls[0]
1156        {
1157            assert!(ignore_if_not_exists);
1158        } else {
1159            panic!("expected AlterTable call");
1160        }
1161    }
1162
1163    #[tokio::test]
1164    async fn test_alter_table_without_if_exists() {
1165        let catalog = Arc::new(MockCatalog::new());
1166        let handler = make_handler(catalog.clone());
1167
1168        handler
1169            .sql("ALTER TABLE mydb.t1 ADD COLUMN age INT")
1170            .await
1171            .unwrap();
1172
1173        let calls = catalog.take_calls();
1174        if let CatalogCall::AlterTable {
1175            ignore_if_not_exists,
1176            ..
1177        } = &calls[0]
1178        {
1179            assert!(!ignore_if_not_exists);
1180        } else {
1181            panic!("expected AlterTable call");
1182        }
1183    }
1184
1185    #[tokio::test]
1186    async fn test_alter_table_if_exists_rename() {
1187        let catalog = Arc::new(MockCatalog::new());
1188        let handler = make_handler(catalog.clone());
1189
1190        handler
1191            .sql("ALTER TABLE IF EXISTS mydb.t1 RENAME TO t2")
1192            .await
1193            .unwrap();
1194
1195        let calls = catalog.take_calls();
1196        assert_eq!(calls.len(), 1);
1197        if let CatalogCall::RenameTable {
1198            from,
1199            to,
1200            ignore_if_not_exists,
1201        } = &calls[0]
1202        {
1203            assert!(ignore_if_not_exists);
1204            assert_eq!(from.object(), "t1");
1205            assert_eq!(to.object(), "t2");
1206        } else {
1207            panic!("expected RenameTable call");
1208        }
1209    }
1210
1211    #[tokio::test]
1212    async fn test_alter_table_rename_three_part_name() {
1213        let catalog = Arc::new(MockCatalog::new());
1214        let handler = make_handler(catalog.clone());
1215
1216        handler
1217            .sql("ALTER TABLE paimon.mydb.t1 RENAME TO t2")
1218            .await
1219            .unwrap();
1220
1221        let calls = catalog.take_calls();
1222        assert_eq!(calls.len(), 1);
1223        if let CatalogCall::RenameTable { from, to, .. } = &calls[0] {
1224            assert_eq!(from.database(), "mydb");
1225            assert_eq!(from.object(), "t1");
1226            assert_eq!(to.database(), "mydb");
1227            assert_eq!(to.object(), "t2");
1228        } else {
1229            panic!("expected RenameTable call");
1230        }
1231    }
1232
1233    #[tokio::test]
1234    async fn test_sql_parse_error() {
1235        let catalog = Arc::new(MockCatalog::new());
1236        let handler = make_handler(catalog);
1237        let result = handler.sql("NOT VALID SQL !!!").await;
1238        assert!(result.is_err());
1239        assert!(result.unwrap_err().to_string().contains("SQL parse error"));
1240    }
1241
1242    #[tokio::test]
1243    async fn test_multiple_statements_error() {
1244        let catalog = Arc::new(MockCatalog::new());
1245        let handler = make_handler(catalog);
1246        let result = handler.sql("SELECT 1; SELECT 2").await;
1247        assert!(result.is_err());
1248        assert!(result
1249            .unwrap_err()
1250            .to_string()
1251            .contains("exactly one SQL statement"));
1252    }
1253
1254    #[tokio::test]
1255    async fn test_create_external_table_rejected() {
1256        let catalog = Arc::new(MockCatalog::new());
1257        let handler = make_handler(catalog);
1258        let result = handler
1259            .sql("CREATE EXTERNAL TABLE mydb.t1 (id INT) STORED AS PARQUET")
1260            .await;
1261        assert!(result.is_err());
1262        assert!(result
1263            .unwrap_err()
1264            .to_string()
1265            .contains("CREATE EXTERNAL TABLE is not supported"));
1266    }
1267
1268    #[tokio::test]
1269    async fn test_non_ddl_delegates_to_datafusion() {
1270        let catalog = Arc::new(MockCatalog::new());
1271        let handler = make_handler(catalog.clone());
1272        // SELECT should be delegated to DataFusion, not intercepted
1273        let df = handler.sql("SELECT 1 AS x").await.unwrap();
1274        let batches = df.collect().await.unwrap();
1275        assert_eq!(batches.len(), 1);
1276        assert_eq!(batches[0].num_rows(), 1);
1277        // No catalog calls
1278        assert!(catalog.take_calls().is_empty());
1279    }
1280}