1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{datatypes::*, record_batch::RecordBatch};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};
7
8use pgwire::api::portal::Format;
9use pgwire::api::results::FieldInfo;
10use pgwire::api::Type;
11use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
12use pgwire::messages::data::DataRow;
13use postgres_types::Kind;
14
15use crate::row_encoder::RowEncoder;
16
17#[cfg(feature = "datafusion")]
18pub mod df;
19
20pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
21 Ok(match arrow_type {
22 DataType::Null => Type::UNKNOWN,
23 DataType::Boolean => Type::BOOL,
24 DataType::Int8 | DataType::UInt8 => Type::CHAR,
25 DataType::Int16 | DataType::UInt16 => Type::INT2,
26 DataType::Int32 | DataType::UInt32 => Type::INT4,
27 DataType::Int64 | DataType::UInt64 => Type::INT8,
28 DataType::Timestamp(_, tz) => {
29 if tz.is_some() {
30 Type::TIMESTAMPTZ
31 } else {
32 Type::TIMESTAMP
33 }
34 }
35 DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
36 DataType::Date32 | DataType::Date64 => Type::DATE,
37 DataType::Interval(_) => Type::INTERVAL,
38 DataType::Binary
39 | DataType::FixedSizeBinary(_)
40 | DataType::LargeBinary
41 | DataType::BinaryView => Type::BYTEA,
42 DataType::Float16 | DataType::Float32 => Type::FLOAT4,
43 DataType::Float64 => Type::FLOAT8,
44 DataType::Decimal128(_, _) => Type::NUMERIC,
45 DataType::Utf8 => Type::VARCHAR,
46 DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
47 DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
48 match field.data_type() {
49 DataType::Boolean => Type::BOOL_ARRAY,
50 DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
51 DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
52 DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
53 DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
54 DataType::Timestamp(_, tz) => {
55 if tz.is_some() {
56 Type::TIMESTAMPTZ_ARRAY
57 } else {
58 Type::TIMESTAMP_ARRAY
59 }
60 }
61 DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
62 DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
63 DataType::Interval(_) => Type::INTERVAL_ARRAY,
64 DataType::FixedSizeBinary(_)
65 | DataType::Binary
66 | DataType::LargeBinary
67 | DataType::BinaryView => Type::BYTEA_ARRAY,
68 DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
69 DataType::Float64 => Type::FLOAT8_ARRAY,
70 DataType::Utf8 => Type::VARCHAR_ARRAY,
71 DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
72 struct_type @ DataType::Struct(_) => Type::new(
73 Type::RECORD_ARRAY.name().into(),
74 Type::RECORD_ARRAY.oid(),
75 Kind::Array(into_pg_type(struct_type)?),
76 Type::RECORD_ARRAY.schema().into(),
77 ),
78 list_type => {
79 return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
80 "ERROR".to_owned(),
81 "XX000".to_owned(),
82 format!("Unsupported List Datatype {list_type}"),
83 ))));
84 }
85 }
86 }
87 DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
88 DataType::Struct(fields) => {
89 let name: String = fields
90 .iter()
91 .map(|x| x.name().clone())
92 .reduce(|a, b| a + ", " + &b)
93 .map(|x| format!("({x})"))
94 .unwrap_or("()".to_string());
95 let kind = Kind::Composite(
96 fields
97 .iter()
98 .map(|x| {
99 into_pg_type(x.data_type())
100 .map(|_type| postgres_types::Field::new(x.name().clone(), _type))
101 })
102 .collect::<Result<Vec<_>, PgWireError>>()?,
103 );
104 Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
105 }
106 _ => {
107 return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
108 "ERROR".to_owned(),
109 "XX000".to_owned(),
110 format!("Unsupported Datatype {arrow_type}"),
111 ))));
112 }
113 })
114}
115
116pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
117 schema
118 .fields()
119 .iter()
120 .enumerate()
121 .map(|(idx, f)| {
122 let pg_type = into_pg_type(f.data_type())?;
123 Ok(FieldInfo::new(
124 f.name().into(),
125 None,
126 None,
127 pg_type,
128 format.format_for(idx),
129 ))
130 })
131 .collect::<PgWireResult<Vec<FieldInfo>>>()
132}
133
134pub fn encode_recordbatch(
135 fields: Arc<Vec<FieldInfo>>,
136 record_batch: RecordBatch,
137) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
138 let mut row_stream = RowEncoder::new(record_batch, fields);
139 Box::new(std::iter::from_fn(move || row_stream.next_row()))
140}