1use bigdecimal::BigDecimal;
2use chrono::{DateTime, Offset, TimeZone};
3use datafusion::arrow::{
4 array::{
5 array, timezone::Tz, Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
6 Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
7 StringViewArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
8 },
9 datatypes::{DataType, Field, Fields, IntervalUnit, Schema, SchemaRef, TimeUnit},
10 util::display::array_value_to_string,
11};
12use datafusion::sql::TableReference;
13use num_bigint::BigInt;
14use sea_query::{
15 Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden,
16 IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query,
17 QueryBuilder, SeaRc, SimpleExpr, SqliteQueryBuilder, Table, TableRef,
18};
19use snafu::Snafu;
20use std::{str::FromStr, sync::Arc};
21use time::{OffsetDateTime, PrimitiveDateTime};
22
23#[derive(Debug, Snafu)]
24pub enum Error {
25 #[snafu(display("Failed to build insert statement: {source}"))]
26 FailedToCreateInsertStatement {
27 source: Box<dyn std::error::Error + Send + Sync>,
28 },
29
30 #[snafu(display("Unimplemented data type in insert statement: {data_type:?}"))]
31 UnimplementedDataTypeInInsertStatement { data_type: DataType },
32}
33
34pub type Result<T, E = Error> = std::result::Result<T, E>;
35
36pub struct CreateTableBuilder {
37 schema: SchemaRef,
38 table_name: String,
39 primary_keys: Vec<String>,
40 temporary: bool,
41}
42
43impl CreateTableBuilder {
44 #[must_use]
45 pub fn new(schema: SchemaRef, table_name: &str) -> Self {
46 Self {
47 schema,
48 table_name: table_name.to_string(),
49 primary_keys: Vec::new(),
50 temporary: false,
51 }
52 }
53
54 #[must_use]
55 pub fn primary_keys<T>(mut self, keys: Vec<T>) -> Self
56 where
57 T: Into<String>,
58 {
59 self.primary_keys = keys.into_iter().map(Into::into).collect();
60 self
61 }
62
63 #[must_use]
64 pub fn temporary(mut self, temporary: bool) -> Self {
66 self.temporary = temporary;
67 self
68 }
69
70 #[must_use]
71 #[cfg(feature = "postgres")]
72 pub fn build_postgres(self) -> Vec<String> {
73 use crate::sql::arrow_sql_gen::postgres::{
74 builder::TypeBuilder, get_postgres_composite_type_name,
75 map_data_type_to_column_type_postgres,
76 };
77 let schema = Arc::clone(&self.schema);
78 let table_name = self.table_name.clone();
79 let main_table_creation =
80 self.build(PostgresQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
81 map_data_type_to_column_type_postgres(f.data_type(), &table_name, f.name())
82 });
83
84 let mut creation_stmts = Vec::new();
87 for field in schema.fields() {
88 let DataType::Struct(struct_inner_fields) = field.data_type() else {
89 continue;
90 };
91 let type_builder = TypeBuilder::new(
92 get_postgres_composite_type_name(&table_name, field.name()),
93 struct_inner_fields,
94 );
95 creation_stmts.push(type_builder.build());
96 }
97
98 creation_stmts.push(main_table_creation);
99 creation_stmts
100 }
101
102 #[must_use]
103 pub fn build_sqlite(self) -> String {
104 self.build(SqliteQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
105 if f.data_type().is_nested() {
108 return ColumnType::JsonBinary;
109 }
110
111 map_data_type_to_column_type(f.data_type())
112 })
113 }
114
115 #[must_use]
116 pub fn build_mysql(self) -> String {
117 self.build(MysqlQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
118 if f.data_type().is_nested() {
121 return ColumnType::JsonBinary;
122 }
123 map_data_type_to_column_type(f.data_type())
124 })
125 }
126
127 #[must_use]
128 fn build<T: GenericBuilder>(
129 self,
130 query_builder: T,
131 map_data_type_to_column_type_fn: &dyn Fn(&Arc<Field>) -> ColumnType,
132 ) -> String {
133 let mut create_stmt = Table::create();
134 create_stmt
135 .table(Alias::new(self.table_name.clone()))
136 .if_not_exists();
137
138 for field in self.schema.fields() {
139 let column_type = map_data_type_to_column_type_fn(field);
140 let mut column_def = ColumnDef::new_with_type(Alias::new(field.name()), column_type);
141 if !field.is_nullable() {
142 column_def.not_null();
143 }
144
145 create_stmt.col(&mut column_def);
146 }
147
148 if !self.primary_keys.is_empty() {
149 let mut index = Index::create();
150 index.primary();
151 for key in self.primary_keys {
152 index.col(Alias::new(key).into_iden().into_index_column());
153 }
154 create_stmt.primary_key(&mut index);
155 }
156
157 if self.temporary {
158 create_stmt.temporary();
159 }
160
161 create_stmt.to_string(query_builder)
162 }
163}
164
165macro_rules! push_value {
166 ($row_values:expr, $column:expr, $row:expr, $array_type:ident) => {{
167 let array = $column.as_any().downcast_ref::<array::$array_type>();
168 if let Some(valid_array) = array {
169 if valid_array.is_null($row) {
170 $row_values.push(Keyword::Null.into());
171 continue;
172 }
173 $row_values.push(valid_array.value($row).into());
174 }
175 }};
176}
177
178macro_rules! push_list_values {
179 ($data_type:expr, $list_array:expr, $row_values:expr, $array_type:ty, $vec_type:ty, $sql_type:expr) => {{
180 let mut list_values: Vec<$vec_type> = Vec::new();
181 for i in 0..$list_array.len() {
182 let temp_array = $list_array.as_any().downcast_ref::<$array_type>();
183 if let Some(valid_array) = temp_array {
184 list_values.push(valid_array.value(i));
185 }
186 }
187 let expr: SimpleExpr = list_values.into();
188 $row_values.push(expr.cast_as(Alias::new($sql_type)));
190 }};
191}
192
193pub struct InsertBuilder {
194 table: TableReference,
195 record_batches: Vec<RecordBatch>,
196}
197
198pub fn use_json_insert_for_type<T: QueryBuilder + 'static>(
199 #[allow(unused_variables)] data_type: &DataType,
200 #[allow(unused_variables)] query_builder: &T,
201) -> bool {
202 #[cfg(feature = "sqlite")]
203 {
204 use std::any::Any;
205 let any_builder = query_builder as &dyn Any;
206 if any_builder.is::<SqliteQueryBuilder>() {
207 return data_type.is_nested();
208 }
209 }
210 #[cfg(feature = "mysql")]
211 {
212 use std::any::Any;
213 let any_builder = query_builder as &dyn Any;
214 if any_builder.is::<MysqlQueryBuilder>() {
215 return data_type.is_nested();
216 }
217 }
218 false
219}
220
221impl InsertBuilder {
222 #[must_use]
223 pub fn new(table: &TableReference, record_batches: Vec<RecordBatch>) -> Self {
224 Self {
225 table: table.clone(),
226 record_batches,
227 }
228 }
229
230 #[allow(clippy::too_many_lines)]
236 pub fn construct_insert_stmt<T: QueryBuilder + 'static>(
237 &self,
238 insert_stmt: &mut InsertStatement,
239 record_batch: &RecordBatch,
240 query_builder: &T,
241 ) -> Result<()> {
242 for row in 0..record_batch.num_rows() {
243 let mut row_values: Vec<SimpleExpr> = vec![];
244 for col in 0..record_batch.num_columns() {
245 let column = record_batch.column(col);
246 let column_data_type = column.data_type();
247
248 match column_data_type {
249 DataType::Int8 => push_value!(row_values, column, row, Int8Array),
250 DataType::Int16 => push_value!(row_values, column, row, Int16Array),
251 DataType::Int32 => push_value!(row_values, column, row, Int32Array),
252 DataType::Int64 => push_value!(row_values, column, row, Int64Array),
253 DataType::UInt8 => push_value!(row_values, column, row, UInt8Array),
254 DataType::UInt16 => push_value!(row_values, column, row, UInt16Array),
255 DataType::UInt32 => push_value!(row_values, column, row, UInt32Array),
256 DataType::UInt64 => push_value!(row_values, column, row, UInt64Array),
257 DataType::Float32 => push_value!(row_values, column, row, Float32Array),
258 DataType::Float64 => push_value!(row_values, column, row, Float64Array),
259 DataType::Utf8 => push_value!(row_values, column, row, StringArray),
260 DataType::LargeUtf8 => push_value!(row_values, column, row, LargeStringArray),
261 DataType::Utf8View => push_value!(row_values, column, row, StringViewArray),
262 DataType::Boolean => push_value!(row_values, column, row, BooleanArray),
263 DataType::Decimal128(_, scale) => {
264 let array = column.as_any().downcast_ref::<array::Decimal128Array>();
265 if let Some(valid_array) = array {
266 if valid_array.is_null(row) {
267 row_values.push(Keyword::Null.into());
268 continue;
269 }
270 row_values.push(
271 BigDecimal::new(valid_array.value(row).into(), i64::from(*scale))
272 .into(),
273 );
274 }
275 }
276 DataType::Decimal256(_, scale) => {
277 let array = column.as_any().downcast_ref::<array::Decimal256Array>();
278 if let Some(valid_array) = array {
279 if valid_array.is_null(row) {
280 row_values.push(Keyword::Null.into());
281 continue;
282 }
283
284 let bigint =
285 BigInt::from_signed_bytes_le(&valid_array.value(row).to_le_bytes());
286
287 row_values.push(BigDecimal::new(bigint, i64::from(*scale)).into());
288 }
289 }
290 DataType::Date32 => {
291 let array = column.as_any().downcast_ref::<array::Date32Array>();
292 if let Some(valid_array) = array {
293 if valid_array.is_null(row) {
294 row_values.push(Keyword::Null.into());
295 continue;
296 }
297 row_values.push(
298 match OffsetDateTime::from_unix_timestamp(
299 i64::from(valid_array.value(row)) * 86_400,
300 ) {
301 Ok(offset_time) => offset_time.date().into(),
302 Err(e) => {
303 return Result::Err(Error::FailedToCreateInsertStatement {
304 source: Box::new(e),
305 })
306 }
307 },
308 );
309 }
310 }
311 DataType::Date64 => {
312 let array = column.as_any().downcast_ref::<array::Date64Array>();
313 if let Some(valid_array) = array {
314 if valid_array.is_null(row) {
315 row_values.push(Keyword::Null.into());
316 continue;
317 }
318 row_values.push(
319 match OffsetDateTime::from_unix_timestamp(
320 valid_array.value(row) / 1000,
321 ) {
322 Ok(offset_time) => offset_time.date().into(),
323 Err(e) => {
324 return Result::Err(Error::FailedToCreateInsertStatement {
325 source: Box::new(e),
326 })
327 }
328 },
329 );
330 }
331 }
332 DataType::Duration(time_unit) => match time_unit {
333 TimeUnit::Second => {
334 push_value!(row_values, column, row, DurationSecondArray);
335 }
336 TimeUnit::Microsecond => {
337 push_value!(row_values, column, row, DurationMicrosecondArray);
338 }
339 TimeUnit::Millisecond => {
340 push_value!(row_values, column, row, DurationMillisecondArray);
341 }
342 TimeUnit::Nanosecond => {
343 push_value!(row_values, column, row, DurationNanosecondArray);
344 }
345 },
346 DataType::Time32(time_unit) => match time_unit {
347 TimeUnit::Millisecond => {
348 let array = column
349 .as_any()
350 .downcast_ref::<array::Time32MillisecondArray>();
351 if let Some(valid_array) = array {
352 if valid_array.is_null(row) {
353 row_values.push(Keyword::Null.into());
354 continue;
355 }
356
357 let (h, m, s, micro) =
358 match OffsetDateTime::from_unix_timestamp_nanos(
359 i128::from(valid_array.value(row)) * 1_000_000,
360 ) {
361 Ok(timestamp) => timestamp.to_hms_micro(),
362 Err(e) => {
363 return Result::Err(
364 Error::FailedToCreateInsertStatement {
365 source: Box::new(e),
366 },
367 )
368 }
369 };
370
371 let time = match time::Time::from_hms_micro(h, m, s, micro) {
372 Ok(value) => value,
373 Err(e) => {
374 return Result::Err(Error::FailedToCreateInsertStatement {
375 source: Box::new(e),
376 })
377 }
378 };
379
380 row_values.push(time.into());
381 }
382 }
383 TimeUnit::Second => {
384 let array = column.as_any().downcast_ref::<array::Time32SecondArray>();
385 if let Some(valid_array) = array {
386 if valid_array.is_null(row) {
387 row_values.push(Keyword::Null.into());
388 continue;
389 }
390
391 let (h, m, s) = match OffsetDateTime::from_unix_timestamp(
392 i64::from(valid_array.value(row)),
393 ) {
394 Ok(timestamp) => timestamp.to_hms(),
395 Err(e) => {
396 return Result::Err(Error::FailedToCreateInsertStatement {
397 source: Box::new(e),
398 })
399 }
400 };
401
402 let time = match time::Time::from_hms(h, m, s) {
403 Ok(value) => value,
404 Err(e) => {
405 return Result::Err(Error::FailedToCreateInsertStatement {
406 source: Box::new(e),
407 })
408 }
409 };
410
411 row_values.push(time.into());
412 }
413 }
414 _ => unreachable!(),
415 },
416 DataType::Time64(time_unit) => match time_unit {
417 TimeUnit::Nanosecond => {
418 let array = column
419 .as_any()
420 .downcast_ref::<array::Time64NanosecondArray>();
421 if let Some(valid_array) = array {
422 if valid_array.is_null(row) {
423 row_values.push(Keyword::Null.into());
424 continue;
425 }
426 let (h, m, s, nano) =
427 match OffsetDateTime::from_unix_timestamp_nanos(i128::from(
428 valid_array.value(row),
429 )) {
430 Ok(timestamp) => timestamp.to_hms_nano(),
431 Err(e) => {
432 return Result::Err(
433 Error::FailedToCreateInsertStatement {
434 source: Box::new(e),
435 },
436 )
437 }
438 };
439
440 let time = match time::Time::from_hms_nano(h, m, s, nano) {
441 Ok(value) => value,
442 Err(e) => {
443 return Result::Err(Error::FailedToCreateInsertStatement {
444 source: Box::new(e),
445 })
446 }
447 };
448
449 row_values.push(time.into());
450 }
451 }
452 TimeUnit::Microsecond => {
453 let array = column
454 .as_any()
455 .downcast_ref::<array::Time64MicrosecondArray>();
456 if let Some(valid_array) = array {
457 if valid_array.is_null(row) {
458 row_values.push(Keyword::Null.into());
459 continue;
460 }
461
462 let (h, m, s, micro) =
463 match OffsetDateTime::from_unix_timestamp_nanos(
464 i128::from(valid_array.value(row)) * 1_000,
465 ) {
466 Ok(timestamp) => timestamp.to_hms_micro(),
467 Err(e) => {
468 return Result::Err(
469 Error::FailedToCreateInsertStatement {
470 source: Box::new(e),
471 },
472 )
473 }
474 };
475
476 let time = match time::Time::from_hms_micro(h, m, s, micro) {
477 Ok(value) => value,
478 Err(e) => {
479 return Result::Err(Error::FailedToCreateInsertStatement {
480 source: Box::new(e),
481 })
482 }
483 };
484
485 row_values.push(time.into());
486 }
487 }
488 _ => unreachable!(),
489 },
490 DataType::Timestamp(TimeUnit::Second, timezone) => {
491 let array = column
492 .as_any()
493 .downcast_ref::<array::TimestampSecondArray>();
494
495 if let Some(valid_array) = array {
496 if valid_array.is_null(row) {
497 row_values.push(Keyword::Null.into());
498 continue;
499 }
500 if let Some(timezone) = timezone {
501 let utc_time = DateTime::from_timestamp_nanos(
502 valid_array.value(row) * 1_000_000_000,
503 )
504 .to_utc();
505 let timezone = Tz::from_str(timezone).map_err(|_| {
506 Error::FailedToCreateInsertStatement {
507 source: "Unable to parse arrow timezone information".into(),
508 }
509 })?;
510 let offset = timezone
511 .offset_from_utc_datetime(&utc_time.naive_utc())
512 .fix();
513 let time_with_offset = utc_time.with_timezone(&offset);
514 row_values.push(time_with_offset.into());
515 } else {
516 insert_timestamp_into_row_values(
517 OffsetDateTime::from_unix_timestamp(valid_array.value(row)),
518 &mut row_values,
519 )?;
520 }
521 }
522 }
523 DataType::Timestamp(TimeUnit::Millisecond, timezone) => {
524 let array = column
525 .as_any()
526 .downcast_ref::<array::TimestampMillisecondArray>();
527
528 if let Some(valid_array) = array {
529 if valid_array.is_null(row) {
530 row_values.push(Keyword::Null.into());
531 continue;
532 }
533 if let Some(timezone) = timezone {
534 let utc_time = DateTime::from_timestamp_nanos(
535 valid_array.value(row) * 1_000_000,
536 )
537 .to_utc();
538 let timezone = Tz::from_str(timezone).map_err(|_| {
539 Error::FailedToCreateInsertStatement {
540 source: "Unable to parse arrow timezone information".into(),
541 }
542 })?;
543 let offset = timezone
544 .offset_from_utc_datetime(&utc_time.naive_utc())
545 .fix();
546 let time_with_offset = utc_time.with_timezone(&offset);
547 row_values.push(time_with_offset.into());
548 } else {
549 insert_timestamp_into_row_values(
550 OffsetDateTime::from_unix_timestamp_nanos(
551 i128::from(valid_array.value(row)) * 1_000_000,
552 ),
553 &mut row_values,
554 )?;
555 }
556 }
557 }
558 DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
559 let array = column
560 .as_any()
561 .downcast_ref::<array::TimestampMicrosecondArray>();
562
563 if let Some(valid_array) = array {
564 if valid_array.is_null(row) {
565 row_values.push(Keyword::Null.into());
566 continue;
567 }
568 if let Some(timezone) = timezone {
569 let utc_time =
570 DateTime::from_timestamp_nanos(valid_array.value(row) * 1_000)
571 .to_utc();
572 let timezone = Tz::from_str(timezone).map_err(|_| {
573 Error::FailedToCreateInsertStatement {
574 source: "Unable to parse arrow timezone information".into(),
575 }
576 })?;
577 let offset = timezone
578 .offset_from_utc_datetime(&utc_time.naive_utc())
579 .fix();
580 let time_with_offset = utc_time.with_timezone(&offset);
581 row_values.push(time_with_offset.into());
582 } else {
583 insert_timestamp_into_row_values(
584 OffsetDateTime::from_unix_timestamp_nanos(
585 i128::from(valid_array.value(row)) * 1_000,
586 ),
587 &mut row_values,
588 )?;
589 }
590 }
591 }
592 DataType::Timestamp(TimeUnit::Nanosecond, timezone) => {
593 let array = column
594 .as_any()
595 .downcast_ref::<array::TimestampNanosecondArray>();
596
597 if let Some(valid_array) = array {
598 if valid_array.is_null(row) {
599 row_values.push(Keyword::Null.into());
600 continue;
601 }
602 if let Some(timezone) = timezone {
603 let utc_time =
604 DateTime::from_timestamp_nanos(valid_array.value(row)).to_utc();
605 let timezone = Tz::from_str(timezone).map_err(|_| {
606 Error::FailedToCreateInsertStatement {
607 source: "Unable to parse arrow timezone information".into(),
608 }
609 })?;
610 let offset = timezone
611 .offset_from_utc_datetime(&utc_time.naive_utc())
612 .fix();
613 let time_with_offset = utc_time.with_timezone(&offset);
614 row_values.push(time_with_offset.into());
615 } else {
616 insert_timestamp_into_row_values(
617 OffsetDateTime::from_unix_timestamp_nanos(i128::from(
618 valid_array.value(row),
619 )),
620 &mut row_values,
621 )?;
622 }
623 }
624 }
625 DataType::List(list_type) => {
626 let array = column.as_any().downcast_ref::<array::ListArray>();
627 if let Some(valid_array) = array {
628 if valid_array.is_null(row) {
629 row_values.push(Keyword::Null.into());
630 continue;
631 }
632 let list_array = valid_array.value(row);
633
634 if use_json_insert_for_type(column_data_type, query_builder) {
635 insert_list_into_row_values_json(
636 list_array,
637 list_type,
638 &mut row_values,
639 )?;
640 } else {
641 insert_list_into_row_values(list_array, list_type, &mut row_values);
642 }
643 }
644 }
645 DataType::LargeList(list_type) => {
646 let array = column.as_any().downcast_ref::<array::LargeListArray>();
647 if let Some(valid_array) = array {
648 if valid_array.is_null(row) {
649 row_values.push(Keyword::Null.into());
650 continue;
651 }
652 let list_array = valid_array.value(row);
653
654 if use_json_insert_for_type(column_data_type, query_builder) {
655 insert_list_into_row_values_json(
656 list_array,
657 list_type,
658 &mut row_values,
659 )?;
660 } else {
661 insert_list_into_row_values(list_array, list_type, &mut row_values);
662 }
663 }
664 }
665 DataType::FixedSizeList(list_type, _) => {
666 let array = column.as_any().downcast_ref::<array::FixedSizeListArray>();
667 if let Some(valid_array) = array {
668 if valid_array.is_null(row) {
669 row_values.push(Keyword::Null.into());
670 continue;
671 }
672 let list_array = valid_array.value(row);
673
674 if use_json_insert_for_type(column_data_type, query_builder) {
675 insert_list_into_row_values_json(
676 list_array,
677 list_type,
678 &mut row_values,
679 )?;
680 } else {
681 insert_list_into_row_values(list_array, list_type, &mut row_values);
682 }
683 }
684 }
685 DataType::Binary => {
686 let array = column.as_any().downcast_ref::<array::BinaryArray>();
687
688 if let Some(valid_array) = array {
689 if valid_array.is_null(row) {
690 row_values.push(Keyword::Null.into());
691 continue;
692 }
693
694 row_values.push(valid_array.value(row).into());
695 }
696 }
697 DataType::LargeBinary => {
698 let array = column.as_any().downcast_ref::<array::LargeBinaryArray>();
699
700 if let Some(valid_array) = array {
701 if valid_array.is_null(row) {
702 row_values.push(Keyword::Null.into());
703 continue;
704 }
705
706 row_values.push(valid_array.value(row).into());
707 }
708 }
709 DataType::FixedSizeBinary(_) => {
710 let array = column
711 .as_any()
712 .downcast_ref::<array::FixedSizeBinaryArray>();
713
714 if let Some(valid_array) = array {
715 if valid_array.is_null(row) {
716 row_values.push(Keyword::Null.into());
717 continue;
718 }
719
720 row_values.push(valid_array.value(row).into());
721 }
722 }
723 DataType::BinaryView => {
724 let array = column.as_any().downcast_ref::<array::BinaryViewArray>();
725
726 if let Some(valid_array) = array {
727 if valid_array.is_null(row) {
728 row_values.push(Keyword::Null.into());
729 continue;
730 }
731
732 row_values.push(valid_array.value(row).into());
733 }
734 }
735 DataType::Interval(interval_unit) => match interval_unit {
736 IntervalUnit::DayTime => {
737 let array = column
738 .as_any()
739 .downcast_ref::<array::IntervalDayTimeArray>();
740
741 if let Some(valid_array) = array {
742 if valid_array.is_null(row) {
743 row_values.push(Keyword::Null.into());
744 continue;
745 }
746
747 let interval_str =
748 if let Ok(str) = array_value_to_string(valid_array, row) {
749 str
750 } else {
751 let days = valid_array.value(row).days;
752 let milliseconds = valid_array.value(row).milliseconds;
753 format!("{days} days {milliseconds} milliseconds")
754 };
755
756 row_values.push(interval_str.into());
757 }
758 }
759 IntervalUnit::YearMonth => {
760 let array = column
761 .as_any()
762 .downcast_ref::<array::IntervalYearMonthArray>();
763
764 if let Some(valid_array) = array {
765 if valid_array.is_null(row) {
766 row_values.push(Keyword::Null.into());
767 continue;
768 }
769
770 let interval_str =
771 if let Ok(str) = array_value_to_string(valid_array, row) {
772 str
773 } else {
774 let months = valid_array.value(row);
775 format!("{months} months")
776 };
777
778 row_values.push(interval_str.into());
779 }
780 }
781 IntervalUnit::MonthDayNano => {
784 let array = column
785 .as_any()
786 .downcast_ref::<array::IntervalMonthDayNanoArray>();
787
788 if let Some(valid_array) = array {
789 if valid_array.is_null(row) {
790 row_values.push(Keyword::Null.into());
791 continue;
792 }
793
794 let interval_str =
795 if let Ok(str) = array_value_to_string(valid_array, row) {
796 str
797 } else {
798 let months = valid_array.value(row).months;
799 let days = valid_array.value(row).days;
800 let nanoseconds = valid_array.value(row).nanoseconds;
801 let micros = nanoseconds / 1_000;
802 format!("{months} months {days} days {micros} microseconds")
803 };
804
805 row_values.push(interval_str.into());
806 }
807 }
808 },
809 DataType::Struct(fields) => {
810 let array = column.as_any().downcast_ref::<array::StructArray>();
811
812 if let Some(valid_array) = array {
813 if valid_array.is_null(row) {
814 row_values.push(Keyword::Null.into());
815 continue;
816 }
817
818 if use_json_insert_for_type(column_data_type, query_builder) {
819 insert_struct_into_row_values_json(
820 fields,
821 valid_array,
822 row,
823 &mut row_values,
824 )?;
825 continue;
826 }
827
828 let mut param_values: Vec<SimpleExpr> = vec![];
829
830 for col in valid_array.columns() {
831 match col.data_type() {
832 DataType::Int8 => {
833 let int_array =
834 col.as_any().downcast_ref::<array::Int8Array>();
835
836 if let Some(valid_int_array) = int_array {
837 param_values.push(valid_int_array.value(row).into());
838 }
839 }
840 DataType::Int16 => {
841 let int_array =
842 col.as_any().downcast_ref::<array::Int16Array>();
843
844 if let Some(valid_int_array) = int_array {
845 param_values.push(valid_int_array.value(row).into());
846 }
847 }
848 DataType::Int32 => {
849 let int_array =
850 col.as_any().downcast_ref::<array::Int32Array>();
851
852 if let Some(valid_int_array) = int_array {
853 param_values.push(valid_int_array.value(row).into());
854 }
855 }
856 DataType::Int64 => {
857 let int_array =
858 col.as_any().downcast_ref::<array::Int64Array>();
859
860 if let Some(valid_int_array) = int_array {
861 param_values.push(valid_int_array.value(row).into());
862 }
863 }
864 DataType::UInt8 => {
865 let int_array =
866 col.as_any().downcast_ref::<array::UInt8Array>();
867
868 if let Some(valid_int_array) = int_array {
869 param_values.push(valid_int_array.value(row).into());
870 }
871 }
872 DataType::UInt16 => {
873 let int_array =
874 col.as_any().downcast_ref::<array::UInt16Array>();
875
876 if let Some(valid_int_array) = int_array {
877 param_values.push(valid_int_array.value(row).into());
878 }
879 }
880 DataType::UInt32 => {
881 let int_array =
882 col.as_any().downcast_ref::<array::UInt32Array>();
883
884 if let Some(valid_int_array) = int_array {
885 param_values.push(valid_int_array.value(row).into());
886 }
887 }
888 DataType::UInt64 => {
889 let int_array =
890 col.as_any().downcast_ref::<array::UInt64Array>();
891
892 if let Some(valid_int_array) = int_array {
893 param_values.push(valid_int_array.value(row).into());
894 }
895 }
896 DataType::Float32 => {
897 let float_array =
898 col.as_any().downcast_ref::<array::Float32Array>();
899
900 if let Some(valid_float_array) = float_array {
901 param_values.push(valid_float_array.value(row).into());
902 }
903 }
904 DataType::Float64 => {
905 let float_array =
906 col.as_any().downcast_ref::<array::Float64Array>();
907
908 if let Some(valid_float_array) = float_array {
909 param_values.push(valid_float_array.value(row).into());
910 }
911 }
912 DataType::Utf8 => {
913 let string_array =
914 col.as_any().downcast_ref::<array::StringArray>();
915
916 if let Some(valid_string_array) = string_array {
917 param_values.push(valid_string_array.value(row).into());
918 }
919 }
920 DataType::Null => {
921 param_values.push(Keyword::Null.into());
922 }
923 DataType::Boolean => {
924 let bool_array =
925 col.as_any().downcast_ref::<array::BooleanArray>();
926
927 if let Some(valid_bool_array) = bool_array {
928 param_values.push(valid_bool_array.value(row).into());
929 }
930 }
931 DataType::Binary => {
932 let binary_array =
933 col.as_any().downcast_ref::<array::BinaryArray>();
934
935 if let Some(valid_binary_array) = binary_array {
936 param_values.push(valid_binary_array.value(row).into());
937 }
938 }
939 DataType::FixedSizeBinary(_) => {
940 let binary_array = col
941 .as_any()
942 .downcast_ref::<array::FixedSizeBinaryArray>();
943
944 if let Some(valid_binary_array) = binary_array {
945 param_values.push(valid_binary_array.value(row).into());
946 }
947 }
948 DataType::LargeBinary => {
949 let binary_array =
950 col.as_any().downcast_ref::<array::LargeBinaryArray>();
951
952 if let Some(valid_binary_array) = binary_array {
953 param_values.push(valid_binary_array.value(row).into());
954 }
955 }
956 DataType::LargeUtf8 => {
957 let string_array =
958 col.as_any().downcast_ref::<array::LargeStringArray>();
959
960 if let Some(valid_string_array) = string_array {
961 param_values.push(valid_string_array.value(row).into());
962 }
963 }
964 DataType::Utf8View => {
965 let view_array =
966 col.as_any().downcast_ref::<array::StringViewArray>();
967
968 if let Some(valid_view_array) = view_array {
969 param_values.push(valid_view_array.value(row).into());
970 }
971 }
972 DataType::BinaryView => {
973 let view_array =
974 col.as_any().downcast_ref::<array::BinaryViewArray>();
975
976 if let Some(valid_view_array) = view_array {
977 param_values.push(valid_view_array.value(row).into());
978 }
979 }
980 DataType::Float16
981 | DataType::Timestamp(_, _)
982 | DataType::Date32
983 | DataType::Date64
984 | DataType::Time32(_)
985 | DataType::Time64(_)
986 | DataType::Duration(_)
987 | DataType::Interval(_)
988 | DataType::List(_)
989 | DataType::ListView(_)
990 | DataType::FixedSizeList(_, _)
991 | DataType::LargeList(_)
992 | DataType::LargeListView(_)
993 | DataType::Struct(_)
994 | DataType::Union(_, _)
995 | DataType::Dictionary(_, _)
996 | DataType::Map(_, _)
997 | DataType::RunEndEncoded(_, _)
998 | DataType::Decimal128(_, _)
999 | DataType::Decimal256(_, _) => {
1000 unimplemented!(
1001 "Data type mapping not implemented for Struct of {}",
1002 col.data_type()
1003 )
1004 }
1005 }
1006 }
1007
1008 let mut params_vec = Vec::new();
1009 for param_value in ¶m_values {
1010 let mut params_str = String::new();
1011 query_builder.prepare_simple_expr(param_value, &mut params_str);
1012 params_vec.push(params_str);
1013 }
1014
1015 let params = params_vec.join(", ");
1016 row_values.push(Expr::cust(format!("ROW({params})")));
1017 }
1018 }
1019 unimplemented_type => {
1020 return Result::Err(Error::UnimplementedDataTypeInInsertStatement {
1021 data_type: unimplemented_type.clone(),
1022 })
1023 }
1024 }
1025 }
1026 match insert_stmt.values(row_values) {
1027 Ok(_) => (),
1028 Err(e) => {
1029 return Result::Err(Error::FailedToCreateInsertStatement {
1030 source: Box::new(e),
1031 })
1032 }
1033 }
1034 }
1035 Ok(())
1036 }
1037
1038 pub fn build_postgres(self, on_conflict: Option<OnConflict>) -> Result<String> {
1043 self.build(PostgresQueryBuilder, on_conflict)
1044 }
1045
1046 pub fn build_sqlite(self, on_conflict: Option<OnConflict>) -> Result<String> {
1051 self.build(SqliteQueryBuilder, on_conflict)
1052 }
1053
1054 pub fn build_mysql(self, on_conflict: Option<OnConflict>) -> Result<String> {
1059 self.build(MysqlQueryBuilder, on_conflict)
1060 }
1061
1062 pub fn build<T: GenericBuilder + 'static>(
1067 &self,
1068 query_builder: T,
1069 on_conflict: Option<OnConflict>,
1070 ) -> Result<String> {
1071 let columns: Vec<Alias> = (self.record_batches[0])
1072 .schema()
1073 .fields()
1074 .iter()
1075 .map(|field| Alias::new(field.name()))
1076 .collect();
1077
1078 let mut insert_stmt = Query::insert()
1079 .into_table(table_reference_to_sea_table_ref(&self.table))
1080 .columns(columns)
1081 .to_owned();
1082
1083 for record_batch in &self.record_batches {
1084 self.construct_insert_stmt(&mut insert_stmt, record_batch, &query_builder)?;
1085 }
1086 if let Some(on_conflict) = on_conflict {
1087 insert_stmt.on_conflict(on_conflict);
1088 }
1089 Ok(insert_stmt.to_string(query_builder))
1090 }
1091}
1092
1093fn table_reference_to_sea_table_ref(table: &TableReference) -> TableRef {
1094 match table {
1095 TableReference::Bare { table } => {
1096 TableRef::Table(SeaRc::new(Alias::new(table.to_string())))
1097 }
1098 TableReference::Partial { schema, table } => TableRef::SchemaTable(
1099 SeaRc::new(Alias::new(schema.to_string())),
1100 SeaRc::new(Alias::new(table.to_string())),
1101 ),
1102 TableReference::Full {
1103 catalog,
1104 schema,
1105 table,
1106 } => TableRef::DatabaseSchemaTable(
1107 SeaRc::new(Alias::new(catalog.to_string())),
1108 SeaRc::new(Alias::new(schema.to_string())),
1109 SeaRc::new(Alias::new(table.to_string())),
1110 ),
1111 }
1112}
1113
1114pub struct IndexBuilder {
1115 table_name: String,
1116 columns: Vec<String>,
1117 unique: bool,
1118}
1119
1120impl IndexBuilder {
1121 #[must_use]
1122 pub fn new(table_name: &str, columns: Vec<&str>) -> Self {
1123 Self {
1124 table_name: table_name.to_string(),
1125 columns: columns.into_iter().map(ToString::to_string).collect(),
1126 unique: false,
1127 }
1128 }
1129
1130 #[must_use]
1131 pub fn unique(mut self) -> Self {
1132 self.unique = true;
1133 self
1134 }
1135
1136 #[must_use]
1137 pub fn index_name(&self) -> String {
1138 format!("i_{}_{}", self.table_name, self.columns.join("_"))
1139 }
1140
1141 #[must_use]
1142 pub fn build_postgres(self) -> String {
1143 self.build(PostgresQueryBuilder)
1144 }
1145
1146 #[must_use]
1147 pub fn build_sqlite(self) -> String {
1148 self.build(SqliteQueryBuilder)
1149 }
1150
1151 #[must_use]
1152 pub fn build_mysql(self) -> String {
1153 self.build(MysqlQueryBuilder)
1154 }
1155
1156 #[must_use]
1157 pub fn build<T: GenericBuilder>(self, query_builder: T) -> String {
1158 let mut index = Index::create();
1159 index.table(Alias::new(&self.table_name));
1160 index.name(self.index_name());
1161 if self.unique {
1162 index.unique();
1163 }
1164 for column in self.columns {
1165 index.col(Alias::new(column).into_iden().into_index_column());
1166 }
1167 index.if_not_exists();
1168 index.to_string(query_builder)
1169 }
1170}
1171
1172fn insert_timestamp_into_row_values(
1173 timestamp: Result<OffsetDateTime, time::error::ComponentRange>,
1174 row_values: &mut Vec<SimpleExpr>,
1175) -> Result<()> {
1176 match timestamp {
1177 Ok(offset_time) => {
1178 row_values.push(PrimitiveDateTime::new(offset_time.date(), offset_time.time()).into());
1179 Ok(())
1180 }
1181 Err(e) => Err(Error::FailedToCreateInsertStatement {
1182 source: Box::new(e),
1183 }),
1184 }
1185}
1186
1187#[allow(clippy::needless_pass_by_value)]
1188fn insert_list_into_row_values(
1189 list_array: Arc<dyn Array>,
1190 list_type: &Arc<Field>,
1191 row_values: &mut Vec<SimpleExpr>,
1192) {
1193 match list_type.data_type() {
1194 DataType::Int8 => push_list_values!(
1195 list_type.data_type(),
1196 list_array,
1197 row_values,
1198 array::Int8Array,
1199 i8,
1200 "int2[]"
1201 ),
1202 DataType::Int16 => push_list_values!(
1203 list_type.data_type(),
1204 list_array,
1205 row_values,
1206 array::Int16Array,
1207 i16,
1208 "int2[]"
1209 ),
1210 DataType::Int32 => push_list_values!(
1211 list_type.data_type(),
1212 list_array,
1213 row_values,
1214 array::Int32Array,
1215 i32,
1216 "int4[]"
1217 ),
1218 DataType::Int64 => push_list_values!(
1219 list_type.data_type(),
1220 list_array,
1221 row_values,
1222 array::Int64Array,
1223 i64,
1224 "int8[]"
1225 ),
1226 DataType::Float32 => push_list_values!(
1227 list_type.data_type(),
1228 list_array,
1229 row_values,
1230 array::Float32Array,
1231 f32,
1232 "float4[]"
1233 ),
1234 DataType::Float64 => push_list_values!(
1235 list_type.data_type(),
1236 list_array,
1237 row_values,
1238 array::Float64Array,
1239 f64,
1240 "float8[]"
1241 ),
1242 DataType::Utf8 => {
1243 let mut list_values: Vec<String> = vec![];
1244 for i in 0..list_array.len() {
1245 let int_array = list_array.as_any().downcast_ref::<array::StringArray>();
1246 if let Some(valid_int_array) = int_array {
1247 list_values.push(valid_int_array.value(i).to_string());
1248 }
1249 }
1250 let expr: SimpleExpr = list_values.into();
1251 row_values.push(expr.cast_as(Alias::new("text[]")));
1253 }
1254 DataType::LargeUtf8 => {
1255 let mut list_values: Vec<String> = vec![];
1256 for i in 0..list_array.len() {
1257 let int_array = list_array
1258 .as_any()
1259 .downcast_ref::<array::LargeStringArray>();
1260 if let Some(valid_int_array) = int_array {
1261 list_values.push(valid_int_array.value(i).to_string());
1262 }
1263 }
1264 let expr: SimpleExpr = list_values.into();
1265 row_values.push(expr.cast_as(Alias::new("text[]")));
1267 }
1268 DataType::Utf8View => {
1269 let mut list_values: Vec<String> = vec![];
1270 for i in 0..list_array.len() {
1271 let view_array = list_array.as_any().downcast_ref::<array::StringViewArray>();
1272 if let Some(valid_view_array) = view_array {
1273 list_values.push(valid_view_array.value(i).to_string());
1274 }
1275 }
1276 let expr: SimpleExpr = list_values.into();
1277 row_values.push(expr.cast_as(Alias::new("text[]")));
1278 }
1279 DataType::Boolean => push_list_values!(
1280 list_type.data_type(),
1281 list_array,
1282 row_values,
1283 array::BooleanArray,
1284 bool,
1285 "boolean[]"
1286 ),
1287 DataType::Binary => {
1288 let mut list_values: Vec<Vec<u8>> = Vec::new();
1289 for i in 0..list_array.len() {
1290 let temp_array = list_array.as_any().downcast_ref::<array::BinaryArray>();
1291 if let Some(valid_array) = temp_array {
1292 list_values.push(valid_array.value(i).to_vec());
1293 }
1294 }
1295 let expr: SimpleExpr = list_values.into();
1296 row_values.push(expr.cast_as(Alias::new("bytea[]")));
1298 }
1299 _ => unimplemented!(
1300 "Data type mapping not implemented for {}",
1301 list_type.data_type()
1302 ),
1303 }
1304}
1305
1306#[allow(clippy::cast_sign_loss)]
1307pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
1308 match data_type {
1309 DataType::Int8 => ColumnType::TinyInteger,
1310 DataType::Int16 => ColumnType::SmallInteger,
1311 DataType::Int32 => ColumnType::Integer,
1312 DataType::Int64 | DataType::Duration(_) => ColumnType::BigInteger,
1313 DataType::UInt8 => ColumnType::TinyUnsigned,
1314 DataType::UInt16 => ColumnType::SmallUnsigned,
1315 DataType::UInt32 => ColumnType::Unsigned,
1316 DataType::UInt64 => ColumnType::BigUnsigned,
1317 DataType::Float32 => ColumnType::Float,
1318 DataType::Float64 => ColumnType::Double,
1319 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => ColumnType::Text,
1320 DataType::Boolean => ColumnType::Boolean,
1321 #[allow(clippy::cast_sign_loss)] DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
1323 ColumnType::Decimal(Some((u32::from(*p), *s as u32)))
1324 }
1325 DataType::Timestamp(_unit, time_zone) => {
1326 if time_zone.is_some() {
1327 return ColumnType::TimestampWithTimeZone;
1328 }
1329 ColumnType::Timestamp
1330 }
1331 DataType::Date32 | DataType::Date64 => ColumnType::Date,
1332 DataType::Time64(_unit) | DataType::Time32(_unit) => ColumnType::Time,
1333 DataType::List(list_type)
1334 | DataType::LargeList(list_type)
1335 | DataType::FixedSizeList(list_type, _) => {
1336 ColumnType::Array(map_data_type_to_column_type(list_type.data_type()).into())
1337 }
1338 DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
1343 DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
1344 DataType::Interval(_) => ColumnType::Interval(None, None),
1345 _ => unimplemented!("Data type mapping not implemented for {:?}", data_type),
1347 }
1348}
1349
1350macro_rules! serialize_list_values {
1351 ($data_type:expr, $list_array:expr, $array_type:ty, $vec_type:ty) => {{
1352 let mut list_values: Vec<$vec_type> = vec![];
1353 if let Some(array) = $list_array.as_any().downcast_ref::<$array_type>() {
1354 for i in 0..array.len() {
1355 list_values.push(array.value(i).into());
1356 }
1357 }
1358
1359 serde_json::to_string(&list_values).map_err(|e| Error::FailedToCreateInsertStatement {
1360 source: Box::new(e),
1361 })?
1362 }};
1363}
1364
1365fn insert_list_into_row_values_json(
1366 list_array: Arc<dyn Array>,
1367 list_type: &Arc<Field>,
1368 row_values: &mut Vec<SimpleExpr>,
1369) -> Result<()> {
1370 let data_type = list_type.data_type();
1371
1372 let json_string: String = match data_type {
1373 DataType::Int8 => serialize_list_values!(data_type, list_array, Int8Array, i8),
1374 DataType::Int16 => serialize_list_values!(data_type, list_array, Int16Array, i16),
1375 DataType::Int32 => serialize_list_values!(data_type, list_array, Int32Array, i32),
1376 DataType::Int64 => serialize_list_values!(data_type, list_array, Int64Array, i64),
1377 DataType::UInt8 => serialize_list_values!(data_type, list_array, UInt8Array, u8),
1378 DataType::UInt16 => serialize_list_values!(data_type, list_array, UInt16Array, u16),
1379 DataType::UInt32 => serialize_list_values!(data_type, list_array, UInt32Array, u32),
1380 DataType::UInt64 => serialize_list_values!(data_type, list_array, UInt64Array, u64),
1381 DataType::Float32 => serialize_list_values!(data_type, list_array, Float32Array, f32),
1382 DataType::Float64 => serialize_list_values!(data_type, list_array, Float64Array, f64),
1383 DataType::Utf8 => serialize_list_values!(data_type, list_array, StringArray, String),
1384 DataType::LargeUtf8 => {
1385 serialize_list_values!(data_type, list_array, LargeStringArray, String)
1386 }
1387 DataType::Utf8View => {
1388 serialize_list_values!(data_type, list_array, StringViewArray, String)
1389 }
1390 DataType::Boolean => serialize_list_values!(data_type, list_array, BooleanArray, bool),
1391 _ => unimplemented!(
1392 "List to json conversion is not implemented for {}",
1393 list_type.data_type()
1394 ),
1395 };
1396
1397 let expr: SimpleExpr = Expr::value(json_string);
1398 row_values.push(expr);
1399
1400 Ok(())
1401}
1402
1403fn insert_struct_into_row_values_json(
1404 fields: &Fields,
1405 array: &StructArray,
1406 row_index: usize,
1407 row_values: &mut Vec<SimpleExpr>,
1408) -> Result<()> {
1409 let single_row_columns: Vec<ArrayRef> = (0..array.num_columns())
1413 .map(|i| array.column(i).slice(row_index, 1))
1414 .collect();
1415
1416 let batch = RecordBatch::try_new(Arc::new(Schema::new(fields.clone())), single_row_columns)
1417 .map_err(|e| Error::FailedToCreateInsertStatement {
1418 source: Box::new(e),
1419 })?;
1420
1421 let mut writer = datafusion::arrow::json::LineDelimitedWriter::new(Vec::new());
1422 writer
1423 .write(&batch)
1424 .map_err(|e| Error::FailedToCreateInsertStatement {
1425 source: Box::new(e),
1426 })?;
1427 writer
1428 .finish()
1429 .map_err(|e| Error::FailedToCreateInsertStatement {
1430 source: Box::new(e),
1431 })?;
1432 let json_bytes = writer.into_inner();
1433
1434 let json = String::from_utf8(json_bytes).map_err(|e| Error::FailedToCreateInsertStatement {
1435 source: Box::new(e),
1436 })?;
1437
1438 let expr: SimpleExpr = Expr::value(json);
1439 row_values.push(expr);
1440
1441 Ok(())
1442}
1443
1444#[cfg(test)]
1445mod tests {
1446 use std::sync::Arc;
1447
1448 use super::*;
1449 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema};
1450
1451 #[test]
1452 fn test_basic_table_creation() {
1453 let schema = Schema::new(vec![
1454 Field::new("id", DataType::Int32, false),
1455 Field::new("name", DataType::Utf8, false),
1456 Field::new("age", DataType::Int32, true),
1457 ]);
1458 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users").build_sqlite();
1459
1460 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer )");
1461 }
1462
1463 #[test]
1464 fn test_table_insertion() {
1465 let schema1 = Schema::new(vec![
1466 Field::new("id", DataType::Int32, false),
1467 Field::new("name", DataType::Utf8, false),
1468 Field::new("age", DataType::Int32, true),
1469 ]);
1470 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1471 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1472 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1473
1474 let batch1 = RecordBatch::try_new(
1475 Arc::new(schema1.clone()),
1476 vec![
1477 Arc::new(id_array.clone()),
1478 Arc::new(name_array.clone()),
1479 Arc::new(age_array.clone()),
1480 ],
1481 )
1482 .expect("Unable to build record batch");
1483
1484 let schema2 = Schema::new(vec![
1485 Field::new("id", DataType::Int32, false),
1486 Field::new("name", DataType::Utf8, false),
1487 Field::new("blah", DataType::Int32, true),
1488 ]);
1489
1490 let batch2 = RecordBatch::try_new(
1491 Arc::new(schema2),
1492 vec![
1493 Arc::new(id_array),
1494 Arc::new(name_array),
1495 Arc::new(age_array),
1496 ],
1497 )
1498 .expect("Unable to build record batch");
1499 let record_batches = vec![batch1, batch2];
1500
1501 let sql = InsertBuilder::new(&TableReference::from("users"), record_batches)
1502 .build_postgres(None)
1503 .expect("Failed to build insert statement");
1504 assert_eq!(sql, "INSERT INTO \"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1505 }
1506
1507 #[test]
1508 fn test_table_insertion_with_schema() {
1509 let schema1 = Schema::new(vec![
1510 Field::new("id", DataType::Int32, false),
1511 Field::new("name", DataType::Utf8, false),
1512 Field::new("age", DataType::Int32, true),
1513 ]);
1514 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1515 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1516 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1517
1518 let batch1 = RecordBatch::try_new(
1519 Arc::new(schema1.clone()),
1520 vec![
1521 Arc::new(id_array.clone()),
1522 Arc::new(name_array.clone()),
1523 Arc::new(age_array.clone()),
1524 ],
1525 )
1526 .expect("Unable to build record batch");
1527
1528 let schema2 = Schema::new(vec![
1529 Field::new("id", DataType::Int32, false),
1530 Field::new("name", DataType::Utf8, false),
1531 Field::new("blah", DataType::Int32, true),
1532 ]);
1533
1534 let batch2 = RecordBatch::try_new(
1535 Arc::new(schema2),
1536 vec![
1537 Arc::new(id_array),
1538 Arc::new(name_array),
1539 Arc::new(age_array),
1540 ],
1541 )
1542 .expect("Unable to build record batch");
1543 let record_batches = vec![batch1, batch2];
1544
1545 let sql = InsertBuilder::new(&TableReference::from("schema.users"), record_batches)
1546 .build_postgres(None)
1547 .expect("Failed to build insert statement");
1548 assert_eq!(sql, "INSERT INTO \"schema\".\"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1549 }
1550
1551 #[test]
1552 fn test_table_creation_with_primary_keys() {
1553 let schema = Schema::new(vec![
1554 Field::new("id", DataType::Int32, false),
1555 Field::new("id2", DataType::Int32, false),
1556 Field::new("name", DataType::Utf8, false),
1557 Field::new("age", DataType::Int32, true),
1558 ]);
1559 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1560 .primary_keys(vec!["id", "id2"])
1561 .build_sqlite();
1562
1563 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"id2\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer, PRIMARY KEY (\"id\", \"id2\") )");
1564 }
1565
1566 #[test]
1567 fn test_temporary_table_creation() {
1568 let schema = Schema::new(vec![
1569 Field::new("id", DataType::Int32, false),
1570 Field::new("name", DataType::Utf8, false),
1571 ]);
1572 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1573 .primary_keys(vec!["id"])
1574 .temporary(true)
1575 .build_sqlite();
1576
1577 assert_eq!(sql, "CREATE TEMPORARY TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, PRIMARY KEY (\"id\") )");
1578 }
1579
1580 #[test]
1581 fn test_table_insertion_with_list() {
1582 let schema1 = Schema::new(vec![Field::new(
1583 "list",
1584 DataType::List(Field::new("item", DataType::Int32, true).into()),
1585 true,
1586 )]);
1587 let list_array = array::ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1588 Some(vec![Some(1), Some(2), Some(3)]),
1589 Some(vec![Some(4), Some(5), Some(6)]),
1590 Some(vec![Some(7), Some(8), Some(9)]),
1591 ]);
1592
1593 let batch = RecordBatch::try_new(Arc::new(schema1.clone()), vec![Arc::new(list_array)])
1594 .expect("Unable to build record batch");
1595
1596 let sql = InsertBuilder::new(&TableReference::from("arrays"), vec![batch])
1597 .build_postgres(None)
1598 .expect("Failed to build insert statement");
1599 assert_eq!(
1600 sql,
1601 "INSERT INTO \"arrays\" (\"list\") VALUES (CAST(ARRAY [1,2,3] AS int4[])), (CAST(ARRAY [4,5,6] AS int4[])), (CAST(ARRAY [7,8,9] AS int4[]))"
1602 );
1603 }
1604
1605 #[test]
1606 fn test_create_index() {
1607 let sql = IndexBuilder::new("users", vec!["id", "name"]).build_postgres();
1608 assert_eq!(
1609 sql,
1610 r#"CREATE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1611 );
1612 }
1613
1614 #[test]
1615 fn test_create_unique_index() {
1616 let sql = IndexBuilder::new("users", vec!["id", "name"])
1617 .unique()
1618 .build_postgres();
1619 assert_eq!(
1620 sql,
1621 r#"CREATE UNIQUE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1622 );
1623 }
1624}