arrow_pg/
datatypes.rs

1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{datatypes::*, record_batch::RecordBatch};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};
7
8use pgwire::api::portal::Format;
9use pgwire::api::results::FieldInfo;
10use pgwire::api::Type;
11use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
12use pgwire::messages::data::DataRow;
13use pgwire::types::format::FormatOptions;
14use postgres_types::Kind;
15
16use crate::row_encoder::RowEncoder;
17
18#[cfg(feature = "datafusion")]
19pub mod df;
20
21pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
22    Ok(match arrow_type {
23        DataType::Null => Type::UNKNOWN,
24        DataType::Boolean => Type::BOOL,
25        DataType::Int8 | DataType::UInt8 => Type::CHAR,
26        DataType::Int16 | DataType::UInt16 => Type::INT2,
27        DataType::Int32 | DataType::UInt32 => Type::INT4,
28        DataType::Int64 | DataType::UInt64 => Type::INT8,
29        DataType::Timestamp(_, tz) => {
30            if tz.is_some() {
31                Type::TIMESTAMPTZ
32            } else {
33                Type::TIMESTAMP
34            }
35        }
36        DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
37        DataType::Date32 | DataType::Date64 => Type::DATE,
38        DataType::Interval(_) => Type::INTERVAL,
39        DataType::Binary
40        | DataType::FixedSizeBinary(_)
41        | DataType::LargeBinary
42        | DataType::BinaryView => Type::BYTEA,
43        DataType::Float16 | DataType::Float32 => Type::FLOAT4,
44        DataType::Float64 => Type::FLOAT8,
45        DataType::Decimal128(_, _) => Type::NUMERIC,
46        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
47        DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
48            match field.data_type() {
49                DataType::Boolean => Type::BOOL_ARRAY,
50                DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
51                DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
52                DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
53                DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
54                DataType::Timestamp(_, tz) => {
55                    if tz.is_some() {
56                        Type::TIMESTAMPTZ_ARRAY
57                    } else {
58                        Type::TIMESTAMP_ARRAY
59                    }
60                }
61                DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
62                DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
63                DataType::Interval(_) => Type::INTERVAL_ARRAY,
64                DataType::FixedSizeBinary(_)
65                | DataType::Binary
66                | DataType::LargeBinary
67                | DataType::BinaryView => Type::BYTEA_ARRAY,
68                DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
69                DataType::Float64 => Type::FLOAT8_ARRAY,
70                DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
71                struct_type @ DataType::Struct(_) => Type::new(
72                    Type::RECORD_ARRAY.name().into(),
73                    Type::RECORD_ARRAY.oid(),
74                    Kind::Array(into_pg_type(struct_type)?),
75                    Type::RECORD_ARRAY.schema().into(),
76                ),
77                list_type => {
78                    return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
79                        "ERROR".to_owned(),
80                        "XX000".to_owned(),
81                        format!("Unsupported List Datatype {list_type}"),
82                    ))));
83                }
84            }
85        }
86        DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
87        DataType::Struct(fields) => {
88            let name: String = fields
89                .iter()
90                .map(|x| x.name().clone())
91                .reduce(|a, b| a + ", " + &b)
92                .map(|x| format!("({x})"))
93                .unwrap_or("()".to_string());
94            let kind = Kind::Composite(
95                fields
96                    .iter()
97                    .map(|x| {
98                        into_pg_type(x.data_type())
99                            .map(|_type| postgres_types::Field::new(x.name().clone(), _type))
100                    })
101                    .collect::<Result<Vec<_>, PgWireError>>()?,
102            );
103            Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
104        }
105        _ => {
106            return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
107                "ERROR".to_owned(),
108                "XX000".to_owned(),
109                format!("Unsupported Datatype {arrow_type}"),
110            ))));
111        }
112    })
113}
114
115pub fn arrow_schema_to_pg_fields(
116    schema: &Schema,
117    format: &Format,
118    data_format_options: Option<Arc<FormatOptions>>,
119) -> PgWireResult<Vec<FieldInfo>> {
120    let _ = data_format_options;
121    schema
122        .fields()
123        .iter()
124        .enumerate()
125        .map(|(idx, f)| {
126            let pg_type = into_pg_type(f.data_type())?;
127            let mut field_info =
128                FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
129            if let Some(data_format_options) = &data_format_options {
130                field_info = field_info.with_format_options(data_format_options.clone());
131            }
132
133            Ok(field_info)
134        })
135        .collect::<PgWireResult<Vec<FieldInfo>>>()
136}
137
138pub fn encode_recordbatch(
139    fields: Arc<Vec<FieldInfo>>,
140    record_batch: RecordBatch,
141) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
142    let mut row_stream = RowEncoder::new(record_batch, fields);
143    Box::new(std::iter::from_fn(move || row_stream.next_row()))
144}