1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{datatypes::*, record_batch::RecordBatch};
5#[cfg(feature = "postgis")]
6use arrow_schema::extension::ExtensionType;
7#[cfg(feature = "datafusion")]
8use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};
9
10use pgwire::api::portal::Format;
11use pgwire::api::results::FieldInfo;
12use pgwire::api::Type;
13use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
14use pgwire::messages::data::DataRow;
15use pgwire::types::format::FormatOptions;
16use postgres_types::Kind;
17
18use crate::row_encoder::RowEncoder;
19
20#[cfg(feature = "datafusion")]
21pub mod df;
22
23pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
24 let datatype = match arrow_type {
25 DataType::Null => Type::UNKNOWN,
26 DataType::Boolean => Type::BOOL,
27 DataType::Int8 => Type::INT2,
28 DataType::Int16 | DataType::UInt8 => Type::INT2,
29 DataType::Int32 | DataType::UInt16 => Type::INT4,
30 DataType::Int64 | DataType::UInt32 => Type::INT8,
31 DataType::UInt64 => Type::NUMERIC,
32 DataType::Timestamp(_, tz) => {
33 if tz.is_some() {
34 Type::TIMESTAMPTZ
35 } else {
36 Type::TIMESTAMP
37 }
38 }
39 DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
40 DataType::Date32 | DataType::Date64 => Type::DATE,
41 DataType::Interval(_) | DataType::Duration(_) => Type::INTERVAL,
42 DataType::Binary
43 | DataType::FixedSizeBinary(_)
44 | DataType::LargeBinary
45 | DataType::BinaryView => Type::BYTEA,
46 DataType::Float16 | DataType::Float32 => Type::FLOAT4,
47 DataType::Float64 => Type::FLOAT8,
48 DataType::Decimal128(_, _) => Type::NUMERIC,
49 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
50 DataType::List(field)
51 | DataType::FixedSizeList(field, _)
52 | DataType::LargeList(field)
53 | DataType::ListView(field)
54 | DataType::LargeListView(field) => match field.data_type() {
55 DataType::Boolean => Type::BOOL_ARRAY,
56 DataType::Int8 => Type::INT2_ARRAY,
57 DataType::Int16 | DataType::UInt8 => Type::INT2_ARRAY,
58 DataType::Int32 | DataType::UInt16 => Type::INT4_ARRAY,
59 DataType::Int64 | DataType::UInt32 => Type::INT8_ARRAY,
60 DataType::UInt64 | DataType::Decimal128(_, _) => Type::NUMERIC_ARRAY,
61 DataType::Timestamp(_, tz) => {
62 if tz.is_some() {
63 Type::TIMESTAMPTZ_ARRAY
64 } else {
65 Type::TIMESTAMP_ARRAY
66 }
67 }
68 DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
69 DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
70 DataType::Interval(_) | DataType::Duration(_) => Type::INTERVAL_ARRAY,
71 DataType::FixedSizeBinary(_)
72 | DataType::Binary
73 | DataType::LargeBinary
74 | DataType::BinaryView => Type::BYTEA_ARRAY,
75 DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
76 DataType::Float64 => Type::FLOAT8_ARRAY,
77 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
78 DataType::Struct(_) => Type::new(
79 Type::RECORD_ARRAY.name().into(),
80 Type::RECORD_ARRAY.oid(),
81 Kind::Array(field_into_pg_type(field)?),
82 Type::RECORD_ARRAY.schema().into(),
83 ),
84 list_type => {
85 return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
86 "ERROR".to_owned(),
87 "XX000".to_owned(),
88 format!("Unsupported List Datatype {list_type}"),
89 ))));
90 }
91 },
92 DataType::Dictionary(_, value_type) => into_pg_type(value_type.as_ref())?,
93 DataType::Struct(fields) => {
94 let name: String = fields
95 .iter()
96 .map(|x| x.name().clone())
97 .reduce(|a, b| a + ", " + &b)
98 .map(|x| format!("({x})"))
99 .unwrap_or("()".to_string());
100 let kind = Kind::Composite(
101 fields
102 .iter()
103 .map(|x| {
104 field_into_pg_type(x)
105 .map(|_type| postgres_types::Field::new(x.name().clone(), _type))
106 })
107 .collect::<Result<Vec<_>, PgWireError>>()?,
108 );
109 Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
110 }
111 _ => {
112 return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
113 "ERROR".to_owned(),
114 "XX000".to_owned(),
115 format!("Unsupported Datatype {arrow_type}"),
116 ))));
117 }
118 };
119
120 Ok(datatype)
121}
122
123pub fn field_into_pg_type(field: &Arc<Field>) -> PgWireResult<Type> {
124 let arrow_type = field.data_type();
125
126 match field.extension_type_name() {
127 #[cfg(feature = "postgis")]
133 Some(geoarrow_schema::PointType::NAME) => Ok(Type::TEXT),
134 #[cfg(feature = "postgis")]
135 Some(geoarrow_schema::LineStringType::NAME) => Ok(Type::TEXT),
136 #[cfg(feature = "postgis")]
137 Some(geoarrow_schema::PolygonType::NAME) => Ok(Type::TEXT),
138 #[cfg(feature = "postgis")]
139 Some(geoarrow_schema::MultiPointType::NAME) => Ok(Type::TEXT),
140 #[cfg(feature = "postgis")]
141 Some(geoarrow_schema::MultiLineStringType::NAME) => Ok(Type::TEXT),
142 #[cfg(feature = "postgis")]
143 Some(geoarrow_schema::MultiPolygonType::NAME) => Ok(Type::TEXT),
144 #[cfg(feature = "postgis")]
145 Some(geoarrow_schema::GeometryCollectionType::NAME) => Ok(Type::TEXT),
146 #[cfg(feature = "postgis")]
147 Some(geoarrow_schema::GeometryType::NAME) => Ok(Type::TEXT),
148 #[cfg(feature = "postgis")]
149 Some(geoarrow_schema::RectType::NAME) => Ok(Type::TEXT),
150 #[cfg(feature = "postgis")]
151 Some(geoarrow_schema::WktType::NAME) => Ok(Type::TEXT),
152 #[cfg(feature = "postgis")]
153 Some(geoarrow_schema::WkbType::NAME) => Ok(Type::TEXT),
154
155 _ => into_pg_type(arrow_type),
156 }
157}
158
159pub fn arrow_schema_to_pg_fields(
160 schema: &Schema,
161 format: &Format,
162 data_format_options: Option<Arc<FormatOptions>>,
163) -> PgWireResult<Vec<FieldInfo>> {
164 let _ = data_format_options;
165 schema
166 .fields()
167 .iter()
168 .enumerate()
169 .map(|(idx, f)| {
170 let pg_type = field_into_pg_type(f)?;
171 let mut field_info =
172 FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
173 if let Some(data_format_options) = &data_format_options {
174 field_info = field_info.with_format_options(data_format_options.clone());
175 }
176
177 Ok(field_info)
178 })
179 .collect::<PgWireResult<Vec<FieldInfo>>>()
180}
181
182pub fn encode_recordbatch(
183 fields: Arc<Vec<FieldInfo>>,
184 record_batch: RecordBatch,
185) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
186 let mut row_stream = RowEncoder::new(record_batch, fields);
187 Box::new(std::iter::from_fn(move || row_stream.next_row()))
188}