1use std::convert;
2use std::io::Read;
3use std::sync::Arc;
4
5use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional;
6use crate::sql::arrow_sql_gen::statement::map_data_type_to_column_type;
7use arrow::array::{
8 ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
9 FixedSizeListBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
10 Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder,
11 RecordBatch, RecordBatchOptions, StringBuilder, StringDictionaryBuilder, StructBuilder,
12 Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder,
13};
14use arrow::datatypes::{
15 DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema,
16 SchemaRef, TimeUnit,
17};
18use bigdecimal::num_bigint::BigInt;
19use bigdecimal::num_bigint::Sign;
20use bigdecimal::BigDecimal;
21use bigdecimal::ToPrimitive;
22use byteorder::{BigEndian, ReadBytesExt};
23use chrono::{DateTime, Timelike, Utc};
24use composite::CompositeType;
25use geo_types::geometry::Point;
26use sea_query::{Alias, ColumnType, SeaRc};
27use serde_json::Value;
28use snafu::prelude::*;
29use std::time::{SystemTime, UNIX_EPOCH};
30use tokio_postgres::types::FromSql;
31use tokio_postgres::types::Kind;
32use tokio_postgres::{types::Type, Row};
33
34pub mod builder;
35pub mod composite;
36pub mod schema;
37
38#[derive(Debug, Snafu)]
39pub enum Error {
40 #[snafu(display("Failed to build record batch: {source}"))]
41 FailedToBuildRecordBatch {
42 source: datafusion::arrow::error::ArrowError,
43 },
44
45 #[snafu(display("No builder found for index {index}"))]
46 NoBuilderForIndex { index: usize },
47
48 #[snafu(display("Failed to downcast builder for {postgres_type}"))]
49 FailedToDowncastBuilder { postgres_type: String },
50
51 #[snafu(display("Integer overflow when converting u64 to i64: {source}"))]
52 FailedToConvertU64toI64 {
53 source: <u64 as convert::TryInto<i64>>::Error,
54 },
55
56 #[snafu(display("Integer overflow when converting u128 to i64: {source}"))]
57 FailedToConvertU128toI64 {
58 source: <u128 as convert::TryInto<i64>>::Error,
59 },
60
61 #[snafu(display("Failed to get a row value for {pg_type}: {source}"))]
62 FailedToGetRowValue {
63 pg_type: Type,
64 source: tokio_postgres::Error,
65 },
66
67 #[snafu(display("Failed to get a composite row value for {pg_type}: {source}"))]
68 FailedToGetCompositeRowValue {
69 pg_type: Type,
70 source: composite::Error,
71 },
72
73 #[snafu(display("Failed to parse raw Postgres Bytes as BigDecimal: {:?}", bytes))]
74 FailedToParseBigDecimalFromPostgres { bytes: Vec<u8> },
75
76 #[snafu(display("Cannot represent BigDecimal as i128: {big_decimal}"))]
77 FailedToConvertBigDecimalToI128 { big_decimal: BigDecimal },
78
79 #[snafu(display("Failed to find field {column_name} in schema"))]
80 FailedToFindFieldInSchema { column_name: String },
81
82 #[snafu(display("No Arrow field found for index {index}"))]
83 NoArrowFieldForIndex { index: usize },
84
85 #[snafu(display("No PostgreSQL scale found for index {index}"))]
86 NoPostgresScaleForIndex { index: usize },
87
88 #[snafu(display("No column name for index: {index}"))]
89 NoColumnNameForIndex { index: usize },
90
91 #[snafu(display("The field '{field_name}' has an unsupported data type: {data_type}."))]
92 UnsupportedDataType {
93 data_type: String,
94 field_name: String,
95 },
96}
97
98pub type Result<T, E = Error> = std::result::Result<T, E>;
99
100macro_rules! handle_primitive_type {
101 ($builder:expr, $type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
102 let Some(builder) = $builder else {
103 return NoBuilderForIndexSnafu { index: $index }.fail();
104 };
105 let Some(builder) = builder.as_any_mut().downcast_mut::<$builder_ty>() else {
106 return FailedToDowncastBuilderSnafu {
107 postgres_type: format!("{:?}", $type),
108 }
109 .fail();
110 };
111 let v: Option<$value_ty> = $row
112 .try_get($index)
113 .context(FailedToGetRowValueSnafu { pg_type: $type })?;
114
115 match v {
116 Some(v) => builder.append_value(v),
117 None => builder.append_null(),
118 }
119 }};
120}
121
122macro_rules! handle_primitive_array_type {
123 ($type:expr, $builder:expr, $row:expr, $i:expr, $list_builder:ty, $value_type:ty) => {{
124 let Some(builder) = $builder else {
125 return NoBuilderForIndexSnafu { index: $i }.fail();
126 };
127 let Some(builder) = builder.as_any_mut().downcast_mut::<$list_builder>() else {
128 return FailedToDowncastBuilderSnafu {
129 postgres_type: format!("{:?}", $type),
130 }
131 .fail();
132 };
133 let v: Option<Vec<$value_type>> = $row
134 .try_get($i)
135 .context(FailedToGetRowValueSnafu { pg_type: $type })?;
136 match v {
137 Some(v) => {
138 let v = v.into_iter().map(Some);
139 builder.append_value(v);
140 }
141 None => builder.append_null(),
142 }
143 }};
144}
145
146macro_rules! handle_composite_type {
147 ($BuilderType:ty, $ValueType:ty, $pg_type:expr, $composite_type:expr, $builder:expr, $idx:expr, $field_name:expr) => {{
148 let Some(field_builder) = $builder.field_builder::<$BuilderType>($idx) else {
149 return FailedToDowncastBuilderSnafu {
150 postgres_type: format!("{}", $pg_type),
151 }
152 .fail();
153 };
154 let v: Option<$ValueType> =
155 $composite_type
156 .try_get($field_name)
157 .context(FailedToGetCompositeRowValueSnafu {
158 pg_type: $pg_type.clone(),
159 })?;
160 match v {
161 Some(v) => field_builder.append_value(v),
162 None => field_builder.append_null(),
163 }
164 }};
165}
166
167macro_rules! handle_composite_types {
168 ($field_type:expr, $pg_type:expr, $composite_type:expr, $builder:expr, $idx:expr, $field_name:expr, $($DataType:ident => ($BuilderType:ty, $ValueType:ty)),*) => {
169 match $field_type {
170 $(
171 DataType::$DataType => {
172 handle_composite_type!(
173 $BuilderType,
174 $ValueType,
175 $pg_type,
176 $composite_type,
177 $builder,
178 $idx,
179 $field_name
180 );
181 }
182 )*
183 _ => unimplemented!("Unsupported field type {:?}", $field_type),
184 }
185 }
186}
187
188#[allow(clippy::too_many_lines)]
195pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Result<RecordBatch> {
196 let mut arrow_fields: Vec<Option<Field>> = Vec::new();
197 let mut arrow_columns_builders: Vec<Option<Box<dyn ArrayBuilder>>> = Vec::new();
198 let mut postgres_types: Vec<Type> = Vec::new();
199 let mut postgres_numeric_scales: Vec<Option<u16>> = Vec::new();
200 let mut column_names: Vec<String> = Vec::new();
201
202 if !rows.is_empty() {
203 let row = &rows[0];
204 for column in row.columns() {
205 let column_name = column.name();
206 let column_type = column.type_();
207
208 let mut numeric_scale: Option<u16> = None;
209
210 let data_type = if *column_type == Type::NUMERIC {
211 if let Some(schema) = projected_schema.as_ref() {
212 match get_decimal_column_precision_and_scale(column_name, schema) {
213 Some((precision, scale)) => {
214 numeric_scale = Some(u16::try_from(scale).unwrap_or_default());
215 Some(DataType::Decimal128(precision, scale))
216 }
217 None => None,
218 }
219 } else {
220 None
221 }
222 } else {
223 map_column_type_to_data_type(column_type, column_name)?
224 };
225
226 match &data_type {
227 Some(data_type) => {
228 arrow_fields.push(Some(Field::new(column_name, data_type.clone(), true)));
229 }
230 None => arrow_fields.push(None),
231 }
232 postgres_numeric_scales.push(numeric_scale);
233 arrow_columns_builders
234 .push(map_data_type_to_array_builder_optional(data_type.as_ref()));
235 postgres_types.push(column_type.clone());
236 column_names.push(column_name.to_string());
237 }
238 }
239
240 for row in rows {
241 for (i, postgres_type) in postgres_types.iter().enumerate() {
242 let Some(builder) = arrow_columns_builders.get_mut(i) else {
243 return NoBuilderForIndexSnafu { index: i }.fail();
244 };
245
246 let Some(arrow_field) = arrow_fields.get_mut(i) else {
247 return NoArrowFieldForIndexSnafu { index: i }.fail();
248 };
249
250 let Some(postgres_numeric_scale) = postgres_numeric_scales.get_mut(i) else {
251 return NoPostgresScaleForIndexSnafu { index: i }.fail();
252 };
253
254 match *postgres_type {
255 Type::INT2 => {
256 handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, i);
257 }
258 Type::INT4 => {
259 handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, i);
260 }
261 Type::INT8 => {
262 handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, i);
263 }
264 Type::OID => {
265 handle_primitive_type!(builder, Type::OID, UInt32Builder, u32, row, i);
266 }
267 Type::XID => {
268 let Some(builder) = builder else {
269 return NoBuilderForIndexSnafu { index: i }.fail();
270 };
271 let Some(builder) = builder.as_any_mut().downcast_mut::<UInt32Builder>() else {
272 return FailedToDowncastBuilderSnafu {
273 postgres_type: format!("{postgres_type}"),
274 }
275 .fail();
276 };
277 let v = row
278 .try_get::<usize, Option<XidFromSql>>(i)
279 .with_context(|_| FailedToGetRowValueSnafu { pg_type: Type::XID })?;
280
281 match v {
282 Some(v) => {
283 builder.append_value(v.xid);
284 }
285 None => builder.append_null(),
286 }
287 }
288 Type::FLOAT4 => {
289 handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, i);
290 }
291 Type::FLOAT8 => {
292 handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, i);
293 }
294 Type::CHAR => {
295 handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, i);
296 }
297 Type::TEXT => {
298 handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, i);
299 }
300 Type::VARCHAR => {
301 handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, i);
302 }
303 Type::NAME => {
304 handle_primitive_type!(builder, Type::NAME, StringBuilder, &str, row, i);
305 }
306 Type::BYTEA => {
307 handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, i);
308 }
309 Type::BPCHAR => {
310 let Some(builder) = builder else {
311 return NoBuilderForIndexSnafu { index: i }.fail();
312 };
313 let Some(builder) = builder.as_any_mut().downcast_mut::<StringBuilder>() else {
314 return FailedToDowncastBuilderSnafu {
315 postgres_type: format!("{postgres_type}"),
316 }
317 .fail();
318 };
319 let v: Option<&str> = row.try_get(i).context(FailedToGetRowValueSnafu {
320 pg_type: Type::BPCHAR,
321 })?;
322
323 match v {
324 Some(v) => builder.append_value(v.trim_end()),
325 None => builder.append_null(),
326 }
327 }
328 Type::BOOL => {
329 handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, i);
330 }
331 Type::MONEY => {
332 let Some(builder) = builder else {
333 return NoBuilderForIndexSnafu { index: i }.fail();
334 };
335 let Some(builder) = builder.as_any_mut().downcast_mut::<Int64Builder>() else {
336 return FailedToDowncastBuilderSnafu {
337 postgres_type: format!("{postgres_type}"),
338 }
339 .fail();
340 };
341 let v = row
342 .try_get::<usize, Option<MoneyFromSql>>(i)
343 .with_context(|_| FailedToGetRowValueSnafu {
344 pg_type: Type::MONEY,
345 })?;
346
347 match v {
348 Some(v) => {
349 builder.append_value(v.cash_value);
350 }
351 None => builder.append_null(),
352 }
353 }
354 Type::JSON | Type::JSONB => {
356 let Some(builder) = builder else {
357 return NoBuilderForIndexSnafu { index: i }.fail();
358 };
359 let Some(builder) = builder.as_any_mut().downcast_mut::<StringBuilder>() else {
360 return FailedToDowncastBuilderSnafu {
361 postgres_type: format!("{postgres_type}"),
362 }
363 .fail();
364 };
365 let v = row.try_get::<usize, Option<Value>>(i).with_context(|_| {
366 FailedToGetRowValueSnafu {
367 pg_type: postgres_type.clone(),
368 }
369 })?;
370
371 match v {
372 Some(v) => {
373 builder.append_value(v.to_string());
374 }
375 None => builder.append_null(),
376 }
377 }
378 Type::TIME => {
379 let Some(builder) = builder else {
380 return NoBuilderForIndexSnafu { index: i }.fail();
381 };
382 let Some(builder) = builder
383 .as_any_mut()
384 .downcast_mut::<Time64NanosecondBuilder>()
385 else {
386 return FailedToDowncastBuilderSnafu {
387 postgres_type: format!("{postgres_type}"),
388 }
389 .fail();
390 };
391 let v = row
392 .try_get::<usize, Option<chrono::NaiveTime>>(i)
393 .with_context(|_| FailedToGetRowValueSnafu {
394 pg_type: Type::TIME,
395 })?;
396
397 match v {
398 Some(v) => {
399 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
400 * 1_000_000_000
401 + i64::from(v.nanosecond());
402 builder.append_value(timestamp);
403 }
404 None => builder.append_null(),
405 }
406 }
407 Type::POINT => {
408 let Some(builder) = builder else {
409 return NoBuilderForIndexSnafu { index: i }.fail();
410 };
411 let Some(builder) = builder
412 .as_any_mut()
413 .downcast_mut::<FixedSizeListBuilder<Float64Builder>>()
414 else {
415 return FailedToDowncastBuilderSnafu {
416 postgres_type: format!("{postgres_type}"),
417 }
418 .fail();
419 };
420
421 let v = row.try_get::<usize, Option<Point>>(i).with_context(|_| {
422 FailedToGetRowValueSnafu {
423 pg_type: Type::POINT,
424 }
425 })?;
426
427 if let Some(v) = v {
428 builder.values().append_value(v.x());
429 builder.values().append_value(v.y());
430 builder.append(true);
431 } else {
432 builder.values().append_null();
433 builder.values().append_null();
434 builder.append(false);
435 }
436 }
437 Type::INTERVAL => {
438 let Some(builder) = builder else {
439 return NoBuilderForIndexSnafu { index: i }.fail();
440 };
441 let Some(builder) = builder
442 .as_any_mut()
443 .downcast_mut::<IntervalMonthDayNanoBuilder>()
444 else {
445 return FailedToDowncastBuilderSnafu {
446 postgres_type: format!("{postgres_type}"),
447 }
448 .fail();
449 };
450
451 let v: Option<IntervalFromSql> =
452 row.try_get(i).context(FailedToGetRowValueSnafu {
453 pg_type: Type::INTERVAL,
454 })?;
455 match v {
456 Some(v) => {
457 let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
458 v.month,
459 v.day,
460 v.time * 1_000,
461 );
462 builder.append_value(interval_month_day_nano);
463 }
464 None => builder.append_null(),
465 }
466 }
467 Type::NUMERIC => {
468 let v: Option<BigDecimalFromSql> =
469 row.try_get(i).context(FailedToGetRowValueSnafu {
470 pg_type: Type::NUMERIC,
471 })?;
472 let scale = {
473 if let Some(v) = &v {
474 v.scale()
475 } else {
476 0
477 }
478 };
479
480 let dec_builder = builder.get_or_insert_with(|| {
481 Box::new(
482 Decimal128Builder::new()
483 .with_precision_and_scale(38, scale.try_into().unwrap_or_default())
484 .unwrap_or_default(),
485 )
486 });
487
488 let Some(dec_builder) =
489 dec_builder.as_any_mut().downcast_mut::<Decimal128Builder>()
490 else {
491 return FailedToDowncastBuilderSnafu {
492 postgres_type: format!("{postgres_type}"),
493 }
494 .fail();
495 };
496
497 if arrow_field.is_none() {
498 let Some(field_name) = column_names.get(i) else {
499 return NoColumnNameForIndexSnafu { index: i }.fail();
500 };
501 let new_arrow_field = Field::new(
502 field_name,
503 DataType::Decimal128(38, scale.try_into().unwrap_or_default()),
504 true,
505 );
506
507 *arrow_field = Some(new_arrow_field);
508 }
509
510 if postgres_numeric_scale.is_none() {
511 *postgres_numeric_scale = Some(scale);
512 };
513
514 let Some(v) = v else {
515 dec_builder.append_null();
516 continue;
517 };
518
519 let dest_scale = postgres_numeric_scale.unwrap_or_default();
522 let Some(v_i128) = v.to_decimal_128_with_scale(dest_scale) else {
523 return FailedToConvertBigDecimalToI128Snafu {
524 big_decimal: v.inner,
525 }
526 .fail();
527 };
528 dec_builder.append_value(v_i128);
529 }
530 Type::TIMESTAMP => {
531 let Some(builder) = builder else {
532 return NoBuilderForIndexSnafu { index: i }.fail();
533 };
534 let Some(builder) = builder
535 .as_any_mut()
536 .downcast_mut::<TimestampNanosecondBuilder>()
537 else {
538 return FailedToDowncastBuilderSnafu {
539 postgres_type: format!("{postgres_type}"),
540 }
541 .fail();
542 };
543 let v = row
544 .try_get::<usize, Option<SystemTime>>(i)
545 .with_context(|_| FailedToGetRowValueSnafu {
546 pg_type: Type::TIMESTAMP,
547 })?;
548
549 match v {
550 Some(v) => {
551 if let Ok(v) = v.duration_since(UNIX_EPOCH) {
552 let timestamp: i64 = v
553 .as_nanos()
554 .try_into()
555 .context(FailedToConvertU128toI64Snafu)?;
556 builder.append_value(timestamp);
557 }
558 }
559 None => builder.append_null(),
560 }
561 }
562 Type::TIMESTAMPTZ => {
563 let v = row
564 .try_get::<usize, Option<DateTime<Utc>>>(i)
565 .with_context(|_| FailedToGetRowValueSnafu {
566 pg_type: Type::TIMESTAMPTZ,
567 })?;
568
569 let timestamptz_builder = builder.get_or_insert_with(|| {
570 Box::new(TimestampNanosecondBuilder::new().with_timezone("UTC"))
571 });
572
573 let Some(timestamptz_builder) = timestamptz_builder
574 .as_any_mut()
575 .downcast_mut::<TimestampNanosecondBuilder>()
576 else {
577 return FailedToDowncastBuilderSnafu {
578 postgres_type: format!("{postgres_type}"),
579 }
580 .fail();
581 };
582
583 if arrow_field.is_none() {
584 let Some(field_name) = column_names.get(i) else {
585 return NoColumnNameForIndexSnafu { index: i }.fail();
586 };
587 let new_arrow_field = Field::new(
588 field_name,
589 DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))),
590 true,
591 );
592
593 *arrow_field = Some(new_arrow_field);
594 }
595
596 match v {
597 Some(v) => {
598 let utc_timestamp =
599 v.to_utc().timestamp_nanos_opt().unwrap_or_default();
600 timestamptz_builder.append_value(utc_timestamp);
601 }
602 None => timestamptz_builder.append_null(),
603 }
604 }
605
606 Type::DATE => {
607 let Some(builder) = builder else {
608 return NoBuilderForIndexSnafu { index: i }.fail();
609 };
610 let Some(builder) = builder.as_any_mut().downcast_mut::<Date32Builder>() else {
611 return FailedToDowncastBuilderSnafu {
612 postgres_type: format!("{postgres_type}"),
613 }
614 .fail();
615 };
616 let v = row.try_get::<usize, Option<chrono::NaiveDate>>(i).context(
617 FailedToGetRowValueSnafu {
618 pg_type: Type::DATE,
619 },
620 )?;
621
622 match v {
623 Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
624 None => builder.append_null(),
625 }
626 }
627 Type::UUID => {
628 let Some(builder) = builder else {
629 return NoBuilderForIndexSnafu { index: i }.fail();
630 };
631 let Some(builder) = builder.as_any_mut().downcast_mut::<StringBuilder>() else {
632 return FailedToDowncastBuilderSnafu {
633 postgres_type: format!("{postgres_type}"),
634 }
635 .fail();
636 };
637 let v = row.try_get::<usize, Option<uuid::Uuid>>(i).context(
638 FailedToGetRowValueSnafu {
639 pg_type: Type::UUID,
640 },
641 )?;
642
643 match v {
644 Some(v) => builder.append_value(v.to_string()),
645 None => builder.append_null(),
646 }
647 }
648 Type::INT2_ARRAY => handle_primitive_array_type!(
649 Type::INT2_ARRAY,
650 builder,
651 row,
652 i,
653 ListBuilder<Int16Builder>,
654 i16
655 ),
656 Type::INT4_ARRAY => handle_primitive_array_type!(
657 Type::INT4_ARRAY,
658 builder,
659 row,
660 i,
661 ListBuilder<Int32Builder>,
662 i32
663 ),
664 Type::INT8_ARRAY => handle_primitive_array_type!(
665 Type::INT8_ARRAY,
666 builder,
667 row,
668 i,
669 ListBuilder<Int64Builder>,
670 i64
671 ),
672 Type::OID_ARRAY => handle_primitive_array_type!(
673 Type::OID_ARRAY,
674 builder,
675 row,
676 i,
677 ListBuilder<UInt32Builder>,
678 u32
679 ),
680 Type::FLOAT4_ARRAY => handle_primitive_array_type!(
681 Type::FLOAT4_ARRAY,
682 builder,
683 row,
684 i,
685 ListBuilder<Float32Builder>,
686 f32
687 ),
688 Type::FLOAT8_ARRAY => handle_primitive_array_type!(
689 Type::FLOAT8_ARRAY,
690 builder,
691 row,
692 i,
693 ListBuilder<Float64Builder>,
694 f64
695 ),
696 Type::TEXT_ARRAY => handle_primitive_array_type!(
697 Type::TEXT_ARRAY,
698 builder,
699 row,
700 i,
701 ListBuilder<StringBuilder>,
702 String
703 ),
704 Type::BOOL_ARRAY => handle_primitive_array_type!(
705 Type::BOOL_ARRAY,
706 builder,
707 row,
708 i,
709 ListBuilder<BooleanBuilder>,
710 bool
711 ),
712 Type::BYTEA_ARRAY => handle_primitive_array_type!(
713 Type::BYTEA_ARRAY,
714 builder,
715 row,
716 i,
717 ListBuilder<BinaryBuilder>,
718 Vec<u8>
719 ),
720 _ if matches!(postgres_type.name(), "geometry" | "geography") => {
721 let Some(builder) = builder else {
722 return NoBuilderForIndexSnafu { index: i }.fail();
723 };
724 let Some(builder) = builder.as_any_mut().downcast_mut::<BinaryBuilder>() else {
725 return FailedToDowncastBuilderSnafu {
726 postgres_type: format!("{postgres_type}"),
727 }
728 .fail();
729 };
730 let v = row.try_get::<usize, Option<GeometryFromSql>>(i).context(
731 FailedToGetRowValueSnafu {
732 pg_type: postgres_type.clone(),
733 },
734 )?;
735
736 match v {
737 Some(v) => builder.append_value(v.wkb),
738 None => builder.append_null(),
739 }
740 }
741 _ if matches!(postgres_type.name(), "_geometry" | "_geography") => {
742 let Some(builder) = builder else {
743 return NoBuilderForIndexSnafu { index: i }.fail();
744 };
745 let Some(builder) = builder
746 .as_any_mut()
747 .downcast_mut::<ListBuilder<BinaryBuilder>>()
748 else {
749 return FailedToDowncastBuilderSnafu {
750 postgres_type: format!("{postgres_type}"),
751 }
752 .fail();
753 };
754 let v: Option<Vec<GeometryFromSql>> =
755 row.try_get(i).context(FailedToGetRowValueSnafu {
756 pg_type: postgres_type.clone(),
757 })?;
758 match v {
759 Some(v) => {
760 let v = v.into_iter().map(|item| Some(item.wkb));
761 builder.append_value(v);
762 }
763 None => builder.append_null(),
764 }
765 }
766 _ => match *postgres_type.kind() {
767 Kind::Composite(_) => {
768 let Some(builder) = builder else {
769 return NoBuilderForIndexSnafu { index: i }.fail();
770 };
771 let Some(builder) = builder.as_any_mut().downcast_mut::<StructBuilder>()
772 else {
773 return FailedToDowncastBuilderSnafu {
774 postgres_type: format!("{postgres_type}"),
775 }
776 .fail();
777 };
778
779 let v = row.try_get::<usize, Option<CompositeType>>(i).context(
780 FailedToGetRowValueSnafu {
781 pg_type: postgres_type.clone(),
782 },
783 )?;
784
785 let Some(composite_type) = v else {
786 builder.append_null();
787 continue;
788 };
789
790 builder.append(true);
791
792 let fields = composite_type.fields();
793 for (idx, field) in fields.iter().enumerate() {
794 let field_name = field.name();
795 let Some(field_type) =
796 map_column_type_to_data_type(field.type_(), field_name)?
797 else {
798 return FailedToDowncastBuilderSnafu {
799 postgres_type: format!("{}", field.type_()),
800 }
801 .fail();
802 };
803
804 handle_composite_types!(
805 field_type,
806 field.type_(),
807 composite_type,
808 builder,
809 idx,
810 field_name,
811 Boolean => (BooleanBuilder, bool),
812 Int8 => (Int8Builder, i8),
813 Int16 => (Int16Builder, i16),
814 Int32 => (Int32Builder, i32),
815 Int64 => (Int64Builder, i64),
816 UInt32 => (UInt32Builder, u32),
817 Float32 => (Float32Builder, f32),
818 Float64 => (Float64Builder, f64),
819 Binary => (BinaryBuilder, Vec<u8>),
820 LargeBinary => (LargeBinaryBuilder, Vec<u8>),
821 Utf8 => (StringBuilder, String),
822 LargeUtf8 => (LargeStringBuilder, String)
823 );
824 }
825 }
826 Kind::Enum(_) => {
827 let Some(builder) = builder else {
828 return NoBuilderForIndexSnafu { index: i }.fail();
829 };
830 let Some(builder) = builder
831 .as_any_mut()
832 .downcast_mut::<StringDictionaryBuilder<Int8Type>>()
833 else {
834 return FailedToDowncastBuilderSnafu {
835 postgres_type: format!("{postgres_type}"),
836 }
837 .fail();
838 };
839
840 let v = row.try_get::<usize, Option<EnumValueFromSql>>(i).context(
841 FailedToGetRowValueSnafu {
842 pg_type: postgres_type.clone(),
843 },
844 )?;
845
846 match v {
847 Some(v) => builder.append_value(v.enum_value),
848 None => builder.append_null(),
849 }
850 }
851 _ => {
852 return UnsupportedDataTypeSnafu {
853 data_type: postgres_type.to_string(),
854 field_name: column_names[i].clone(),
855 }
856 .fail();
857 }
858 },
859 }
860 }
861 }
862
863 let columns = arrow_columns_builders
864 .into_iter()
865 .filter_map(|builder| builder.map(|mut b| b.finish()))
866 .collect::<Vec<ArrayRef>>();
867 let arrow_fields = arrow_fields.into_iter().flatten().collect::<Vec<Field>>();
868
869 let options = &RecordBatchOptions::new().with_row_count(Some(rows.len()));
870 match RecordBatch::try_new_with_options(Arc::new(Schema::new(arrow_fields)), columns, options) {
871 Ok(record_batch) => Ok(record_batch),
872 Err(e) => Err(e).context(FailedToBuildRecordBatchSnafu),
873 }
874}
875
876fn map_column_type_to_data_type(column_type: &Type, field_name: &str) -> Result<Option<DataType>> {
877 match *column_type {
878 Type::INT2 => Ok(Some(DataType::Int16)),
879 Type::INT4 => Ok(Some(DataType::Int32)),
880 Type::INT8 | Type::MONEY => Ok(Some(DataType::Int64)),
881 Type::OID | Type::XID => Ok(Some(DataType::UInt32)),
882 Type::FLOAT4 => Ok(Some(DataType::Float32)),
883 Type::FLOAT8 => Ok(Some(DataType::Float64)),
884 Type::CHAR => Ok(Some(DataType::Int8)),
885 Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::UUID | Type::NAME => {
886 Ok(Some(DataType::Utf8))
887 }
888 Type::BYTEA => Ok(Some(DataType::Binary)),
889 Type::BOOL => Ok(Some(DataType::Boolean)),
890 Type::JSON | Type::JSONB => Ok(Some(DataType::Utf8)),
892 Type::NUMERIC => Ok(None),
894 Type::TIMESTAMPTZ => Ok(Some(DataType::Timestamp(
895 TimeUnit::Nanosecond,
896 Some(Arc::from("UTC")),
897 ))),
898 Type::TIMESTAMP => Ok(Some(DataType::Timestamp(TimeUnit::Nanosecond, None))),
900 Type::DATE => Ok(Some(DataType::Date32)),
901 Type::TIME => Ok(Some(DataType::Time64(TimeUnit::Nanosecond))),
902 Type::INTERVAL => Ok(Some(DataType::Interval(IntervalUnit::MonthDayNano))),
903 Type::POINT => Ok(Some(DataType::FixedSizeList(
904 Arc::new(Field::new("item", DataType::Float64, true)),
905 2,
906 ))),
907 Type::PG_NODE_TREE => Ok(Some(DataType::Utf8)),
908 Type::INT2_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
909 "item",
910 DataType::Int16,
911 true,
912 ))))),
913 Type::INT4_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
914 "item",
915 DataType::Int32,
916 true,
917 ))))),
918 Type::INT8_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
919 "item",
920 DataType::Int64,
921 true,
922 ))))),
923 Type::OID_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
924 "item",
925 DataType::UInt32,
926 true,
927 ))))),
928 Type::FLOAT4_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
929 "item",
930 DataType::Float32,
931 true,
932 ))))),
933 Type::FLOAT8_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
934 "item",
935 DataType::Float64,
936 true,
937 ))))),
938 Type::TEXT_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
939 "item",
940 DataType::Utf8,
941 true,
942 ))))),
943 Type::BOOL_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
944 "item",
945 DataType::Boolean,
946 true,
947 ))))),
948 Type::BYTEA_ARRAY => Ok(Some(DataType::List(Arc::new(Field::new(
949 "item",
950 DataType::Binary,
951 true,
952 ))))),
953 _ if matches!(column_type.name(), "geometry" | "geography") => Ok(Some(DataType::Binary)),
954 _ if matches!(column_type.name(), "_geometry" | "_geography") => Ok(Some(DataType::List(
955 Arc::new(Field::new("item", DataType::Binary, true)),
956 ))),
957 _ => match *column_type.kind() {
958 Kind::Composite(ref fields) => {
959 let mut arrow_fields = Vec::new();
960 for field in fields {
961 let field_name = field.name();
962 let field_type = map_column_type_to_data_type(field.type_(), field_name)?;
963 match field_type {
964 Some(field_type) => {
965 arrow_fields.push(Field::new(field_name, field_type, true));
966 }
967 None => {
968 return UnsupportedDataTypeSnafu {
969 data_type: field.type_().to_string(),
970 field_name: field_name.to_string(),
971 }
972 .fail();
973 }
974 }
975 }
976 Ok(Some(DataType::Struct(arrow_fields.into())))
977 }
978 Kind::Enum(_) => Ok(Some(DataType::Dictionary(
979 Box::new(DataType::Int8),
980 Box::new(DataType::Utf8),
981 ))),
982 _ => UnsupportedDataTypeSnafu {
983 data_type: column_type.to_string(),
984 field_name: field_name.to_string(),
985 }
986 .fail(),
987 },
988 }
989}
990
991pub(crate) fn map_data_type_to_column_type_postgres(
992 data_type: &DataType,
993 table_name: &str,
994 field_name: &str,
995) -> ColumnType {
996 match data_type {
997 DataType::Struct(_) => ColumnType::Custom(SeaRc::new(Alias::new(
998 get_postgres_composite_type_name(table_name, field_name),
999 ))),
1000 _ => map_data_type_to_column_type(data_type),
1001 }
1002}
1003
1004#[must_use]
1005pub(crate) fn get_postgres_composite_type_name(table_name: &str, field_name: &str) -> String {
1006 format!("struct_{table_name}_{field_name}")
1007}
1008
1009struct BigDecimalFromSql {
1010 inner: BigDecimal,
1011 scale: u16,
1012}
1013
1014impl BigDecimalFromSql {
1015 fn to_decimal_128_with_scale(&self, dest_scale: u16) -> Option<i128> {
1016 if dest_scale != self.scale {
1018 return (&self.inner * 10i128.pow(u32::from(dest_scale))).to_i128();
1019 }
1020
1021 (&self.inner * 10i128.pow(u32::from(self.scale))).to_i128()
1022 }
1023
1024 fn scale(&self) -> u16 {
1025 self.scale
1026 }
1027}
1028
1029#[allow(clippy::cast_sign_loss)]
1030#[allow(clippy::cast_possible_wrap)]
1031#[allow(clippy::cast_possible_truncation)]
1032impl<'a> FromSql<'a> for BigDecimalFromSql {
1033 fn from_sql(
1034 _ty: &Type,
1035 raw: &'a [u8],
1036 ) -> std::prelude::v1::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1037 let raw_u16: Vec<u16> = raw
1038 .chunks(2)
1039 .map(|chunk| {
1040 if chunk.len() == 2 {
1041 u16::from_be_bytes([chunk[0], chunk[1]])
1042 } else {
1043 u16::from_be_bytes([chunk[0], 0])
1044 }
1045 })
1046 .collect();
1047
1048 let base_10_000_digit_count = raw_u16[0];
1049 let weight = raw_u16[1] as i16;
1050 let sign = raw_u16[2];
1051 let scale = raw_u16[3];
1052
1053 let mut base_10_000_digits = Vec::new();
1054 for i in 4..4 + base_10_000_digit_count {
1055 base_10_000_digits.push(raw_u16[i as usize]);
1056 }
1057
1058 let mut u8_digits = Vec::new();
1059 for &base_10_000_digit in base_10_000_digits.iter().rev() {
1060 let mut base_10_000_digit = base_10_000_digit;
1061 let mut temp_result = Vec::new();
1062 while base_10_000_digit > 0 {
1063 temp_result.push((base_10_000_digit % 10) as u8);
1064 base_10_000_digit /= 10;
1065 }
1066 while temp_result.len() < 4 {
1067 temp_result.push(0);
1068 }
1069 u8_digits.extend(temp_result);
1070 }
1071 u8_digits.reverse();
1072
1073 let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
1074 let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
1075 u8_digits.resize(size as usize, 0);
1076
1077 let sign = match sign {
1078 0x4000 => Sign::Minus,
1079 0x0000 => Sign::Plus,
1080 _ => {
1081 return Err(Box::new(Error::FailedToParseBigDecimalFromPostgres {
1082 bytes: raw.to_vec(),
1083 }))
1084 }
1085 };
1086
1087 let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
1088 return Err(Box::new(Error::FailedToParseBigDecimalFromPostgres {
1089 bytes: raw.to_vec(),
1090 }));
1091 };
1092 Ok(BigDecimalFromSql {
1093 inner: BigDecimal::new(digits, i64::from(scale)),
1094 scale,
1095 })
1096 }
1097
1098 fn accepts(ty: &Type) -> bool {
1099 matches!(*ty, Type::NUMERIC)
1100 }
1101}
1102
1103struct IntervalFromSql {
1106 time: i64,
1107 day: i32,
1108 month: i32,
1109}
1110
1111impl<'a> FromSql<'a> for IntervalFromSql {
1112 fn from_sql(
1113 _ty: &Type,
1114 raw: &'a [u8],
1115 ) -> std::prelude::v1::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1116 let mut cursor = std::io::Cursor::new(raw);
1117
1118 let time = cursor.read_i64::<BigEndian>()?;
1119 let day = cursor.read_i32::<BigEndian>()?;
1120 let month = cursor.read_i32::<BigEndian>()?;
1121
1122 Ok(IntervalFromSql { time, day, month })
1123 }
1124
1125 fn accepts(ty: &Type) -> bool {
1126 matches!(*ty, Type::INTERVAL)
1127 }
1128}
1129
1130struct MoneyFromSql {
1132 cash_value: i64,
1133}
1134
1135impl<'a> FromSql<'a> for MoneyFromSql {
1136 fn from_sql(
1137 _ty: &Type,
1138 raw: &'a [u8],
1139 ) -> std::prelude::v1::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1140 let mut cursor = std::io::Cursor::new(raw);
1141 let cash_value = cursor.read_i64::<BigEndian>()?;
1142 Ok(MoneyFromSql { cash_value })
1143 }
1144
1145 fn accepts(ty: &Type) -> bool {
1146 matches!(*ty, Type::MONEY)
1147 }
1148}
1149
1150struct EnumValueFromSql {
1151 enum_value: String,
1152}
1153
1154impl<'a> FromSql<'a> for EnumValueFromSql {
1155 fn from_sql(
1156 _ty: &Type,
1157 raw: &'a [u8],
1158 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1159 let mut cursor = std::io::Cursor::new(raw);
1160 let mut enum_value = String::new();
1161 cursor.read_to_string(&mut enum_value)?;
1162 Ok(EnumValueFromSql { enum_value })
1163 }
1164
1165 fn accepts(ty: &Type) -> bool {
1166 matches!(*ty.kind(), Kind::Enum(_))
1167 }
1168}
1169
1170pub struct GeometryFromSql<'a> {
1171 wkb: &'a [u8],
1172}
1173
1174impl<'a> FromSql<'a> for GeometryFromSql<'a> {
1175 fn from_sql(
1176 _ty: &Type,
1177 raw: &'a [u8],
1178 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1179 Ok(GeometryFromSql { wkb: raw })
1180 }
1181
1182 fn accepts(ty: &Type) -> bool {
1183 matches!(ty.name(), "geometry" | "geography")
1184 }
1185}
1186
1187struct XidFromSql {
1188 xid: u32,
1189}
1190
1191impl<'a> FromSql<'a> for XidFromSql {
1192 fn from_sql(
1193 _ty: &Type,
1194 raw: &'a [u8],
1195 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1196 let mut cursor = std::io::Cursor::new(raw);
1197 let xid = cursor.read_u32::<BigEndian>()?;
1198 Ok(XidFromSql { xid })
1199 }
1200
1201 fn accepts(ty: &Type) -> bool {
1202 matches!(*ty, Type::XID)
1203 }
1204}
1205
1206fn get_decimal_column_precision_and_scale(
1207 column_name: &str,
1208 projected_schema: &SchemaRef,
1209) -> Option<(u8, i8)> {
1210 let field = projected_schema.field_with_name(column_name).ok()?;
1211 match field.data_type() {
1212 DataType::Decimal128(precision, scale) => Some((*precision, *scale)),
1213 _ => None,
1214 }
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219 use super::*;
1220 use chrono::NaiveTime;
1221 use datafusion::arrow::array::{Time64NanosecondArray, Time64NanosecondBuilder};
1222 use geo_types::{point, polygon, Geometry};
1223 use geozero::{CoordDimensions, ToWkb};
1224 use std::str::FromStr;
1225
1226 #[allow(clippy::cast_possible_truncation)]
1227 #[tokio::test]
1228 async fn test_big_decimal_from_sql() {
1229 let positive_u16: Vec<u16> = vec![5, 3, 0, 5, 9345, 1293, 2903, 1293, 932];
1230 let positive_raw: Vec<u8> = positive_u16
1231 .iter()
1232 .flat_map(|&x| vec![(x >> 8) as u8, x as u8])
1233 .collect();
1234 let positive =
1235 BigDecimal::from_str("9345129329031293.0932").expect("Failed to parse big decimal");
1236 let positive_result = BigDecimalFromSql::from_sql(&Type::NUMERIC, positive_raw.as_slice())
1237 .expect("Failed to run FromSql");
1238 assert_eq!(positive_result.inner, positive);
1239
1240 let negative_u16: Vec<u16> = vec![5, 3, 0x4000, 5, 9345, 1293, 2903, 1293, 932];
1241 let negative_raw: Vec<u8> = negative_u16
1242 .iter()
1243 .flat_map(|&x| vec![(x >> 8) as u8, x as u8])
1244 .collect();
1245 let negative =
1246 BigDecimal::from_str("-9345129329031293.0932").expect("Failed to parse big decimal");
1247 let negative_result = BigDecimalFromSql::from_sql(&Type::NUMERIC, negative_raw.as_slice())
1248 .expect("Failed to run FromSql");
1249 assert_eq!(negative_result.inner, negative);
1250 }
1251
1252 #[test]
1253 fn test_interval_from_sql() {
1254 let positive_time: i64 = 123_123;
1255 let positive_day: i32 = 10;
1256 let positive_month: i32 = 2;
1257
1258 let mut positive_raw: Vec<u8> = Vec::new();
1259 positive_raw.extend_from_slice(&positive_time.to_be_bytes());
1260 positive_raw.extend_from_slice(&positive_day.to_be_bytes());
1261 positive_raw.extend_from_slice(&positive_month.to_be_bytes());
1262
1263 let positive_result = IntervalFromSql::from_sql(&Type::INTERVAL, positive_raw.as_slice())
1264 .expect("Failed to run FromSql");
1265 assert_eq!(positive_result.day, positive_day);
1266 assert_eq!(positive_result.time, positive_time);
1267 assert_eq!(positive_result.month, positive_month);
1268
1269 let negative_time: i64 = -123_123;
1270 let negative_day: i32 = -10;
1271 let negative_month: i32 = -2;
1272
1273 let mut negative_raw: Vec<u8> = Vec::new();
1274 negative_raw.extend_from_slice(&negative_time.to_be_bytes());
1275 negative_raw.extend_from_slice(&negative_day.to_be_bytes());
1276 negative_raw.extend_from_slice(&negative_month.to_be_bytes());
1277
1278 let negative_result = IntervalFromSql::from_sql(&Type::INTERVAL, negative_raw.as_slice())
1279 .expect("Failed to run FromSql");
1280 assert_eq!(negative_result.day, negative_day);
1281 assert_eq!(negative_result.time, negative_time);
1282 assert_eq!(negative_result.month, negative_month);
1283 }
1284
1285 #[test]
1286 fn test_money_from_sql() {
1287 let positive_cash_value: i64 = 123;
1288 let mut positive_raw: Vec<u8> = Vec::new();
1289 positive_raw.extend_from_slice(&positive_cash_value.to_be_bytes());
1290
1291 let positive_result = MoneyFromSql::from_sql(&Type::MONEY, positive_raw.as_slice())
1292 .expect("Failed to run FromSql");
1293 assert_eq!(positive_result.cash_value, positive_cash_value);
1294
1295 let negative_cash_value: i64 = -123;
1296 let mut negative_raw: Vec<u8> = Vec::new();
1297 negative_raw.extend_from_slice(&negative_cash_value.to_be_bytes());
1298
1299 let negative_result = MoneyFromSql::from_sql(&Type::MONEY, negative_raw.as_slice())
1300 .expect("Failed to run FromSql");
1301 assert_eq!(negative_result.cash_value, negative_cash_value);
1302 }
1303
1304 #[test]
1305 fn test_chrono_naive_time_to_time64nanosecond() {
1306 let chrono_naive_vec = vec![
1307 NaiveTime::from_hms_opt(10, 30, 00).unwrap_or_default(),
1308 NaiveTime::from_hms_opt(10, 45, 15).unwrap_or_default(),
1309 ];
1310
1311 let time_array: Time64NanosecondArray = vec![
1312 (10 * 3600 + 30 * 60) * 1_000_000_000,
1313 (10 * 3600 + 45 * 60 + 15) * 1_000_000_000,
1314 ]
1315 .into();
1316
1317 let mut builder = Time64NanosecondBuilder::new();
1318 for time in chrono_naive_vec {
1319 let timestamp: i64 = i64::from(time.num_seconds_from_midnight()) * 1_000_000_000
1320 + i64::from(time.nanosecond());
1321 builder.append_value(timestamp);
1322 }
1323 let converted_result = builder.finish();
1324 assert_eq!(converted_result, time_array);
1325 }
1326
1327 #[test]
1328 fn test_geometry_from_sql() {
1329 let positive_geometry = Geometry::from(point! { x: 181.2, y: 51.79 })
1330 .to_wkb(CoordDimensions::xy())
1331 .unwrap();
1332 let mut positive_raw: Vec<u8> = Vec::new();
1333 positive_raw.extend_from_slice(&positive_geometry);
1334
1335 let positive_result = GeometryFromSql::from_sql(
1336 &Type::new(
1337 "geometry".to_owned(),
1338 16462,
1339 Kind::Simple,
1340 "public".to_owned(),
1341 ),
1342 positive_raw.as_slice(),
1343 )
1344 .expect("Failed to run FromSql");
1345 assert_eq!(positive_result.wkb, positive_geometry);
1346
1347 let positive_geometry = Geometry::from(polygon![
1348 (x: -111., y: 45.),
1349 (x: -111., y: 41.),
1350 (x: -104., y: 41.),
1351 (x: -104., y: 45.),
1352 ])
1353 .to_wkb(CoordDimensions::xy())
1354 .unwrap();
1355 let mut positive_raw: Vec<u8> = Vec::new();
1356 positive_raw.extend_from_slice(&positive_geometry);
1357
1358 let positive_result = GeometryFromSql::from_sql(
1359 &Type::new(
1360 "geometry".to_owned(),
1361 16462,
1362 Kind::Simple,
1363 "public".to_owned(),
1364 ),
1365 positive_raw.as_slice(),
1366 )
1367 .expect("Failed to run FromSql");
1368 assert_eq!(positive_result.wkb, positive_geometry);
1369 }
1370}