arrow_pg/
list_encoder.rs

1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5    array::{
6        timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
7        Decimal256Array, DurationMicrosecondArray, LargeBinaryArray, LargeListArray,
8        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, Time32MillisecondArray,
9        Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
10        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
11        TimestampSecondArray,
12    },
13    datatypes::{
14        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
15        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
16        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
17    },
18    temporal_conversions::{as_date, as_time},
19};
20#[cfg(feature = "datafusion")]
21use datafusion::arrow::{
22    array::{
23        timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
24        Decimal256Array, DurationMicrosecondArray, LargeBinaryArray, LargeListArray,
25        LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, Time32MillisecondArray,
26        Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
27        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
28        TimestampSecondArray,
29    },
30    datatypes::{
31        DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
32        Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
33        Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
34    },
35    temporal_conversions::{as_date, as_time},
36};
37
38use bytes::{BufMut, BytesMut};
39use chrono::{DateTime, TimeZone, Utc};
40use pgwire::api::results::FieldFormat;
41use pgwire::error::{PgWireError, PgWireResult};
42use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
43use postgres_types::{ToSql, Type};
44use rust_decimal::Decimal;
45
46use crate::encoder::EncodedValue;
47use crate::error::ToSqlError;
48use crate::struct_encoder::encode_struct;
49
50fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
51    arr.as_any()
52        .downcast_ref::<BooleanArray>()
53        .unwrap()
54        .iter()
55        .collect()
56}
57
58macro_rules! get_primitive_list_value {
59    ($name:ident, $t:ty, $pt:ty) => {
60        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
61            arr.as_any()
62                .downcast_ref::<PrimitiveArray<$t>>()
63                .unwrap()
64                .iter()
65                .collect()
66        }
67    };
68
69    ($name:ident, $t:ty, $pt:ty, $f:expr) => {
70        fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
71            arr.as_any()
72                .downcast_ref::<PrimitiveArray<$t>>()
73                .unwrap()
74                .iter()
75                .map(|val| val.map($f))
76                .collect()
77        }
78    };
79}
80
81get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
82get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
83get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
84get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
85get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
86get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
87    val as i16
88});
89get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
90get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
91    val as i64
92});
93get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
94get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
95
96fn encode_field<T: ToSql + ToSqlText>(
97    t: &[T],
98    type_: &Type,
99    format: FieldFormat,
100) -> PgWireResult<EncodedValue> {
101    let mut bytes = BytesMut::new();
102    match format {
103        FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
104        FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
105    };
106    Ok(EncodedValue { bytes })
107}
108
109pub(crate) fn encode_list(
110    arr: Arc<dyn Array>,
111    type_: &Type,
112    format: FieldFormat,
113) -> PgWireResult<EncodedValue> {
114    match arr.data_type() {
115        DataType::Null => {
116            let mut bytes = BytesMut::new();
117            match format {
118                FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
119                FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
120            }?;
121            Ok(EncodedValue { bytes })
122        }
123        DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
124        DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
125        DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
126        DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
127        DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
128        DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
129        DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
130        DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
131        DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
132        DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
133        DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
134        DataType::Decimal128(_, s) => {
135            let value: Vec<_> = arr
136                .as_any()
137                .downcast_ref::<Decimal128Array>()
138                .unwrap()
139                .iter()
140                .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
141                .collect();
142            encode_field(&value, type_, format)
143        }
144        DataType::Utf8 => {
145            let value: Vec<Option<&str>> = arr
146                .as_any()
147                .downcast_ref::<StringArray>()
148                .unwrap()
149                .iter()
150                .collect();
151            encode_field(&value, type_, format)
152        }
153        DataType::Binary => {
154            let value: Vec<Option<_>> = arr
155                .as_any()
156                .downcast_ref::<BinaryArray>()
157                .unwrap()
158                .iter()
159                .collect();
160            encode_field(&value, type_, format)
161        }
162        DataType::LargeBinary => {
163            let value: Vec<Option<_>> = arr
164                .as_any()
165                .downcast_ref::<LargeBinaryArray>()
166                .unwrap()
167                .iter()
168                .collect();
169            encode_field(&value, type_, format)
170        }
171
172        DataType::Date32 => {
173            let value: Vec<Option<_>> = arr
174                .as_any()
175                .downcast_ref::<Date32Array>()
176                .unwrap()
177                .iter()
178                .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
179                .collect();
180            encode_field(&value, type_, format)
181        }
182        DataType::Date64 => {
183            let value: Vec<Option<_>> = arr
184                .as_any()
185                .downcast_ref::<Date64Array>()
186                .unwrap()
187                .iter()
188                .map(|val| val.and_then(as_date::<Date64Type>))
189                .collect();
190            encode_field(&value, type_, format)
191        }
192        DataType::Time32(unit) => match unit {
193            TimeUnit::Second => {
194                let value: Vec<Option<_>> = arr
195                    .as_any()
196                    .downcast_ref::<Time32SecondArray>()
197                    .unwrap()
198                    .iter()
199                    .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
200                    .collect();
201                encode_field(&value, type_, format)
202            }
203            TimeUnit::Millisecond => {
204                let value: Vec<Option<_>> = arr
205                    .as_any()
206                    .downcast_ref::<Time32MillisecondArray>()
207                    .unwrap()
208                    .iter()
209                    .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
210                    .collect();
211                encode_field(&value, type_, format)
212            }
213            _ => {
214                // Time32 only supports Second and Millisecond in Arrow
215                // Other units are not available, so return an error
216                Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
217            }
218        },
219        DataType::Time64(unit) => match unit {
220            TimeUnit::Microsecond => {
221                let value: Vec<Option<_>> = arr
222                    .as_any()
223                    .downcast_ref::<Time64MicrosecondArray>()
224                    .unwrap()
225                    .iter()
226                    .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
227                    .collect();
228                encode_field(&value, type_, format)
229            }
230            TimeUnit::Nanosecond => {
231                let value: Vec<Option<_>> = arr
232                    .as_any()
233                    .downcast_ref::<Time64NanosecondArray>()
234                    .unwrap()
235                    .iter()
236                    .map(|val| val.and_then(as_time::<Time64NanosecondType>))
237                    .collect();
238                encode_field(&value, type_, format)
239            }
240            _ => {
241                // Time64 only supports Microsecond and Nanosecond in Arrow
242                // Other units are not available, so return an error
243                Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
244            }
245        },
246        DataType::Timestamp(unit, timezone) => match unit {
247            TimeUnit::Second => {
248                let array_iter = arr
249                    .as_any()
250                    .downcast_ref::<TimestampSecondArray>()
251                    .unwrap()
252                    .iter();
253
254                if let Some(tz) = timezone {
255                    let tz = Tz::from_str(tz.as_ref())
256                        .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
257                    let value: Vec<_> = array_iter
258                        .map(|i| {
259                            i.and_then(|i| {
260                                DateTime::from_timestamp(i, 0).map(|dt| {
261                                    Utc.from_utc_datetime(&dt.naive_utc())
262                                        .with_timezone(&tz)
263                                        .fixed_offset()
264                                })
265                            })
266                        })
267                        .collect();
268                    encode_field(&value, type_, format)
269                } else {
270                    let value: Vec<_> = array_iter
271                        .map(|i| {
272                            i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
273                        })
274                        .collect();
275                    encode_field(&value, type_, format)
276                }
277            }
278            TimeUnit::Millisecond => {
279                let array_iter = arr
280                    .as_any()
281                    .downcast_ref::<TimestampMillisecondArray>()
282                    .unwrap()
283                    .iter();
284
285                if let Some(tz) = timezone {
286                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
287                    let value: Vec<_> = array_iter
288                        .map(|i| {
289                            i.and_then(|i| {
290                                DateTime::from_timestamp_millis(i).map(|dt| {
291                                    Utc.from_utc_datetime(&dt.naive_utc())
292                                        .with_timezone(&tz)
293                                        .fixed_offset()
294                                })
295                            })
296                        })
297                        .collect();
298                    encode_field(&value, type_, format)
299                } else {
300                    let value: Vec<_> = array_iter
301                        .map(|i| {
302                            i.and_then(|i| {
303                                DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
304                            })
305                        })
306                        .collect();
307                    encode_field(&value, type_, format)
308                }
309            }
310            TimeUnit::Microsecond => {
311                let array_iter = arr
312                    .as_any()
313                    .downcast_ref::<TimestampMicrosecondArray>()
314                    .unwrap()
315                    .iter();
316
317                if let Some(tz) = timezone {
318                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
319                    let value: Vec<_> = array_iter
320                        .map(|i| {
321                            i.and_then(|i| {
322                                DateTime::from_timestamp_micros(i).map(|dt| {
323                                    Utc.from_utc_datetime(&dt.naive_utc())
324                                        .with_timezone(&tz)
325                                        .fixed_offset()
326                                })
327                            })
328                        })
329                        .collect();
330                    encode_field(&value, type_, format)
331                } else {
332                    let value: Vec<_> = array_iter
333                        .map(|i| {
334                            i.and_then(|i| {
335                                DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
336                            })
337                        })
338                        .collect();
339                    encode_field(&value, type_, format)
340                }
341            }
342            TimeUnit::Nanosecond => {
343                let array_iter = arr
344                    .as_any()
345                    .downcast_ref::<TimestampNanosecondArray>()
346                    .unwrap()
347                    .iter();
348
349                if let Some(tz) = timezone {
350                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
351                    let value: Vec<_> = array_iter
352                        .map(|i| {
353                            i.map(|i| {
354                                Utc.from_utc_datetime(
355                                    &DateTime::from_timestamp_nanos(i).naive_utc(),
356                                )
357                                .with_timezone(&tz)
358                                .fixed_offset()
359                            })
360                        })
361                        .collect();
362                    encode_field(&value, type_, format)
363                } else {
364                    let value: Vec<_> = array_iter
365                        .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
366                        .collect();
367                    encode_field(&value, type_, format)
368                }
369            }
370        },
371        DataType::Struct(_) => {
372            let fields = match type_.kind() {
373                postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
374                _ => Err(format!(
375                    "Expected list type found type {} of kind {:?}",
376                    type_,
377                    type_.kind()
378                )),
379            }
380            .and_then(|struct_type| match struct_type.kind() {
381                postgres_types::Kind::Composite(fields) => Ok(fields),
382                _ => Err(format!(
383                    "Failed to unwrap a composite type inside from type {} kind {:?}",
384                    type_,
385                    type_.kind()
386                )),
387            })
388            .map_err(ToSqlError::from)?;
389
390            let values: PgWireResult<Vec<_>> = (0..arr.len())
391                .map(|row| encode_struct(&arr, row, fields, format))
392                .map(|x| {
393                    if matches!(format, FieldFormat::Text) {
394                        x.map(|opt| {
395                            opt.map(|value| {
396                                let mut w = BytesMut::new();
397                                w.put_u8(b'"');
398                                w.put_slice(
399                                    QUOTE_ESCAPE
400                                        .replace_all(
401                                            &String::from_utf8_lossy(&value.bytes),
402                                            r#"\$1"#,
403                                        )
404                                        .as_bytes(),
405                                );
406                                w.put_u8(b'"');
407                                EncodedValue { bytes: w }
408                            })
409                        })
410                    } else {
411                        x
412                    }
413                })
414                .collect();
415            encode_field(&values?, type_, format)
416        }
417        DataType::LargeUtf8 => {
418            let value: Vec<Option<&str>> = arr
419                .as_any()
420                .downcast_ref::<LargeStringArray>()
421                .unwrap()
422                .iter()
423                .collect();
424            encode_field(&value, type_, format)
425        }
426        DataType::Decimal256(_, s) => {
427            // Convert Decimal256 to string representation for now
428            // since rust_decimal doesn't support 256-bit decimals
429            let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
430            let value: Vec<Option<String>> = (0..decimal_array.len())
431                .map(|i| {
432                    if decimal_array.is_null(i) {
433                        None
434                    } else {
435                        // Convert to string representation
436                        let raw_value = decimal_array.value(i);
437                        let scale = *s as u32;
438                        // Convert i256 to string and handle decimal placement manually
439                        let value_str = raw_value.to_string();
440                        if scale == 0 {
441                            Some(value_str)
442                        } else {
443                            // Insert decimal point
444                            let mut chars: Vec<char> = value_str.chars().collect();
445                            if chars.len() <= scale as usize {
446                                // Prepend zeros if needed
447                                let zeros_needed = scale as usize - chars.len() + 1;
448                                chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
449                                chars.insert(1, '.');
450                            } else {
451                                let decimal_pos = chars.len() - scale as usize;
452                                chars.insert(decimal_pos, '.');
453                            }
454                            Some(chars.into_iter().collect())
455                        }
456                    }
457                })
458                .collect();
459            encode_field(&value, type_, format)
460        }
461        DataType::Duration(_) => {
462            // Convert duration to microseconds for now
463            let value: Vec<Option<i64>> = arr
464                .as_any()
465                .downcast_ref::<DurationMicrosecondArray>()
466                .unwrap()
467                .iter()
468                .collect();
469            encode_field(&value, type_, format)
470        }
471        DataType::List(_) => {
472            // Support for nested lists (list of lists)
473            // For now, convert to string representation
474            let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
475            let value: Vec<Option<String>> = (0..list_array.len())
476                .map(|i| {
477                    if list_array.is_null(i) {
478                        None
479                    } else {
480                        // Convert nested list to string representation
481                        Some(format!("[nested_list_{i}]"))
482                    }
483                })
484                .collect();
485            encode_field(&value, type_, format)
486        }
487        DataType::LargeList(_) => {
488            // Support for large lists
489            let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
490            let value: Vec<Option<String>> = (0..list_array.len())
491                .map(|i| {
492                    if list_array.is_null(i) {
493                        None
494                    } else {
495                        Some(format!("[large_list_{i}]"))
496                    }
497                })
498                .collect();
499            encode_field(&value, type_, format)
500        }
501        DataType::Map(_, _) => {
502            // Support for map types
503            let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
504            let value: Vec<Option<String>> = (0..map_array.len())
505                .map(|i| {
506                    if map_array.is_null(i) {
507                        None
508                    } else {
509                        Some(format!("{{map_{i}}}"))
510                    }
511                })
512                .collect();
513            encode_field(&value, type_, format)
514        }
515
516        DataType::Union(_, _) => {
517            // Support for union types
518            let value: Vec<Option<String>> = (0..arr.len())
519                .map(|i| {
520                    if arr.is_null(i) {
521                        None
522                    } else {
523                        Some(format!("union_{i}"))
524                    }
525                })
526                .collect();
527            encode_field(&value, type_, format)
528        }
529        DataType::Dictionary(_, _) => {
530            // Support for dictionary types
531            let value: Vec<Option<String>> = (0..arr.len())
532                .map(|i| {
533                    if arr.is_null(i) {
534                        None
535                    } else {
536                        Some(format!("dict_{i}"))
537                    }
538                })
539                .collect();
540            encode_field(&value, type_, format)
541        }
542        // TODO: add support for more advanced types (fixed size lists, etc.)
543        list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
544            "Unsupported List Datatype {} and array {:?}",
545            list_type, &arr
546        )))),
547    }
548}