datafusion_table_providers/sql/arrow_sql_gen/
statement.rs

1use bigdecimal::BigDecimal;
2use chrono::{DateTime, Offset, TimeZone};
3use datafusion::arrow::{
4    array::{
5        array, timezone::Tz, Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
6        Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
7        StringViewArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
8    },
9    datatypes::{DataType, Field, Fields, IntervalUnit, Schema, SchemaRef, TimeUnit},
10    util::display::array_value_to_string,
11};
12use datafusion::sql::TableReference;
13use num_bigint::BigInt;
14use sea_query::{
15    Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden,
16    IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query,
17    QueryBuilder, SeaRc, SimpleExpr, SqliteQueryBuilder, Table, TableRef,
18};
19use snafu::Snafu;
20use std::{str::FromStr, sync::Arc};
21use time::{OffsetDateTime, PrimitiveDateTime};
22
23#[derive(Debug, Snafu)]
24pub enum Error {
25    #[snafu(display("Failed to build insert statement: {source}"))]
26    FailedToCreateInsertStatement {
27        source: Box<dyn std::error::Error + Send + Sync>,
28    },
29
30    #[snafu(display("Unimplemented data type in insert statement: {data_type:?}"))]
31    UnimplementedDataTypeInInsertStatement { data_type: DataType },
32}
33
34pub type Result<T, E = Error> = std::result::Result<T, E>;
35
36pub struct CreateTableBuilder {
37    schema: SchemaRef,
38    table_name: String,
39    primary_keys: Vec<String>,
40    temporary: bool,
41}
42
43impl CreateTableBuilder {
44    #[must_use]
45    pub fn new(schema: SchemaRef, table_name: &str) -> Self {
46        Self {
47            schema,
48            table_name: table_name.to_string(),
49            primary_keys: Vec::new(),
50            temporary: false,
51        }
52    }
53
54    #[must_use]
55    pub fn primary_keys<T>(mut self, keys: Vec<T>) -> Self
56    where
57        T: Into<String>,
58    {
59        self.primary_keys = keys.into_iter().map(Into::into).collect();
60        self
61    }
62
63    #[must_use]
64    /// Set whether the table is temporary or not.
65    pub fn temporary(mut self, temporary: bool) -> Self {
66        self.temporary = temporary;
67        self
68    }
69
70    #[must_use]
71    #[cfg(feature = "postgres")]
72    pub fn build_postgres(self) -> Vec<String> {
73        use crate::sql::arrow_sql_gen::postgres::{
74            builder::TypeBuilder, get_postgres_composite_type_name,
75            map_data_type_to_column_type_postgres,
76        };
77        let schema = Arc::clone(&self.schema);
78        let table_name = self.table_name.clone();
79        let main_table_creation =
80            self.build(PostgresQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
81                map_data_type_to_column_type_postgres(f.data_type(), &table_name, f.name())
82            });
83
84        // Postgres supports composite types (i.e. Structs) but needs to have the type defined first
85        // https://www.postgresql.org/docs/current/rowtypes.html
86        let mut creation_stmts = Vec::new();
87        for field in schema.fields() {
88            let DataType::Struct(struct_inner_fields) = field.data_type() else {
89                continue;
90            };
91            let type_builder = TypeBuilder::new(
92                get_postgres_composite_type_name(&table_name, field.name()),
93                struct_inner_fields,
94            );
95            creation_stmts.push(type_builder.build());
96        }
97
98        creation_stmts.push(main_table_creation);
99        creation_stmts
100    }
101
102    #[must_use]
103    pub fn build_sqlite(self) -> String {
104        self.build(SqliteQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
105            // Sqlite does not natively support Arrays, Structs, etc
106            // so we use JSON column type for List, FixedSizeList, LargeList, Struct, etc
107            if f.data_type().is_nested() {
108                return ColumnType::JsonBinary;
109            }
110
111            map_data_type_to_column_type(f.data_type())
112        })
113    }
114
115    #[must_use]
116    pub fn build_mysql(self) -> String {
117        self.build(MysqlQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
118            // MySQL does not natively support Arrays, Structs, etc
119            // so we use JSON column type for List, FixedSizeList, LargeList, Struct, etc
120            if f.data_type().is_nested() {
121                return ColumnType::JsonBinary;
122            }
123            map_data_type_to_column_type(f.data_type())
124        })
125    }
126
127    #[must_use]
128    fn build<T: GenericBuilder>(
129        self,
130        query_builder: T,
131        map_data_type_to_column_type_fn: &dyn Fn(&Arc<Field>) -> ColumnType,
132    ) -> String {
133        let mut create_stmt = Table::create();
134        create_stmt
135            .table(Alias::new(self.table_name.clone()))
136            .if_not_exists();
137
138        for field in self.schema.fields() {
139            let column_type = map_data_type_to_column_type_fn(field);
140            let mut column_def = ColumnDef::new_with_type(Alias::new(field.name()), column_type);
141            if !field.is_nullable() {
142                column_def.not_null();
143            }
144
145            create_stmt.col(&mut column_def);
146        }
147
148        if !self.primary_keys.is_empty() {
149            let mut index = Index::create();
150            index.primary();
151            for key in self.primary_keys {
152                index.col(Alias::new(key).into_iden().into_index_column());
153            }
154            create_stmt.primary_key(&mut index);
155        }
156
157        if self.temporary {
158            create_stmt.temporary();
159        }
160
161        create_stmt.to_string(query_builder)
162    }
163}
164
165macro_rules! push_value {
166    ($row_values:expr, $column:expr, $row:expr, $array_type:ident) => {{
167        let array = $column.as_any().downcast_ref::<array::$array_type>();
168        if let Some(valid_array) = array {
169            if valid_array.is_null($row) {
170                $row_values.push(Keyword::Null.into());
171                continue;
172            }
173            $row_values.push(valid_array.value($row).into());
174        }
175    }};
176}
177
178macro_rules! push_list_values {
179    ($data_type:expr, $list_array:expr, $row_values:expr, $array_type:ty, $vec_type:ty, $sql_type:expr) => {{
180        let mut list_values: Vec<$vec_type> = Vec::new();
181        for i in 0..$list_array.len() {
182            let temp_array = $list_array.as_any().downcast_ref::<$array_type>();
183            if let Some(valid_array) = temp_array {
184                list_values.push(valid_array.value(i));
185            }
186        }
187        let expr: SimpleExpr = list_values.into();
188        // We must cast here in case the array is empty which SeaQuery does not handle.
189        $row_values.push(expr.cast_as(Alias::new($sql_type)));
190    }};
191}
192
193pub struct InsertBuilder {
194    table: TableReference,
195    record_batches: Vec<RecordBatch>,
196}
197
198pub fn use_json_insert_for_type<T: QueryBuilder + 'static>(
199    #[allow(unused_variables)] data_type: &DataType,
200    #[allow(unused_variables)] query_builder: &T,
201) -> bool {
202    #[cfg(feature = "sqlite")]
203    {
204        use std::any::Any;
205        let any_builder = query_builder as &dyn Any;
206        if any_builder.is::<SqliteQueryBuilder>() {
207            return data_type.is_nested();
208        }
209    }
210    #[cfg(feature = "mysql")]
211    {
212        use std::any::Any;
213        let any_builder = query_builder as &dyn Any;
214        if any_builder.is::<MysqlQueryBuilder>() {
215            return data_type.is_nested();
216        }
217    }
218    false
219}
220
221impl InsertBuilder {
222    #[must_use]
223    pub fn new(table: &TableReference, record_batches: Vec<RecordBatch>) -> Self {
224        Self {
225            table: table.clone(),
226            record_batches,
227        }
228    }
229
230    /// Create an Insert statement from a `RecordBatch`.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if a column's data type is not supported, or its conversion failed.
235    #[allow(clippy::too_many_lines)]
236    pub fn construct_insert_stmt<T: QueryBuilder + 'static>(
237        &self,
238        insert_stmt: &mut InsertStatement,
239        record_batch: &RecordBatch,
240        query_builder: &T,
241    ) -> Result<()> {
242        for row in 0..record_batch.num_rows() {
243            let mut row_values: Vec<SimpleExpr> = vec![];
244            for col in 0..record_batch.num_columns() {
245                let column = record_batch.column(col);
246                let column_data_type = column.data_type();
247
248                match column_data_type {
249                    DataType::Int8 => push_value!(row_values, column, row, Int8Array),
250                    DataType::Int16 => push_value!(row_values, column, row, Int16Array),
251                    DataType::Int32 => push_value!(row_values, column, row, Int32Array),
252                    DataType::Int64 => push_value!(row_values, column, row, Int64Array),
253                    DataType::UInt8 => push_value!(row_values, column, row, UInt8Array),
254                    DataType::UInt16 => push_value!(row_values, column, row, UInt16Array),
255                    DataType::UInt32 => push_value!(row_values, column, row, UInt32Array),
256                    DataType::UInt64 => push_value!(row_values, column, row, UInt64Array),
257                    DataType::Float32 => push_value!(row_values, column, row, Float32Array),
258                    DataType::Float64 => push_value!(row_values, column, row, Float64Array),
259                    DataType::Utf8 => push_value!(row_values, column, row, StringArray),
260                    DataType::LargeUtf8 => push_value!(row_values, column, row, LargeStringArray),
261                    DataType::Utf8View => push_value!(row_values, column, row, StringViewArray),
262                    DataType::Boolean => push_value!(row_values, column, row, BooleanArray),
263                    DataType::Decimal128(_, scale) => {
264                        let array = column.as_any().downcast_ref::<array::Decimal128Array>();
265                        if let Some(valid_array) = array {
266                            if valid_array.is_null(row) {
267                                row_values.push(Keyword::Null.into());
268                                continue;
269                            }
270                            row_values.push(
271                                BigDecimal::new(valid_array.value(row).into(), i64::from(*scale))
272                                    .into(),
273                            );
274                        }
275                    }
276                    DataType::Decimal256(_, scale) => {
277                        let array = column.as_any().downcast_ref::<array::Decimal256Array>();
278                        if let Some(valid_array) = array {
279                            if valid_array.is_null(row) {
280                                row_values.push(Keyword::Null.into());
281                                continue;
282                            }
283
284                            let bigint =
285                                BigInt::from_signed_bytes_le(&valid_array.value(row).to_le_bytes());
286
287                            row_values.push(BigDecimal::new(bigint, i64::from(*scale)).into());
288                        }
289                    }
290                    DataType::Date32 => {
291                        let array = column.as_any().downcast_ref::<array::Date32Array>();
292                        if let Some(valid_array) = array {
293                            if valid_array.is_null(row) {
294                                row_values.push(Keyword::Null.into());
295                                continue;
296                            }
297                            row_values.push(
298                                match OffsetDateTime::from_unix_timestamp(
299                                    i64::from(valid_array.value(row)) * 86_400,
300                                ) {
301                                    Ok(offset_time) => offset_time.date().into(),
302                                    Err(e) => {
303                                        return Result::Err(Error::FailedToCreateInsertStatement {
304                                            source: Box::new(e),
305                                        })
306                                    }
307                                },
308                            );
309                        }
310                    }
311                    DataType::Date64 => {
312                        let array = column.as_any().downcast_ref::<array::Date64Array>();
313                        if let Some(valid_array) = array {
314                            if valid_array.is_null(row) {
315                                row_values.push(Keyword::Null.into());
316                                continue;
317                            }
318                            row_values.push(
319                                match OffsetDateTime::from_unix_timestamp(
320                                    valid_array.value(row) / 1000,
321                                ) {
322                                    Ok(offset_time) => offset_time.date().into(),
323                                    Err(e) => {
324                                        return Result::Err(Error::FailedToCreateInsertStatement {
325                                            source: Box::new(e),
326                                        })
327                                    }
328                                },
329                            );
330                        }
331                    }
332                    DataType::Duration(time_unit) => match time_unit {
333                        TimeUnit::Second => {
334                            push_value!(row_values, column, row, DurationSecondArray);
335                        }
336                        TimeUnit::Microsecond => {
337                            push_value!(row_values, column, row, DurationMicrosecondArray);
338                        }
339                        TimeUnit::Millisecond => {
340                            push_value!(row_values, column, row, DurationMillisecondArray);
341                        }
342                        TimeUnit::Nanosecond => {
343                            push_value!(row_values, column, row, DurationNanosecondArray);
344                        }
345                    },
346                    DataType::Time32(time_unit) => match time_unit {
347                        TimeUnit::Millisecond => {
348                            let array = column
349                                .as_any()
350                                .downcast_ref::<array::Time32MillisecondArray>();
351                            if let Some(valid_array) = array {
352                                if valid_array.is_null(row) {
353                                    row_values.push(Keyword::Null.into());
354                                    continue;
355                                }
356
357                                let (h, m, s, micro) =
358                                    match OffsetDateTime::from_unix_timestamp_nanos(
359                                        i128::from(valid_array.value(row)) * 1_000_000,
360                                    ) {
361                                        Ok(timestamp) => timestamp.to_hms_micro(),
362                                        Err(e) => {
363                                            return Result::Err(
364                                                Error::FailedToCreateInsertStatement {
365                                                    source: Box::new(e),
366                                                },
367                                            )
368                                        }
369                                    };
370
371                                let time = match time::Time::from_hms_micro(h, m, s, micro) {
372                                    Ok(value) => value,
373                                    Err(e) => {
374                                        return Result::Err(Error::FailedToCreateInsertStatement {
375                                            source: Box::new(e),
376                                        })
377                                    }
378                                };
379
380                                row_values.push(time.into());
381                            }
382                        }
383                        TimeUnit::Second => {
384                            let array = column.as_any().downcast_ref::<array::Time32SecondArray>();
385                            if let Some(valid_array) = array {
386                                if valid_array.is_null(row) {
387                                    row_values.push(Keyword::Null.into());
388                                    continue;
389                                }
390
391                                let (h, m, s) = match OffsetDateTime::from_unix_timestamp(
392                                    i64::from(valid_array.value(row)),
393                                ) {
394                                    Ok(timestamp) => timestamp.to_hms(),
395                                    Err(e) => {
396                                        return Result::Err(Error::FailedToCreateInsertStatement {
397                                            source: Box::new(e),
398                                        })
399                                    }
400                                };
401
402                                let time = match time::Time::from_hms(h, m, s) {
403                                    Ok(value) => value,
404                                    Err(e) => {
405                                        return Result::Err(Error::FailedToCreateInsertStatement {
406                                            source: Box::new(e),
407                                        })
408                                    }
409                                };
410
411                                row_values.push(time.into());
412                            }
413                        }
414                        _ => unreachable!(),
415                    },
416                    DataType::Time64(time_unit) => match time_unit {
417                        TimeUnit::Nanosecond => {
418                            let array = column
419                                .as_any()
420                                .downcast_ref::<array::Time64NanosecondArray>();
421                            if let Some(valid_array) = array {
422                                if valid_array.is_null(row) {
423                                    row_values.push(Keyword::Null.into());
424                                    continue;
425                                }
426                                let (h, m, s, nano) =
427                                    match OffsetDateTime::from_unix_timestamp_nanos(i128::from(
428                                        valid_array.value(row),
429                                    )) {
430                                        Ok(timestamp) => timestamp.to_hms_nano(),
431                                        Err(e) => {
432                                            return Result::Err(
433                                                Error::FailedToCreateInsertStatement {
434                                                    source: Box::new(e),
435                                                },
436                                            )
437                                        }
438                                    };
439
440                                let time = match time::Time::from_hms_nano(h, m, s, nano) {
441                                    Ok(value) => value,
442                                    Err(e) => {
443                                        return Result::Err(Error::FailedToCreateInsertStatement {
444                                            source: Box::new(e),
445                                        })
446                                    }
447                                };
448
449                                row_values.push(time.into());
450                            }
451                        }
452                        TimeUnit::Microsecond => {
453                            let array = column
454                                .as_any()
455                                .downcast_ref::<array::Time64MicrosecondArray>();
456                            if let Some(valid_array) = array {
457                                if valid_array.is_null(row) {
458                                    row_values.push(Keyword::Null.into());
459                                    continue;
460                                }
461
462                                let (h, m, s, micro) =
463                                    match OffsetDateTime::from_unix_timestamp_nanos(
464                                        i128::from(valid_array.value(row)) * 1_000,
465                                    ) {
466                                        Ok(timestamp) => timestamp.to_hms_micro(),
467                                        Err(e) => {
468                                            return Result::Err(
469                                                Error::FailedToCreateInsertStatement {
470                                                    source: Box::new(e),
471                                                },
472                                            )
473                                        }
474                                    };
475
476                                let time = match time::Time::from_hms_micro(h, m, s, micro) {
477                                    Ok(value) => value,
478                                    Err(e) => {
479                                        return Result::Err(Error::FailedToCreateInsertStatement {
480                                            source: Box::new(e),
481                                        })
482                                    }
483                                };
484
485                                row_values.push(time.into());
486                            }
487                        }
488                        _ => unreachable!(),
489                    },
490                    DataType::Timestamp(TimeUnit::Second, timezone) => {
491                        let array = column
492                            .as_any()
493                            .downcast_ref::<array::TimestampSecondArray>();
494
495                        if let Some(valid_array) = array {
496                            if valid_array.is_null(row) {
497                                row_values.push(Keyword::Null.into());
498                                continue;
499                            }
500                            if let Some(timezone) = timezone {
501                                let utc_time = DateTime::from_timestamp_nanos(
502                                    valid_array.value(row) * 1_000_000_000,
503                                )
504                                .to_utc();
505                                let timezone = Tz::from_str(timezone).map_err(|_| {
506                                    Error::FailedToCreateInsertStatement {
507                                        source: "Unable to parse arrow timezone information".into(),
508                                    }
509                                })?;
510                                let offset = timezone
511                                    .offset_from_utc_datetime(&utc_time.naive_utc())
512                                    .fix();
513                                let time_with_offset = utc_time.with_timezone(&offset);
514                                row_values.push(time_with_offset.into());
515                            } else {
516                                insert_timestamp_into_row_values(
517                                    OffsetDateTime::from_unix_timestamp(valid_array.value(row)),
518                                    &mut row_values,
519                                )?;
520                            }
521                        }
522                    }
523                    DataType::Timestamp(TimeUnit::Millisecond, timezone) => {
524                        let array = column
525                            .as_any()
526                            .downcast_ref::<array::TimestampMillisecondArray>();
527
528                        if let Some(valid_array) = array {
529                            if valid_array.is_null(row) {
530                                row_values.push(Keyword::Null.into());
531                                continue;
532                            }
533                            if let Some(timezone) = timezone {
534                                let utc_time = DateTime::from_timestamp_nanos(
535                                    valid_array.value(row) * 1_000_000,
536                                )
537                                .to_utc();
538                                let timezone = Tz::from_str(timezone).map_err(|_| {
539                                    Error::FailedToCreateInsertStatement {
540                                        source: "Unable to parse arrow timezone information".into(),
541                                    }
542                                })?;
543                                let offset = timezone
544                                    .offset_from_utc_datetime(&utc_time.naive_utc())
545                                    .fix();
546                                let time_with_offset = utc_time.with_timezone(&offset);
547                                row_values.push(time_with_offset.into());
548                            } else {
549                                insert_timestamp_into_row_values(
550                                    OffsetDateTime::from_unix_timestamp_nanos(
551                                        i128::from(valid_array.value(row)) * 1_000_000,
552                                    ),
553                                    &mut row_values,
554                                )?;
555                            }
556                        }
557                    }
558                    DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
559                        let array = column
560                            .as_any()
561                            .downcast_ref::<array::TimestampMicrosecondArray>();
562
563                        if let Some(valid_array) = array {
564                            if valid_array.is_null(row) {
565                                row_values.push(Keyword::Null.into());
566                                continue;
567                            }
568                            if let Some(timezone) = timezone {
569                                let utc_time =
570                                    DateTime::from_timestamp_nanos(valid_array.value(row) * 1_000)
571                                        .to_utc();
572                                let timezone = Tz::from_str(timezone).map_err(|_| {
573                                    Error::FailedToCreateInsertStatement {
574                                        source: "Unable to parse arrow timezone information".into(),
575                                    }
576                                })?;
577                                let offset = timezone
578                                    .offset_from_utc_datetime(&utc_time.naive_utc())
579                                    .fix();
580                                let time_with_offset = utc_time.with_timezone(&offset);
581                                row_values.push(time_with_offset.into());
582                            } else {
583                                insert_timestamp_into_row_values(
584                                    OffsetDateTime::from_unix_timestamp_nanos(
585                                        i128::from(valid_array.value(row)) * 1_000,
586                                    ),
587                                    &mut row_values,
588                                )?;
589                            }
590                        }
591                    }
592                    DataType::Timestamp(TimeUnit::Nanosecond, timezone) => {
593                        let array = column
594                            .as_any()
595                            .downcast_ref::<array::TimestampNanosecondArray>();
596
597                        if let Some(valid_array) = array {
598                            if valid_array.is_null(row) {
599                                row_values.push(Keyword::Null.into());
600                                continue;
601                            }
602                            if let Some(timezone) = timezone {
603                                let utc_time =
604                                    DateTime::from_timestamp_nanos(valid_array.value(row)).to_utc();
605                                let timezone = Tz::from_str(timezone).map_err(|_| {
606                                    Error::FailedToCreateInsertStatement {
607                                        source: "Unable to parse arrow timezone information".into(),
608                                    }
609                                })?;
610                                let offset = timezone
611                                    .offset_from_utc_datetime(&utc_time.naive_utc())
612                                    .fix();
613                                let time_with_offset = utc_time.with_timezone(&offset);
614                                row_values.push(time_with_offset.into());
615                            } else {
616                                insert_timestamp_into_row_values(
617                                    OffsetDateTime::from_unix_timestamp_nanos(i128::from(
618                                        valid_array.value(row),
619                                    )),
620                                    &mut row_values,
621                                )?;
622                            }
623                        }
624                    }
625                    DataType::List(list_type) => {
626                        let array = column.as_any().downcast_ref::<array::ListArray>();
627                        if let Some(valid_array) = array {
628                            if valid_array.is_null(row) {
629                                row_values.push(Keyword::Null.into());
630                                continue;
631                            }
632                            let list_array = valid_array.value(row);
633
634                            if use_json_insert_for_type(column_data_type, query_builder) {
635                                insert_list_into_row_values_json(
636                                    list_array,
637                                    list_type,
638                                    &mut row_values,
639                                )?;
640                            } else {
641                                insert_list_into_row_values(list_array, list_type, &mut row_values);
642                            }
643                        }
644                    }
645                    DataType::LargeList(list_type) => {
646                        let array = column.as_any().downcast_ref::<array::LargeListArray>();
647                        if let Some(valid_array) = array {
648                            if valid_array.is_null(row) {
649                                row_values.push(Keyword::Null.into());
650                                continue;
651                            }
652                            let list_array = valid_array.value(row);
653
654                            if use_json_insert_for_type(column_data_type, query_builder) {
655                                insert_list_into_row_values_json(
656                                    list_array,
657                                    list_type,
658                                    &mut row_values,
659                                )?;
660                            } else {
661                                insert_list_into_row_values(list_array, list_type, &mut row_values);
662                            }
663                        }
664                    }
665                    DataType::FixedSizeList(list_type, _) => {
666                        let array = column.as_any().downcast_ref::<array::FixedSizeListArray>();
667                        if let Some(valid_array) = array {
668                            if valid_array.is_null(row) {
669                                row_values.push(Keyword::Null.into());
670                                continue;
671                            }
672                            let list_array = valid_array.value(row);
673
674                            if use_json_insert_for_type(column_data_type, query_builder) {
675                                insert_list_into_row_values_json(
676                                    list_array,
677                                    list_type,
678                                    &mut row_values,
679                                )?;
680                            } else {
681                                insert_list_into_row_values(list_array, list_type, &mut row_values);
682                            }
683                        }
684                    }
685                    DataType::Binary => {
686                        let array = column.as_any().downcast_ref::<array::BinaryArray>();
687
688                        if let Some(valid_array) = array {
689                            if valid_array.is_null(row) {
690                                row_values.push(Keyword::Null.into());
691                                continue;
692                            }
693
694                            row_values.push(valid_array.value(row).into());
695                        }
696                    }
697                    DataType::LargeBinary => {
698                        let array = column.as_any().downcast_ref::<array::LargeBinaryArray>();
699
700                        if let Some(valid_array) = array {
701                            if valid_array.is_null(row) {
702                                row_values.push(Keyword::Null.into());
703                                continue;
704                            }
705
706                            row_values.push(valid_array.value(row).into());
707                        }
708                    }
709                    DataType::FixedSizeBinary(_) => {
710                        let array = column
711                            .as_any()
712                            .downcast_ref::<array::FixedSizeBinaryArray>();
713
714                        if let Some(valid_array) = array {
715                            if valid_array.is_null(row) {
716                                row_values.push(Keyword::Null.into());
717                                continue;
718                            }
719
720                            row_values.push(valid_array.value(row).into());
721                        }
722                    }
723                    DataType::BinaryView => {
724                        let array = column.as_any().downcast_ref::<array::BinaryViewArray>();
725
726                        if let Some(valid_array) = array {
727                            if valid_array.is_null(row) {
728                                row_values.push(Keyword::Null.into());
729                                continue;
730                            }
731
732                            row_values.push(valid_array.value(row).into());
733                        }
734                    }
735                    DataType::Interval(interval_unit) => match interval_unit {
736                        IntervalUnit::DayTime => {
737                            let array = column
738                                .as_any()
739                                .downcast_ref::<array::IntervalDayTimeArray>();
740
741                            if let Some(valid_array) = array {
742                                if valid_array.is_null(row) {
743                                    row_values.push(Keyword::Null.into());
744                                    continue;
745                                }
746
747                                let interval_str =
748                                    if let Ok(str) = array_value_to_string(valid_array, row) {
749                                        str
750                                    } else {
751                                        let days = valid_array.value(row).days;
752                                        let milliseconds = valid_array.value(row).milliseconds;
753                                        format!("{days} days {milliseconds} milliseconds")
754                                    };
755
756                                row_values.push(interval_str.into());
757                            }
758                        }
759                        IntervalUnit::YearMonth => {
760                            let array = column
761                                .as_any()
762                                .downcast_ref::<array::IntervalYearMonthArray>();
763
764                            if let Some(valid_array) = array {
765                                if valid_array.is_null(row) {
766                                    row_values.push(Keyword::Null.into());
767                                    continue;
768                                }
769
770                                let interval_str =
771                                    if let Ok(str) = array_value_to_string(valid_array, row) {
772                                        str
773                                    } else {
774                                        let months = valid_array.value(row);
775                                        format!("{months} months")
776                                    };
777
778                                row_values.push(interval_str.into());
779                            }
780                        }
781                        // The smallest unit in Postgres for interval is microsecond
782                        // Cast with loss of precision in nano second
783                        IntervalUnit::MonthDayNano => {
784                            let array = column
785                                .as_any()
786                                .downcast_ref::<array::IntervalMonthDayNanoArray>();
787
788                            if let Some(valid_array) = array {
789                                if valid_array.is_null(row) {
790                                    row_values.push(Keyword::Null.into());
791                                    continue;
792                                }
793
794                                let interval_str =
795                                    if let Ok(str) = array_value_to_string(valid_array, row) {
796                                        str
797                                    } else {
798                                        let months = valid_array.value(row).months;
799                                        let days = valid_array.value(row).days;
800                                        let nanoseconds = valid_array.value(row).nanoseconds;
801                                        let micros = nanoseconds / 1_000;
802                                        format!("{months} months {days} days {micros} microseconds")
803                                    };
804
805                                row_values.push(interval_str.into());
806                            }
807                        }
808                    },
809                    DataType::Struct(fields) => {
810                        let array = column.as_any().downcast_ref::<array::StructArray>();
811
812                        if let Some(valid_array) = array {
813                            if valid_array.is_null(row) {
814                                row_values.push(Keyword::Null.into());
815                                continue;
816                            }
817
818                            if use_json_insert_for_type(column_data_type, query_builder) {
819                                insert_struct_into_row_values_json(
820                                    fields,
821                                    valid_array,
822                                    row,
823                                    &mut row_values,
824                                )?;
825                                continue;
826                            }
827
828                            let mut param_values: Vec<SimpleExpr> = vec![];
829
830                            for col in valid_array.columns() {
831                                match col.data_type() {
832                                    DataType::Int8 => {
833                                        let int_array =
834                                            col.as_any().downcast_ref::<array::Int8Array>();
835
836                                        if let Some(valid_int_array) = int_array {
837                                            param_values.push(valid_int_array.value(row).into());
838                                        }
839                                    }
840                                    DataType::Int16 => {
841                                        let int_array =
842                                            col.as_any().downcast_ref::<array::Int16Array>();
843
844                                        if let Some(valid_int_array) = int_array {
845                                            param_values.push(valid_int_array.value(row).into());
846                                        }
847                                    }
848                                    DataType::Int32 => {
849                                        let int_array =
850                                            col.as_any().downcast_ref::<array::Int32Array>();
851
852                                        if let Some(valid_int_array) = int_array {
853                                            param_values.push(valid_int_array.value(row).into());
854                                        }
855                                    }
856                                    DataType::Int64 => {
857                                        let int_array =
858                                            col.as_any().downcast_ref::<array::Int64Array>();
859
860                                        if let Some(valid_int_array) = int_array {
861                                            param_values.push(valid_int_array.value(row).into());
862                                        }
863                                    }
864                                    DataType::UInt8 => {
865                                        let int_array =
866                                            col.as_any().downcast_ref::<array::UInt8Array>();
867
868                                        if let Some(valid_int_array) = int_array {
869                                            param_values.push(valid_int_array.value(row).into());
870                                        }
871                                    }
872                                    DataType::UInt16 => {
873                                        let int_array =
874                                            col.as_any().downcast_ref::<array::UInt16Array>();
875
876                                        if let Some(valid_int_array) = int_array {
877                                            param_values.push(valid_int_array.value(row).into());
878                                        }
879                                    }
880                                    DataType::UInt32 => {
881                                        let int_array =
882                                            col.as_any().downcast_ref::<array::UInt32Array>();
883
884                                        if let Some(valid_int_array) = int_array {
885                                            param_values.push(valid_int_array.value(row).into());
886                                        }
887                                    }
888                                    DataType::UInt64 => {
889                                        let int_array =
890                                            col.as_any().downcast_ref::<array::UInt64Array>();
891
892                                        if let Some(valid_int_array) = int_array {
893                                            param_values.push(valid_int_array.value(row).into());
894                                        }
895                                    }
896                                    DataType::Float32 => {
897                                        let float_array =
898                                            col.as_any().downcast_ref::<array::Float32Array>();
899
900                                        if let Some(valid_float_array) = float_array {
901                                            param_values.push(valid_float_array.value(row).into());
902                                        }
903                                    }
904                                    DataType::Float64 => {
905                                        let float_array =
906                                            col.as_any().downcast_ref::<array::Float64Array>();
907
908                                        if let Some(valid_float_array) = float_array {
909                                            param_values.push(valid_float_array.value(row).into());
910                                        }
911                                    }
912                                    DataType::Utf8 => {
913                                        let string_array =
914                                            col.as_any().downcast_ref::<array::StringArray>();
915
916                                        if let Some(valid_string_array) = string_array {
917                                            param_values.push(valid_string_array.value(row).into());
918                                        }
919                                    }
920                                    DataType::Null => {
921                                        param_values.push(Keyword::Null.into());
922                                    }
923                                    DataType::Boolean => {
924                                        let bool_array =
925                                            col.as_any().downcast_ref::<array::BooleanArray>();
926
927                                        if let Some(valid_bool_array) = bool_array {
928                                            param_values.push(valid_bool_array.value(row).into());
929                                        }
930                                    }
931                                    DataType::Binary => {
932                                        let binary_array =
933                                            col.as_any().downcast_ref::<array::BinaryArray>();
934
935                                        if let Some(valid_binary_array) = binary_array {
936                                            param_values.push(valid_binary_array.value(row).into());
937                                        }
938                                    }
939                                    DataType::FixedSizeBinary(_) => {
940                                        let binary_array = col
941                                            .as_any()
942                                            .downcast_ref::<array::FixedSizeBinaryArray>();
943
944                                        if let Some(valid_binary_array) = binary_array {
945                                            param_values.push(valid_binary_array.value(row).into());
946                                        }
947                                    }
948                                    DataType::LargeBinary => {
949                                        let binary_array =
950                                            col.as_any().downcast_ref::<array::LargeBinaryArray>();
951
952                                        if let Some(valid_binary_array) = binary_array {
953                                            param_values.push(valid_binary_array.value(row).into());
954                                        }
955                                    }
956                                    DataType::LargeUtf8 => {
957                                        let string_array =
958                                            col.as_any().downcast_ref::<array::LargeStringArray>();
959
960                                        if let Some(valid_string_array) = string_array {
961                                            param_values.push(valid_string_array.value(row).into());
962                                        }
963                                    }
964                                    DataType::Utf8View => {
965                                        let view_array =
966                                            col.as_any().downcast_ref::<array::StringViewArray>();
967
968                                        if let Some(valid_view_array) = view_array {
969                                            param_values.push(valid_view_array.value(row).into());
970                                        }
971                                    }
972                                    DataType::BinaryView => {
973                                        let view_array =
974                                            col.as_any().downcast_ref::<array::BinaryViewArray>();
975
976                                        if let Some(valid_view_array) = view_array {
977                                            param_values.push(valid_view_array.value(row).into());
978                                        }
979                                    }
980                                    DataType::Float16
981                                    | DataType::Timestamp(_, _)
982                                    | DataType::Date32
983                                    | DataType::Date64
984                                    | DataType::Time32(_)
985                                    | DataType::Time64(_)
986                                    | DataType::Duration(_)
987                                    | DataType::Interval(_)
988                                    | DataType::List(_)
989                                    | DataType::ListView(_)
990                                    | DataType::FixedSizeList(_, _)
991                                    | DataType::LargeList(_)
992                                    | DataType::LargeListView(_)
993                                    | DataType::Struct(_)
994                                    | DataType::Union(_, _)
995                                    | DataType::Dictionary(_, _)
996                                    | DataType::Map(_, _)
997                                    | DataType::RunEndEncoded(_, _)
998                                    | DataType::Decimal32(_, _)
999                                    | DataType::Decimal64(_, _)
1000                                    | DataType::Decimal128(_, _)
1001                                    | DataType::Decimal256(_, _) => {
1002                                        unimplemented!(
1003                                            "Data type mapping not implemented for Struct of {}",
1004                                            col.data_type()
1005                                        )
1006                                    }
1007                                }
1008                            }
1009
1010                            let mut params_vec = Vec::new();
1011                            for param_value in &param_values {
1012                                let mut params_str = String::new();
1013                                query_builder.prepare_simple_expr(param_value, &mut params_str);
1014                                params_vec.push(params_str);
1015                            }
1016
1017                            let params = params_vec.join(", ");
1018                            row_values.push(Expr::cust(format!("ROW({params})")));
1019                        }
1020                    }
1021                    unimplemented_type => {
1022                        return Result::Err(Error::UnimplementedDataTypeInInsertStatement {
1023                            data_type: unimplemented_type.clone(),
1024                        })
1025                    }
1026                }
1027            }
1028            match insert_stmt.values(row_values) {
1029                Ok(_) => (),
1030                Err(e) => {
1031                    return Result::Err(Error::FailedToCreateInsertStatement {
1032                        source: Box::new(e),
1033                    })
1034                }
1035            }
1036        }
1037        Ok(())
1038    }
1039
1040    ///
1041    /// # Errors
1042    ///
1043    /// Returns an error if any `RecordBatch` fails to convert into a valid postgres insert statement.
1044    pub fn build_postgres(self, on_conflict: Option<OnConflict>) -> Result<String> {
1045        self.build(PostgresQueryBuilder, on_conflict)
1046    }
1047
1048    ///
1049    /// # Errors
1050    ///
1051    /// Returns an error if any `RecordBatch` fails to convert into a valid sqlite insert statement.
1052    pub fn build_sqlite(self, on_conflict: Option<OnConflict>) -> Result<String> {
1053        self.build(SqliteQueryBuilder, on_conflict)
1054    }
1055
1056    ///
1057    /// # Errors
1058    ///
1059    /// Returns an error if any `RecordBatch` fails to convert into a valid `MySQL` insert statement.
1060    pub fn build_mysql(self, on_conflict: Option<OnConflict>) -> Result<String> {
1061        self.build(MysqlQueryBuilder, on_conflict)
1062    }
1063
1064    /// # Errors
1065    ///
1066    /// Returns an error if any `RecordBatch` fails to convert into a valid insert statement. Upon
1067    /// error, no further `RecordBatch` is processed.
1068    pub fn build<T: GenericBuilder + 'static>(
1069        &self,
1070        query_builder: T,
1071        on_conflict: Option<OnConflict>,
1072    ) -> Result<String> {
1073        let columns: Vec<Alias> = (self.record_batches[0])
1074            .schema()
1075            .fields()
1076            .iter()
1077            .map(|field| Alias::new(field.name()))
1078            .collect();
1079
1080        let mut insert_stmt = Query::insert()
1081            .into_table(table_reference_to_sea_table_ref(&self.table))
1082            .columns(columns)
1083            .to_owned();
1084
1085        for record_batch in &self.record_batches {
1086            self.construct_insert_stmt(&mut insert_stmt, record_batch, &query_builder)?;
1087        }
1088        if let Some(on_conflict) = on_conflict {
1089            insert_stmt.on_conflict(on_conflict);
1090        }
1091        Ok(insert_stmt.to_string(query_builder))
1092    }
1093}
1094
1095fn table_reference_to_sea_table_ref(table: &TableReference) -> TableRef {
1096    match table {
1097        TableReference::Bare { table } => {
1098            TableRef::Table(SeaRc::new(Alias::new(table.to_string())))
1099        }
1100        TableReference::Partial { schema, table } => TableRef::SchemaTable(
1101            SeaRc::new(Alias::new(schema.to_string())),
1102            SeaRc::new(Alias::new(table.to_string())),
1103        ),
1104        TableReference::Full {
1105            catalog,
1106            schema,
1107            table,
1108        } => TableRef::DatabaseSchemaTable(
1109            SeaRc::new(Alias::new(catalog.to_string())),
1110            SeaRc::new(Alias::new(schema.to_string())),
1111            SeaRc::new(Alias::new(table.to_string())),
1112        ),
1113    }
1114}
1115
1116pub struct IndexBuilder {
1117    table_name: String,
1118    columns: Vec<String>,
1119    unique: bool,
1120}
1121
1122impl IndexBuilder {
1123    #[must_use]
1124    pub fn new(table_name: &str, columns: Vec<&str>) -> Self {
1125        Self {
1126            table_name: table_name.to_string(),
1127            columns: columns.into_iter().map(ToString::to_string).collect(),
1128            unique: false,
1129        }
1130    }
1131
1132    #[must_use]
1133    pub fn unique(mut self) -> Self {
1134        self.unique = true;
1135        self
1136    }
1137
1138    #[must_use]
1139    pub fn index_name(&self) -> String {
1140        format!("i_{}_{}", self.table_name, self.columns.join("_"))
1141    }
1142
1143    #[must_use]
1144    pub fn build_postgres(self) -> String {
1145        self.build(PostgresQueryBuilder)
1146    }
1147
1148    #[must_use]
1149    pub fn build_sqlite(self) -> String {
1150        self.build(SqliteQueryBuilder)
1151    }
1152
1153    #[must_use]
1154    pub fn build_mysql(self) -> String {
1155        self.build(MysqlQueryBuilder)
1156    }
1157
1158    #[must_use]
1159    pub fn build<T: GenericBuilder>(self, query_builder: T) -> String {
1160        let mut index = Index::create();
1161        index.table(Alias::new(&self.table_name));
1162        index.name(self.index_name());
1163        if self.unique {
1164            index.unique();
1165        }
1166        for column in self.columns {
1167            index.col(Alias::new(column).into_iden().into_index_column());
1168        }
1169        index.if_not_exists();
1170        index.to_string(query_builder)
1171    }
1172}
1173
1174fn insert_timestamp_into_row_values(
1175    timestamp: Result<OffsetDateTime, time::error::ComponentRange>,
1176    row_values: &mut Vec<SimpleExpr>,
1177) -> Result<()> {
1178    match timestamp {
1179        Ok(offset_time) => {
1180            row_values.push(PrimitiveDateTime::new(offset_time.date(), offset_time.time()).into());
1181            Ok(())
1182        }
1183        Err(e) => Err(Error::FailedToCreateInsertStatement {
1184            source: Box::new(e),
1185        }),
1186    }
1187}
1188
1189#[allow(clippy::needless_pass_by_value)]
1190fn insert_list_into_row_values(
1191    list_array: Arc<dyn Array>,
1192    list_type: &Arc<Field>,
1193    row_values: &mut Vec<SimpleExpr>,
1194) {
1195    match list_type.data_type() {
1196        DataType::Int8 => push_list_values!(
1197            list_type.data_type(),
1198            list_array,
1199            row_values,
1200            array::Int8Array,
1201            i8,
1202            "int2[]"
1203        ),
1204        DataType::Int16 => push_list_values!(
1205            list_type.data_type(),
1206            list_array,
1207            row_values,
1208            array::Int16Array,
1209            i16,
1210            "int2[]"
1211        ),
1212        DataType::Int32 => push_list_values!(
1213            list_type.data_type(),
1214            list_array,
1215            row_values,
1216            array::Int32Array,
1217            i32,
1218            "int4[]"
1219        ),
1220        DataType::Int64 => push_list_values!(
1221            list_type.data_type(),
1222            list_array,
1223            row_values,
1224            array::Int64Array,
1225            i64,
1226            "int8[]"
1227        ),
1228        DataType::Float32 => push_list_values!(
1229            list_type.data_type(),
1230            list_array,
1231            row_values,
1232            array::Float32Array,
1233            f32,
1234            "float4[]"
1235        ),
1236        DataType::Float64 => push_list_values!(
1237            list_type.data_type(),
1238            list_array,
1239            row_values,
1240            array::Float64Array,
1241            f64,
1242            "float8[]"
1243        ),
1244        DataType::Utf8 => {
1245            let mut list_values: Vec<String> = vec![];
1246            for i in 0..list_array.len() {
1247                let int_array = list_array.as_any().downcast_ref::<array::StringArray>();
1248                if let Some(valid_int_array) = int_array {
1249                    list_values.push(valid_int_array.value(i).to_string());
1250                }
1251            }
1252            let expr: SimpleExpr = list_values.into();
1253            // We must cast here in case the array is empty which SeaQuery does not handle.
1254            row_values.push(expr.cast_as(Alias::new("text[]")));
1255        }
1256        DataType::LargeUtf8 => {
1257            let mut list_values: Vec<String> = vec![];
1258            for i in 0..list_array.len() {
1259                let int_array = list_array
1260                    .as_any()
1261                    .downcast_ref::<array::LargeStringArray>();
1262                if let Some(valid_int_array) = int_array {
1263                    list_values.push(valid_int_array.value(i).to_string());
1264                }
1265            }
1266            let expr: SimpleExpr = list_values.into();
1267            // We must cast here in case the array is empty which SeaQuery does not handle.
1268            row_values.push(expr.cast_as(Alias::new("text[]")));
1269        }
1270        DataType::Utf8View => {
1271            let mut list_values: Vec<String> = vec![];
1272            for i in 0..list_array.len() {
1273                let view_array = list_array.as_any().downcast_ref::<array::StringViewArray>();
1274                if let Some(valid_view_array) = view_array {
1275                    list_values.push(valid_view_array.value(i).to_string());
1276                }
1277            }
1278            let expr: SimpleExpr = list_values.into();
1279            row_values.push(expr.cast_as(Alias::new("text[]")));
1280        }
1281        DataType::Boolean => push_list_values!(
1282            list_type.data_type(),
1283            list_array,
1284            row_values,
1285            array::BooleanArray,
1286            bool,
1287            "boolean[]"
1288        ),
1289        DataType::Binary => {
1290            let mut list_values: Vec<Vec<u8>> = Vec::new();
1291            for i in 0..list_array.len() {
1292                let temp_array = list_array.as_any().downcast_ref::<array::BinaryArray>();
1293                if let Some(valid_array) = temp_array {
1294                    list_values.push(valid_array.value(i).to_vec());
1295                }
1296            }
1297            let expr: SimpleExpr = list_values.into();
1298            // We must cast here in case the array is empty which SeaQuery does not handle.
1299            row_values.push(expr.cast_as(Alias::new("bytea[]")));
1300        }
1301        _ => unimplemented!(
1302            "Data type mapping not implemented for {}",
1303            list_type.data_type()
1304        ),
1305    }
1306}
1307
1308#[allow(clippy::cast_sign_loss)]
1309pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
1310    match data_type {
1311        DataType::Int8 => ColumnType::TinyInteger,
1312        DataType::Int16 => ColumnType::SmallInteger,
1313        DataType::Int32 => ColumnType::Integer,
1314        DataType::Int64 | DataType::Duration(_) => ColumnType::BigInteger,
1315        DataType::UInt8 => ColumnType::TinyUnsigned,
1316        DataType::UInt16 => ColumnType::SmallUnsigned,
1317        DataType::UInt32 => ColumnType::Unsigned,
1318        DataType::UInt64 => ColumnType::BigUnsigned,
1319        DataType::Float32 => ColumnType::Float,
1320        DataType::Float64 => ColumnType::Double,
1321        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => ColumnType::Text,
1322        DataType::Boolean => ColumnType::Boolean,
1323        #[allow(clippy::cast_sign_loss)] // This is safe because scale will never be negative
1324        DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
1325            ColumnType::Decimal(Some((u32::from(*p), *s as u32)))
1326        }
1327        DataType::Timestamp(_unit, time_zone) => {
1328            if time_zone.is_some() {
1329                return ColumnType::TimestampWithTimeZone;
1330            }
1331            ColumnType::Timestamp
1332        }
1333        DataType::Date32 | DataType::Date64 => ColumnType::Date,
1334        DataType::Time64(_unit) | DataType::Time32(_unit) => ColumnType::Time,
1335        DataType::List(list_type)
1336        | DataType::LargeList(list_type)
1337        | DataType::FixedSizeList(list_type, _) => {
1338            ColumnType::Array(map_data_type_to_column_type(list_type.data_type()).into())
1339        }
1340        // Originally mapped to VarBinary type, corresponding to MySQL's varbinary, which has a maximum length of 65535.
1341        // This caused the error: "Row size too large. The maximum row size for the used table type, not counting BLOBs, is 65535.
1342        // This includes storage overhead, check the manual. You have to change some columns to TEXT or BLOBs."
1343        // Changing to Blob fixes this issue. This change does not affect Postgres, and for Sqlite, the mapping type changes from varbinary_blob to blob.
1344        DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
1345        DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
1346        DataType::Interval(_) => ColumnType::Interval(None, None),
1347        // Add more mappings here as needed
1348        _ => unimplemented!("Data type mapping not implemented for {:?}", data_type),
1349    }
1350}
1351
1352macro_rules! serialize_list_values {
1353    ($data_type:expr, $list_array:expr, $array_type:ty, $vec_type:ty) => {{
1354        let mut list_values: Vec<$vec_type> = vec![];
1355        if let Some(array) = $list_array.as_any().downcast_ref::<$array_type>() {
1356            for i in 0..array.len() {
1357                list_values.push(array.value(i).into());
1358            }
1359        }
1360
1361        serde_json::to_string(&list_values).map_err(|e| Error::FailedToCreateInsertStatement {
1362            source: Box::new(e),
1363        })?
1364    }};
1365}
1366
1367fn insert_list_into_row_values_json(
1368    list_array: Arc<dyn Array>,
1369    list_type: &Arc<Field>,
1370    row_values: &mut Vec<SimpleExpr>,
1371) -> Result<()> {
1372    let data_type = list_type.data_type();
1373
1374    let json_string: String = match data_type {
1375        DataType::Int8 => serialize_list_values!(data_type, list_array, Int8Array, i8),
1376        DataType::Int16 => serialize_list_values!(data_type, list_array, Int16Array, i16),
1377        DataType::Int32 => serialize_list_values!(data_type, list_array, Int32Array, i32),
1378        DataType::Int64 => serialize_list_values!(data_type, list_array, Int64Array, i64),
1379        DataType::UInt8 => serialize_list_values!(data_type, list_array, UInt8Array, u8),
1380        DataType::UInt16 => serialize_list_values!(data_type, list_array, UInt16Array, u16),
1381        DataType::UInt32 => serialize_list_values!(data_type, list_array, UInt32Array, u32),
1382        DataType::UInt64 => serialize_list_values!(data_type, list_array, UInt64Array, u64),
1383        DataType::Float32 => serialize_list_values!(data_type, list_array, Float32Array, f32),
1384        DataType::Float64 => serialize_list_values!(data_type, list_array, Float64Array, f64),
1385        DataType::Utf8 => serialize_list_values!(data_type, list_array, StringArray, String),
1386        DataType::LargeUtf8 => {
1387            serialize_list_values!(data_type, list_array, LargeStringArray, String)
1388        }
1389        DataType::Utf8View => {
1390            serialize_list_values!(data_type, list_array, StringViewArray, String)
1391        }
1392        DataType::Boolean => serialize_list_values!(data_type, list_array, BooleanArray, bool),
1393        _ => unimplemented!(
1394            "List to json conversion is not implemented for {}",
1395            list_type.data_type()
1396        ),
1397    };
1398
1399    let expr: SimpleExpr = Expr::value(json_string);
1400    row_values.push(expr);
1401
1402    Ok(())
1403}
1404
1405fn insert_struct_into_row_values_json(
1406    fields: &Fields,
1407    array: &StructArray,
1408    row_index: usize,
1409    row_values: &mut Vec<SimpleExpr>,
1410) -> Result<()> {
1411    // The length of each column in a StructArray is the same as the length of the StructArray itself.
1412    // The presence of null values does not change the length of the columns (affects the validity bitmap only).
1413    // Similar to Recordbatch slice: https://github.com/apache/arrow-rs/blob/855666d9e9283c1ef11648762fe92c7c188b68f1/arrow-array/src/record_batch.rs#L375
1414    let single_row_columns: Vec<ArrayRef> = (0..array.num_columns())
1415        .map(|i| array.column(i).slice(row_index, 1))
1416        .collect();
1417
1418    let batch = RecordBatch::try_new(Arc::new(Schema::new(fields.clone())), single_row_columns)
1419        .map_err(|e| Error::FailedToCreateInsertStatement {
1420            source: Box::new(e),
1421        })?;
1422
1423    let mut writer = datafusion::arrow::json::LineDelimitedWriter::new(Vec::new());
1424    writer
1425        .write(&batch)
1426        .map_err(|e| Error::FailedToCreateInsertStatement {
1427            source: Box::new(e),
1428        })?;
1429    writer
1430        .finish()
1431        .map_err(|e| Error::FailedToCreateInsertStatement {
1432            source: Box::new(e),
1433        })?;
1434    let json_bytes = writer.into_inner();
1435
1436    let json = String::from_utf8(json_bytes).map_err(|e| Error::FailedToCreateInsertStatement {
1437        source: Box::new(e),
1438    })?;
1439
1440    let expr: SimpleExpr = Expr::value(json);
1441    row_values.push(expr);
1442
1443    Ok(())
1444}
1445
1446#[cfg(test)]
1447mod tests {
1448    use std::sync::Arc;
1449
1450    use super::*;
1451    use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema};
1452
1453    #[test]
1454    fn test_basic_table_creation() {
1455        let schema = Schema::new(vec![
1456            Field::new("id", DataType::Int32, false),
1457            Field::new("name", DataType::Utf8, false),
1458            Field::new("age", DataType::Int32, true),
1459        ]);
1460        let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users").build_sqlite();
1461
1462        assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer )");
1463    }
1464
1465    #[test]
1466    fn test_table_insertion() {
1467        let schema1 = Schema::new(vec![
1468            Field::new("id", DataType::Int32, false),
1469            Field::new("name", DataType::Utf8, false),
1470            Field::new("age", DataType::Int32, true),
1471        ]);
1472        let id_array = array::Int32Array::from(vec![1, 2, 3]);
1473        let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1474        let age_array = array::Int32Array::from(vec![10, 20, 30]);
1475
1476        let batch1 = RecordBatch::try_new(
1477            Arc::new(schema1.clone()),
1478            vec![
1479                Arc::new(id_array.clone()),
1480                Arc::new(name_array.clone()),
1481                Arc::new(age_array.clone()),
1482            ],
1483        )
1484        .expect("Unable to build record batch");
1485
1486        let schema2 = Schema::new(vec![
1487            Field::new("id", DataType::Int32, false),
1488            Field::new("name", DataType::Utf8, false),
1489            Field::new("blah", DataType::Int32, true),
1490        ]);
1491
1492        let batch2 = RecordBatch::try_new(
1493            Arc::new(schema2),
1494            vec![
1495                Arc::new(id_array),
1496                Arc::new(name_array),
1497                Arc::new(age_array),
1498            ],
1499        )
1500        .expect("Unable to build record batch");
1501        let record_batches = vec![batch1, batch2];
1502
1503        let sql = InsertBuilder::new(&TableReference::from("users"), record_batches)
1504            .build_postgres(None)
1505            .expect("Failed to build insert statement");
1506        assert_eq!(sql, "INSERT INTO \"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1507    }
1508
1509    #[test]
1510    fn test_table_insertion_with_schema() {
1511        let schema1 = Schema::new(vec![
1512            Field::new("id", DataType::Int32, false),
1513            Field::new("name", DataType::Utf8, false),
1514            Field::new("age", DataType::Int32, true),
1515        ]);
1516        let id_array = array::Int32Array::from(vec![1, 2, 3]);
1517        let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1518        let age_array = array::Int32Array::from(vec![10, 20, 30]);
1519
1520        let batch1 = RecordBatch::try_new(
1521            Arc::new(schema1.clone()),
1522            vec![
1523                Arc::new(id_array.clone()),
1524                Arc::new(name_array.clone()),
1525                Arc::new(age_array.clone()),
1526            ],
1527        )
1528        .expect("Unable to build record batch");
1529
1530        let schema2 = Schema::new(vec![
1531            Field::new("id", DataType::Int32, false),
1532            Field::new("name", DataType::Utf8, false),
1533            Field::new("blah", DataType::Int32, true),
1534        ]);
1535
1536        let batch2 = RecordBatch::try_new(
1537            Arc::new(schema2),
1538            vec![
1539                Arc::new(id_array),
1540                Arc::new(name_array),
1541                Arc::new(age_array),
1542            ],
1543        )
1544        .expect("Unable to build record batch");
1545        let record_batches = vec![batch1, batch2];
1546
1547        let sql = InsertBuilder::new(&TableReference::from("schema.users"), record_batches)
1548            .build_postgres(None)
1549            .expect("Failed to build insert statement");
1550        assert_eq!(sql, "INSERT INTO \"schema\".\"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1551    }
1552
1553    #[test]
1554    fn test_table_creation_with_primary_keys() {
1555        let schema = Schema::new(vec![
1556            Field::new("id", DataType::Int32, false),
1557            Field::new("id2", DataType::Int32, false),
1558            Field::new("name", DataType::Utf8, false),
1559            Field::new("age", DataType::Int32, true),
1560        ]);
1561        let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1562            .primary_keys(vec!["id", "id2"])
1563            .build_sqlite();
1564
1565        assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"id2\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer, PRIMARY KEY (\"id\", \"id2\") )");
1566    }
1567
1568    #[test]
1569    fn test_temporary_table_creation() {
1570        let schema = Schema::new(vec![
1571            Field::new("id", DataType::Int32, false),
1572            Field::new("name", DataType::Utf8, false),
1573        ]);
1574        let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1575            .primary_keys(vec!["id"])
1576            .temporary(true)
1577            .build_sqlite();
1578
1579        assert_eq!(sql, "CREATE TEMPORARY TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, PRIMARY KEY (\"id\") )");
1580    }
1581
1582    #[test]
1583    fn test_table_insertion_with_list() {
1584        let schema1 = Schema::new(vec![Field::new(
1585            "list",
1586            DataType::List(Field::new("item", DataType::Int32, true).into()),
1587            true,
1588        )]);
1589        let list_array = array::ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1590            Some(vec![Some(1), Some(2), Some(3)]),
1591            Some(vec![Some(4), Some(5), Some(6)]),
1592            Some(vec![Some(7), Some(8), Some(9)]),
1593        ]);
1594
1595        let batch = RecordBatch::try_new(Arc::new(schema1.clone()), vec![Arc::new(list_array)])
1596            .expect("Unable to build record batch");
1597
1598        let sql = InsertBuilder::new(&TableReference::from("arrays"), vec![batch])
1599            .build_postgres(None)
1600            .expect("Failed to build insert statement");
1601        assert_eq!(
1602            sql,
1603            "INSERT INTO \"arrays\" (\"list\") VALUES (CAST(ARRAY [1,2,3] AS int4[])), (CAST(ARRAY [4,5,6] AS int4[])), (CAST(ARRAY [7,8,9] AS int4[]))"
1604        );
1605    }
1606
1607    #[test]
1608    fn test_create_index() {
1609        let sql = IndexBuilder::new("users", vec!["id", "name"]).build_postgres();
1610        assert_eq!(
1611            sql,
1612            r#"CREATE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1613        );
1614    }
1615
1616    #[test]
1617    fn test_create_unique_index() {
1618        let sql = IndexBuilder::new("users", vec!["id", "name"])
1619            .unique()
1620            .build_postgres();
1621        assert_eq!(
1622            sql,
1623            r#"CREATE UNIQUE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1624        );
1625    }
1626}