convergence_arrow/
table.rs

1//! Utilities for converting between Arrow and Postgres formats.
2
3use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState};
4use convergence::protocol_ext::DataRowBatch;
5use datafusion::arrow::array::{
6	BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array,
7   	Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
8	StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
9	TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array
10};
11use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
12use datafusion::arrow::record_batch::RecordBatch;
13
14macro_rules! array_cast {
15	($arrtype: ident, $arr: expr) => {
16		$arr.as_any().downcast_ref::<$arrtype>().expect("array cast failed")
17	};
18}
19
20macro_rules! array_val {
21	($arrtype: ident, $arr: expr, $idx: expr, $func: ident) => {
22		array_cast!($arrtype, $arr).$func($idx)
23	};
24	($arrtype: ident, $arr: expr, $idx: expr) => {
25		array_val!($arrtype, $arr, $idx, value)
26	};
27}
28
29/// Writes the contents of an Arrow [RecordBatch] into a Postgres [DataRowBatch].
30pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBatch) -> Result<(), ErrorResponse> {
31	for row_idx in 0..arrow_batch.num_rows() {
32		let mut row = pg_batch.create_row();
33		for col_idx in 0..arrow_batch.num_columns() {
34			let col = arrow_batch.column(col_idx);
35			if col.is_null(row_idx) {
36				row.write_null();
37			} else {
38				match col.data_type() {
39					DataType::Boolean => row.write_bool(array_val!(BooleanArray, col, row_idx)),
40					DataType::Int8 => row.write_int2(array_val!(Int8Array, col, row_idx) as i16),
41					DataType::Int16 => row.write_int2(array_val!(Int16Array, col, row_idx)),
42					DataType::Int32 => row.write_int4(array_val!(Int32Array, col, row_idx)),
43					DataType::Int64 => row.write_int8(array_val!(Int64Array, col, row_idx)),
44					DataType::UInt8 => row.write_int2(array_val!(UInt8Array, col, row_idx) as i16),
45					DataType::UInt16 => row.write_int2(array_val!(UInt16Array, col, row_idx) as i16),
46					DataType::UInt32 => row.write_int4(array_val!(UInt32Array, col, row_idx) as i32),
47					DataType::UInt64 => row.write_int8(array_val!(UInt64Array, col, row_idx) as i64),
48					DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()),
49					DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)),
50					DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)),
51					DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s),
52					DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)),
53					DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)),
54					DataType::Date32 => {
55						row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| {
56							ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
57						})?)
58					}
59					DataType::Date64 => {
60						row.write_date(array_val!(Date64Array, col, row_idx, value_as_date).ok_or_else(|| {
61							ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
62						})?)
63					}
64					DataType::Timestamp(unit, None) => row.write_timestamp(
65						match unit {
66							TimeUnit::Second => array_val!(TimestampSecondArray, col, row_idx, value_as_datetime),
67							TimeUnit::Millisecond => {
68								array_val!(TimestampMillisecondArray, col, row_idx, value_as_datetime)
69							}
70							TimeUnit::Microsecond => {
71								array_val!(TimestampMicrosecondArray, col, row_idx, value_as_datetime)
72							}
73							TimeUnit::Nanosecond => {
74								array_val!(TimestampNanosecondArray, col, row_idx, value_as_datetime)
75							}
76						}
77						.ok_or_else(|| {
78							ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timestamp type")
79						})?,
80					),
81					other => {
82						return Err(ErrorResponse::error(
83							SqlState::FeatureNotSupported,
84							format!("arrow to pg conversion not implemented for {}", other),
85						))
86					}
87				};
88			}
89		}
90	}
91
92	Ok(())
93}
94
95/// Converts an Arrow [DataType] into a Postgres [DataTypeOid].
96pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
97	Ok(match ty {
98		DataType::Boolean => DataTypeOid::Bool,
99		DataType::Int8 | DataType::Int16 => DataTypeOid::Int2,
100		DataType::Int32 => DataTypeOid::Int4,
101		DataType::Int64 => DataTypeOid::Int8,
102		// TODO: need to figure out a sensible mapping for unsigned
103		DataType::UInt8 | DataType::UInt16 => DataTypeOid::Int2,
104		DataType::UInt32 => DataTypeOid::Int4,
105		DataType::UInt64 => DataTypeOid::Int8,
106		DataType::Float16 | DataType::Float32 => DataTypeOid::Float4,
107		DataType::Float64 => DataTypeOid::Float8,
108		DataType::Decimal128(_, _) => DataTypeOid::Numeric,
109		DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
110		DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
111		DataType::Timestamp(_, None) => DataTypeOid::Timestamp,
112		other => {
113			return Err(ErrorResponse::error(
114				SqlState::FeatureNotSupported,
115				format!("arrow to pg conversion not implemented for {}", other),
116			))
117		}
118	})
119}
120
121/// Converts an Arrow [Schema] into a vector of Postgres [FieldDescription] instances.
122pub fn schema_to_field_desc(schema: &Schema) -> Result<Vec<FieldDescription>, ErrorResponse> {
123	schema
124		.fields()
125		.iter()
126		.map(|f| {
127			Ok(FieldDescription {
128				name: f.name().clone(),
129				data_type: data_type_to_oid(f.data_type())?,
130			})
131		})
132		.collect()
133}