1use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState};
4use convergence::protocol_ext::DataRowBatch;
5use datafusion::arrow::array::{
6 BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array,
7 Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
8 StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
9 TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array
10};
11use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
12use datafusion::arrow::record_batch::RecordBatch;
13
14macro_rules! array_cast {
15 ($arrtype: ident, $arr: expr) => {
16 $arr.as_any().downcast_ref::<$arrtype>().expect("array cast failed")
17 };
18}
19
20macro_rules! array_val {
21 ($arrtype: ident, $arr: expr, $idx: expr, $func: ident) => {
22 array_cast!($arrtype, $arr).$func($idx)
23 };
24 ($arrtype: ident, $arr: expr, $idx: expr) => {
25 array_val!($arrtype, $arr, $idx, value)
26 };
27}
28
29pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBatch) -> Result<(), ErrorResponse> {
31 for row_idx in 0..arrow_batch.num_rows() {
32 let mut row = pg_batch.create_row();
33 for col_idx in 0..arrow_batch.num_columns() {
34 let col = arrow_batch.column(col_idx);
35 if col.is_null(row_idx) {
36 row.write_null();
37 } else {
38 match col.data_type() {
39 DataType::Boolean => row.write_bool(array_val!(BooleanArray, col, row_idx)),
40 DataType::Int8 => row.write_int2(array_val!(Int8Array, col, row_idx) as i16),
41 DataType::Int16 => row.write_int2(array_val!(Int16Array, col, row_idx)),
42 DataType::Int32 => row.write_int4(array_val!(Int32Array, col, row_idx)),
43 DataType::Int64 => row.write_int8(array_val!(Int64Array, col, row_idx)),
44 DataType::UInt8 => row.write_int2(array_val!(UInt8Array, col, row_idx) as i16),
45 DataType::UInt16 => row.write_int2(array_val!(UInt16Array, col, row_idx) as i16),
46 DataType::UInt32 => row.write_int4(array_val!(UInt32Array, col, row_idx) as i32),
47 DataType::UInt64 => row.write_int8(array_val!(UInt64Array, col, row_idx) as i64),
48 DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()),
49 DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)),
50 DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)),
51 DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s),
52 DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)),
53 DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)),
54 DataType::Date32 => {
55 row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| {
56 ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
57 })?)
58 }
59 DataType::Date64 => {
60 row.write_date(array_val!(Date64Array, col, row_idx, value_as_date).ok_or_else(|| {
61 ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
62 })?)
63 }
64 DataType::Timestamp(unit, None) => row.write_timestamp(
65 match unit {
66 TimeUnit::Second => array_val!(TimestampSecondArray, col, row_idx, value_as_datetime),
67 TimeUnit::Millisecond => {
68 array_val!(TimestampMillisecondArray, col, row_idx, value_as_datetime)
69 }
70 TimeUnit::Microsecond => {
71 array_val!(TimestampMicrosecondArray, col, row_idx, value_as_datetime)
72 }
73 TimeUnit::Nanosecond => {
74 array_val!(TimestampNanosecondArray, col, row_idx, value_as_datetime)
75 }
76 }
77 .ok_or_else(|| {
78 ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timestamp type")
79 })?,
80 ),
81 other => {
82 return Err(ErrorResponse::error(
83 SqlState::FeatureNotSupported,
84 format!("arrow to pg conversion not implemented for {}", other),
85 ))
86 }
87 };
88 }
89 }
90 }
91
92 Ok(())
93}
94
95pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
97 Ok(match ty {
98 DataType::Boolean => DataTypeOid::Bool,
99 DataType::Int8 | DataType::Int16 => DataTypeOid::Int2,
100 DataType::Int32 => DataTypeOid::Int4,
101 DataType::Int64 => DataTypeOid::Int8,
102 DataType::UInt8 | DataType::UInt16 => DataTypeOid::Int2,
104 DataType::UInt32 => DataTypeOid::Int4,
105 DataType::UInt64 => DataTypeOid::Int8,
106 DataType::Float16 | DataType::Float32 => DataTypeOid::Float4,
107 DataType::Float64 => DataTypeOid::Float8,
108 DataType::Decimal128(_, _) => DataTypeOid::Numeric,
109 DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
110 DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
111 DataType::Timestamp(_, None) => DataTypeOid::Timestamp,
112 other => {
113 return Err(ErrorResponse::error(
114 SqlState::FeatureNotSupported,
115 format!("arrow to pg conversion not implemented for {}", other),
116 ))
117 }
118 })
119}
120
121pub fn schema_to_field_desc(schema: &Schema) -> Result<Vec<FieldDescription>, ErrorResponse> {
123 schema
124 .fields()
125 .iter()
126 .map(|f| {
127 Ok(FieldDescription {
128 name: f.name().clone(),
129 data_type: data_type_to_oid(f.data_type())?,
130 })
131 })
132 .collect()
133}