1use bigdecimal::BigDecimal;
2use chrono::{DateTime, Offset, TimeZone};
3use datafusion::arrow::{
4 array::{
5 array, timezone::Tz, Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
6 Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
7 StringViewArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
8 },
9 datatypes::{DataType, Field, Fields, IntervalUnit, Schema, SchemaRef, TimeUnit},
10 util::display::array_value_to_string,
11};
12use datafusion::sql::TableReference;
13use num_bigint::BigInt;
14use sea_query::{
15 Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden,
16 IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query,
17 QueryBuilder, SeaRc, SimpleExpr, SqliteQueryBuilder, Table, TableRef,
18};
19use snafu::Snafu;
20use std::{str::FromStr, sync::Arc};
21use time::{OffsetDateTime, PrimitiveDateTime};
22
23#[derive(Debug, Snafu)]
24pub enum Error {
25 #[snafu(display("Failed to build insert statement: {source}"))]
26 FailedToCreateInsertStatement {
27 source: Box<dyn std::error::Error + Send + Sync>,
28 },
29
30 #[snafu(display("Unimplemented data type in insert statement: {data_type:?}"))]
31 UnimplementedDataTypeInInsertStatement { data_type: DataType },
32}
33
34pub type Result<T, E = Error> = std::result::Result<T, E>;
35
36pub struct CreateTableBuilder {
37 schema: SchemaRef,
38 table_name: String,
39 primary_keys: Vec<String>,
40}
41
42impl CreateTableBuilder {
43 #[must_use]
44 pub fn new(schema: SchemaRef, table_name: &str) -> Self {
45 Self {
46 schema,
47 table_name: table_name.to_string(),
48 primary_keys: Vec::new(),
49 }
50 }
51
52 #[must_use]
53 pub fn primary_keys<T>(mut self, keys: Vec<T>) -> Self
54 where
55 T: Into<String>,
56 {
57 self.primary_keys = keys.into_iter().map(Into::into).collect();
58 self
59 }
60
61 #[must_use]
62 #[cfg(feature = "postgres")]
63 pub fn build_postgres(self) -> Vec<String> {
64 use crate::sql::arrow_sql_gen::postgres::{
65 builder::TypeBuilder, get_postgres_composite_type_name,
66 map_data_type_to_column_type_postgres,
67 };
68 let schema = Arc::clone(&self.schema);
69 let table_name = self.table_name.clone();
70 let main_table_creation =
71 self.build(PostgresQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
72 map_data_type_to_column_type_postgres(f.data_type(), &table_name, f.name())
73 });
74
75 let mut creation_stmts = Vec::new();
78 for field in schema.fields() {
79 let DataType::Struct(struct_inner_fields) = field.data_type() else {
80 continue;
81 };
82 let type_builder = TypeBuilder::new(
83 get_postgres_composite_type_name(&table_name, field.name()),
84 struct_inner_fields,
85 );
86 creation_stmts.push(type_builder.build());
87 }
88
89 creation_stmts.push(main_table_creation);
90 creation_stmts
91 }
92
93 #[must_use]
94 pub fn build_sqlite(self) -> String {
95 self.build(SqliteQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
96 if f.data_type().is_nested() {
99 return ColumnType::JsonBinary;
100 }
101
102 map_data_type_to_column_type(f.data_type())
103 })
104 }
105
106 #[must_use]
107 pub fn build_mysql(self) -> String {
108 self.build(MysqlQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
109 if f.data_type().is_nested() {
112 return ColumnType::JsonBinary;
113 }
114 map_data_type_to_column_type(f.data_type())
115 })
116 }
117
118 #[must_use]
119 fn build<T: GenericBuilder>(
120 self,
121 query_builder: T,
122 map_data_type_to_column_type_fn: &dyn Fn(&Arc<Field>) -> ColumnType,
123 ) -> String {
124 let mut create_stmt = Table::create();
125 create_stmt
126 .table(Alias::new(self.table_name.clone()))
127 .if_not_exists();
128
129 for field in self.schema.fields() {
130 let column_type = map_data_type_to_column_type_fn(field);
131 let mut column_def = ColumnDef::new_with_type(Alias::new(field.name()), column_type);
132 if !field.is_nullable() {
133 column_def.not_null();
134 }
135
136 create_stmt.col(&mut column_def);
137 }
138
139 if !self.primary_keys.is_empty() {
140 let mut index = Index::create();
141 index.primary();
142 for key in self.primary_keys {
143 index.col(Alias::new(key).into_iden().into_index_column());
144 }
145 create_stmt.primary_key(&mut index);
146 }
147
148 create_stmt.to_string(query_builder)
149 }
150}
151
152macro_rules! push_value {
153 ($row_values:expr, $column:expr, $row:expr, $array_type:ident) => {{
154 let array = $column.as_any().downcast_ref::<array::$array_type>();
155 if let Some(valid_array) = array {
156 if valid_array.is_null($row) {
157 $row_values.push(Keyword::Null.into());
158 continue;
159 }
160 $row_values.push(valid_array.value($row).into());
161 }
162 }};
163}
164
165macro_rules! push_list_values {
166 ($data_type:expr, $list_array:expr, $row_values:expr, $array_type:ty, $vec_type:ty, $sql_type:expr) => {{
167 let mut list_values: Vec<$vec_type> = Vec::new();
168 for i in 0..$list_array.len() {
169 let temp_array = $list_array.as_any().downcast_ref::<$array_type>();
170 if let Some(valid_array) = temp_array {
171 list_values.push(valid_array.value(i));
172 }
173 }
174 let expr: SimpleExpr = list_values.into();
175 $row_values.push(expr.cast_as(Alias::new($sql_type)));
177 }};
178}
179
180pub struct InsertBuilder {
181 table: TableReference,
182 record_batches: Vec<RecordBatch>,
183}
184
185pub fn use_json_insert_for_type<T: QueryBuilder + 'static>(
186 data_type: &DataType,
187 query_builder: &T,
188) -> bool {
189 #[cfg(feature = "sqlite")]
190 {
191 use std::any::Any;
192 let any_builder = query_builder as &dyn Any;
193 if any_builder.is::<SqliteQueryBuilder>() {
194 return data_type.is_nested();
195 }
196 }
197 #[cfg(feature = "mysql")]
198 {
199 use std::any::Any;
200 let any_builder = query_builder as &dyn Any;
201 if any_builder.is::<MysqlQueryBuilder>() {
202 return data_type.is_nested();
203 }
204 }
205 false
206}
207
208impl InsertBuilder {
209 #[must_use]
210 pub fn new(table: &TableReference, record_batches: Vec<RecordBatch>) -> Self {
211 Self {
212 table: table.clone(),
213 record_batches,
214 }
215 }
216
217 #[allow(clippy::too_many_lines)]
223 pub fn construct_insert_stmt<T: QueryBuilder + 'static>(
224 &self,
225 insert_stmt: &mut InsertStatement,
226 record_batch: &RecordBatch,
227 query_builder: &T,
228 ) -> Result<()> {
229 for row in 0..record_batch.num_rows() {
230 let mut row_values: Vec<SimpleExpr> = vec![];
231 for col in 0..record_batch.num_columns() {
232 let column = record_batch.column(col);
233 let column_data_type = column.data_type();
234
235 match column_data_type {
236 DataType::Int8 => push_value!(row_values, column, row, Int8Array),
237 DataType::Int16 => push_value!(row_values, column, row, Int16Array),
238 DataType::Int32 => push_value!(row_values, column, row, Int32Array),
239 DataType::Int64 => push_value!(row_values, column, row, Int64Array),
240 DataType::UInt8 => push_value!(row_values, column, row, UInt8Array),
241 DataType::UInt16 => push_value!(row_values, column, row, UInt16Array),
242 DataType::UInt32 => push_value!(row_values, column, row, UInt32Array),
243 DataType::UInt64 => push_value!(row_values, column, row, UInt64Array),
244 DataType::Float32 => push_value!(row_values, column, row, Float32Array),
245 DataType::Float64 => push_value!(row_values, column, row, Float64Array),
246 DataType::Utf8 => push_value!(row_values, column, row, StringArray),
247 DataType::LargeUtf8 => push_value!(row_values, column, row, LargeStringArray),
248 DataType::Utf8View => push_value!(row_values, column, row, StringViewArray),
249 DataType::Boolean => push_value!(row_values, column, row, BooleanArray),
250 DataType::Decimal128(_, scale) => {
251 let array = column.as_any().downcast_ref::<array::Decimal128Array>();
252 if let Some(valid_array) = array {
253 if valid_array.is_null(row) {
254 row_values.push(Keyword::Null.into());
255 continue;
256 }
257 row_values.push(
258 BigDecimal::new(valid_array.value(row).into(), i64::from(*scale))
259 .into(),
260 );
261 }
262 }
263 DataType::Decimal256(_, scale) => {
264 let array = column.as_any().downcast_ref::<array::Decimal256Array>();
265 if let Some(valid_array) = array {
266 if valid_array.is_null(row) {
267 row_values.push(Keyword::Null.into());
268 continue;
269 }
270
271 let bigint =
272 BigInt::from_signed_bytes_le(&valid_array.value(row).to_le_bytes());
273
274 row_values.push(BigDecimal::new(bigint, i64::from(*scale)).into());
275 }
276 }
277 DataType::Date32 => {
278 let array = column.as_any().downcast_ref::<array::Date32Array>();
279 if let Some(valid_array) = array {
280 if valid_array.is_null(row) {
281 row_values.push(Keyword::Null.into());
282 continue;
283 }
284 row_values.push(
285 match OffsetDateTime::from_unix_timestamp(
286 i64::from(valid_array.value(row)) * 86_400,
287 ) {
288 Ok(offset_time) => offset_time.date().into(),
289 Err(e) => {
290 return Result::Err(Error::FailedToCreateInsertStatement {
291 source: Box::new(e),
292 })
293 }
294 },
295 );
296 }
297 }
298 DataType::Date64 => {
299 let array = column.as_any().downcast_ref::<array::Date64Array>();
300 if let Some(valid_array) = array {
301 if valid_array.is_null(row) {
302 row_values.push(Keyword::Null.into());
303 continue;
304 }
305 row_values.push(
306 match OffsetDateTime::from_unix_timestamp(
307 valid_array.value(row) / 1000,
308 ) {
309 Ok(offset_time) => offset_time.date().into(),
310 Err(e) => {
311 return Result::Err(Error::FailedToCreateInsertStatement {
312 source: Box::new(e),
313 })
314 }
315 },
316 );
317 }
318 }
319 DataType::Duration(time_unit) => match time_unit {
320 TimeUnit::Second => {
321 push_value!(row_values, column, row, DurationSecondArray);
322 }
323 TimeUnit::Microsecond => {
324 push_value!(row_values, column, row, DurationMicrosecondArray);
325 }
326 TimeUnit::Millisecond => {
327 push_value!(row_values, column, row, DurationMillisecondArray);
328 }
329 TimeUnit::Nanosecond => {
330 push_value!(row_values, column, row, DurationNanosecondArray);
331 }
332 },
333 DataType::Time32(time_unit) => match time_unit {
334 TimeUnit::Millisecond => {
335 let array = column
336 .as_any()
337 .downcast_ref::<array::Time32MillisecondArray>();
338 if let Some(valid_array) = array {
339 if valid_array.is_null(row) {
340 row_values.push(Keyword::Null.into());
341 continue;
342 }
343
344 let (h, m, s, micro) =
345 match OffsetDateTime::from_unix_timestamp_nanos(
346 i128::from(valid_array.value(row)) * 1_000_000,
347 ) {
348 Ok(timestamp) => timestamp.to_hms_micro(),
349 Err(e) => {
350 return Result::Err(
351 Error::FailedToCreateInsertStatement {
352 source: Box::new(e),
353 },
354 )
355 }
356 };
357
358 let time = match time::Time::from_hms_micro(h, m, s, micro) {
359 Ok(value) => value,
360 Err(e) => {
361 return Result::Err(Error::FailedToCreateInsertStatement {
362 source: Box::new(e),
363 })
364 }
365 };
366
367 row_values.push(time.into());
368 }
369 }
370 TimeUnit::Second => {
371 let array = column.as_any().downcast_ref::<array::Time32SecondArray>();
372 if let Some(valid_array) = array {
373 if valid_array.is_null(row) {
374 row_values.push(Keyword::Null.into());
375 continue;
376 }
377
378 let (h, m, s) = match OffsetDateTime::from_unix_timestamp(
379 i64::from(valid_array.value(row)),
380 ) {
381 Ok(timestamp) => timestamp.to_hms(),
382 Err(e) => {
383 return Result::Err(Error::FailedToCreateInsertStatement {
384 source: Box::new(e),
385 })
386 }
387 };
388
389 let time = match time::Time::from_hms(h, m, s) {
390 Ok(value) => value,
391 Err(e) => {
392 return Result::Err(Error::FailedToCreateInsertStatement {
393 source: Box::new(e),
394 })
395 }
396 };
397
398 row_values.push(time.into());
399 }
400 }
401 _ => unreachable!(),
402 },
403 DataType::Time64(time_unit) => match time_unit {
404 TimeUnit::Nanosecond => {
405 let array = column
406 .as_any()
407 .downcast_ref::<array::Time64NanosecondArray>();
408 if let Some(valid_array) = array {
409 if valid_array.is_null(row) {
410 row_values.push(Keyword::Null.into());
411 continue;
412 }
413 let (h, m, s, nano) =
414 match OffsetDateTime::from_unix_timestamp_nanos(i128::from(
415 valid_array.value(row),
416 )) {
417 Ok(timestamp) => timestamp.to_hms_nano(),
418 Err(e) => {
419 return Result::Err(
420 Error::FailedToCreateInsertStatement {
421 source: Box::new(e),
422 },
423 )
424 }
425 };
426
427 let time = match time::Time::from_hms_nano(h, m, s, nano) {
428 Ok(value) => value,
429 Err(e) => {
430 return Result::Err(Error::FailedToCreateInsertStatement {
431 source: Box::new(e),
432 })
433 }
434 };
435
436 row_values.push(time.into());
437 }
438 }
439 TimeUnit::Microsecond => {
440 let array = column
441 .as_any()
442 .downcast_ref::<array::Time64MicrosecondArray>();
443 if let Some(valid_array) = array {
444 if valid_array.is_null(row) {
445 row_values.push(Keyword::Null.into());
446 continue;
447 }
448
449 let (h, m, s, micro) =
450 match OffsetDateTime::from_unix_timestamp_nanos(
451 i128::from(valid_array.value(row)) * 1_000,
452 ) {
453 Ok(timestamp) => timestamp.to_hms_micro(),
454 Err(e) => {
455 return Result::Err(
456 Error::FailedToCreateInsertStatement {
457 source: Box::new(e),
458 },
459 )
460 }
461 };
462
463 let time = match time::Time::from_hms_micro(h, m, s, micro) {
464 Ok(value) => value,
465 Err(e) => {
466 return Result::Err(Error::FailedToCreateInsertStatement {
467 source: Box::new(e),
468 })
469 }
470 };
471
472 row_values.push(time.into());
473 }
474 }
475 _ => unreachable!(),
476 },
477 DataType::Timestamp(TimeUnit::Second, timezone) => {
478 let array = column
479 .as_any()
480 .downcast_ref::<array::TimestampSecondArray>();
481
482 if let Some(valid_array) = array {
483 if valid_array.is_null(row) {
484 row_values.push(Keyword::Null.into());
485 continue;
486 }
487 if let Some(timezone) = timezone {
488 let utc_time = DateTime::from_timestamp_nanos(
489 valid_array.value(row) * 1_000_000_000,
490 )
491 .to_utc();
492 let timezone = Tz::from_str(timezone).map_err(|_| {
493 Error::FailedToCreateInsertStatement {
494 source: "Unable to parse arrow timezone information".into(),
495 }
496 })?;
497 let offset = timezone
498 .offset_from_utc_datetime(&utc_time.naive_utc())
499 .fix();
500 let time_with_offset = utc_time.with_timezone(&offset);
501 row_values.push(time_with_offset.into());
502 } else {
503 insert_timestamp_into_row_values(
504 OffsetDateTime::from_unix_timestamp(valid_array.value(row)),
505 &mut row_values,
506 )?;
507 }
508 }
509 }
510 DataType::Timestamp(TimeUnit::Millisecond, timezone) => {
511 let array = column
512 .as_any()
513 .downcast_ref::<array::TimestampMillisecondArray>();
514
515 if let Some(valid_array) = array {
516 if valid_array.is_null(row) {
517 row_values.push(Keyword::Null.into());
518 continue;
519 }
520 if let Some(timezone) = timezone {
521 let utc_time = DateTime::from_timestamp_nanos(
522 valid_array.value(row) * 1_000_000,
523 )
524 .to_utc();
525 let timezone = Tz::from_str(timezone).map_err(|_| {
526 Error::FailedToCreateInsertStatement {
527 source: "Unable to parse arrow timezone information".into(),
528 }
529 })?;
530 let offset = timezone
531 .offset_from_utc_datetime(&utc_time.naive_utc())
532 .fix();
533 let time_with_offset = utc_time.with_timezone(&offset);
534 row_values.push(time_with_offset.into());
535 } else {
536 insert_timestamp_into_row_values(
537 OffsetDateTime::from_unix_timestamp_nanos(
538 i128::from(valid_array.value(row)) * 1_000_000,
539 ),
540 &mut row_values,
541 )?;
542 }
543 }
544 }
545 DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
546 let array = column
547 .as_any()
548 .downcast_ref::<array::TimestampMicrosecondArray>();
549
550 if let Some(valid_array) = array {
551 if valid_array.is_null(row) {
552 row_values.push(Keyword::Null.into());
553 continue;
554 }
555 if let Some(timezone) = timezone {
556 let utc_time =
557 DateTime::from_timestamp_nanos(valid_array.value(row) * 1_000)
558 .to_utc();
559 let timezone = Tz::from_str(timezone).map_err(|_| {
560 Error::FailedToCreateInsertStatement {
561 source: "Unable to parse arrow timezone information".into(),
562 }
563 })?;
564 let offset = timezone
565 .offset_from_utc_datetime(&utc_time.naive_utc())
566 .fix();
567 let time_with_offset = utc_time.with_timezone(&offset);
568 row_values.push(time_with_offset.into());
569 } else {
570 insert_timestamp_into_row_values(
571 OffsetDateTime::from_unix_timestamp_nanos(
572 i128::from(valid_array.value(row)) * 1_000,
573 ),
574 &mut row_values,
575 )?;
576 }
577 }
578 }
579 DataType::Timestamp(TimeUnit::Nanosecond, timezone) => {
580 let array = column
581 .as_any()
582 .downcast_ref::<array::TimestampNanosecondArray>();
583
584 if let Some(valid_array) = array {
585 if valid_array.is_null(row) {
586 row_values.push(Keyword::Null.into());
587 continue;
588 }
589 if let Some(timezone) = timezone {
590 let utc_time =
591 DateTime::from_timestamp_nanos(valid_array.value(row)).to_utc();
592 let timezone = Tz::from_str(timezone).map_err(|_| {
593 Error::FailedToCreateInsertStatement {
594 source: "Unable to parse arrow timezone information".into(),
595 }
596 })?;
597 let offset = timezone
598 .offset_from_utc_datetime(&utc_time.naive_utc())
599 .fix();
600 let time_with_offset = utc_time.with_timezone(&offset);
601 row_values.push(time_with_offset.into());
602 } else {
603 insert_timestamp_into_row_values(
604 OffsetDateTime::from_unix_timestamp_nanos(i128::from(
605 valid_array.value(row),
606 )),
607 &mut row_values,
608 )?;
609 }
610 }
611 }
612 DataType::List(list_type) => {
613 let array = column.as_any().downcast_ref::<array::ListArray>();
614 if let Some(valid_array) = array {
615 if valid_array.is_null(row) {
616 row_values.push(Keyword::Null.into());
617 continue;
618 }
619 let list_array = valid_array.value(row);
620
621 if use_json_insert_for_type(column_data_type, query_builder) {
622 insert_list_into_row_values_json(
623 list_array,
624 list_type,
625 &mut row_values,
626 )?;
627 } else {
628 insert_list_into_row_values(list_array, list_type, &mut row_values);
629 }
630 }
631 }
632 DataType::LargeList(list_type) => {
633 let array = column.as_any().downcast_ref::<array::LargeListArray>();
634 if let Some(valid_array) = array {
635 if valid_array.is_null(row) {
636 row_values.push(Keyword::Null.into());
637 continue;
638 }
639 let list_array = valid_array.value(row);
640
641 if use_json_insert_for_type(column_data_type, query_builder) {
642 insert_list_into_row_values_json(
643 list_array,
644 list_type,
645 &mut row_values,
646 )?;
647 } else {
648 insert_list_into_row_values(list_array, list_type, &mut row_values);
649 }
650 }
651 }
652 DataType::FixedSizeList(list_type, _) => {
653 let array = column.as_any().downcast_ref::<array::FixedSizeListArray>();
654 if let Some(valid_array) = array {
655 if valid_array.is_null(row) {
656 row_values.push(Keyword::Null.into());
657 continue;
658 }
659 let list_array = valid_array.value(row);
660
661 if use_json_insert_for_type(column_data_type, query_builder) {
662 insert_list_into_row_values_json(
663 list_array,
664 list_type,
665 &mut row_values,
666 )?;
667 } else {
668 insert_list_into_row_values(list_array, list_type, &mut row_values);
669 }
670 }
671 }
672 DataType::Binary => {
673 let array = column.as_any().downcast_ref::<array::BinaryArray>();
674
675 if let Some(valid_array) = array {
676 if valid_array.is_null(row) {
677 row_values.push(Keyword::Null.into());
678 continue;
679 }
680
681 row_values.push(valid_array.value(row).into());
682 }
683 }
684 DataType::LargeBinary => {
685 let array = column.as_any().downcast_ref::<array::LargeBinaryArray>();
686
687 if let Some(valid_array) = array {
688 if valid_array.is_null(row) {
689 row_values.push(Keyword::Null.into());
690 continue;
691 }
692
693 row_values.push(valid_array.value(row).into());
694 }
695 }
696 DataType::FixedSizeBinary(_) => {
697 let array = column
698 .as_any()
699 .downcast_ref::<array::FixedSizeBinaryArray>();
700
701 if let Some(valid_array) = array {
702 if valid_array.is_null(row) {
703 row_values.push(Keyword::Null.into());
704 continue;
705 }
706
707 row_values.push(valid_array.value(row).into());
708 }
709 }
710 DataType::BinaryView => {
711 let array = column.as_any().downcast_ref::<array::BinaryViewArray>();
712
713 if let Some(valid_array) = array {
714 if valid_array.is_null(row) {
715 row_values.push(Keyword::Null.into());
716 continue;
717 }
718
719 row_values.push(valid_array.value(row).into());
720 }
721 }
722 DataType::Interval(interval_unit) => match interval_unit {
723 IntervalUnit::DayTime => {
724 let array = column
725 .as_any()
726 .downcast_ref::<array::IntervalDayTimeArray>();
727
728 if let Some(valid_array) = array {
729 if valid_array.is_null(row) {
730 row_values.push(Keyword::Null.into());
731 continue;
732 }
733
734 let interval_str =
735 if let Ok(str) = array_value_to_string(valid_array, row) {
736 str
737 } else {
738 let days = valid_array.value(row).days;
739 let milliseconds = valid_array.value(row).milliseconds;
740 format!("{days} days {milliseconds} milliseconds")
741 };
742
743 row_values.push(interval_str.into());
744 }
745 }
746 IntervalUnit::YearMonth => {
747 let array = column
748 .as_any()
749 .downcast_ref::<array::IntervalYearMonthArray>();
750
751 if let Some(valid_array) = array {
752 if valid_array.is_null(row) {
753 row_values.push(Keyword::Null.into());
754 continue;
755 }
756
757 let interval_str =
758 if let Ok(str) = array_value_to_string(valid_array, row) {
759 str
760 } else {
761 let months = valid_array.value(row);
762 format!("{months} months")
763 };
764
765 row_values.push(interval_str.into());
766 }
767 }
768 IntervalUnit::MonthDayNano => {
771 let array = column
772 .as_any()
773 .downcast_ref::<array::IntervalMonthDayNanoArray>();
774
775 if let Some(valid_array) = array {
776 if valid_array.is_null(row) {
777 row_values.push(Keyword::Null.into());
778 continue;
779 }
780
781 let interval_str =
782 if let Ok(str) = array_value_to_string(valid_array, row) {
783 str
784 } else {
785 let months = valid_array.value(row).months;
786 let days = valid_array.value(row).days;
787 let nanoseconds = valid_array.value(row).nanoseconds;
788 let micros = nanoseconds / 1_000;
789 format!("{months} months {days} days {micros} microseconds")
790 };
791
792 row_values.push(interval_str.into());
793 }
794 }
795 },
796 DataType::Struct(fields) => {
797 let array = column.as_any().downcast_ref::<array::StructArray>();
798
799 if let Some(valid_array) = array {
800 if valid_array.is_null(row) {
801 row_values.push(Keyword::Null.into());
802 continue;
803 }
804
805 if use_json_insert_for_type(column_data_type, query_builder) {
806 insert_struct_into_row_values_json(
807 fields,
808 valid_array,
809 row,
810 &mut row_values,
811 )?;
812 continue;
813 }
814
815 let mut param_values: Vec<SimpleExpr> = vec![];
816
817 for col in valid_array.columns() {
818 match col.data_type() {
819 DataType::Int8 => {
820 let int_array =
821 col.as_any().downcast_ref::<array::Int8Array>();
822
823 if let Some(valid_int_array) = int_array {
824 param_values.push(valid_int_array.value(row).into());
825 }
826 }
827 DataType::Int16 => {
828 let int_array =
829 col.as_any().downcast_ref::<array::Int16Array>();
830
831 if let Some(valid_int_array) = int_array {
832 param_values.push(valid_int_array.value(row).into());
833 }
834 }
835 DataType::Int32 => {
836 let int_array =
837 col.as_any().downcast_ref::<array::Int32Array>();
838
839 if let Some(valid_int_array) = int_array {
840 param_values.push(valid_int_array.value(row).into());
841 }
842 }
843 DataType::Int64 => {
844 let int_array =
845 col.as_any().downcast_ref::<array::Int64Array>();
846
847 if let Some(valid_int_array) = int_array {
848 param_values.push(valid_int_array.value(row).into());
849 }
850 }
851 DataType::UInt8 => {
852 let int_array =
853 col.as_any().downcast_ref::<array::UInt8Array>();
854
855 if let Some(valid_int_array) = int_array {
856 param_values.push(valid_int_array.value(row).into());
857 }
858 }
859 DataType::UInt16 => {
860 let int_array =
861 col.as_any().downcast_ref::<array::UInt16Array>();
862
863 if let Some(valid_int_array) = int_array {
864 param_values.push(valid_int_array.value(row).into());
865 }
866 }
867 DataType::UInt32 => {
868 let int_array =
869 col.as_any().downcast_ref::<array::UInt32Array>();
870
871 if let Some(valid_int_array) = int_array {
872 param_values.push(valid_int_array.value(row).into());
873 }
874 }
875 DataType::UInt64 => {
876 let int_array =
877 col.as_any().downcast_ref::<array::UInt64Array>();
878
879 if let Some(valid_int_array) = int_array {
880 param_values.push(valid_int_array.value(row).into());
881 }
882 }
883 DataType::Float32 => {
884 let float_array =
885 col.as_any().downcast_ref::<array::Float32Array>();
886
887 if let Some(valid_float_array) = float_array {
888 param_values.push(valid_float_array.value(row).into());
889 }
890 }
891 DataType::Float64 => {
892 let float_array =
893 col.as_any().downcast_ref::<array::Float64Array>();
894
895 if let Some(valid_float_array) = float_array {
896 param_values.push(valid_float_array.value(row).into());
897 }
898 }
899 DataType::Utf8 => {
900 let string_array =
901 col.as_any().downcast_ref::<array::StringArray>();
902
903 if let Some(valid_string_array) = string_array {
904 param_values.push(valid_string_array.value(row).into());
905 }
906 }
907 DataType::Null => {
908 param_values.push(Keyword::Null.into());
909 }
910 DataType::Boolean => {
911 let bool_array =
912 col.as_any().downcast_ref::<array::BooleanArray>();
913
914 if let Some(valid_bool_array) = bool_array {
915 param_values.push(valid_bool_array.value(row).into());
916 }
917 }
918 DataType::Binary => {
919 let binary_array =
920 col.as_any().downcast_ref::<array::BinaryArray>();
921
922 if let Some(valid_binary_array) = binary_array {
923 param_values.push(valid_binary_array.value(row).into());
924 }
925 }
926 DataType::FixedSizeBinary(_) => {
927 let binary_array = col
928 .as_any()
929 .downcast_ref::<array::FixedSizeBinaryArray>();
930
931 if let Some(valid_binary_array) = binary_array {
932 param_values.push(valid_binary_array.value(row).into());
933 }
934 }
935 DataType::LargeBinary => {
936 let binary_array =
937 col.as_any().downcast_ref::<array::LargeBinaryArray>();
938
939 if let Some(valid_binary_array) = binary_array {
940 param_values.push(valid_binary_array.value(row).into());
941 }
942 }
943 DataType::LargeUtf8 => {
944 let string_array =
945 col.as_any().downcast_ref::<array::LargeStringArray>();
946
947 if let Some(valid_string_array) = string_array {
948 param_values.push(valid_string_array.value(row).into());
949 }
950 }
951 DataType::Utf8View => {
952 let view_array =
953 col.as_any().downcast_ref::<array::StringViewArray>();
954
955 if let Some(valid_view_array) = view_array {
956 param_values.push(valid_view_array.value(row).into());
957 }
958 }
959 DataType::BinaryView => {
960 let view_array =
961 col.as_any().downcast_ref::<array::BinaryViewArray>();
962
963 if let Some(valid_view_array) = view_array {
964 param_values.push(valid_view_array.value(row).into());
965 }
966 }
967 DataType::Float16
968 | DataType::Timestamp(_, _)
969 | DataType::Date32
970 | DataType::Date64
971 | DataType::Time32(_)
972 | DataType::Time64(_)
973 | DataType::Duration(_)
974 | DataType::Interval(_)
975 | DataType::List(_)
976 | DataType::ListView(_)
977 | DataType::FixedSizeList(_, _)
978 | DataType::LargeList(_)
979 | DataType::LargeListView(_)
980 | DataType::Struct(_)
981 | DataType::Union(_, _)
982 | DataType::Dictionary(_, _)
983 | DataType::Map(_, _)
984 | DataType::RunEndEncoded(_, _)
985 | DataType::Decimal128(_, _)
986 | DataType::Decimal256(_, _) => {
987 unimplemented!(
988 "Data type mapping not implemented for Struct of {}",
989 col.data_type()
990 )
991 }
992 }
993 }
994
995 let mut params_vec = Vec::new();
996 for param_value in ¶m_values {
997 let mut params_str = String::new();
998 query_builder.prepare_simple_expr(param_value, &mut params_str);
999 params_vec.push(params_str);
1000 }
1001
1002 let params = params_vec.join(", ");
1003 row_values.push(Expr::cust(format!("ROW({params})")));
1004 }
1005 }
1006 unimplemented_type => {
1007 return Result::Err(Error::UnimplementedDataTypeInInsertStatement {
1008 data_type: unimplemented_type.clone(),
1009 })
1010 }
1011 }
1012 }
1013 match insert_stmt.values(row_values) {
1014 Ok(_) => (),
1015 Err(e) => {
1016 return Result::Err(Error::FailedToCreateInsertStatement {
1017 source: Box::new(e),
1018 })
1019 }
1020 }
1021 }
1022 Ok(())
1023 }
1024
1025 pub fn build_postgres(self, on_conflict: Option<OnConflict>) -> Result<String> {
1030 self.build(PostgresQueryBuilder, on_conflict)
1031 }
1032
1033 pub fn build_sqlite(self, on_conflict: Option<OnConflict>) -> Result<String> {
1038 self.build(SqliteQueryBuilder, on_conflict)
1039 }
1040
1041 pub fn build_mysql(self, on_conflict: Option<OnConflict>) -> Result<String> {
1046 self.build(MysqlQueryBuilder, on_conflict)
1047 }
1048
1049 pub fn build<T: GenericBuilder + 'static>(
1054 &self,
1055 query_builder: T,
1056 on_conflict: Option<OnConflict>,
1057 ) -> Result<String> {
1058 let columns: Vec<Alias> = (self.record_batches[0])
1059 .schema()
1060 .fields()
1061 .iter()
1062 .map(|field| Alias::new(field.name()))
1063 .collect();
1064
1065 let mut insert_stmt = Query::insert()
1066 .into_table(table_reference_to_sea_table_ref(&self.table))
1067 .columns(columns)
1068 .to_owned();
1069
1070 for record_batch in &self.record_batches {
1071 self.construct_insert_stmt(&mut insert_stmt, record_batch, &query_builder)?;
1072 }
1073 if let Some(on_conflict) = on_conflict {
1074 insert_stmt.on_conflict(on_conflict);
1075 }
1076 Ok(insert_stmt.to_string(query_builder))
1077 }
1078}
1079
1080fn table_reference_to_sea_table_ref(table: &TableReference) -> TableRef {
1081 match table {
1082 TableReference::Bare { table } => {
1083 TableRef::Table(SeaRc::new(Alias::new(table.to_string())))
1084 }
1085 TableReference::Partial { schema, table } => TableRef::SchemaTable(
1086 SeaRc::new(Alias::new(schema.to_string())),
1087 SeaRc::new(Alias::new(table.to_string())),
1088 ),
1089 TableReference::Full {
1090 catalog,
1091 schema,
1092 table,
1093 } => TableRef::DatabaseSchemaTable(
1094 SeaRc::new(Alias::new(catalog.to_string())),
1095 SeaRc::new(Alias::new(schema.to_string())),
1096 SeaRc::new(Alias::new(table.to_string())),
1097 ),
1098 }
1099}
1100
1101pub struct IndexBuilder {
1102 table_name: String,
1103 columns: Vec<String>,
1104 unique: bool,
1105}
1106
1107impl IndexBuilder {
1108 #[must_use]
1109 pub fn new(table_name: &str, columns: Vec<&str>) -> Self {
1110 Self {
1111 table_name: table_name.to_string(),
1112 columns: columns.into_iter().map(ToString::to_string).collect(),
1113 unique: false,
1114 }
1115 }
1116
1117 #[must_use]
1118 pub fn unique(mut self) -> Self {
1119 self.unique = true;
1120 self
1121 }
1122
1123 #[must_use]
1124 pub fn index_name(&self) -> String {
1125 format!("i_{}_{}", self.table_name, self.columns.join("_"))
1126 }
1127
1128 #[must_use]
1129 pub fn build_postgres(self) -> String {
1130 self.build(PostgresQueryBuilder)
1131 }
1132
1133 #[must_use]
1134 pub fn build_sqlite(self) -> String {
1135 self.build(SqliteQueryBuilder)
1136 }
1137
1138 #[must_use]
1139 pub fn build_mysql(self) -> String {
1140 self.build(MysqlQueryBuilder)
1141 }
1142
1143 #[must_use]
1144 pub fn build<T: GenericBuilder>(self, query_builder: T) -> String {
1145 let mut index = Index::create();
1146 index.table(Alias::new(&self.table_name));
1147 index.name(self.index_name());
1148 if self.unique {
1149 index.unique();
1150 }
1151 for column in self.columns {
1152 index.col(Alias::new(column).into_iden().into_index_column());
1153 }
1154 index.if_not_exists();
1155 index.to_string(query_builder)
1156 }
1157}
1158
1159fn insert_timestamp_into_row_values(
1160 timestamp: Result<OffsetDateTime, time::error::ComponentRange>,
1161 row_values: &mut Vec<SimpleExpr>,
1162) -> Result<()> {
1163 match timestamp {
1164 Ok(offset_time) => {
1165 row_values.push(PrimitiveDateTime::new(offset_time.date(), offset_time.time()).into());
1166 Ok(())
1167 }
1168 Err(e) => Err(Error::FailedToCreateInsertStatement {
1169 source: Box::new(e),
1170 }),
1171 }
1172}
1173
1174#[allow(clippy::needless_pass_by_value)]
1175fn insert_list_into_row_values(
1176 list_array: Arc<dyn Array>,
1177 list_type: &Arc<Field>,
1178 row_values: &mut Vec<SimpleExpr>,
1179) {
1180 match list_type.data_type() {
1181 DataType::Int8 => push_list_values!(
1182 list_type.data_type(),
1183 list_array,
1184 row_values,
1185 array::Int8Array,
1186 i8,
1187 "int2[]"
1188 ),
1189 DataType::Int16 => push_list_values!(
1190 list_type.data_type(),
1191 list_array,
1192 row_values,
1193 array::Int16Array,
1194 i16,
1195 "int2[]"
1196 ),
1197 DataType::Int32 => push_list_values!(
1198 list_type.data_type(),
1199 list_array,
1200 row_values,
1201 array::Int32Array,
1202 i32,
1203 "int4[]"
1204 ),
1205 DataType::Int64 => push_list_values!(
1206 list_type.data_type(),
1207 list_array,
1208 row_values,
1209 array::Int64Array,
1210 i64,
1211 "int8[]"
1212 ),
1213 DataType::Float32 => push_list_values!(
1214 list_type.data_type(),
1215 list_array,
1216 row_values,
1217 array::Float32Array,
1218 f32,
1219 "float4[]"
1220 ),
1221 DataType::Float64 => push_list_values!(
1222 list_type.data_type(),
1223 list_array,
1224 row_values,
1225 array::Float64Array,
1226 f64,
1227 "float8[]"
1228 ),
1229 DataType::Utf8 => {
1230 let mut list_values: Vec<String> = vec![];
1231 for i in 0..list_array.len() {
1232 let int_array = list_array.as_any().downcast_ref::<array::StringArray>();
1233 if let Some(valid_int_array) = int_array {
1234 list_values.push(valid_int_array.value(i).to_string());
1235 }
1236 }
1237 let expr: SimpleExpr = list_values.into();
1238 row_values.push(expr.cast_as(Alias::new("text[]")));
1240 }
1241 DataType::LargeUtf8 => {
1242 let mut list_values: Vec<String> = vec![];
1243 for i in 0..list_array.len() {
1244 let int_array = list_array
1245 .as_any()
1246 .downcast_ref::<array::LargeStringArray>();
1247 if let Some(valid_int_array) = int_array {
1248 list_values.push(valid_int_array.value(i).to_string());
1249 }
1250 }
1251 let expr: SimpleExpr = list_values.into();
1252 row_values.push(expr.cast_as(Alias::new("text[]")));
1254 }
1255 DataType::Utf8View => {
1256 let mut list_values: Vec<String> = vec![];
1257 for i in 0..list_array.len() {
1258 let view_array = list_array.as_any().downcast_ref::<array::StringViewArray>();
1259 if let Some(valid_view_array) = view_array {
1260 list_values.push(valid_view_array.value(i).to_string());
1261 }
1262 }
1263 let expr: SimpleExpr = list_values.into();
1264 row_values.push(expr.cast_as(Alias::new("text[]")));
1265 }
1266 DataType::Boolean => push_list_values!(
1267 list_type.data_type(),
1268 list_array,
1269 row_values,
1270 array::BooleanArray,
1271 bool,
1272 "boolean[]"
1273 ),
1274 DataType::Binary => {
1275 let mut list_values: Vec<Vec<u8>> = Vec::new();
1276 for i in 0..list_array.len() {
1277 let temp_array = list_array.as_any().downcast_ref::<array::BinaryArray>();
1278 if let Some(valid_array) = temp_array {
1279 list_values.push(valid_array.value(i).to_vec());
1280 }
1281 }
1282 let expr: SimpleExpr = list_values.into();
1283 row_values.push(expr.cast_as(Alias::new("bytea[]")));
1285 }
1286 _ => unimplemented!(
1287 "Data type mapping not implemented for {}",
1288 list_type.data_type()
1289 ),
1290 }
1291}
1292
1293#[allow(clippy::cast_sign_loss)]
1294pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
1295 match data_type {
1296 DataType::Int8 => ColumnType::TinyInteger,
1297 DataType::Int16 => ColumnType::SmallInteger,
1298 DataType::Int32 => ColumnType::Integer,
1299 DataType::Int64 | DataType::Duration(_) => ColumnType::BigInteger,
1300 DataType::UInt8 => ColumnType::TinyUnsigned,
1301 DataType::UInt16 => ColumnType::SmallUnsigned,
1302 DataType::UInt32 => ColumnType::Unsigned,
1303 DataType::UInt64 => ColumnType::BigUnsigned,
1304 DataType::Float32 => ColumnType::Float,
1305 DataType::Float64 => ColumnType::Double,
1306 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => ColumnType::Text,
1307 DataType::Boolean => ColumnType::Boolean,
1308 #[allow(clippy::cast_sign_loss)] DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => {
1310 ColumnType::Decimal(Some((u32::from(*p), *s as u32)))
1311 }
1312 DataType::Timestamp(_unit, time_zone) => {
1313 if time_zone.is_some() {
1314 return ColumnType::TimestampWithTimeZone;
1315 }
1316 ColumnType::Timestamp
1317 }
1318 DataType::Date32 | DataType::Date64 => ColumnType::Date,
1319 DataType::Time64(_unit) | DataType::Time32(_unit) => ColumnType::Time,
1320 DataType::List(list_type)
1321 | DataType::LargeList(list_type)
1322 | DataType::FixedSizeList(list_type, _) => {
1323 ColumnType::Array(map_data_type_to_column_type(list_type.data_type()).into())
1324 }
1325 DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
1330 DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
1331 DataType::Interval(_) => ColumnType::Interval(None, None),
1332 _ => unimplemented!("Data type mapping not implemented for {:?}", data_type),
1334 }
1335}
1336
1337macro_rules! serialize_list_values {
1338 ($data_type:expr, $list_array:expr, $array_type:ty, $vec_type:ty) => {{
1339 let mut list_values: Vec<$vec_type> = vec![];
1340 if let Some(array) = $list_array.as_any().downcast_ref::<$array_type>() {
1341 for i in 0..array.len() {
1342 list_values.push(array.value(i).into());
1343 }
1344 }
1345
1346 serde_json::to_string(&list_values).map_err(|e| Error::FailedToCreateInsertStatement {
1347 source: Box::new(e),
1348 })?
1349 }};
1350}
1351
1352fn insert_list_into_row_values_json(
1353 list_array: Arc<dyn Array>,
1354 list_type: &Arc<Field>,
1355 row_values: &mut Vec<SimpleExpr>,
1356) -> Result<()> {
1357 let data_type = list_type.data_type();
1358
1359 let json_string: String = match data_type {
1360 DataType::Int8 => serialize_list_values!(data_type, list_array, Int8Array, i8),
1361 DataType::Int16 => serialize_list_values!(data_type, list_array, Int16Array, i16),
1362 DataType::Int32 => serialize_list_values!(data_type, list_array, Int32Array, i32),
1363 DataType::Int64 => serialize_list_values!(data_type, list_array, Int64Array, i64),
1364 DataType::UInt8 => serialize_list_values!(data_type, list_array, UInt8Array, u8),
1365 DataType::UInt16 => serialize_list_values!(data_type, list_array, UInt16Array, u16),
1366 DataType::UInt32 => serialize_list_values!(data_type, list_array, UInt32Array, u32),
1367 DataType::UInt64 => serialize_list_values!(data_type, list_array, UInt64Array, u64),
1368 DataType::Float32 => serialize_list_values!(data_type, list_array, Float32Array, f32),
1369 DataType::Float64 => serialize_list_values!(data_type, list_array, Float64Array, f64),
1370 DataType::Utf8 => serialize_list_values!(data_type, list_array, StringArray, String),
1371 DataType::LargeUtf8 => {
1372 serialize_list_values!(data_type, list_array, LargeStringArray, String)
1373 }
1374 DataType::Utf8View => {
1375 serialize_list_values!(data_type, list_array, StringViewArray, String)
1376 }
1377 DataType::Boolean => serialize_list_values!(data_type, list_array, BooleanArray, bool),
1378 _ => unimplemented!(
1379 "List to json conversion is not implemented for {}",
1380 list_type.data_type()
1381 ),
1382 };
1383
1384 let expr: SimpleExpr = Expr::value(json_string);
1385 row_values.push(expr);
1386
1387 Ok(())
1388}
1389
1390fn insert_struct_into_row_values_json(
1391 fields: &Fields,
1392 array: &StructArray,
1393 row_index: usize,
1394 row_values: &mut Vec<SimpleExpr>,
1395) -> Result<()> {
1396 let single_row_columns: Vec<ArrayRef> = (0..array.num_columns())
1400 .map(|i| array.column(i).slice(row_index, 1))
1401 .collect();
1402
1403 let batch = RecordBatch::try_new(Arc::new(Schema::new(fields.clone())), single_row_columns)
1404 .map_err(|e| Error::FailedToCreateInsertStatement {
1405 source: Box::new(e),
1406 })?;
1407
1408 let mut writer = datafusion::arrow::json::LineDelimitedWriter::new(Vec::new());
1409 writer
1410 .write(&batch)
1411 .map_err(|e| Error::FailedToCreateInsertStatement {
1412 source: Box::new(e),
1413 })?;
1414 writer
1415 .finish()
1416 .map_err(|e| Error::FailedToCreateInsertStatement {
1417 source: Box::new(e),
1418 })?;
1419 let json_bytes = writer.into_inner();
1420
1421 let json = String::from_utf8(json_bytes).map_err(|e| Error::FailedToCreateInsertStatement {
1422 source: Box::new(e),
1423 })?;
1424
1425 let expr: SimpleExpr = Expr::value(json);
1426 row_values.push(expr);
1427
1428 Ok(())
1429}
1430
1431#[cfg(test)]
1432mod tests {
1433 use std::sync::Arc;
1434
1435 use super::*;
1436 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema};
1437
1438 #[test]
1439 fn test_basic_table_creation() {
1440 let schema = Schema::new(vec![
1441 Field::new("id", DataType::Int32, false),
1442 Field::new("name", DataType::Utf8, false),
1443 Field::new("age", DataType::Int32, true),
1444 ]);
1445 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users").build_sqlite();
1446
1447 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer )");
1448 }
1449
1450 #[test]
1451 fn test_table_insertion() {
1452 let schema1 = Schema::new(vec![
1453 Field::new("id", DataType::Int32, false),
1454 Field::new("name", DataType::Utf8, false),
1455 Field::new("age", DataType::Int32, true),
1456 ]);
1457 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1458 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1459 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1460
1461 let batch1 = RecordBatch::try_new(
1462 Arc::new(schema1.clone()),
1463 vec![
1464 Arc::new(id_array.clone()),
1465 Arc::new(name_array.clone()),
1466 Arc::new(age_array.clone()),
1467 ],
1468 )
1469 .expect("Unable to build record batch");
1470
1471 let schema2 = Schema::new(vec![
1472 Field::new("id", DataType::Int32, false),
1473 Field::new("name", DataType::Utf8, false),
1474 Field::new("blah", DataType::Int32, true),
1475 ]);
1476
1477 let batch2 = RecordBatch::try_new(
1478 Arc::new(schema2),
1479 vec![
1480 Arc::new(id_array),
1481 Arc::new(name_array),
1482 Arc::new(age_array),
1483 ],
1484 )
1485 .expect("Unable to build record batch");
1486 let record_batches = vec![batch1, batch2];
1487
1488 let sql = InsertBuilder::new(&TableReference::from("users"), record_batches)
1489 .build_postgres(None)
1490 .expect("Failed to build insert statement");
1491 assert_eq!(sql, "INSERT INTO \"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1492 }
1493
1494 #[test]
1495 fn test_table_insertion_with_schema() {
1496 let schema1 = Schema::new(vec![
1497 Field::new("id", DataType::Int32, false),
1498 Field::new("name", DataType::Utf8, false),
1499 Field::new("age", DataType::Int32, true),
1500 ]);
1501 let id_array = array::Int32Array::from(vec![1, 2, 3]);
1502 let name_array = array::StringArray::from(vec!["a", "b", "c"]);
1503 let age_array = array::Int32Array::from(vec![10, 20, 30]);
1504
1505 let batch1 = RecordBatch::try_new(
1506 Arc::new(schema1.clone()),
1507 vec![
1508 Arc::new(id_array.clone()),
1509 Arc::new(name_array.clone()),
1510 Arc::new(age_array.clone()),
1511 ],
1512 )
1513 .expect("Unable to build record batch");
1514
1515 let schema2 = Schema::new(vec![
1516 Field::new("id", DataType::Int32, false),
1517 Field::new("name", DataType::Utf8, false),
1518 Field::new("blah", DataType::Int32, true),
1519 ]);
1520
1521 let batch2 = RecordBatch::try_new(
1522 Arc::new(schema2),
1523 vec![
1524 Arc::new(id_array),
1525 Arc::new(name_array),
1526 Arc::new(age_array),
1527 ],
1528 )
1529 .expect("Unable to build record batch");
1530 let record_batches = vec![batch1, batch2];
1531
1532 let sql = InsertBuilder::new(&TableReference::from("schema.users"), record_batches)
1533 .build_postgres(None)
1534 .expect("Failed to build insert statement");
1535 assert_eq!(sql, "INSERT INTO \"schema\".\"users\" (\"id\", \"name\", \"age\") VALUES (1, 'a', 10), (2, 'b', 20), (3, 'c', 30), (1, 'a', 10), (2, 'b', 20), (3, 'c', 30)");
1536 }
1537
1538 #[test]
1539 fn test_table_creation_with_primary_keys() {
1540 let schema = Schema::new(vec![
1541 Field::new("id", DataType::Int32, false),
1542 Field::new("id2", DataType::Int32, false),
1543 Field::new("name", DataType::Utf8, false),
1544 Field::new("age", DataType::Int32, true),
1545 ]);
1546 let sql = CreateTableBuilder::new(SchemaRef::new(schema), "users")
1547 .primary_keys(vec!["id", "id2"])
1548 .build_sqlite();
1549
1550 assert_eq!(sql, "CREATE TABLE IF NOT EXISTS \"users\" ( \"id\" integer NOT NULL, \"id2\" integer NOT NULL, \"name\" text NOT NULL, \"age\" integer, PRIMARY KEY (\"id\", \"id2\") )");
1551 }
1552
1553 #[test]
1554 fn test_table_insertion_with_list() {
1555 let schema1 = Schema::new(vec![Field::new(
1556 "list",
1557 DataType::List(Field::new("item", DataType::Int32, true).into()),
1558 true,
1559 )]);
1560 let list_array = array::ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1561 Some(vec![Some(1), Some(2), Some(3)]),
1562 Some(vec![Some(4), Some(5), Some(6)]),
1563 Some(vec![Some(7), Some(8), Some(9)]),
1564 ]);
1565
1566 let batch = RecordBatch::try_new(Arc::new(schema1.clone()), vec![Arc::new(list_array)])
1567 .expect("Unable to build record batch");
1568
1569 let sql = InsertBuilder::new(&TableReference::from("arrays"), vec![batch])
1570 .build_postgres(None)
1571 .expect("Failed to build insert statement");
1572 assert_eq!(
1573 sql,
1574 "INSERT INTO \"arrays\" (\"list\") VALUES (CAST(ARRAY [1,2,3] AS int4[])), (CAST(ARRAY [4,5,6] AS int4[])), (CAST(ARRAY [7,8,9] AS int4[]))"
1575 );
1576 }
1577
1578 #[test]
1579 fn test_create_index() {
1580 let sql = IndexBuilder::new("users", vec!["id", "name"]).build_postgres();
1581 assert_eq!(
1582 sql,
1583 r#"CREATE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1584 );
1585 }
1586
1587 #[test]
1588 fn test_create_unique_index() {
1589 let sql = IndexBuilder::new("users", vec!["id", "name"])
1590 .unique()
1591 .build_postgres();
1592 assert_eq!(
1593 sql,
1594 r#"CREATE UNIQUE INDEX IF NOT EXISTS "i_users_id_name" ON "users" ("id", "name")"#
1595 );
1596 }
1597}