arrow_pg/
datatypes.rs

1use std::sync::Arc;
2
3use arrow::datatypes::*;
4use arrow::record_batch::RecordBatch;
5use pgwire::api::portal::Format;
6use pgwire::api::results::FieldInfo;
7use pgwire::api::Type;
8use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
9use pgwire::messages::data::DataRow;
10use postgres_types::Kind;
11
12use crate::row_encoder::RowEncoder;
13
14pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
15    Ok(match arrow_type {
16        DataType::Null => Type::UNKNOWN,
17        DataType::Boolean => Type::BOOL,
18        DataType::Int8 | DataType::UInt8 => Type::CHAR,
19        DataType::Int16 | DataType::UInt16 => Type::INT2,
20        DataType::Int32 | DataType::UInt32 => Type::INT4,
21        DataType::Int64 | DataType::UInt64 => Type::INT8,
22        DataType::Timestamp(_, tz) => {
23            if tz.is_some() {
24                Type::TIMESTAMPTZ
25            } else {
26                Type::TIMESTAMP
27            }
28        }
29        DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
30        DataType::Date32 | DataType::Date64 => Type::DATE,
31        DataType::Interval(_) => Type::INTERVAL,
32        DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
33        DataType::Float16 | DataType::Float32 => Type::FLOAT4,
34        DataType::Float64 => Type::FLOAT8,
35        DataType::Decimal128(_, _) => Type::NUMERIC,
36        DataType::Utf8 => Type::VARCHAR,
37        DataType::LargeUtf8 => Type::TEXT,
38        DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
39            match field.data_type() {
40                DataType::Boolean => Type::BOOL_ARRAY,
41                DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
42                DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
43                DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
44                DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
45                DataType::Timestamp(_, tz) => {
46                    if tz.is_some() {
47                        Type::TIMESTAMPTZ_ARRAY
48                    } else {
49                        Type::TIMESTAMP_ARRAY
50                    }
51                }
52                DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
53                DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
54                DataType::Interval(_) => Type::INTERVAL_ARRAY,
55                DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY,
56                DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
57                DataType::Float64 => Type::FLOAT8_ARRAY,
58                DataType::Utf8 => Type::VARCHAR_ARRAY,
59                DataType::LargeUtf8 => Type::TEXT_ARRAY,
60                struct_type @ DataType::Struct(_) => Type::new(
61                    Type::RECORD_ARRAY.name().into(),
62                    Type::RECORD_ARRAY.oid(),
63                    Kind::Array(into_pg_type(struct_type)?),
64                    Type::RECORD_ARRAY.schema().into(),
65                ),
66                list_type => {
67                    return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
68                        "ERROR".to_owned(),
69                        "XX000".to_owned(),
70                        format!("Unsupported List Datatype {list_type}"),
71                    ))));
72                }
73            }
74        }
75        DataType::Utf8View => Type::TEXT,
76        DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
77        DataType::Struct(fields) => {
78            let name: String = fields
79                .iter()
80                .map(|x| x.name().clone())
81                .reduce(|a, b| a + ", " + &b)
82                .map(|x| format!("({x})"))
83                .unwrap_or("()".to_string());
84            let kind = Kind::Composite(
85                fields
86                    .iter()
87                    .map(|x| {
88                        into_pg_type(x.data_type())
89                            .map(|_type| postgres_types::Field::new(x.name().clone(), _type))
90                    })
91                    .collect::<Result<Vec<_>, PgWireError>>()?,
92            );
93            Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
94        }
95        _ => {
96            return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
97                "ERROR".to_owned(),
98                "XX000".to_owned(),
99                format!("Unsupported Datatype {arrow_type}"),
100            ))));
101        }
102    })
103}
104
105pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
106    schema
107        .fields()
108        .iter()
109        .enumerate()
110        .map(|(idx, f)| {
111            let pg_type = into_pg_type(f.data_type())?;
112            Ok(FieldInfo::new(
113                f.name().into(),
114                None,
115                None,
116                pg_type,
117                format.format_for(idx),
118            ))
119        })
120        .collect::<PgWireResult<Vec<FieldInfo>>>()
121}
122
123pub fn encode_recordbatch(
124    fields: Arc<Vec<FieldInfo>>,
125    record_batch: RecordBatch,
126) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
127    let mut row_stream = RowEncoder::new(record_batch, fields);
128    Box::new(std::iter::from_fn(move || row_stream.next_row()))
129}