arrow_pg/
list_encoder.rs

1use std::{str::FromStr, sync::Arc};
2
3use arrow::array::{
4    timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
5    LargeBinaryArray, PrimitiveArray, StringArray, Time32MillisecondArray, Time32SecondArray,
6    Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
7    TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
8};
9use arrow::{
10    datatypes::{
11        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
12        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
13        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
14    },
15    temporal_conversions::{as_date, as_time},
16};
17use bytes::{BufMut, BytesMut};
18use chrono::{DateTime, TimeZone, Utc};
19use pgwire::api::results::FieldFormat;
20use pgwire::error::{PgWireError, PgWireResult};
21use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
22use postgres_types::{ToSql, Type};
23use rust_decimal::Decimal;
24
25use crate::encoder::EncodedValue;
26use crate::error::ToSqlError;
27use crate::struct_encoder::encode_struct;
28
29fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
30    arr.as_any()
31        .downcast_ref::<BooleanArray>()
32        .unwrap()
33        .iter()
34        .collect()
35}
36
37macro_rules! get_primitive_list_value {
38    ($name:ident, $t:ty, $pt:ty) => {
39        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
40            arr.as_any()
41                .downcast_ref::<PrimitiveArray<$t>>()
42                .unwrap()
43                .iter()
44                .collect()
45        }
46    };
47
48    ($name:ident, $t:ty, $pt:ty, $f:expr) => {
49        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
50            arr.as_any()
51                .downcast_ref::<PrimitiveArray<$t>>()
52                .unwrap()
53                .iter()
54                .map(|val| val.map($f))
55                .collect()
56        }
57    };
58}
59
60get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
61get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
62get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
63get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
64get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
65get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
66    val as i16
67});
68get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
69get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
70    val as i64
71});
72get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
73get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
74
75fn encode_field<T: ToSql + ToSqlText>(
76    t: &[T],
77    type_: &Type,
78    format: FieldFormat,
79) -> PgWireResult<EncodedValue> {
80    let mut bytes = BytesMut::new();
81    match format {
82        FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
83        FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
84    };
85    Ok(EncodedValue { bytes })
86}
87
88pub(crate) fn encode_list(
89    arr: Arc<dyn Array>,
90    type_: &Type,
91    format: FieldFormat,
92) -> PgWireResult<EncodedValue> {
93    match arr.data_type() {
94        DataType::Null => {
95            let mut bytes = BytesMut::new();
96            match format {
97                FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
98                FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
99            }?;
100            Ok(EncodedValue { bytes })
101        }
102        DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
103        DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
104        DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
105        DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
106        DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
107        DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
108        DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
109        DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
110        DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
111        DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
112        DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
113        DataType::Decimal128(_, s) => {
114            let value: Vec<_> = arr
115                .as_any()
116                .downcast_ref::<Decimal128Array>()
117                .unwrap()
118                .iter()
119                .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
120                .collect();
121            encode_field(&value, type_, format)
122        }
123        DataType::Utf8 => {
124            let value: Vec<Option<&str>> = arr
125                .as_any()
126                .downcast_ref::<StringArray>()
127                .unwrap()
128                .iter()
129                .collect();
130            encode_field(&value, type_, format)
131        }
132        DataType::Binary => {
133            let value: Vec<Option<_>> = arr
134                .as_any()
135                .downcast_ref::<BinaryArray>()
136                .unwrap()
137                .iter()
138                .collect();
139            encode_field(&value, type_, format)
140        }
141        DataType::LargeBinary => {
142            let value: Vec<Option<_>> = arr
143                .as_any()
144                .downcast_ref::<LargeBinaryArray>()
145                .unwrap()
146                .iter()
147                .collect();
148            encode_field(&value, type_, format)
149        }
150
151        DataType::Date32 => {
152            let value: Vec<Option<_>> = arr
153                .as_any()
154                .downcast_ref::<Date32Array>()
155                .unwrap()
156                .iter()
157                .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
158                .collect();
159            encode_field(&value, type_, format)
160        }
161        DataType::Date64 => {
162            let value: Vec<Option<_>> = arr
163                .as_any()
164                .downcast_ref::<Date64Array>()
165                .unwrap()
166                .iter()
167                .map(|val| val.and_then(as_date::<Date64Type>))
168                .collect();
169            encode_field(&value, type_, format)
170        }
171        DataType::Time32(unit) => match unit {
172            TimeUnit::Second => {
173                let value: Vec<Option<_>> = arr
174                    .as_any()
175                    .downcast_ref::<Time32SecondArray>()
176                    .unwrap()
177                    .iter()
178                    .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
179                    .collect();
180                encode_field(&value, type_, format)
181            }
182            TimeUnit::Millisecond => {
183                let value: Vec<Option<_>> = arr
184                    .as_any()
185                    .downcast_ref::<Time32MillisecondArray>()
186                    .unwrap()
187                    .iter()
188                    .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
189                    .collect();
190                encode_field(&value, type_, format)
191            }
192            _ => {
193                unimplemented!()
194            }
195        },
196        DataType::Time64(unit) => match unit {
197            TimeUnit::Microsecond => {
198                let value: Vec<Option<_>> = arr
199                    .as_any()
200                    .downcast_ref::<Time64MicrosecondArray>()
201                    .unwrap()
202                    .iter()
203                    .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
204                    .collect();
205                encode_field(&value, type_, format)
206            }
207            TimeUnit::Nanosecond => {
208                let value: Vec<Option<_>> = arr
209                    .as_any()
210                    .downcast_ref::<Time64NanosecondArray>()
211                    .unwrap()
212                    .iter()
213                    .map(|val| val.and_then(as_time::<Time64NanosecondType>))
214                    .collect();
215                encode_field(&value, type_, format)
216            }
217            _ => {
218                unimplemented!()
219            }
220        },
221        DataType::Timestamp(unit, timezone) => match unit {
222            TimeUnit::Second => {
223                let array_iter = arr
224                    .as_any()
225                    .downcast_ref::<TimestampSecondArray>()
226                    .unwrap()
227                    .iter();
228
229                if let Some(tz) = timezone {
230                    let tz = Tz::from_str(tz.as_ref())
231                        .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
232                    let value: Vec<_> = array_iter
233                        .map(|i| {
234                            i.and_then(|i| {
235                                DateTime::from_timestamp(i, 0).map(|dt| {
236                                    Utc.from_utc_datetime(&dt.naive_utc())
237                                        .with_timezone(&tz)
238                                        .fixed_offset()
239                                })
240                            })
241                        })
242                        .collect();
243                    encode_field(&value, type_, format)
244                } else {
245                    let value: Vec<_> = array_iter
246                        .map(|i| {
247                            i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
248                        })
249                        .collect();
250                    encode_field(&value, type_, format)
251                }
252            }
253            TimeUnit::Millisecond => {
254                let array_iter = arr
255                    .as_any()
256                    .downcast_ref::<TimestampMillisecondArray>()
257                    .unwrap()
258                    .iter();
259
260                if let Some(tz) = timezone {
261                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
262                    let value: Vec<_> = array_iter
263                        .map(|i| {
264                            i.and_then(|i| {
265                                DateTime::from_timestamp_millis(i).map(|dt| {
266                                    Utc.from_utc_datetime(&dt.naive_utc())
267                                        .with_timezone(&tz)
268                                        .fixed_offset()
269                                })
270                            })
271                        })
272                        .collect();
273                    encode_field(&value, type_, format)
274                } else {
275                    let value: Vec<_> = array_iter
276                        .map(|i| {
277                            i.and_then(|i| {
278                                DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
279                            })
280                        })
281                        .collect();
282                    encode_field(&value, type_, format)
283                }
284            }
285            TimeUnit::Microsecond => {
286                let array_iter = arr
287                    .as_any()
288                    .downcast_ref::<TimestampMicrosecondArray>()
289                    .unwrap()
290                    .iter();
291
292                if let Some(tz) = timezone {
293                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
294                    let value: Vec<_> = array_iter
295                        .map(|i| {
296                            i.and_then(|i| {
297                                DateTime::from_timestamp_micros(i).map(|dt| {
298                                    Utc.from_utc_datetime(&dt.naive_utc())
299                                        .with_timezone(&tz)
300                                        .fixed_offset()
301                                })
302                            })
303                        })
304                        .collect();
305                    encode_field(&value, type_, format)
306                } else {
307                    let value: Vec<_> = array_iter
308                        .map(|i| {
309                            i.and_then(|i| {
310                                DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
311                            })
312                        })
313                        .collect();
314                    encode_field(&value, type_, format)
315                }
316            }
317            TimeUnit::Nanosecond => {
318                let array_iter = arr
319                    .as_any()
320                    .downcast_ref::<TimestampNanosecondArray>()
321                    .unwrap()
322                    .iter();
323
324                if let Some(tz) = timezone {
325                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
326                    let value: Vec<_> = array_iter
327                        .map(|i| {
328                            i.map(|i| {
329                                Utc.from_utc_datetime(
330                                    &DateTime::from_timestamp_nanos(i).naive_utc(),
331                                )
332                                .with_timezone(&tz)
333                                .fixed_offset()
334                            })
335                        })
336                        .collect();
337                    encode_field(&value, type_, format)
338                } else {
339                    let value: Vec<_> = array_iter
340                        .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
341                        .collect();
342                    encode_field(&value, type_, format)
343                }
344            }
345        },
346        DataType::Struct(_) => {
347            let fields = match type_.kind() {
348                postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
349                _ => Err(format!(
350                    "Expected list type found type {} of kind {:?}",
351                    type_,
352                    type_.kind()
353                )),
354            }
355            .and_then(|struct_type| match struct_type.kind() {
356                postgres_types::Kind::Composite(fields) => Ok(fields),
357                _ => Err(format!(
358                    "Failed to unwrap a composite type inside from type {} kind {:?}",
359                    type_,
360                    type_.kind()
361                )),
362            })
363            .map_err(ToSqlError::from)?;
364
365            let values: PgWireResult<Vec<_>> = (0..arr.len())
366                .map(|row| encode_struct(&arr, row, fields, format))
367                .map(|x| {
368                    if matches!(format, FieldFormat::Text) {
369                        x.map(|opt| {
370                            opt.map(|value| {
371                                let mut w = BytesMut::new();
372                                w.put_u8(b'"');
373                                w.put_slice(
374                                    QUOTE_ESCAPE
375                                        .replace_all(
376                                            &String::from_utf8_lossy(&value.bytes),
377                                            r#"\$1"#,
378                                        )
379                                        .as_bytes(),
380                                );
381                                w.put_u8(b'"');
382                                EncodedValue { bytes: w }
383                            })
384                        })
385                    } else {
386                        x
387                    }
388                })
389                .collect();
390            encode_field(&values?, type_, format)
391        }
392        // TODO: more types
393        list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
394            "Unsupported List Datatype {} and array {:?}",
395            list_type, &arr
396        )))),
397    }
398}