arrow_pg/
list_encoder.rs

1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5    array::{
6        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
7        Decimal128Array, Decimal256Array, DurationMicrosecondArray, LargeBinaryArray,
8        LargeListArray, LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray,
9        StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
10        Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
11        TimestampNanosecondArray, TimestampSecondArray,
12    },
13    datatypes::{
14        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
15        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
16        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
17    },
18    temporal_conversions::{as_date, as_time},
19};
20#[cfg(feature = "datafusion")]
21use datafusion::arrow::{
22    array::{
23        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
24        Decimal128Array, Decimal256Array, DurationMicrosecondArray, LargeBinaryArray,
25        LargeListArray, LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray,
26        StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
27        Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
28        TimestampNanosecondArray, TimestampSecondArray,
29    },
30    datatypes::{
31        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
32        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
33        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
34    },
35    temporal_conversions::{as_date, as_time},
36};
37
38use bytes::{BufMut, BytesMut};
39use chrono::{DateTime, TimeZone, Utc};
40use pgwire::api::results::FieldFormat;
41use pgwire::error::{PgWireError, PgWireResult};
42use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
43use postgres_types::{ToSql, Type};
44use rust_decimal::Decimal;
45
46use crate::encoder::EncodedValue;
47use crate::error::ToSqlError;
48use crate::struct_encoder::encode_struct;
49
50fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
51    arr.as_any()
52        .downcast_ref::<BooleanArray>()
53        .unwrap()
54        .iter()
55        .collect()
56}
57
58macro_rules! get_primitive_list_value {
59    ($name:ident, $t:ty, $pt:ty) => {
60        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
61            arr.as_any()
62                .downcast_ref::<PrimitiveArray<$t>>()
63                .unwrap()
64                .iter()
65                .collect()
66        }
67    };
68
69    ($name:ident, $t:ty, $pt:ty, $f:expr) => {
70        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
71            arr.as_any()
72                .downcast_ref::<PrimitiveArray<$t>>()
73                .unwrap()
74                .iter()
75                .map(|val| val.map($f))
76                .collect()
77        }
78    };
79}
80
81get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
82get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
83get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
84get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
85get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
86get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
87    val as i16
88});
89get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
90get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
91    val as i64
92});
93get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
94get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
95
96fn encode_field<T: ToSql + ToSqlText>(
97    t: &[T],
98    type_: &Type,
99    format: FieldFormat,
100) -> PgWireResult<EncodedValue> {
101    let mut bytes = BytesMut::new();
102    match format {
103        FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
104        FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
105    };
106    Ok(EncodedValue { bytes })
107}
108
109pub(crate) fn encode_list(
110    arr: Arc<dyn Array>,
111    type_: &Type,
112    format: FieldFormat,
113) -> PgWireResult<EncodedValue> {
114    match arr.data_type() {
115        DataType::Null => {
116            let mut bytes = BytesMut::new();
117            match format {
118                FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
119                FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
120            }?;
121            Ok(EncodedValue { bytes })
122        }
123        DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
124        DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
125        DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
126        DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
127        DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
128        DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
129        DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
130        DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
131        DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
132        DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
133        DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
134        DataType::Decimal128(_, s) => {
135            let value: Vec<_> = arr
136                .as_any()
137                .downcast_ref::<Decimal128Array>()
138                .unwrap()
139                .iter()
140                .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
141                .collect();
142            encode_field(&value, type_, format)
143        }
144        DataType::Utf8 => {
145            let value: Vec<Option<&str>> = arr
146                .as_any()
147                .downcast_ref::<StringArray>()
148                .unwrap()
149                .iter()
150                .collect();
151            encode_field(&value, type_, format)
152        }
153        DataType::Utf8View => {
154            let value: Vec<Option<&str>> = arr
155                .as_any()
156                .downcast_ref::<StringViewArray>()
157                .unwrap()
158                .iter()
159                .collect();
160            encode_field(&value, type_, format)
161        }
162        DataType::Binary => {
163            let value: Vec<Option<_>> = arr
164                .as_any()
165                .downcast_ref::<BinaryArray>()
166                .unwrap()
167                .iter()
168                .collect();
169            encode_field(&value, type_, format)
170        }
171        DataType::LargeBinary => {
172            let value: Vec<Option<_>> = arr
173                .as_any()
174                .downcast_ref::<LargeBinaryArray>()
175                .unwrap()
176                .iter()
177                .collect();
178            encode_field(&value, type_, format)
179        }
180        DataType::BinaryView => {
181            let value: Vec<Option<_>> = arr
182                .as_any()
183                .downcast_ref::<BinaryViewArray>()
184                .unwrap()
185                .iter()
186                .collect();
187            encode_field(&value, type_, format)
188        }
189
190        DataType::Date32 => {
191            let value: Vec<Option<_>> = arr
192                .as_any()
193                .downcast_ref::<Date32Array>()
194                .unwrap()
195                .iter()
196                .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
197                .collect();
198            encode_field(&value, type_, format)
199        }
200        DataType::Date64 => {
201            let value: Vec<Option<_>> = arr
202                .as_any()
203                .downcast_ref::<Date64Array>()
204                .unwrap()
205                .iter()
206                .map(|val| val.and_then(as_date::<Date64Type>))
207                .collect();
208            encode_field(&value, type_, format)
209        }
210        DataType::Time32(unit) => match unit {
211            TimeUnit::Second => {
212                let value: Vec<Option<_>> = arr
213                    .as_any()
214                    .downcast_ref::<Time32SecondArray>()
215                    .unwrap()
216                    .iter()
217                    .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
218                    .collect();
219                encode_field(&value, type_, format)
220            }
221            TimeUnit::Millisecond => {
222                let value: Vec<Option<_>> = arr
223                    .as_any()
224                    .downcast_ref::<Time32MillisecondArray>()
225                    .unwrap()
226                    .iter()
227                    .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
228                    .collect();
229                encode_field(&value, type_, format)
230            }
231            _ => {
232                // Time32 only supports Second and Millisecond in Arrow
233                // Other units are not available, so return an error
234                Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
235            }
236        },
237        DataType::Time64(unit) => match unit {
238            TimeUnit::Microsecond => {
239                let value: Vec<Option<_>> = arr
240                    .as_any()
241                    .downcast_ref::<Time64MicrosecondArray>()
242                    .unwrap()
243                    .iter()
244                    .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
245                    .collect();
246                encode_field(&value, type_, format)
247            }
248            TimeUnit::Nanosecond => {
249                let value: Vec<Option<_>> = arr
250                    .as_any()
251                    .downcast_ref::<Time64NanosecondArray>()
252                    .unwrap()
253                    .iter()
254                    .map(|val| val.and_then(as_time::<Time64NanosecondType>))
255                    .collect();
256                encode_field(&value, type_, format)
257            }
258            _ => {
259                // Time64 only supports Microsecond and Nanosecond in Arrow
260                // Other units are not available, so return an error
261                Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
262            }
263        },
264        DataType::Timestamp(unit, timezone) => match unit {
265            TimeUnit::Second => {
266                let array_iter = arr
267                    .as_any()
268                    .downcast_ref::<TimestampSecondArray>()
269                    .unwrap()
270                    .iter();
271
272                if let Some(tz) = timezone {
273                    let tz = Tz::from_str(tz.as_ref())
274                        .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
275                    let value: Vec<_> = array_iter
276                        .map(|i| {
277                            i.and_then(|i| {
278                                DateTime::from_timestamp(i, 0).map(|dt| {
279                                    Utc.from_utc_datetime(&dt.naive_utc())
280                                        .with_timezone(&tz)
281                                        .fixed_offset()
282                                })
283                            })
284                        })
285                        .collect();
286                    encode_field(&value, type_, format)
287                } else {
288                    let value: Vec<_> = array_iter
289                        .map(|i| {
290                            i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
291                        })
292                        .collect();
293                    encode_field(&value, type_, format)
294                }
295            }
296            TimeUnit::Millisecond => {
297                let array_iter = arr
298                    .as_any()
299                    .downcast_ref::<TimestampMillisecondArray>()
300                    .unwrap()
301                    .iter();
302
303                if let Some(tz) = timezone {
304                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
305                    let value: Vec<_> = array_iter
306                        .map(|i| {
307                            i.and_then(|i| {
308                                DateTime::from_timestamp_millis(i).map(|dt| {
309                                    Utc.from_utc_datetime(&dt.naive_utc())
310                                        .with_timezone(&tz)
311                                        .fixed_offset()
312                                })
313                            })
314                        })
315                        .collect();
316                    encode_field(&value, type_, format)
317                } else {
318                    let value: Vec<_> = array_iter
319                        .map(|i| {
320                            i.and_then(|i| {
321                                DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
322                            })
323                        })
324                        .collect();
325                    encode_field(&value, type_, format)
326                }
327            }
328            TimeUnit::Microsecond => {
329                let array_iter = arr
330                    .as_any()
331                    .downcast_ref::<TimestampMicrosecondArray>()
332                    .unwrap()
333                    .iter();
334
335                if let Some(tz) = timezone {
336                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
337                    let value: Vec<_> = array_iter
338                        .map(|i| {
339                            i.and_then(|i| {
340                                DateTime::from_timestamp_micros(i).map(|dt| {
341                                    Utc.from_utc_datetime(&dt.naive_utc())
342                                        .with_timezone(&tz)
343                                        .fixed_offset()
344                                })
345                            })
346                        })
347                        .collect();
348                    encode_field(&value, type_, format)
349                } else {
350                    let value: Vec<_> = array_iter
351                        .map(|i| {
352                            i.and_then(|i| {
353                                DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
354                            })
355                        })
356                        .collect();
357                    encode_field(&value, type_, format)
358                }
359            }
360            TimeUnit::Nanosecond => {
361                let array_iter = arr
362                    .as_any()
363                    .downcast_ref::<TimestampNanosecondArray>()
364                    .unwrap()
365                    .iter();
366
367                if let Some(tz) = timezone {
368                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
369                    let value: Vec<_> = array_iter
370                        .map(|i| {
371                            i.map(|i| {
372                                Utc.from_utc_datetime(
373                                    &DateTime::from_timestamp_nanos(i).naive_utc(),
374                                )
375                                .with_timezone(&tz)
376                                .fixed_offset()
377                            })
378                        })
379                        .collect();
380                    encode_field(&value, type_, format)
381                } else {
382                    let value: Vec<_> = array_iter
383                        .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
384                        .collect();
385                    encode_field(&value, type_, format)
386                }
387            }
388        },
389        DataType::Struct(_) => {
390            let fields = match type_.kind() {
391                postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
392                _ => Err(format!(
393                    "Expected list type found type {} of kind {:?}",
394                    type_,
395                    type_.kind()
396                )),
397            }
398            .and_then(|struct_type| match struct_type.kind() {
399                postgres_types::Kind::Composite(fields) => Ok(fields),
400                _ => Err(format!(
401                    "Failed to unwrap a composite type inside from type {} kind {:?}",
402                    type_,
403                    type_.kind()
404                )),
405            })
406            .map_err(ToSqlError::from)?;
407
408            let values: PgWireResult<Vec<_>> = (0..arr.len())
409                .map(|row| encode_struct(&arr, row, fields, format))
410                .map(|x| {
411                    if matches!(format, FieldFormat::Text) {
412                        x.map(|opt| {
413                            opt.map(|value| {
414                                let mut w = BytesMut::new();
415                                w.put_u8(b'"');
416                                w.put_slice(
417                                    QUOTE_ESCAPE
418                                        .replace_all(
419                                            &String::from_utf8_lossy(&value.bytes),
420                                            r#"\$1"#,
421                                        )
422                                        .as_bytes(),
423                                );
424                                w.put_u8(b'"');
425                                EncodedValue { bytes: w }
426                            })
427                        })
428                    } else {
429                        x
430                    }
431                })
432                .collect();
433            encode_field(&values?, type_, format)
434        }
435        DataType::LargeUtf8 => {
436            let value: Vec<Option<&str>> = arr
437                .as_any()
438                .downcast_ref::<LargeStringArray>()
439                .unwrap()
440                .iter()
441                .collect();
442            encode_field(&value, type_, format)
443        }
444        DataType::Decimal256(_, s) => {
445            // Convert Decimal256 to string representation for now
446            // since rust_decimal doesn't support 256-bit decimals
447            let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
448            let value: Vec<Option<String>> = (0..decimal_array.len())
449                .map(|i| {
450                    if decimal_array.is_null(i) {
451                        None
452                    } else {
453                        // Convert to string representation
454                        let raw_value = decimal_array.value(i);
455                        let scale = *s as u32;
456                        // Convert i256 to string and handle decimal placement manually
457                        let value_str = raw_value.to_string();
458                        if scale == 0 {
459                            Some(value_str)
460                        } else {
461                            // Insert decimal point
462                            let mut chars: Vec<char> = value_str.chars().collect();
463                            if chars.len() <= scale as usize {
464                                // Prepend zeros if needed
465                                let zeros_needed = scale as usize - chars.len() + 1;
466                                chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
467                                chars.insert(1, '.');
468                            } else {
469                                let decimal_pos = chars.len() - scale as usize;
470                                chars.insert(decimal_pos, '.');
471                            }
472                            Some(chars.into_iter().collect())
473                        }
474                    }
475                })
476                .collect();
477            encode_field(&value, type_, format)
478        }
479        DataType::Duration(_) => {
480            // Convert duration to microseconds for now
481            let value: Vec<Option<i64>> = arr
482                .as_any()
483                .downcast_ref::<DurationMicrosecondArray>()
484                .unwrap()
485                .iter()
486                .collect();
487            encode_field(&value, type_, format)
488        }
489        DataType::List(_) => {
490            // Support for nested lists (list of lists)
491            // For now, convert to string representation
492            let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
493            let value: Vec<Option<String>> = (0..list_array.len())
494                .map(|i| {
495                    if list_array.is_null(i) {
496                        None
497                    } else {
498                        // Convert nested list to string representation
499                        Some(format!("[nested_list_{i}]"))
500                    }
501                })
502                .collect();
503            encode_field(&value, type_, format)
504        }
505        DataType::LargeList(_) => {
506            // Support for large lists
507            let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
508            let value: Vec<Option<String>> = (0..list_array.len())
509                .map(|i| {
510                    if list_array.is_null(i) {
511                        None
512                    } else {
513                        Some(format!("[large_list_{i}]"))
514                    }
515                })
516                .collect();
517            encode_field(&value, type_, format)
518        }
519        DataType::Map(_, _) => {
520            // Support for map types
521            let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
522            let value: Vec<Option<String>> = (0..map_array.len())
523                .map(|i| {
524                    if map_array.is_null(i) {
525                        None
526                    } else {
527                        Some(format!("{{map_{i}}}"))
528                    }
529                })
530                .collect();
531            encode_field(&value, type_, format)
532        }
533
534        DataType::Union(_, _) => {
535            // Support for union types
536            let value: Vec<Option<String>> = (0..arr.len())
537                .map(|i| {
538                    if arr.is_null(i) {
539                        None
540                    } else {
541                        Some(format!("union_{i}"))
542                    }
543                })
544                .collect();
545            encode_field(&value, type_, format)
546        }
547        DataType::Dictionary(_, _) => {
548            // Support for dictionary types
549            let value: Vec<Option<String>> = (0..arr.len())
550                .map(|i| {
551                    if arr.is_null(i) {
552                        None
553                    } else {
554                        Some(format!("dict_{i}"))
555                    }
556                })
557                .collect();
558            encode_field(&value, type_, format)
559        }
560        // TODO: add support for more advanced types (fixed size lists, etc.)
561        list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
562            "Unsupported List Datatype {} and array {:?}",
563            list_type, &arr
564        )))),
565    }
566}