1use bigdecimal::BigDecimal;
2use chrono::{DateTime, Offset, TimeZone};
3use datafusion::arrow::{
4 array::{
5 array, timezone::Tz, Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
6 Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
7 StringViewArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
8 },
9 datatypes::{DataType, Field, Fields, IntervalUnit, Schema, SchemaRef, TimeUnit},
10 util::display::array_value_to_string,
11};
12use datafusion::sql::TableReference;
13use num_bigint::BigInt;
14use sea_query::{
15 Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden,
16 IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query,
17 QueryBuilder, SeaRc, SimpleExpr, SqliteQueryBuilder, Table, TableRef,
18};
19use snafu::Snafu;
20use std::{str::FromStr, sync::Arc};
21use time::{OffsetDateTime, PrimitiveDateTime};
22
23#[derive(Debug, Snafu)]
24pub enum Error {
25 #[snafu(display("Failed to build insert statement: {source}"))]
26 FailedToCreateInsertStatement {
27 source: Box<dyn std::error::Error + Send + Sync>,
28 },
29
30 #[snafu(display("Unimplemented data type in insert statement: {data_type:?}"))]
31 UnimplementedDataTypeInInsertStatement { data_type: DataType },
32}
33
34pub type Result<T, E = Error> = std::result::Result<T, E>;
35
36pub struct CreateTableBuilder {
37 schema: SchemaRef,
38 table_name: String,
39 primary_keys: Vec<String>,
40 temporary: bool,
41}
42
43impl CreateTableBuilder {
44 #[must_use]
45 pub fn new(schema: SchemaRef, table_name: &str) -> Self {
46 Self {
47 schema,
48 table_name: table_name.to_string(),
49 primary_keys: Vec::new(),
50 temporary: false,
51 }
52 }
53
54 #[must_use]
55 pub fn primary_keys<T>(mut self, keys: Vec<T>) -> Self
56 where
57 T: Into<String>,
58 {
59 self.primary_keys = keys.into_iter().map(Into::into).collect();
60 self
61 }
62
63 #[must_use]
64 pub fn temporary(mut self, temporary: bool) -> Self {
66 self.temporary = temporary;
67 self
68 }
69
70 #[must_use]
71 #[cfg(feature = "postgres")]
72 pub fn build_postgres(self) -> Vec<String> {
73 use crate::sql::arrow_sql_gen::postgres::{
74 builder::TypeBuilder, get_postgres_composite_type_name,
75 map_data_type_to_column_type_postgres,
76 };
77 let schema = Arc::clone(&self.schema);
78 let table_name = self.table_name.clone();
79 let main_table_creation =
80 self.build(PostgresQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
81 map_data_type_to_column_type_postgres(f.data_type(), &table_name, f.name())
82 });
83
84 let mut creation_stmts = Vec::new();
87 for field in schema.fields() {
88 let DataType::Struct(struct_inner_fields) = field.data_type() else {
89 continue;
90 };
91 let type_builder = TypeBuilder::new(
92 get_postgres_composite_type_name(&table_name, field.name()),
93 struct_inner_fields,
94 );
95 creation_stmts.push(type_builder.build());
96 }
97
98 creation_stmts.push(main_table_creation);
99 creation_stmts
100 }
101
102 #[must_use]
103 pub fn build_sqlite(self) -> String {
104 self.build(SqliteQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
105 if f.data_type().is_nested() {
108 return ColumnType::JsonBinary;
109 }
110
111 map_data_type_to_column_type(f.data_type())
112 })
113 }
114
115 #[must_use]
116 pub fn build_mysql(self) -> String {
117 self.build(MysqlQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
118 if f.data_type().is_nested() {
121 return ColumnType::JsonBinary;
122 }
123 map_data_type_to_column_type(f.data_type())
124 })
125 }
126
127 #[must_use]
128 fn build<T: GenericBuilder>(
129 self,
130 query_builder: T,
131 map_data_type_to_column_type_fn: &dyn Fn(&Arc<Field>) -> ColumnType,
132 ) -> String {
133 let mut create_stmt = Table::create();
134 create_stmt
135 .table(Alias::new(self.table_name.clone()))
136 .if_not_exists();
137
138 for field in self.schema.fields() {
139 let column_type = map_data_type_to_column_type_fn(field);
140 let mut column_def = ColumnDef::new_with_type(Alias::new(field.name()), column_type);
141 if !field.is_nullable() {
142 column_def.not_null();
143 }
144
145 create_stmt.col(&mut column_def);
146 }
147
148 if !self.primary_keys.is_empty() {
149 let mut index = Index::create();
150 index.primary();
151 for key in self.primary_keys {
152 index.col(Alias::new(key).into_iden().into_index_column());
153 }
154 create_stmt.primary_key(&mut index);
155 }
156
157 if self.temporary {
158 create_stmt.temporary();
159 }
160
161 create_stmt.to_string(query_builder)
162 }
163}
164
165macro_rules! push_value {
166 ($row_values:expr, $column:expr, $row:expr, $array_type:ident) => {{
167 let array = $column.as_any().downcast_ref::<array::$array_type>();
168 if let Some(valid_array) = array {
169 if valid_array.is_null($row) {
170 $row_values.push(Keyword::Null.into());
171 continue;
172 }
173 $row_values.push(valid_array.value($row).into());
174 }
175 }};
176}
177
178macro_rules! push_list_values {
179 ($data_type:expr, $list_array:expr, $row_values:expr, $array_type:ty, $vec_type:ty, $sql_type:expr) => {{
180 let mut list_values: Vec<$vec_type> = Vec::new();
181 for i in 0..$list_array.len() {
182 let temp_array = $list_array.as_any().downcast_ref::<$array_type>();
183 if let Some(valid_array) = temp_array {
184 list_values.push(valid_array.value(i));
185 }
186 }
187 let expr: SimpleExpr = list_values.into();
188 $row_values.push(expr.cast_as(Alias::new($sql_type)));
190 }};
191}
192
193pub struct InsertBuilder {
194 table: TableReference,
195 record_batches: Vec<RecordBatch>,
196}
197
198pub fn use_json_insert_for_type<T: QueryBuilder + 'static>(
199 #[allow(unused_variables)] data_type: &DataType,
200 #[allow(unused_variables)] query_builder: &T,
201) -> bool {
202 #[cfg(feature = "sqlite")]
203 {
204 use std::any::Any;
205 let any_builder = query_builder as &dyn Any;
206 if any_builder.is::<SqliteQueryBuilder>() {
207 return data_type.is_nested();
208 }
209 }
210 #[cfg(feature = "mysql")]
211 {
212 use std::any::Any;
213 let any_builder = query_builder as &dyn Any;
214 if any_builder.is::<MysqlQueryBuilder>() {
215 return data_type.is_nested();
216 }
217 }
218 false
219}
220
221impl InsertBuilder {
222 #[must_use]
223 pub fn new(table: &TableReference, record_batches: Vec<RecordBatch>) -> Self {
224 Self {
225 table: table.clone(),
226 record_batches,
227 }
228 }
229
230 #[allow(clippy::too_many_lines)]
236 pub fn construct_insert_stmt<T: QueryBuilder + 'static>(
237 &self,
238 insert_stmt: &mut InsertStatement,
239 record_batch: &RecordBatch,
240 query_builder: &T,
241 ) -> Result<()> {
242 for row in 0..record_batch.num_rows() {
243 let mut row_values: Vec<SimpleExpr> = vec![];
244 for col in 0..record_batch.num_columns() {
245 let column = record_batch.column(col);
246 let column_data_type = column.data_type();
247
248 match column_data_type {
249 DataType::Int8 => push_value!(row_values, column, row, Int8Array),
250 DataType::Int16 => push_value!(row_values, column, row, Int16Array),
251 DataType::Int32 => push_value!(row_values, column, row, Int32Array),
252 DataType::Int64 => push_value!(row_values, column, row, Int64Array),
253 DataType::UInt8 => push_value!(row_values, column, row, UInt8Array),
254 DataType::UInt16 => push_value!(row_values, column, row, UInt16Array),
255 DataType::UInt32 => push_value!(row_values, column, row, UInt32Array),
256 DataType::UInt64 => push_value!(row_values, column, row, UInt64Array),
257 DataType::Float32 => push_value!(row_values, column, row, Float32Array),
258 DataType::Float64 => push_value!(row_values, column, row, Float64Array),
259 DataType::Utf8 => push_value!(row_values, column, row, StringArray),
260 DataType::LargeUtf8 => push_value!(row_values, column, row, LargeStringArray),
261 DataType::Utf8View => push_value!(row_values, column, row, StringViewArray),
262 DataType::Boolean => push_value!(row_values, column, row, BooleanArray),
263 DataType::Decimal128(_, scale) => {
264 let array = column.as_any().downcast_ref::<array::Decimal128Array>();
265 if let Some(valid_array) = array {
266 if valid_array.is_null(row) {
267 row_values.push(Keyword::Null.into());
268 continue;
269 }
270 row_values.push(
271 BigDecimal::new(valid_array.value(row).into(), i64::from(*scale))
272 .into(),
273 );
274 }
275 }
276 DataType::Decimal256(_, scale) => {
277 let array = column.as_any().downcast_ref::<array::Decimal256Array>();
278 if let Some(valid_array) = array {
279 if valid_array.is_null(row) {
280 row_values.push(Keyword::Null.into());
281 continue;
282 }
283
284 let bigint =
285 BigInt::from_signed_bytes_le(&valid_array.value(row).to_le_bytes());
286
287 row_values.push(BigDecimal::new(bigint, i64::from(*scale)).into());
288 }
289 }
290 DataType::Date32 => {
291 let array = column.as_any().downcast_ref::<array::Date32Array>();
292 if let Some(valid_array) = array {
293 if valid_array.is_null(row) {
294 row_values.push(Keyword::Null.into());
295 continue;
296 }
297 row_values.push(
298 match OffsetDateTime::from_unix_timestamp(
299 i64::from(valid_array.value(row)) * 86_400,
300 ) {
301 Ok(offset_time) => offset_time.date().into(),
302 Err(e) => {
303 return Result::Err(Error::FailedToCreateInsertStatement {
304 source: Box::new(e),
305 })
306 }
307 },
308 );
309 }
310 }
311 DataType::Date64 => {
312 let array = column.as_any().downcast_ref::<array::Date64Array>();
313 if let Some(valid_array) = array {
314 if valid_array.is_null(row) {
315 row_values.push(Keyword::Null.into());
316 continue;
317 }
318 row_values.push(
319 match OffsetDateTime::from_unix_timestamp(
320 valid_array.value(row) / 1000,
321 ) {
322 Ok(offset_time) => offset_time.date().into(),
323 Err(e) => {
324 return Result::Err(Error::FailedToCreateInsertStatement {
325 source: Box::new(e),
326 })
327 }
328 },
329 );
330 }
331 }
332 DataType::Duration(time_unit) => match time_unit {
333 TimeUnit::Second => {
334 push_value!(row_values, column, row, DurationSecondArray);
335 }
336 TimeUnit::Microsecond => {
337 push_value!(row_values, column, row, DurationMicrosecondArray);
338 }
339 TimeUnit::Millisecond => {
340 push_value!(row_values, column, row, DurationMillisecondArray);
341 }
342 TimeUnit::Nanosecond => {
343 push_value!(row_values, column, row, DurationNanosecondArray);
344 }
345 },
346 DataType::Time32(time_unit) => match time_unit {
347 TimeUnit::Millisecond => {
348 let array = column
349 .as_any()
350 .downcast_ref::<array::Time32MillisecondArray>();
351 if let Some(valid_array) = array {
352 if valid_array.is_null(row) {
353 row_values.push(Keyword::Null.into());
354 continue;
355 }
356
357 let (h, m, s, micro) =
358 match OffsetDateTime::from_unix_timestamp_nanos(
359 i128::from(valid_array.value(row)) * 1_000_000,
360 ) {
361 Ok(timestamp) => timestamp.to_hms_micro(),
362 Err(e) => {
363 return Result::Err(
364 Error::FailedToCreateInsertStatement {
365 source: Box::new(e),
366 },
367 )
368 }
369 };
370
371 let time = match time::Time::from_hms_micro(h, m, s, micro) {
372 Ok(value) => value,
373 Err(e) => {
374 return Result::Err(Error::FailedToCreateInsertStatement {
375 source: Box::new(e),
376 })
377 }
378 };
379
380 row_values.push(time.into());
381 }
382 }
383 TimeUnit::Second => {
384 let array = column.as_any().downcast_ref::<array::Time32SecondArray>();
385 if let Some(valid_array) = array {
386 if valid_array.is_null(row) {
387 row_values.push(Keyword::Null.into());
388 continue;
389 }
390
391 let (h, m, s) = match OffsetDateTime::from_unix_timestamp(
392 i64::from(valid_array.value(row)),
393 ) {
394 Ok(timestamp) => timestamp.to_hms(),
395 Err(e) => {
396 return Result::Err(Error::FailedToCreateInsertStatement {
397 source: Box::new(e),
398 })
399 }
400 };
401
402 let time = match time::Time::from_hms(h, m, s) {
403 Ok(value) => value,
404 Err(e) => {
405 return Result::Err(Error::FailedToCreateInsertStatement {
406 source: Box::new(e),
407 })
408 }
409 };
410
411 row_values.push(time.into());
412 }
413 }
414 _ => unreachable!(),
415 },
416 DataType::Time64(time_unit) => match time_unit {
417 TimeUnit::Nanosecond => {
418 let array = column
419 .as_any()
420 .downcast_ref::<array::Time64NanosecondArray>();
421 if let Some(valid_array) = array {
422 if valid_array.is_null(row) {
423 row_values.push(Keyword::Null.into());
424 continue;
425 }
426 let (h, m, s, nano) =
427 match OffsetDateTime::from_unix_timestamp_nanos(i128::from(
428 valid_array.value(row),
429 )) {
430 Ok(timestamp) => timestamp.to_hms_nano(),
431 Err(e) => {
432 return Result::Err(
433 Error::FailedToCreateInsertStatement {
434 source: Box::new(e),
435 },
436 )
437 }
438 };
439
440 let time = match time::Time::from_hms_nano(h, m, s, nano) {
441 Ok(value) => value,
442 Err(e) => {
443 return Result::Err(Error::FailedToCreateInsertStatement {
444 source: Box::new(e),
445 })
446 }
447 };
448
449 row_values.push(time.into());
450 }
451 }
452 TimeUnit::Microsecond => {
453 let array = column
454 .as_any()
455 .downcast_ref::<array::Time64MicrosecondArray>();
456 if let Some(valid_array) = array {
457 if valid_array.is_null(row) {
458 row_values.push(Keyword::Null.into());
459 continue;
460 }
461
462 let (h, m, s, micro) =
463 match OffsetDateTime::from_unix_timestamp_nanos(
464 i128::from(valid_array.value(row)) * 1_000,
465 ) {
466 Ok(timestamp) => timestamp.to_hms_micro(),
467 Err(e) => {
468 return Result::Err(
469 Error::FailedToCreateInsertStatement {
470 source: Box::new(e),
471 },
472 )
473 }
474 };
475
476 let time = match time::Time::from_hms_micro(h, m, s, micro) {
477 Ok(value) => value,
478 Err(e) => {
479 return Result::Err(Error::FailedToCreateInsertStatement {
480 source: Box::new(e),
481 })
482 }
483 };
484
485 row_values.push(time.into());
486 }
487 }
488 _ => unreachable!(),
489 },
490 DataType::Timestamp(TimeUnit::Second, timezone) => {
491 let array = column
492 .as_any()
493 .downcast_ref::<array::TimestampSecondArray>();
494
495 if let Some(valid_array) = array {
496 if valid_array.is_null(row) {
497 row_values.push(Keyword::Null.into());
498 continue;
499 }
500 if let Some(timezone) = timezone {
501 let utc_time = DateTime::from_timestamp_nanos(
502 valid_array.value(row) * 1_000_000_000,
503 )
504 .to_utc();
505 let timezone = Tz::from_str(timezone).map_err(|_| {
506 Error::FailedToCreateInsertStatement {
507 source: "Unable to parse arrow timezone information".into(),
508 }
509 })?;
510 let offset = timezone
511 .offset_from_utc_datetime(&utc_time.naive_utc())
512 .fix();
513 let time_with_offset = utc_time.with_timezone(&offset);
514 row_values.push(time_with_offset.into());
515 } else {
516 insert_timestamp_into_row_values(
517 OffsetDateTime::from_unix_timestamp(valid_array.value(row)),
518 &mut row_values,
519 )?;
520 }
521 }
522 }
523 DataType::Timestamp(TimeUnit::Millisecond, timezone) => {
524 let array = column
525 .as_any()
526 .downcast_ref::<array::TimestampMillisecondArray>();
527
528 if let Some(valid_array) = array {
529 if valid_array.is_null(row) {
530 row_values.push(Keyword::Null.into());
531 continue;
532 }
533 if let Some(timezone) = timezone {
534 let utc_time = DateTime::from_timestamp_nanos(
535 valid_array.value(row) * 1_000_000,
536 )
537 .to_utc();
538 let timezone = Tz::from_str(timezone).map_err(|_| {
539 Error::FailedToCreateInsertStatement {
540 source: "Unable to parse arrow timezone information".into(),
541 }
542 })?;
543 let offset = timezone
544 .offset_from_utc_datetime(&utc_time.naive_utc())
545 .fix();
546 let time_with_offset = utc_time.with_timezone(&offset);
547 row_values.push(time_with_offset.into());
548 } else {
549 insert_timestamp_into_row_values(
550 OffsetDateTime::from_unix_timestamp_nanos(
551 i128::from(valid_array.value(row)) * 1_000_000,
552 ),
553 &mut row_values,
554 )?;
555 }
556 }
557 }
558 DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
559 let array = column
560 .as_any()
561 .downcast_ref::<array::TimestampMicrosecondArray>();
562
563 if let Some(valid_array) = array {
564 if valid_array.is_null(row) {
565 row_values.push(Keyword::Null.into());
566 continue;
567 }
568 if let Some(timezone) = timezone {
569 let utc_time =
570 DateTime::from_timestamp_nanos(valid_array.value(row) * 1_000)
571 .to_utc();
572 let timezone = Tz::from_str(timezone).map_err(|_| {
573 Error::FailedToCreateInsertStatement {
574 source: "Unable to parse arrow timezone information".into(),
575 }
576 })?;
577 let offset = timezone
578 .offset_from_utc_datetime(&utc_time.naive_utc())
579 .fix();
580 let time_with_offset = utc_time.with_timezone(&offset);
581 row_values.push(time_with_offset.into());
582 } else {
583 insert_timestamp_into_row_values(
584 OffsetDateTime::from_unix_timestamp_nanos(
585 i128::from(valid_array.value(row)) * 1_000,
586 ),
587 &mut row_values,
588 )?;
589 }
590 }
591 }
592 DataType::Timestamp(TimeUnit::Nanosecond, timezone) => {
593 let array = column
594 .as_any()
595 .downcast_ref::<array::TimestampNanosecondArray>();
596
597 if let Some(valid_array) = array {
598 if valid_array.is_null(row) {
599 row_values.push(Keyword::Null.into());
600 continue;
601 }
602 if let Some(timezone) = timezone {
603 let utc_time =
604 DateTime::from_timestamp_nanos(valid_array.value(row)).to_utc();
605 let timezone = Tz::from_str(timezone).map_err(|_| {
606 Error::FailedToCreateInsertStatement {
607 source: "Unable to parse arrow timezone information".into(),
608 }
609 })?;
610 let offset = timezone
611 .offset_from_utc_datetime(&utc_time.naive_utc())
612 .fix();
613 let time_with_offset = utc_time.with_timezone(&offset);
614 row_values.push(time_with_offset.into());
615 } else {
616 insert_timestamp_into_row_values(
617 OffsetDateTime::from_unix_timestamp_nanos(i128::from(
618 valid_array.value(row),
619 )),
620 &mut row_values,
621 )?;
622 }
623 }
624 }
625 DataType::List(list_type) => {
626 let array = column.as_any().downcast_ref::<array::ListArray>();
627 if let Some(valid_array) = array {
628 if valid_array.is_null(row) {
629 row_values.push(Keyword::Null.into());
630 continue;
631 }
632 let list_array = valid_array.value(row);
633
634 if use_json_insert_for_type(column_data_type, query_builder) {
635 insert_list_into_row_values_json(
636 list_array,
637 list_type,
638 &mut row_values,
639 )?;
640 } else {
641 insert_list_into_row_values(list_array, list_type, &mut row_values);
642 }
643 }
644 }
645 DataType::LargeList(list_type) => {
646 let array = column.as_any().downcast_ref::<array::LargeListArray>();
647 if let Some(valid_array) = array {
648 if valid_array.is_null(row) {
649 row_values.push(Keyword::Null.into());
650 continue;
651 }
652 let list_array = valid_array.value(row);
653
654 if use_json_insert_for_type(column_data_type, query_builder) {
655 insert_list_into_row_values_json(
656 list_array,
657 list_type,
658 &mut row_values,
659 )?;
660 } else {
661 insert_list_into_row_values(list_array, list_type, &mut row_values);
662 }
663 }
664 }
665 DataType::FixedSizeList(list_type, _) => {
666 let array = column.as_any().downcast_ref::<array::FixedSizeListArray>();
667 if let Some(valid_array) = array {
668 if valid_array.is_null(row) {
669 row_values.push(Keyword::Null.into());
670 continue;
671 }
672 let list_array = valid_array.value(row);
673
674 if use_json_insert_for_type(column_data_type, query_builder) {
675 insert_list_into_row_values_json(
676 list_array,
677 list_type,
678 &mut row_values,
679 )?;
680 } else {
681 insert_list_into_row_values(list_array, list_type, &mut row_values);
682 }
683 }
684 }
685 DataType::Binary => {
686 let array = column.as_any().downcast_ref::<array::BinaryArray>();
687
688 if let Some(valid_array) = array {
689 if valid_array.is_null(row) {
690 row_values.push(Keyword::Null.into());
691 continue;
692 }
693
694 row_values.push(valid_array.value(row).into());
695 }
696 }
697 DataType::LargeBinary => {
698 let array = column.as_any().downcast_ref::<array::LargeBinaryArray>();
699
700 if let Some(valid_array) = array {
701 if valid_array.is_null(row) {
702 row_values.push(Keyword::Null.into());
703 continue;
704 }
705
706 row_values.push(valid_array.value(row).into());
707 }
708 }
709 DataType::FixedSizeBinary(_) => {
710 let array = column
711 .as_any()
712 .downcast_ref::<array::FixedSizeBinaryArray>();
713
714 if let Some(valid_array) = array {
715 if valid_array.is_null(row) {
716 row_values.push(Keyword::Null.into());
717 continue;
718 }
719
720 row_values.push(valid_array.value(row).into());
721 }
722 }
723 DataType::BinaryView => {
724 let array = column.as_any().downcast_ref::<array::BinaryViewArray>();
725
726 if let Some(valid_array) = array {
727 if valid_array.is_null(row) {
728 row_values.push(Keyword::Null.into());
729 continue;
730 }
731
732 row_values.push(valid_array.value(row).into());
733 }
734 }
735 DataType::Interval(interval_unit) => match interval_unit {
736 IntervalUnit::DayTime => {
737 let array = column
738 .as_any()
739 .downcast_ref::<array::IntervalDayTimeArray>();
740
741 if let Some(valid_array) = array {
742 if valid_array.is_null(row) {
743 row_values.push(Keyword::Null.into());
744 continue;
745 }
746
747 let interval_str =
748 if let Ok(str) = array_value_to_string(valid_array, row) {
749 str
750 } else {
751 let days = valid_array.value(row).days;
752 let milliseconds = valid_array.value(row).milliseconds;
753 format!("{days} days {milliseconds} milliseconds")
754 };
755
756 row_values.push(interval_str.into());
757 }
758 }
759 IntervalUnit::YearMonth => {
760 let array = column
761 .as_any()
762 .downcast_ref::<array::IntervalYearMonthArray>();
763
764 if let Some(valid_array) = array {
765 if valid_array.is_null(row) {
766 row_values.push(Keyword::Null.into());
767 continue;
768 }
769
770 let interval_str =
771 if let Ok(str) = array_value_to_string(valid_array, row) {
772 str
773 } else {
774 let months = valid_array.value(row);
775 format!("{months} months")
776 };
777
778 row_values.push(interval_str.into());
779 }
780 }
781 IntervalUnit::MonthDayNano => {
784 let array = column
785 .as_any()
786 .downcast_ref::<array::IntervalMonthDayNanoArray>();
787
788 if let Some(valid_array) = array {
789 if valid_array.is_null(row) {
790 row_values.push(Keyword::Null.into());
791 continue;
792 }
793
794 let interval_str =
795 if let Ok(str) = array_value_to_string(valid_array, row) {
796 str
797 } else {
798 let months = valid_array.value(row).months;
799 let days = valid_array.value(row).days;
800 let nanoseconds = valid_array.value(row).nanoseconds;
801 let micros = nanoseconds / 1_000;
802 format!("{months} months {days} days {micros} microseconds")
803 };
804
805 row_values.push(interval_str.into());
806 }
807 }
808 },
809 DataType::Struct(fields) => {
810 let array = column.as_any().downcast_ref::<array::StructArray>();
811
812 if let Some(valid_array) = array {
813 if valid_array.is_null(row) {
814 row_values.push(Keyword::Null.into());
815 continue;
816 }
817
818 if use_json_insert_for_type(column_data_type, query_builder) {
819 insert_struct_into_row_values_json(
820 fields,
821 valid_array,
822 row,
823 &mut row_values,
824 )?;
825 continue;
826 }
827
828 let mut param_values: Vec<SimpleExpr> = vec![];
829
830 for col in valid_array.columns() {
831 match col.data_type() {
832 DataType::Int8 => {
833 let int_array =
834 col.as_any().downcast_ref::<array::Int8Array>();
835
836 if let Some(valid_int_array) = int_array {
837 param_values.push(valid_int_array.value(row).into());
838 }
839 }
840 DataType::Int16 => {
841 let int_array =
842 col.as_any().downcast_ref::<array::Int16Array>();
843
844 if let Some(valid_int_array) = int_array {
845 param_values.push(valid_int_array.value(row).into());
846 }
847 }
848 DataType::Int32 => {
849 let int_array =
850 col.as_any().downcast_ref::<array::Int32Array>();
851
852 if let Some(valid_int_array) = int_array {
853 param_values.push(valid_int_array.value(row).into());
854 }
855 }
856 DataType::Int64 => {
857 let int_array =
858 col.as_any().downcast_ref::<array::Int64Array>();
859
860 if let Some(valid_int_array) = int_array {
861 param_values.push(valid_int_array.value(row).into());
862 }
863 }
864 DataType::UInt8 => {
865 let int_array =
866 col.as_any().downcast_ref::<array::UInt8Array>();
867
868 if let Some(valid_int_array) = int_array {
869 param_values.push(valid_int_array.value(row).into());
870 }
871 }
872 DataType::UInt16 => {
873 let int_array =
874 col.as_any().downcast_ref::<array::UInt16Array>();
875
876 if let Some(valid_int_array) = int_array {
877 param_values.push(valid_int_array.value(row).into());
878 }
879 }
880 DataType::UInt32 => {
881 let int_array =
882 col.as_any().downcast_ref::<array::UInt32Array>();
883
884 if let Some(valid_int_array) = int_array {
885 param_values.push(valid_int_array.value(row).into());
886 }
887 }
888 DataType::UInt64 => {
889 let int_array =
890 col.as_any().downcast_ref::<array::UInt64Array>();
891
892 if let Some(valid_int_array) = int_array {
893 param_values.push(valid_int_array.value(row).into());
894 }
895 }
896 DataType::Float32 => {
897 let float_array =
898 col.as_any().downcast_ref::<array::Float32Array>();
899
900 if let Some(valid_float_array) = float_array {
901 param_values.push(valid_float_array.value(row).into());
902 }
903 }
904 DataType::Float64 => {
905 let float_array =
906 col.as_any().downcast_ref::<array::Float64Array>();
907
908 if let Some(valid_float_array) = float_array {
909 param_values.push(valid_float_array.value(row).into());
910 }
911 }
912 DataType::Utf8 => {
913 let string_array =
914 col.as_any().downcast_ref::<array::StringArray>();
915
916 if let Some(valid_string_array) = string_array {
917 param_values.push(valid_string_array.value(row).into());
918 }
919 }
920 DataType::Null => {
921 param_values.push(Keyword::Null.into());
922 }
923 DataType::Boolean => {
924 let bool_array =
925 col.as_any().downcast_ref::<array::BooleanArray>();
926
927 if let Some(valid_bool_array) = bool_array {
928 param_values.push(valid_bool_array.value(row).into());
929 }
930 }
931 DataType::Binary => {
932 let binary_array =
933 col.as_any().downcast_ref::<array::BinaryArray>();
934
935 if let Some(valid_binary_array) = binary_array {
936 param_values.push(valid_binary_array.value(row).into());
937 }
938 }
939 DataType::FixedSizeBinary(_) => {
940 let binary_array = col
941 .as_any()
942 .downcast_ref::<array::FixedSizeBinaryArray>();
943
944 if let Some(valid_binary_array) = binary_array {
945 param_values.push(valid_binary_array.value(row).into());
946 }
947 }
948 DataType::LargeBinary => {
949 let binary_array =
950 col.as_any().downcast_ref::<array::LargeBinaryArray>();
951
952 if let Some(valid_binary_array) = binary_array {
953 param_values.push(valid_binary_array.value(row).into());
954 }
955 }
956 DataType::LargeUtf8 => {
957 let string_array =
958 col.as_any().downcast_ref::<array::LargeStringArray>();
959
960 if let Some(valid_string_array) = string_array {
961 param_values.push(valid_string_array.value(row).into());
962 }
963 }
964 DataType::Utf8View => {
965 let view_array =
966 col.as_any().downcast_ref::<array::StringViewArray>();
967
968 if let Some(valid_view_array) = view_array {
969 param_values.push(valid_view_array.value(row).into());
970 }
971 }
972 DataType::BinaryView => {
973 let view_array =
974 col.as_any().downcast_ref::<array::BinaryViewArray>();
975
976 if let Some(valid_view_array) = view_array {
977 param_values.push(valid_view_array.value(row).into());
978 }
979 }
980 DataType::Float16
981 | DataType::Timestamp(_, _)
982 | DataType::Date32
983 | DataType::Date64
984 | DataType::Time32(_)
985 | DataType::Time64(_)
986 | DataType::Duration(_)
987 | DataType::Interval(_)
988 | DataType::List(_)
989 | DataType::ListView(_)
990 | DataType::FixedSizeList(_, _)
991 | DataType::LargeList(_)
992 | DataType::LargeListView(_)
993 | DataType::Struct(_)
994 | DataType::Union(_, _)
995 | DataType::Dictionary(_, _)
996 | DataType::Map(_, _)
997 | DataType::RunEndEncoded(_, _)
998 | DataType::Decimal32(_, _)
999 | DataType::Decimal64(_, _)
1000 | DataType::Decimal128(_, _)
1001 | DataType::Decimal256(_, _) => {
1002 unimplemented!(
1003 "Data type mapping not implemented for Struct of {}",
1004 col.data_type()
1005 )
1006 }
1007 }
1008 }
1009
1010 let mut params_vec = Vec::new();
1011 for param_value in ¶m_values {
1012 let mut params_str = String::new();
1013 query_builder.prepare_simple_expr(param_value, &mut params_str);
1014 params_vec.push(params_str);
1015 }
1016
1017 let params = params_vec.join(", ");
1018 row_values.push(Expr::cust(format!("ROW({params})")));
1019 }
1020 }
1021 unimplemented_type => {
1022 return Result::Err(Error::UnimplementedDataTypeInInsertStatement {
1023 data_type: unimplemented_type.clone(),
1024 })
1025 }
1026 }
1027 }
1028 match insert_stmt.values(row_values) {
1029 Ok(_) => (),
1030 Err(e) => {
1031 return Result::Err(Error::FailedToCreateInsertStatement {
1032 source: Box::new(e),
1033 })
1034 }
1035 }
1036 }
1037 Ok(())
1038 }
1039
1040 pub fn build_postgres(self, on_conflict: Option<OnConflict>) -> Result<String> {
1045 self.build(PostgresQueryBuilder, on_conflict)
1046 }
1047
1048 pub fn build_sqlite(self, on_conflict: Option<OnConflict>) -> Result<String> {
1053 self.build(SqliteQueryBuilder, on_conflict)
1054 }
1055
1056 pub fn build_mysql(self, on_conflict: Option<OnConflict>) -> Result<String> {
1061 self.build(MysqlQueryBuilder, on_conflict)
1062 }
1063
1064 pub fn build<T: GenericBuilder + 'static>(
1069 &self,
1070 query_builder: T,
1071 on_conflict: Option<OnConflict>,
1072 ) -> Result<String> {
1073 let columns: Vec<Alias> = (self.record_batches[0])
1074 .schema()
1075 .fields()
1076 .iter()
1077 .map(|field| Alias::new(field.name()))
1078 .collect();
1079
1080 let mut insert_stmt = Query::insert()
1081 .into_table(table_reference_to_sea_table_ref(&self.table))
1082 .columns(columns)
1083 .to_owned();
1084
1085 for record_batch in &self.record_batches {
1086 self.construct_insert_stmt(&mut insert_stmt, record_batch, &query_builder)?;
1087 }
1088 if let Some(on_conflict) = on_conflict {
1089 insert_stmt.on_conflict(on_conflict);
1090 }
1091 Ok(insert_stmt.to_string(query_builder))
1092 }
1093}
1094
1095fn table_reference_to_sea_table_ref(table: &TableReference) -> TableRef {
1096 match table {
1097 TableReference::Bare { table } => {
1098 TableRef::Table(SeaRc::new(Alias::new(table.to_string())))
1099 }
1100 TableReference::Partial { schema, table } => TableRef::SchemaTable(
1101 SeaRc::new(Alias::new(schema.to_string())),
1102 SeaRc::new(Alias::new(table.to_string())),
1103 ),
1104 TableReference::Full {
1105 catalog,
1106 schema,
1107 table,
1108 } => TableRef::DatabaseSchemaTable(
1109 SeaRc::new(Alias::new(catalog.to_string())),
1110 SeaRc::new(Alias::new(schema.to_string())),
1111 SeaRc::new(Alias::new(table.to_string())),
1112 ),
1113 }
1114}
1115
1116pub struct IndexBuilder {
1117 table_name: String,
1118 columns: Vec<String>,
1119 unique: bool,
1120}
1121
1122impl IndexBuilder {
1123 #[must_use]
1124 pub fn new(table_name: &str, columns: Vec<&str>) -> Self {
1125 Self {
1126 table_name: table_name.to_string(),
1127 columns: columns.into_iter().map(ToString::to_string).collect(),
1128 unique: false,
1129 }
1130 }
1131
1132 #[must_use]
1133 pub fn unique(mut self) -> Self {
1134 self.unique = true;
1135 self
1136 }
1137
1138 #[must_use]
1139 pub fn index_name(&self) -> String {
1140 format!("i_{}_{}", self.table_name, self.columns.join("_"))
1141 }
1142
1143 #[must_use]
1144 pub fn build_postgres(self) -> String {
1145 self.build(PostgresQueryBuilder)
1146 }
1147
1148 #[must_use]
1149 pub fn build_sqlite(self) -> String {
1150 self.build(SqliteQueryBuilder)
1151 }
1152
1153 #[must_use]
1154 pub fn build_mysql(self) -> String {
1155 self.build(MysqlQueryBuilder)
1156 }
1157
1158 #[must_use]
1159 pub fn build<T: GenericBuilder>(self, query_builder: T) -> String {
1160 let mut index = Index::create();
1161 index.table(Alias::new(&self.table_name));
1162 index.name(self.index_name());
1163 if self.unique {
1164 index.unique();
1165 }
1166 for column in self.columns {
1167 index.col(Alias::new(column).into_iden().into_index_column());
1168 }
1169 index.if_not_exists();
1170 index.to_string(query_builder)
1171 }
1172}
1173
1174fn insert_timestamp_into_row_values(
1175 timestamp: Result<OffsetDateTime, time::error::ComponentRange>,
1176 row_values: &mut Vec<SimpleExpr>,
1177) -> Result<()> {
1178 match timestamp {
1179 Ok(offset_time) => {
1180 row_values.push(PrimitiveDateTime::new(offset_time.date(), offset_time.time()).into());
1181 Ok(())
1182 }
1183 Err(e) => Err(Error::FailedToCreateInsertStatement {
1184 source: Box::new(e),
1185 }),
1186 }
1187}
1188
1189#[allow(clippy::needless_pass_by_value)]
1190fn insert_list_into_row_values(
1191 list_array: Arc<dyn Array>,
1192 list_type: &Arc<Field>,
1193 row_values: &mut Vec<SimpleExpr>,
1194) {
1195 match list_type.data_type() {
1196 DataType::Int8 => push_list_values!(
1197 list_type.data_type(),
1198 list_array,
1199 row_values,
1200 array::Int8Array,
1201 i8,
1202 "int2[]"
1203 ),
1204 DataType::Int16 => push_list_values!(
1205 list_type.data_type(),
1206 list_array,
1207 row_values,
1208 array::Int16Array,
1209 i16,
1210 "int2[]"
1211 ),
1212 DataType::Int32 => push_list_values!(
1213 list_type.data_type(),
1214 list_array,
1215 row_values,
1216 array::Int32Array,
1217 i32,
1218 "int4[]"
1219 ),
1220 DataType::Int64 => push_list_values!(
1221 list_type.data_type(),
1222 list_array,
1223 row_values,
1224 array::Int64Array,
1225 i64,
1226 "int8[]"
1227 ),
1228 DataType::Float32 => push_list_values!(
1229 list_type.data_type(),
1230 list_array,
1231 row_values,
1232 array::Float32Array,
1233 f32,
1234 "float4[]"
1235 ),
1236 DataType::Float64 => push_list_values!(
1237 list_type.data_type(),
1238 list_array,
1239 row_values,
1240 array::Float64Array,
1241 f64,
1242 "float8[]"
1243 ),
1244 DataType::Utf8 => {
1245 let mut list_values: Vec<String> = vec![];
1246 for i in 0..list_array.len() {
1247 let int_array = list_array.as_any().downcast_ref::<array::StringArray>();
1248 if let Some(valid_int_array) = int_array {
1249 list_values.push(valid_int_array.value(i).to_string());
1250 }
1251 }
1252 let expr: SimpleExpr = list_values.into();
1253 row_values.push(expr.cast_as(Alias::new("text[]")));
1255 }
1256 DataType::LargeUtf8 => {
1257 let mut list_values: Vec<String> = vec![];
1258 for i in 0..list_array.len() {
1259 let int_array = list_array
1260 .as_any()
1261 .downcast_ref::<array::LargeStringArray>();
1262 if let Some(valid_int_array) = int_array {
1263 list_values.push(valid_int_array.value(i).to_string());
1264 }
1265 }
1266 let expr: SimpleExpr = list_values.into();
1267 row_values.push(expr.cast_as(Alias::new("text[]")));
1269 }
1270 DataType::Utf8View => {
1271 let mut list_values: Vec<String> = vec![];
1272 for i in 0..list_array.len() {
1273 let view_array = list_array.as_any().downcast_ref::<array::StringViewArray>();
1274 if let Some(valid_view_array) = view_array {
1275 list_values.push(valid_view_array.value(i).to_string());
1276 }
1277 }
1278 let expr: SimpleExpr = list_values.into();
1279 row_values.push(expr.cast_as(Alias::new("text[]")));
1280 }
1281 DataType::Boolean => push_list_values!(
1282 list_type.data_type(),
1283 list_array,
1284 row_values,
1285 array::BooleanArray,
1286 bool,
1287 "boolean[]"
1288 ),
1289 DataType::Binary => {
1290 let mut list_values: Vec<Vec<u8>> = Vec::new();
1291 for i in 0..list_array.len() {
1292 let temp_array = list_array.as_any().downcast_ref::<array::BinaryArray>();
1293 if let Some(valid_array) = temp_array {
1294 list_values.push(valid_array.value(i).to_vec());
1295 }
1296 }
1297 let expr: SimpleExpr = list_values.into();
1298 row_values.push(expr.cast_as(Alias::new("bytea[]")));
1300 }
1301 _ => unimplemented!(
1302 "Data type mapping not implemented for {}",
1303 list_type.data_type()
1304 ),
1305 }
1306}
1307
1308#[allow(clippy::cast_sign_loss)]
1309pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
1310 match data_type {
1311 DataType::Int8 => ColumnType::TinyInteger,
1312 DataType::Int16 => ColumnType::SmallInteger,
1313 DataType::Int32 => ColumnType::Integer,
1314 DataType::Int64 | DataType::Duration(_) => ColumnType::BigInteger,
1315 DataType::UInt8 => ColumnType::TinyUnsigned,
1316 DataType::UInt16 => ColumnType::SmallUnsigned,
1317 DataType::UInt32 => ColumnType::Unsigned,
1318 DataType::UInt64 => ColumnType::BigUnsigned,
1319 DataType::Float32 => ColumnType::Float,
1320 DataType::Float64 => ColumnType::Double,
1321 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => ColumnType::Text,
1322 DataType::Boolean => ColumnType::Boolean,
1323 #[allow(clippy::cast_sign_loss)] DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
1325 ColumnType::Decimal(Some((u32::from(*p), *s as u32)))
1326 }
1327 DataType::Timestamp(_unit, time_zone) => {
1328 if time_zone.is_some() {
1329 return ColumnType::TimestampWithTimeZone;
1330 }
1331 ColumnType::Timestamp
1332 }
1333 DataType::Date32 | DataType::Date64 => ColumnType::Date,
1334 DataType::Time64(_unit) | DataType::Time32(_unit) => ColumnType::Time,
1335 DataType::List(list_type)
1336 | DataType::LargeList(list_type)
1337 | DataType::FixedSizeList(list_type, _) => {
1338 ColumnType::Array(map_data_type_to_column_type(list_type.data_type()).into())
1339 }
1340 DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
1345 DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
1346 DataType::Interval(_) => ColumnType::Interval(None, None),
1347 _ => unimplemented!("Data type mapping not implemented for {:?}", data_type),
1349 }
1350}
1351
1352macro_rules! serialize_list_values {
1353 ($data_type:expr, $list_array:expr, $array_type:ty, $vec_type:ty) => {{
1354 let mut list_values: Vec<$vec_type> = vec![];
1355 if let Some(array) = $list_array.as_any().downcast_ref::<$array_type>() {
1356 for i in 0..array.len() {
1357 list_values.push(array.value(i).into());
1358 }
1359 }
1360
1361 serde_json::to_string(&list_values).map_err(|e| Error::FailedToCreateInsertStatement {
1362 source: Box::new(e),
1363 })?
1364 }};
1365}
1366
1367fn insert_list_into_row_values_json(
1368 list_array: Arc<dyn Array>,
1369 list_type: &Arc<Field>,
1370 row_values: &mut Vec<SimpleExpr>,
1371) -> Result<()> {
1372 let data_type = list_type.data_type();
1373
1374 let json_string: String = match data_type {
1375 DataType::Int8 => serialize_list_values!(data_type, list_array, Int8Array, i8),
1376 DataType::Int16 => serialize_list_values!(data_type, list_array, Int16Array, i16),
1377 DataType::Int32 => serialize_list_values!(data_type, list_array, Int32Array, i32),
1378 DataType::Int64 => serialize_list_values!(data_type, list_array, Int64Array, i64),
1379 DataType::UInt8 => serialize_list_values!(data_type, list_array, UInt8Array, u8),
1380 DataType::UInt16 => serialize_list_values!(data_type, list_array, UInt16Array, u16),
1381 DataType::UInt32 => serialize_list_values!(data_type, list_array, UInt32Array, u32),
1382 DataType::UInt64 => serialize_list_values!(data_type, list_array, UInt64Array, u64),
1383 DataType::Float32 => serialize_list_values!(data_type, list_array, Float32Array, f32),
1384 DataType::Float64 => serialize_list_values!(data_type, list_array, Float64Array, f64),
1385 DataType::Utf8 => serialize_list_values!(data_type, list_array, StringArray, String),
1386 DataType::LargeUtf8 => {
1387 serialize_list_values!(data_type, list_array, LargeStringArray, String)
1388 }
1389 DataType::Utf8View => {
1390 serialize_list_values!(data_type, list_array, StringViewArray, String)
1391 }
1392 DataType::Boolean => serialize_list_values!(data_type, list_array, BooleanArray, bool),
1393 _ => unimplemented!(
1394 "List to json conversion is not implemented for {}",
1395 list_type.data_type()
1396 ),
1397 };
1398
1399 let expr: SimpleExpr = Expr::value(json_string);
1400 row_values.push(expr);
1401
1402 Ok(())
1403}
1404
1405fn insert_struct_into_row_values_json(
1406 fields: &Fields,
1407 array: &StructArray,
1408 row_index: usize,
1409 row_values: &mut Vec<SimpleExpr>,
1410) -> Result<()> {
1411 let single_row_columns: Vec<ArrayRef> = (0..array.num_columns())
1415 .map(|i| array.column(i).slice(row_index, 1))
1416 .collect();
1417
1418 let batch = RecordBatch::try_new(Arc::new(Schema::new(fields.clone())), single_row_columns)
1419 .map_err(|e| Error::FailedToCreateInsertStatement {
1420 source: Box::new(e),
1421 })?;
1422
1423 let mut writer = datafusion::arrow::json::LineDelimitedWriter::new(Vec::new());
1424 writer
1425 .write(&batch)
1426 .map_err(|e| Error::FailedToCreateInsertStatement {
1427 source: Box::new(e),
1428 })?;
1429 writer
1430 .finish()
1431 .map_err(|e| Error::FailedToCreateInsertStatement {
1432 source: Box::new(e),
1433 })?;
1434 let json_bytes = writer.into_inner();
1435
1436 let json = String::from_utf8(json_bytes).map_err(|e| Error::FailedToCreateInsertStatement {
1437 source: Box::new(e),
1438 })?;
1439
1440 let expr: SimpleExpr = Expr::value(json);
1441 row_values.push(expr);
1442
1443 Ok(())
1444}
1445
1446#[cfg(test)]
1447mod tests {
1448 use std::sync::Arc;
1449
1450 use super::*;
1451 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema};
1452
1453 #[test]
1454 fn test_basic_table_creation() {
1455 let schema = Schema::new(vec![
1456 Field::new("id", DataType::Int32, false),
1457 Field::new("name", DataType::Utf8, false),
1458 Field::new("age", DataType::Int32, true),
1459 ]);
1460 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users").build_sqlite();
1461
1462 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer )");
1463 }
1464
1465 #[test]
1466 fn test_table_insertion() {
1467 let schema1 = Schema::new(vec![
1468 Field::new("id", DataType::Int32, false),
1469 Field::new("name", DataType::Utf8, false),
1470 Field::new("age", DataType::Int32, true),
1471 ]);
1472 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1473 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1474 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1475
1476 let batch1 = RecordBatch::try_new(
1477 Arc::new(schema1.clone()),
1478 vec![
1479 Arc::new(id_array.clone()),
1480 Arc::new(name_array.clone()),
1481 Arc::new(age_array.clone()),
1482 ],
1483 )
1484 .expect("Unable to build record batch");
1485
1486 let schema2 = Schema::new(vec![
1487 Field::new("id", DataType::Int32, false),
1488 Field::new("name", DataType::Utf8, false),
1489 Field::new("blah", DataType::Int32, true),
1490 ]);
1491
1492 let batch2 = RecordBatch::try_new(
1493 Arc::new(schema2),
1494 vec![
1495 Arc::new(id_array),
1496 Arc::new(name_array),
1497 Arc::new(age_array),
1498 ],
1499 )
1500 .expect("Unable to build record batch");
1501 let record_batches = vec![batch1, batch2];
1502
1503 let sql = InsertBuilder::new(&TableReference::from("users"), record_batches)
1504 .build_postgres(None)
1505 .expect("Failed to build insert statement");
1506 assert_eq!(sql, "INSERT INTO \"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1507 }
1508
1509 #[test]
1510 fn test_table_insertion_with_schema() {
1511 let schema1 = Schema::new(vec![
1512 Field::new("id", DataType::Int32, false),
1513 Field::new("name", DataType::Utf8, false),
1514 Field::new("age", DataType::Int32, true),
1515 ]);
1516 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1517 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1518 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1519
1520 let batch1 = RecordBatch::try_new(
1521 Arc::new(schema1.clone()),
1522 vec![
1523 Arc::new(id_array.clone()),
1524 Arc::new(name_array.clone()),
1525 Arc::new(age_array.clone()),
1526 ],
1527 )
1528 .expect("Unable to build record batch");
1529
1530 let schema2 = Schema::new(vec![
1531 Field::new("id", DataType::Int32, false),
1532 Field::new("name", DataType::Utf8, false),
1533 Field::new("blah", DataType::Int32, true),
1534 ]);
1535
1536 let batch2 = RecordBatch::try_new(
1537 Arc::new(schema2),
1538 vec![
1539 Arc::new(id_array),
1540 Arc::new(name_array),
1541 Arc::new(age_array),
1542 ],
1543 )
1544 .expect("Unable to build record batch");
1545 let record_batches = vec![batch1, batch2];
1546
1547 let sql = InsertBuilder::new(&TableReference::from("schema.users"), record_batches)
1548 .build_postgres(None)
1549 .expect("Failed to build insert statement");
1550 assert_eq!(sql, "INSERT INTO \"schema\".\"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1551 }
1552
1553 #[test]
1554 fn test_table_creation_with_primary_keys() {
1555 let schema = Schema::new(vec![
1556 Field::new("id", DataType::Int32, false),
1557 Field::new("id2", DataType::Int32, false),
1558 Field::new("name", DataType::Utf8, false),
1559 Field::new("age", DataType::Int32, true),
1560 ]);
1561 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1562 .primary_keys(vec!["id", "id2"])
1563 .build_sqlite();
1564
1565 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"id2\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer, PRIMARY KEY (\"id\", \"id2\") )");
1566 }
1567
1568 #[test]
1569 fn test_temporary_table_creation() {
1570 let schema = Schema::new(vec![
1571 Field::new("id", DataType::Int32, false),
1572 Field::new("name", DataType::Utf8, false),
1573 ]);
1574 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1575 .primary_keys(vec!["id"])
1576 .temporary(true)
1577 .build_sqlite();
1578
1579 assert_eq!(sql, "CREATE TEMPORARY TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, PRIMARY KEY (\"id\") )");
1580 }
1581
1582 #[test]
1583 fn test_table_insertion_with_list() {
1584 let schema1 = Schema::new(vec![Field::new(
1585 "list",
1586 DataType::List(Field::new("item", DataType::Int32, true).into()),
1587 true,
1588 )]);
1589 let list_array = array::ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1590 Some(vec![Some(1), Some(2), Some(3)]),
1591 Some(vec![Some(4), Some(5), Some(6)]),
1592 Some(vec![Some(7), Some(8), Some(9)]),
1593 ]);
1594
1595 let batch = RecordBatch::try_new(Arc::new(schema1.clone()), vec![Arc::new(list_array)])
1596 .expect("Unable to build record batch");
1597
1598 let sql = InsertBuilder::new(&TableReference::from("arrays"), vec![batch])
1599 .build_postgres(None)
1600 .expect("Failed to build insert statement");
1601 assert_eq!(
1602 sql,
1603 "INSERT INTO \"arrays\" (\"list\") VALUES (CAST(ARRAY [1,2,3] AS int4[])), (CAST(ARRAY [4,5,6] AS int4[])), (CAST(ARRAY [7,8,9] AS int4[]))"
1604 );
1605 }
1606
1607 #[test]
1608 fn test_create_index() {
1609 let sql = IndexBuilder::new("users", vec!["id", "name"]).build_postgres();
1610 assert_eq!(
1611 sql,
1612 r#"CREATE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1613 );
1614 }
1615
1616 #[test]
1617 fn test_create_unique_index() {
1618 let sql = IndexBuilder::new("users", vec!["id", "name"])
1619 .unique()
1620 .build_postgres();
1621 assert_eq!(
1622 sql,
1623 r#"CREATE UNIQUE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1624 );
1625 }
1626}