Skip to main content

arrow_pg/
datatypes.rs

1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{datatypes::*, record_batch::RecordBatch};
5#[cfg(feature = "postgis")]
6use arrow_schema::extension::ExtensionType;
7#[cfg(feature = "datafusion")]
8use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};
9
10use pgwire::api::portal::Format;
11use pgwire::api::results::FieldInfo;
12use pgwire::api::Type;
13use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
14use pgwire::messages::data::DataRow;
15use pgwire::types::format::FormatOptions;
16use postgres_types::Kind;
17
18use crate::row_encoder::RowEncoder;
19
20#[cfg(feature = "datafusion")]
21pub mod df;
22
23pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
24    let datatype = match arrow_type {
25        DataType::Null => Type::UNKNOWN,
26        DataType::Boolean => Type::BOOL,
27        DataType::Int8 => Type::INT2,
28        DataType::Int16 | DataType::UInt8 => Type::INT2,
29        DataType::Int32 | DataType::UInt16 => Type::INT4,
30        DataType::Int64 | DataType::UInt32 => Type::INT8,
31        DataType::UInt64 => Type::NUMERIC,
32        DataType::Timestamp(_, tz) => {
33            if tz.is_some() {
34                Type::TIMESTAMPTZ
35            } else {
36                Type::TIMESTAMP
37            }
38        }
39        DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
40        DataType::Date32 | DataType::Date64 => Type::DATE,
41        DataType::Interval(_) | DataType::Duration(_) => Type::INTERVAL,
42        DataType::Binary
43        | DataType::FixedSizeBinary(_)
44        | DataType::LargeBinary
45        | DataType::BinaryView => Type::BYTEA,
46        DataType::Float16 | DataType::Float32 => Type::FLOAT4,
47        DataType::Float64 => Type::FLOAT8,
48        DataType::Decimal128(_, _) => Type::NUMERIC,
49        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
50        DataType::List(field)
51        | DataType::FixedSizeList(field, _)
52        | DataType::LargeList(field)
53        | DataType::ListView(field)
54        | DataType::LargeListView(field) => match field.data_type() {
55            DataType::Boolean => Type::BOOL_ARRAY,
56            DataType::Int8 => Type::INT2_ARRAY,
57            DataType::Int16 | DataType::UInt8 => Type::INT2_ARRAY,
58            DataType::Int32 | DataType::UInt16 => Type::INT4_ARRAY,
59            DataType::Int64 | DataType::UInt32 => Type::INT8_ARRAY,
60            DataType::UInt64 | DataType::Decimal128(_, _) => Type::NUMERIC_ARRAY,
61            DataType::Timestamp(_, tz) => {
62                if tz.is_some() {
63                    Type::TIMESTAMPTZ_ARRAY
64                } else {
65                    Type::TIMESTAMP_ARRAY
66                }
67            }
68            DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
69            DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
70            DataType::Interval(_) | DataType::Duration(_) => Type::INTERVAL_ARRAY,
71            DataType::FixedSizeBinary(_)
72            | DataType::Binary
73            | DataType::LargeBinary
74            | DataType::BinaryView => Type::BYTEA_ARRAY,
75            DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
76            DataType::Float64 => Type::FLOAT8_ARRAY,
77            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
78            DataType::Struct(_) => Type::new(
79                Type::RECORD_ARRAY.name().into(),
80                Type::RECORD_ARRAY.oid(),
81                Kind::Array(field_into_pg_type(field)?),
82                Type::RECORD_ARRAY.schema().into(),
83            ),
84            list_type => {
85                return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
86                    "ERROR".to_owned(),
87                    "XX000".to_owned(),
88                    format!("Unsupported List Datatype {list_type}"),
89                ))));
90            }
91        },
92        DataType::Dictionary(_, value_type) => into_pg_type(value_type.as_ref())?,
93        DataType::Struct(fields) => {
94            let name: String = fields
95                .iter()
96                .map(|x| x.name().clone())
97                .reduce(|a, b| a + ", " + &b)
98                .map(|x| format!("({x})"))
99                .unwrap_or("()".to_string());
100            let kind = Kind::Composite(
101                fields
102                    .iter()
103                    .map(|x| {
104                        field_into_pg_type(x)
105                            .map(|_type| postgres_types::Field::new(x.name().clone(), _type))
106                    })
107                    .collect::<Result<Vec<_>, PgWireError>>()?,
108            );
109            Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
110        }
111        _ => {
112            return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
113                "ERROR".to_owned(),
114                "XX000".to_owned(),
115                format!("Unsupported Datatype {arrow_type}"),
116            ))));
117        }
118    };
119
120    Ok(datatype)
121}
122
123pub fn field_into_pg_type(field: &Arc<Field>) -> PgWireResult<Type> {
124    let arrow_type = field.data_type();
125
126    match field.extension_type_name() {
127        // As of arrow 56, there are additional extension logical type that is
128        // defined using field metadata, for instance, json or geo.
129        //
130        // TODO: there is no fixed Geometry/Geography type id, here we use text
131        // for placeholder.
132        #[cfg(feature = "postgis")]
133        Some(geoarrow_schema::PointType::NAME) => Ok(Type::TEXT),
134        #[cfg(feature = "postgis")]
135        Some(geoarrow_schema::LineStringType::NAME) => Ok(Type::TEXT),
136        #[cfg(feature = "postgis")]
137        Some(geoarrow_schema::PolygonType::NAME) => Ok(Type::TEXT),
138        #[cfg(feature = "postgis")]
139        Some(geoarrow_schema::MultiPointType::NAME) => Ok(Type::TEXT),
140        #[cfg(feature = "postgis")]
141        Some(geoarrow_schema::MultiLineStringType::NAME) => Ok(Type::TEXT),
142        #[cfg(feature = "postgis")]
143        Some(geoarrow_schema::MultiPolygonType::NAME) => Ok(Type::TEXT),
144        #[cfg(feature = "postgis")]
145        Some(geoarrow_schema::GeometryCollectionType::NAME) => Ok(Type::TEXT),
146        #[cfg(feature = "postgis")]
147        Some(geoarrow_schema::GeometryType::NAME) => Ok(Type::TEXT),
148        #[cfg(feature = "postgis")]
149        Some(geoarrow_schema::RectType::NAME) => Ok(Type::TEXT),
150        #[cfg(feature = "postgis")]
151        Some(geoarrow_schema::WktType::NAME) => Ok(Type::TEXT),
152        #[cfg(feature = "postgis")]
153        Some(geoarrow_schema::WkbType::NAME) => Ok(Type::TEXT),
154
155        _ => into_pg_type(arrow_type),
156    }
157}
158
159pub fn arrow_schema_to_pg_fields(
160    schema: &Schema,
161    format: &Format,
162    data_format_options: Option<Arc<FormatOptions>>,
163) -> PgWireResult<Vec<FieldInfo>> {
164    let _ = data_format_options;
165    schema
166        .fields()
167        .iter()
168        .enumerate()
169        .map(|(idx, f)| {
170            let pg_type = field_into_pg_type(f)?;
171            let mut field_info =
172                FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
173            if let Some(data_format_options) = &data_format_options {
174                field_info = field_info.with_format_options(data_format_options.clone());
175            }
176
177            Ok(field_info)
178        })
179        .collect::<PgWireResult<Vec<FieldInfo>>>()
180}
181
182pub fn encode_recordbatch(
183    fields: Arc<Vec<FieldInfo>>,
184    record_batch: RecordBatch,
185) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
186    let mut row_stream = RowEncoder::new(record_batch, fields);
187    Box::new(std::iter::from_fn(move || row_stream.next_row()))
188}