Skip to main content

arrow_pg/
list_encoder.rs

1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5    array::{
6        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
7        Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
8        DurationNanosecondArray, DurationSecondArray, IntervalDayTimeArray,
9        IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray,
10        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, StringViewArray,
11        Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
12        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
13        TimestampSecondArray,
14    },
15    datatypes::{
16        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
17        Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
18        Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
19        TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
20    },
21    temporal_conversions::{as_date, as_time},
22};
23#[cfg(feature = "datafusion")]
24use datafusion::arrow::{
25    array::{
26        timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
27        Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
28        DurationNanosecondArray, DurationSecondArray, IntervalDayTimeArray,
29        IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray,
30        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, StringViewArray,
31        Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
32        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
33        TimestampSecondArray,
34    },
35    datatypes::{
36        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
37        Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
38        Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
39        TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40    },
41    temporal_conversions::{as_date, as_time},
42};
43
44use chrono::{DateTime, TimeZone, Utc};
45use pg_interval::Interval as PgInterval;
46use pgwire::api::results::FieldInfo;
47use pgwire::error::{PgWireError, PgWireResult};
48use rust_decimal::Decimal;
49
50use crate::encoder::Encoder;
51use crate::error::ToSqlError;
52use crate::struct_encoder::encode_structs;
53
54fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
55    arr.as_any()
56        .downcast_ref::<BooleanArray>()
57        .unwrap()
58        .iter()
59        .collect()
60}
61
62macro_rules! get_primitive_list_value {
63    ($name:ident, $t:ty, $pt:ty) => {
64        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
65            arr.as_any()
66                .downcast_ref::<PrimitiveArray<$t>>()
67                .unwrap()
68                .iter()
69                .collect()
70        }
71    };
72
73    ($name:ident, $t:ty, $pt:ty, $f:expr) => {
74        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
75            arr.as_any()
76                .downcast_ref::<PrimitiveArray<$t>>()
77                .unwrap()
78                .iter()
79                .map(|val| val.map($f))
80                .collect()
81        }
82    };
83}
84
85get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
86get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
87get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
88get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
89get_primitive_list_value!(get_u8_list_value, UInt8Type, i16, |val: u8| { val as i16 });
90get_primitive_list_value!(get_u16_list_value, UInt16Type, i32, |val: u16| {
91    val as i32
92});
93get_primitive_list_value!(get_u32_list_value, UInt32Type, i64, |val: u32| {
94    val as i64
95});
96get_primitive_list_value!(get_u64_list_value, UInt64Type, Decimal, |val: u64| {
97    Decimal::from(val)
98});
99get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
100get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
101
102pub fn encode_list<T: Encoder>(
103    encoder: &mut T,
104    arr: Arc<dyn Array>,
105    pg_field: &FieldInfo,
106) -> PgWireResult<()> {
107    match arr.data_type() {
108        DataType::Null => {
109            encoder.encode_field(&None::<i8>, pg_field)?;
110            Ok(())
111        }
112        DataType::Boolean => {
113            encoder.encode_field(&get_bool_list_value(&arr), pg_field)?;
114            Ok(())
115        }
116        DataType::Int8 => {
117            encoder.encode_field(&get_i8_list_value(&arr), pg_field)?;
118            Ok(())
119        }
120        DataType::Int16 => {
121            encoder.encode_field(&get_i16_list_value(&arr), pg_field)?;
122            Ok(())
123        }
124        DataType::Int32 => {
125            encoder.encode_field(&get_i32_list_value(&arr), pg_field)?;
126            Ok(())
127        }
128        DataType::Int64 => {
129            encoder.encode_field(&get_i64_list_value(&arr), pg_field)?;
130            Ok(())
131        }
132        DataType::UInt8 => {
133            encoder.encode_field(&get_u8_list_value(&arr), pg_field)?;
134            Ok(())
135        }
136        DataType::UInt16 => {
137            encoder.encode_field(&get_u16_list_value(&arr), pg_field)?;
138            Ok(())
139        }
140        DataType::UInt32 => {
141            encoder.encode_field(&get_u32_list_value(&arr), pg_field)?;
142            Ok(())
143        }
144        DataType::UInt64 => {
145            encoder.encode_field(&get_u64_list_value(&arr), pg_field)?;
146            Ok(())
147        }
148        DataType::Float32 => {
149            encoder.encode_field(&get_f32_list_value(&arr), pg_field)?;
150            Ok(())
151        }
152        DataType::Float64 => {
153            encoder.encode_field(&get_f64_list_value(&arr), pg_field)?;
154            Ok(())
155        }
156        DataType::Decimal128(_, s) => {
157            let value: Vec<_> = arr
158                .as_any()
159                .downcast_ref::<Decimal128Array>()
160                .unwrap()
161                .iter()
162                .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
163                .collect();
164            encoder.encode_field(&value, pg_field)
165        }
166        DataType::Utf8 => {
167            let value: Vec<Option<&str>> = arr
168                .as_any()
169                .downcast_ref::<StringArray>()
170                .unwrap()
171                .iter()
172                .collect();
173            encoder.encode_field(&value, pg_field)
174        }
175        DataType::Utf8View => {
176            let value: Vec<Option<&str>> = arr
177                .as_any()
178                .downcast_ref::<StringViewArray>()
179                .unwrap()
180                .iter()
181                .collect();
182            encoder.encode_field(&value, pg_field)
183        }
184        DataType::Binary => {
185            let value: Vec<Option<_>> = arr
186                .as_any()
187                .downcast_ref::<BinaryArray>()
188                .unwrap()
189                .iter()
190                .collect();
191            encoder.encode_field(&value, pg_field)
192        }
193        DataType::LargeBinary => {
194            let value: Vec<Option<_>> = arr
195                .as_any()
196                .downcast_ref::<LargeBinaryArray>()
197                .unwrap()
198                .iter()
199                .collect();
200            encoder.encode_field(&value, pg_field)
201        }
202        DataType::BinaryView => {
203            let value: Vec<Option<_>> = arr
204                .as_any()
205                .downcast_ref::<BinaryViewArray>()
206                .unwrap()
207                .iter()
208                .collect();
209            encoder.encode_field(&value, pg_field)
210        }
211
212        DataType::Date32 => {
213            let value: Vec<Option<_>> = arr
214                .as_any()
215                .downcast_ref::<Date32Array>()
216                .unwrap()
217                .iter()
218                .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
219                .collect();
220            encoder.encode_field(&value, pg_field)
221        }
222        DataType::Date64 => {
223            let value: Vec<Option<_>> = arr
224                .as_any()
225                .downcast_ref::<Date64Array>()
226                .unwrap()
227                .iter()
228                .map(|val| val.and_then(as_date::<Date64Type>))
229                .collect();
230            encoder.encode_field(&value, pg_field)
231        }
232        DataType::Time32(unit) => match unit {
233            TimeUnit::Second => {
234                let value: Vec<Option<_>> = arr
235                    .as_any()
236                    .downcast_ref::<Time32SecondArray>()
237                    .unwrap()
238                    .iter()
239                    .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
240                    .collect();
241                encoder.encode_field(&value, pg_field)
242            }
243            TimeUnit::Millisecond => {
244                let value: Vec<Option<_>> = arr
245                    .as_any()
246                    .downcast_ref::<Time32MillisecondArray>()
247                    .unwrap()
248                    .iter()
249                    .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
250                    .collect();
251                encoder.encode_field(&value, pg_field)
252            }
253            _ => {
254                // Time32 only supports Second and Millisecond in Arrow
255                // Other units are not available, so return an error
256                Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
257            }
258        },
259        DataType::Time64(unit) => match unit {
260            TimeUnit::Microsecond => {
261                let value: Vec<Option<_>> = arr
262                    .as_any()
263                    .downcast_ref::<Time64MicrosecondArray>()
264                    .unwrap()
265                    .iter()
266                    .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
267                    .collect();
268                encoder.encode_field(&value, pg_field)
269            }
270            TimeUnit::Nanosecond => {
271                let value: Vec<Option<_>> = arr
272                    .as_any()
273                    .downcast_ref::<Time64NanosecondArray>()
274                    .unwrap()
275                    .iter()
276                    .map(|val| val.and_then(as_time::<Time64NanosecondType>))
277                    .collect();
278                encoder.encode_field(&value, pg_field)
279            }
280            _ => {
281                // Time64 only supports Microsecond and Nanosecond in Arrow
282                // Other units are not available, so return an error
283                Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
284            }
285        },
286        DataType::Timestamp(unit, timezone) => match unit {
287            TimeUnit::Second => {
288                let array_iter = arr
289                    .as_any()
290                    .downcast_ref::<TimestampSecondArray>()
291                    .unwrap()
292                    .iter();
293
294                if let Some(tz) = timezone {
295                    let tz = Tz::from_str(tz.as_ref())
296                        .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
297                    let value: Vec<_> = array_iter
298                        .map(|i| {
299                            i.and_then(|i| {
300                                DateTime::from_timestamp(i, 0).map(|dt| {
301                                    Utc.from_utc_datetime(&dt.naive_utc())
302                                        .with_timezone(&tz)
303                                        .fixed_offset()
304                                })
305                            })
306                        })
307                        .collect();
308                    encoder.encode_field(&value, pg_field)
309                } else {
310                    let value: Vec<_> = array_iter
311                        .map(|i| {
312                            i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
313                        })
314                        .collect();
315                    encoder.encode_field(&value, pg_field)
316                }
317            }
318            TimeUnit::Millisecond => {
319                let array_iter = arr
320                    .as_any()
321                    .downcast_ref::<TimestampMillisecondArray>()
322                    .unwrap()
323                    .iter();
324
325                if let Some(tz) = timezone {
326                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
327                    let value: Vec<_> = array_iter
328                        .map(|i| {
329                            i.and_then(|i| {
330                                DateTime::from_timestamp_millis(i).map(|dt| {
331                                    Utc.from_utc_datetime(&dt.naive_utc())
332                                        .with_timezone(&tz)
333                                        .fixed_offset()
334                                })
335                            })
336                        })
337                        .collect();
338                    encoder.encode_field(&value, pg_field)
339                } else {
340                    let value: Vec<_> = array_iter
341                        .map(|i| {
342                            i.and_then(|i| {
343                                DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
344                            })
345                        })
346                        .collect();
347                    encoder.encode_field(&value, pg_field)
348                }
349            }
350            TimeUnit::Microsecond => {
351                let array_iter = arr
352                    .as_any()
353                    .downcast_ref::<TimestampMicrosecondArray>()
354                    .unwrap()
355                    .iter();
356
357                if let Some(tz) = timezone {
358                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
359                    let value: Vec<_> = array_iter
360                        .map(|i| {
361                            i.and_then(|i| {
362                                DateTime::from_timestamp_micros(i).map(|dt| {
363                                    Utc.from_utc_datetime(&dt.naive_utc())
364                                        .with_timezone(&tz)
365                                        .fixed_offset()
366                                })
367                            })
368                        })
369                        .collect();
370                    encoder.encode_field(&value, pg_field)
371                } else {
372                    let value: Vec<_> = array_iter
373                        .map(|i| {
374                            i.and_then(|i| {
375                                DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
376                            })
377                        })
378                        .collect();
379                    encoder.encode_field(&value, pg_field)
380                }
381            }
382            TimeUnit::Nanosecond => {
383                let array_iter = arr
384                    .as_any()
385                    .downcast_ref::<TimestampNanosecondArray>()
386                    .unwrap()
387                    .iter();
388
389                if let Some(tz) = timezone {
390                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
391                    let value: Vec<_> = array_iter
392                        .map(|i| {
393                            i.map(|i| {
394                                Utc.from_utc_datetime(
395                                    &DateTime::from_timestamp_nanos(i).naive_utc(),
396                                )
397                                .with_timezone(&tz)
398                                .fixed_offset()
399                            })
400                        })
401                        .collect();
402                    encoder.encode_field(&value, pg_field)
403                } else {
404                    let value: Vec<_> = array_iter
405                        .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
406                        .collect();
407                    encoder.encode_field(&value, pg_field)
408                }
409            }
410        },
411        DataType::Struct(arrow_fields) => encode_structs(encoder, &arr, arrow_fields, pg_field),
412        DataType::LargeUtf8 => {
413            let value: Vec<Option<&str>> = arr
414                .as_any()
415                .downcast_ref::<LargeStringArray>()
416                .unwrap()
417                .iter()
418                .collect();
419            encoder.encode_field(&value, pg_field)?;
420            Ok(())
421        }
422        DataType::Decimal256(_, s) => {
423            // Convert Decimal256 to string representation for now
424            // since rust_decimal doesn't support 256-bit decimals
425            let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
426            let value: Vec<Option<String>> = (0..decimal_array.len())
427                .map(|i| {
428                    if decimal_array.is_null(i) {
429                        None
430                    } else {
431                        // Convert to string representation
432                        let raw_value = decimal_array.value(i);
433                        let scale = *s as u32;
434                        // Convert i256 to string and handle decimal placement manually
435                        let value_str = raw_value.to_string();
436                        if scale == 0 {
437                            Some(value_str)
438                        } else {
439                            // Insert decimal point
440                            let mut chars: Vec<char> = value_str.chars().collect();
441                            if chars.len() <= scale as usize {
442                                // Prepend zeros if needed
443                                let zeros_needed = scale as usize - chars.len() + 1;
444                                chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
445                                chars.insert(1, '.');
446                            } else {
447                                let decimal_pos = chars.len() - scale as usize;
448                                chars.insert(decimal_pos, '.');
449                            }
450                            Some(chars.into_iter().collect())
451                        }
452                    }
453                })
454                .collect();
455            encoder.encode_field(&value, pg_field)?;
456            Ok(())
457        }
458        DataType::Duration(unit) => match unit {
459            TimeUnit::Second => {
460                let value: Vec<Option<PgInterval>> = arr
461                    .as_any()
462                    .downcast_ref::<DurationSecondArray>()
463                    .unwrap()
464                    .iter()
465                    .map(|val| val.map(|v| PgInterval::new(0, 0, v * 1_000_000i64)))
466                    .collect();
467                encoder.encode_field(&value, pg_field)?;
468                Ok(())
469            }
470            TimeUnit::Millisecond => {
471                let value: Vec<Option<PgInterval>> = arr
472                    .as_any()
473                    .downcast_ref::<DurationMillisecondArray>()
474                    .unwrap()
475                    .iter()
476                    .map(|val| val.map(|v| PgInterval::new(0, 0, v * 1_000i64)))
477                    .collect();
478                encoder.encode_field(&value, pg_field)?;
479                Ok(())
480            }
481            TimeUnit::Microsecond => {
482                let value: Vec<Option<PgInterval>> = arr
483                    .as_any()
484                    .downcast_ref::<DurationMicrosecondArray>()
485                    .unwrap()
486                    .iter()
487                    .map(|val| val.map(|v| PgInterval::new(0, 0, v)))
488                    .collect();
489                encoder.encode_field(&value, pg_field)?;
490                Ok(())
491            }
492            TimeUnit::Nanosecond => {
493                let value: Vec<Option<PgInterval>> = arr
494                    .as_any()
495                    .downcast_ref::<DurationNanosecondArray>()
496                    .unwrap()
497                    .iter()
498                    .map(|val| val.map(|v| PgInterval::new(0, 0, v / 1_000i64)))
499                    .collect();
500                encoder.encode_field(&value, pg_field)?;
501                Ok(())
502            }
503        },
504        DataType::Interval(interval_unit) => match interval_unit {
505            IntervalUnit::YearMonth => {
506                let value: Vec<Option<PgInterval>> = arr
507                    .as_any()
508                    .downcast_ref::<IntervalYearMonthArray>()
509                    .unwrap()
510                    .iter()
511                    .map(|val| val.map(|v| PgInterval::new(v, 0, 0)))
512                    .collect();
513                encoder.encode_field(&value, pg_field)?;
514                Ok(())
515            }
516            IntervalUnit::DayTime => {
517                let value: Vec<Option<PgInterval>> = arr
518                    .as_any()
519                    .downcast_ref::<IntervalDayTimeArray>()
520                    .unwrap()
521                    .iter()
522                    .map(|val| {
523                        val.map(|v| {
524                            let (days, millis) = IntervalDayTimeType::to_parts(v);
525                            PgInterval::new(0, days, millis as i64 * 1000i64)
526                        })
527                    })
528                    .collect();
529                encoder.encode_field(&value, pg_field)?;
530                Ok(())
531            }
532            IntervalUnit::MonthDayNano => {
533                let value: Vec<Option<PgInterval>> = arr
534                    .as_any()
535                    .downcast_ref::<IntervalMonthDayNanoArray>()
536                    .unwrap()
537                    .iter()
538                    .map(|val| {
539                        val.map(|v| {
540                            let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(v);
541                            PgInterval::new(months, days, nanos / 1000i64)
542                        })
543                    })
544                    .collect();
545                encoder.encode_field(&value, pg_field)?;
546                Ok(())
547            }
548        },
549        DataType::List(_) => {
550            // Support for nested lists (list of lists)
551            // For now, convert to string representation
552            let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
553            let value: Vec<Option<String>> = (0..list_array.len())
554                .map(|i| {
555                    if list_array.is_null(i) {
556                        None
557                    } else {
558                        // Convert nested list to string representation
559                        Some(format!("[nested_list_{i}]"))
560                    }
561                })
562                .collect();
563            encoder.encode_field(&value, pg_field)?;
564            Ok(())
565        }
566        DataType::LargeList(_) => {
567            // Support for large lists
568            let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
569            let value: Vec<Option<String>> = (0..list_array.len())
570                .map(|i| {
571                    if list_array.is_null(i) {
572                        None
573                    } else {
574                        Some(format!("[large_list_{i}]"))
575                    }
576                })
577                .collect();
578            encoder.encode_field(&value, pg_field)
579        }
580        DataType::Map(_, _) => {
581            // Support for map types
582            let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
583            let value: Vec<Option<String>> = (0..map_array.len())
584                .map(|i| {
585                    if map_array.is_null(i) {
586                        None
587                    } else {
588                        Some(format!("{{map_{i}}}"))
589                    }
590                })
591                .collect();
592            encoder.encode_field(&value, pg_field)?;
593            Ok(())
594        }
595
596        DataType::Union(_, _) => {
597            // Support for union types
598            let value: Vec<Option<String>> = (0..arr.len())
599                .map(|i| {
600                    if arr.is_null(i) {
601                        None
602                    } else {
603                        Some(format!("union_{i}"))
604                    }
605                })
606                .collect();
607            encoder.encode_field(&value, pg_field)?;
608            Ok(())
609        }
610        DataType::Dictionary(_, _) => {
611            // Support for dictionary types
612            let value: Vec<Option<String>> = (0..arr.len())
613                .map(|i| {
614                    if arr.is_null(i) {
615                        None
616                    } else {
617                        Some(format!("dict_{i}"))
618                    }
619                })
620                .collect();
621            encoder.encode_field(&value, pg_field)?;
622            Ok(())
623        }
624        // TODO: add support for more advanced types (fixed size lists, etc.)
625        list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
626            "Unsupported List Datatype {} and array {:?}",
627            list_type, &arr
628        )))),
629    }
630}