arrow_pg/
list_encoder.rs

1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5    array::{
6        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
7        Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
8        DurationNanosecondArray, DurationSecondArray, IntervalDayTimeArray,
9        IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray,
10        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, StringViewArray,
11        Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
12        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
13        TimestampSecondArray,
14    },
15    datatypes::{
16        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
17        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
18        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
19    },
20    temporal_conversions::{as_date, as_time},
21};
22#[cfg(feature = "datafusion")]
23use datafusion::arrow::{
24    array::{
25        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
26        Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
27        DurationNanosecondArray, DurationSecondArray, IntervalDayTimeArray,
28        IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray,
29        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, StringViewArray,
30        Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
31        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
32        TimestampSecondArray,
33    },
34    datatypes::{
35        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
36        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
37        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
38    },
39    temporal_conversions::{as_date, as_time},
40};
41
42use chrono::{DateTime, TimeZone, Utc};
43use pg_interval::Interval as PgInterval;
44use pgwire::api::results::FieldInfo;
45use pgwire::error::{PgWireError, PgWireResult};
46use rust_decimal::Decimal;
47
48use crate::encoder::Encoder;
49use crate::error::ToSqlError;
50use crate::struct_encoder::encode_structs;
51
52fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
53    arr.as_any()
54        .downcast_ref::<BooleanArray>()
55        .unwrap()
56        .iter()
57        .collect()
58}
59
60macro_rules! get_primitive_list_value {
61    ($name:ident, $t:ty, $pt:ty) => {
62        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
63            arr.as_any()
64                .downcast_ref::<PrimitiveArray<$t>>()
65                .unwrap()
66                .iter()
67                .collect()
68        }
69    };
70
71    ($name:ident, $t:ty, $pt:ty, $f:expr) => {
72        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
73            arr.as_any()
74                .downcast_ref::<PrimitiveArray<$t>>()
75                .unwrap()
76                .iter()
77                .map(|val| val.map($f))
78                .collect()
79        }
80    };
81}
82
83get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
84get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
85get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
86get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
87get_primitive_list_value!(get_u8_list_value, UInt8Type, i16, |val: u8| { val as i16 });
88get_primitive_list_value!(get_u16_list_value, UInt16Type, i32, |val: u16| {
89    val as i32
90});
91get_primitive_list_value!(get_u32_list_value, UInt32Type, i64, |val: u32| {
92    val as i64
93});
94get_primitive_list_value!(get_u64_list_value, UInt64Type, Decimal, |val: u64| {
95    Decimal::from(val)
96});
97get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
98get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
99
100pub fn encode_list<T: Encoder>(
101    encoder: &mut T,
102    arr: Arc<dyn Array>,
103    pg_field: &FieldInfo,
104) -> PgWireResult<()> {
105    match arr.data_type() {
106        DataType::Null => {
107            encoder.encode_field(&None::<i8>, pg_field)?;
108            Ok(())
109        }
110        DataType::Boolean => {
111            encoder.encode_field(&get_bool_list_value(&arr), pg_field)?;
112            Ok(())
113        }
114        DataType::Int8 => {
115            encoder.encode_field(&get_i8_list_value(&arr), pg_field)?;
116            Ok(())
117        }
118        DataType::Int16 => {
119            encoder.encode_field(&get_i16_list_value(&arr), pg_field)?;
120            Ok(())
121        }
122        DataType::Int32 => {
123            encoder.encode_field(&get_i32_list_value(&arr), pg_field)?;
124            Ok(())
125        }
126        DataType::Int64 => {
127            encoder.encode_field(&get_i64_list_value(&arr), pg_field)?;
128            Ok(())
129        }
130        DataType::UInt8 => {
131            encoder.encode_field(&get_u8_list_value(&arr), pg_field)?;
132            Ok(())
133        }
134        DataType::UInt16 => {
135            encoder.encode_field(&get_u16_list_value(&arr), pg_field)?;
136            Ok(())
137        }
138        DataType::UInt32 => {
139            encoder.encode_field(&get_u32_list_value(&arr), pg_field)?;
140            Ok(())
141        }
142        DataType::UInt64 => {
143            encoder.encode_field(&get_u64_list_value(&arr), pg_field)?;
144            Ok(())
145        }
146        DataType::Float32 => {
147            encoder.encode_field(&get_f32_list_value(&arr), pg_field)?;
148            Ok(())
149        }
150        DataType::Float64 => {
151            encoder.encode_field(&get_f64_list_value(&arr), pg_field)?;
152            Ok(())
153        }
154        DataType::Decimal128(_, s) => {
155            let value: Vec<_> = arr
156                .as_any()
157                .downcast_ref::<Decimal128Array>()
158                .unwrap()
159                .iter()
160                .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
161                .collect();
162            encoder.encode_field(&value, pg_field)
163        }
164        DataType::Utf8 => {
165            let value: Vec<Option<&str>> = arr
166                .as_any()
167                .downcast_ref::<StringArray>()
168                .unwrap()
169                .iter()
170                .collect();
171            encoder.encode_field(&value, pg_field)
172        }
173        DataType::Utf8View => {
174            let value: Vec<Option<&str>> = arr
175                .as_any()
176                .downcast_ref::<StringViewArray>()
177                .unwrap()
178                .iter()
179                .collect();
180            encoder.encode_field(&value, pg_field)
181        }
182        DataType::Binary => {
183            let value: Vec<Option<_>> = arr
184                .as_any()
185                .downcast_ref::<BinaryArray>()
186                .unwrap()
187                .iter()
188                .collect();
189            encoder.encode_field(&value, pg_field)
190        }
191        DataType::LargeBinary => {
192            let value: Vec<Option<_>> = arr
193                .as_any()
194                .downcast_ref::<LargeBinaryArray>()
195                .unwrap()
196                .iter()
197                .collect();
198            encoder.encode_field(&value, pg_field)
199        }
200        DataType::BinaryView => {
201            let value: Vec<Option<_>> = arr
202                .as_any()
203                .downcast_ref::<BinaryViewArray>()
204                .unwrap()
205                .iter()
206                .collect();
207            encoder.encode_field(&value, pg_field)
208        }
209
210        DataType::Date32 => {
211            let value: Vec<Option<_>> = arr
212                .as_any()
213                .downcast_ref::<Date32Array>()
214                .unwrap()
215                .iter()
216                .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
217                .collect();
218            encoder.encode_field(&value, pg_field)
219        }
220        DataType::Date64 => {
221            let value: Vec<Option<_>> = arr
222                .as_any()
223                .downcast_ref::<Date64Array>()
224                .unwrap()
225                .iter()
226                .map(|val| val.and_then(as_date::<Date64Type>))
227                .collect();
228            encoder.encode_field(&value, pg_field)
229        }
230        DataType::Time32(unit) => match unit {
231            TimeUnit::Second => {
232                let value: Vec<Option<_>> = arr
233                    .as_any()
234                    .downcast_ref::<Time32SecondArray>()
235                    .unwrap()
236                    .iter()
237                    .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
238                    .collect();
239                encoder.encode_field(&value, pg_field)
240            }
241            TimeUnit::Millisecond => {
242                let value: Vec<Option<_>> = arr
243                    .as_any()
244                    .downcast_ref::<Time32MillisecondArray>()
245                    .unwrap()
246                    .iter()
247                    .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
248                    .collect();
249                encoder.encode_field(&value, pg_field)
250            }
251            _ => {
252                // Time32 only supports Second and Millisecond in Arrow
253                // Other units are not available, so return an error
254                Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
255            }
256        },
257        DataType::Time64(unit) => match unit {
258            TimeUnit::Microsecond => {
259                let value: Vec<Option<_>> = arr
260                    .as_any()
261                    .downcast_ref::<Time64MicrosecondArray>()
262                    .unwrap()
263                    .iter()
264                    .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
265                    .collect();
266                encoder.encode_field(&value, pg_field)
267            }
268            TimeUnit::Nanosecond => {
269                let value: Vec<Option<_>> = arr
270                    .as_any()
271                    .downcast_ref::<Time64NanosecondArray>()
272                    .unwrap()
273                    .iter()
274                    .map(|val| val.and_then(as_time::<Time64NanosecondType>))
275                    .collect();
276                encoder.encode_field(&value, pg_field)
277            }
278            _ => {
279                // Time64 only supports Microsecond and Nanosecond in Arrow
280                // Other units are not available, so return an error
281                Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
282            }
283        },
284        DataType::Timestamp(unit, timezone) => match unit {
285            TimeUnit::Second => {
286                let array_iter = arr
287                    .as_any()
288                    .downcast_ref::<TimestampSecondArray>()
289                    .unwrap()
290                    .iter();
291
292                if let Some(tz) = timezone {
293                    let tz = Tz::from_str(tz.as_ref())
294                        .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
295                    let value: Vec<_> = array_iter
296                        .map(|i| {
297                            i.and_then(|i| {
298                                DateTime::from_timestamp(i, 0).map(|dt| {
299                                    Utc.from_utc_datetime(&dt.naive_utc())
300                                        .with_timezone(&tz)
301                                        .fixed_offset()
302                                })
303                            })
304                        })
305                        .collect();
306                    encoder.encode_field(&value, pg_field)
307                } else {
308                    let value: Vec<_> = array_iter
309                        .map(|i| {
310                            i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
311                        })
312                        .collect();
313                    encoder.encode_field(&value, pg_field)
314                }
315            }
316            TimeUnit::Millisecond => {
317                let array_iter = arr
318                    .as_any()
319                    .downcast_ref::<TimestampMillisecondArray>()
320                    .unwrap()
321                    .iter();
322
323                if let Some(tz) = timezone {
324                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
325                    let value: Vec<_> = array_iter
326                        .map(|i| {
327                            i.and_then(|i| {
328                                DateTime::from_timestamp_millis(i).map(|dt| {
329                                    Utc.from_utc_datetime(&dt.naive_utc())
330                                        .with_timezone(&tz)
331                                        .fixed_offset()
332                                })
333                            })
334                        })
335                        .collect();
336                    encoder.encode_field(&value, pg_field)
337                } else {
338                    let value: Vec<_> = array_iter
339                        .map(|i| {
340                            i.and_then(|i| {
341                                DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
342                            })
343                        })
344                        .collect();
345                    encoder.encode_field(&value, pg_field)
346                }
347            }
348            TimeUnit::Microsecond => {
349                let array_iter = arr
350                    .as_any()
351                    .downcast_ref::<TimestampMicrosecondArray>()
352                    .unwrap()
353                    .iter();
354
355                if let Some(tz) = timezone {
356                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
357                    let value: Vec<_> = array_iter
358                        .map(|i| {
359                            i.and_then(|i| {
360                                DateTime::from_timestamp_micros(i).map(|dt| {
361                                    Utc.from_utc_datetime(&dt.naive_utc())
362                                        .with_timezone(&tz)
363                                        .fixed_offset()
364                                })
365                            })
366                        })
367                        .collect();
368                    encoder.encode_field(&value, pg_field)
369                } else {
370                    let value: Vec<_> = array_iter
371                        .map(|i| {
372                            i.and_then(|i| {
373                                DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
374                            })
375                        })
376                        .collect();
377                    encoder.encode_field(&value, pg_field)
378                }
379            }
380            TimeUnit::Nanosecond => {
381                let array_iter = arr
382                    .as_any()
383                    .downcast_ref::<TimestampNanosecondArray>()
384                    .unwrap()
385                    .iter();
386
387                if let Some(tz) = timezone {
388                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
389                    let value: Vec<_> = array_iter
390                        .map(|i| {
391                            i.map(|i| {
392                                Utc.from_utc_datetime(
393                                    &DateTime::from_timestamp_nanos(i).naive_utc(),
394                                )
395                                .with_timezone(&tz)
396                                .fixed_offset()
397                            })
398                        })
399                        .collect();
400                    encoder.encode_field(&value, pg_field)
401                } else {
402                    let value: Vec<_> = array_iter
403                        .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
404                        .collect();
405                    encoder.encode_field(&value, pg_field)
406                }
407            }
408        },
409        DataType::Struct(arrow_fields) => encode_structs(encoder, &arr, arrow_fields, pg_field),
410        DataType::LargeUtf8 => {
411            let value: Vec<Option<&str>> = arr
412                .as_any()
413                .downcast_ref::<LargeStringArray>()
414                .unwrap()
415                .iter()
416                .collect();
417            encoder.encode_field(&value, pg_field)?;
418            Ok(())
419        }
420        DataType::Decimal256(_, s) => {
421            // Convert Decimal256 to string representation for now
422            // since rust_decimal doesn't support 256-bit decimals
423            let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
424            let value: Vec<Option<String>> = (0..decimal_array.len())
425                .map(|i| {
426                    if decimal_array.is_null(i) {
427                        None
428                    } else {
429                        // Convert to string representation
430                        let raw_value = decimal_array.value(i);
431                        let scale = *s as u32;
432                        // Convert i256 to string and handle decimal placement manually
433                        let value_str = raw_value.to_string();
434                        if scale == 0 {
435                            Some(value_str)
436                        } else {
437                            // Insert decimal point
438                            let mut chars: Vec<char> = value_str.chars().collect();
439                            if chars.len() <= scale as usize {
440                                // Prepend zeros if needed
441                                let zeros_needed = scale as usize - chars.len() + 1;
442                                chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
443                                chars.insert(1, '.');
444                            } else {
445                                let decimal_pos = chars.len() - scale as usize;
446                                chars.insert(decimal_pos, '.');
447                            }
448                            Some(chars.into_iter().collect())
449                        }
450                    }
451                })
452                .collect();
453            encoder.encode_field(&value, pg_field)?;
454            Ok(())
455        }
456        DataType::Duration(unit) => match unit {
457            TimeUnit::Second => {
458                let value: Vec<Option<PgInterval>> = arr
459                    .as_any()
460                    .downcast_ref::<DurationSecondArray>()
461                    .unwrap()
462                    .iter()
463                    .map(|val| val.map(|v| PgInterval::new(0, 0, v * 1_000_000i64)))
464                    .collect();
465                encoder.encode_field(&value, pg_field)?;
466                Ok(())
467            }
468            TimeUnit::Millisecond => {
469                let value: Vec<Option<PgInterval>> = arr
470                    .as_any()
471                    .downcast_ref::<DurationMillisecondArray>()
472                    .unwrap()
473                    .iter()
474                    .map(|val| val.map(|v| PgInterval::new(0, 0, v * 1_000i64)))
475                    .collect();
476                encoder.encode_field(&value, pg_field)?;
477                Ok(())
478            }
479            TimeUnit::Microsecond => {
480                let value: Vec<Option<PgInterval>> = arr
481                    .as_any()
482                    .downcast_ref::<DurationMicrosecondArray>()
483                    .unwrap()
484                    .iter()
485                    .map(|val| val.map(|v| PgInterval::new(0, 0, v)))
486                    .collect();
487                encoder.encode_field(&value, pg_field)?;
488                Ok(())
489            }
490            TimeUnit::Nanosecond => {
491                let value: Vec<Option<PgInterval>> = arr
492                    .as_any()
493                    .downcast_ref::<DurationNanosecondArray>()
494                    .unwrap()
495                    .iter()
496                    .map(|val| val.map(|v| PgInterval::new(0, 0, v / 1_000i64)))
497                    .collect();
498                encoder.encode_field(&value, pg_field)?;
499                Ok(())
500            }
501        },
502        DataType::Interval(interval_unit) => match interval_unit {
503            arrow::datatypes::IntervalUnit::YearMonth => {
504                let value: Vec<Option<PgInterval>> = arr
505                    .as_any()
506                    .downcast_ref::<IntervalYearMonthArray>()
507                    .unwrap()
508                    .iter()
509                    .map(|val| val.map(|v| PgInterval::new(v, 0, 0)))
510                    .collect();
511                encoder.encode_field(&value, pg_field)?;
512                Ok(())
513            }
514            arrow::datatypes::IntervalUnit::DayTime => {
515                let value: Vec<Option<PgInterval>> = arr
516                    .as_any()
517                    .downcast_ref::<IntervalDayTimeArray>()
518                    .unwrap()
519                    .iter()
520                    .map(|val| {
521                        val.map(|v| {
522                            let (days, millis) = arrow::datatypes::IntervalDayTimeType::to_parts(v);
523                            PgInterval::new(0, days, millis as i64 * 1000i64)
524                        })
525                    })
526                    .collect();
527                encoder.encode_field(&value, pg_field)?;
528                Ok(())
529            }
530            arrow::datatypes::IntervalUnit::MonthDayNano => {
531                let value: Vec<Option<PgInterval>> = arr
532                    .as_any()
533                    .downcast_ref::<IntervalMonthDayNanoArray>()
534                    .unwrap()
535                    .iter()
536                    .map(|val| {
537                        val.map(|v| {
538                            let (months, days, nanos) =
539                                arrow::datatypes::IntervalMonthDayNanoType::to_parts(v);
540                            PgInterval::new(months, days, nanos / 1000i64)
541                        })
542                    })
543                    .collect();
544                encoder.encode_field(&value, pg_field)?;
545                Ok(())
546            }
547        },
548        DataType::List(_) => {
549            // Support for nested lists (list of lists)
550            // For now, convert to string representation
551            let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
552            let value: Vec<Option<String>> = (0..list_array.len())
553                .map(|i| {
554                    if list_array.is_null(i) {
555                        None
556                    } else {
557                        // Convert nested list to string representation
558                        Some(format!("[nested_list_{i}]"))
559                    }
560                })
561                .collect();
562            encoder.encode_field(&value, pg_field)?;
563            Ok(())
564        }
565        DataType::LargeList(_) => {
566            // Support for large lists
567            let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
568            let value: Vec<Option<String>> = (0..list_array.len())
569                .map(|i| {
570                    if list_array.is_null(i) {
571                        None
572                    } else {
573                        Some(format!("[large_list_{i}]"))
574                    }
575                })
576                .collect();
577            encoder.encode_field(&value, pg_field)
578        }
579        DataType::Map(_, _) => {
580            // Support for map types
581            let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
582            let value: Vec<Option<String>> = (0..map_array.len())
583                .map(|i| {
584                    if map_array.is_null(i) {
585                        None
586                    } else {
587                        Some(format!("{{map_{i}}}"))
588                    }
589                })
590                .collect();
591            encoder.encode_field(&value, pg_field)?;
592            Ok(())
593        }
594
595        DataType::Union(_, _) => {
596            // Support for union types
597            let value: Vec<Option<String>> = (0..arr.len())
598                .map(|i| {
599                    if arr.is_null(i) {
600                        None
601                    } else {
602                        Some(format!("union_{i}"))
603                    }
604                })
605                .collect();
606            encoder.encode_field(&value, pg_field)?;
607            Ok(())
608        }
609        DataType::Dictionary(_, _) => {
610            // Support for dictionary types
611            let value: Vec<Option<String>> = (0..arr.len())
612                .map(|i| {
613                    if arr.is_null(i) {
614                        None
615                    } else {
616                        Some(format!("dict_{i}"))
617                    }
618                })
619                .collect();
620            encoder.encode_field(&value, pg_field)?;
621            Ok(())
622        }
623        // TODO: add support for more advanced types (fixed size lists, etc.)
624        list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
625            "Unsupported List Datatype {} and array {:?}",
626            list_type, &arr
627        )))),
628    }
629}