arrow_pg/
encoder.rs

1use std::str::FromStr;
2use std::sync::Arc;
3
4#[cfg(not(feature = "datafusion"))]
5use arrow::{array::*, datatypes::*};
6use chrono::NaiveTime;
7use chrono::{NaiveDate, NaiveDateTime};
8#[cfg(feature = "datafusion")]
9use datafusion::arrow::{array::*, datatypes::*};
10use pg_interval::Interval as PgInterval;
11use pgwire::api::results::{DataRowEncoder, FieldInfo};
12use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
13use pgwire::types::ToSqlText;
14use postgres_types::ToSql;
15use rust_decimal::Decimal;
16use timezone::Tz;
17
18use crate::error::ToSqlError;
19#[cfg(feature = "geo")]
20use crate::geo_encoder::encode_geo;
21use crate::list_encoder::encode_list;
22use crate::struct_encoder::encode_struct;
23
24pub trait Encoder {
25    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
26    where
27        T: ToSql + ToSqlText + Sized;
28}
29
30impl Encoder for DataRowEncoder {
31    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
32    where
33        T: ToSql + ToSqlText + Sized,
34    {
35        self.encode_field_with_type_and_format(
36            value,
37            pg_field.datatype(),
38            pg_field.format(),
39            pg_field.format_options(),
40        )
41    }
42}
43
44fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
45    (!arr.is_null(idx)).then(|| {
46        arr.as_any()
47            .downcast_ref::<BooleanArray>()
48            .unwrap()
49            .value(idx)
50    })
51}
52
53macro_rules! get_primitive_value {
54    ($name:ident, $t:ty, $pt:ty) => {
55        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
56            (!arr.is_null(idx)).then(|| {
57                arr.as_any()
58                    .downcast_ref::<PrimitiveArray<$t>>()
59                    .unwrap()
60                    .value(idx)
61            })
62        }
63    };
64}
65
66get_primitive_value!(get_i8_value, Int8Type, i8);
67get_primitive_value!(get_i16_value, Int16Type, i16);
68get_primitive_value!(get_i32_value, Int32Type, i32);
69get_primitive_value!(get_i64_value, Int64Type, i64);
70get_primitive_value!(get_u8_value, UInt8Type, u8);
71get_primitive_value!(get_u16_value, UInt16Type, u16);
72get_primitive_value!(get_u32_value, UInt32Type, u32);
73get_primitive_value!(get_u64_value, UInt64Type, u64);
74
75fn get_u64_as_decimal_value(arr: &Arc<dyn Array>, idx: usize) -> Option<Decimal> {
76    get_u64_value(arr, idx).map(Decimal::from)
77}
78get_primitive_value!(get_f32_value, Float32Type, f32);
79get_primitive_value!(get_f64_value, Float64Type, f64);
80
81fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
82    (!arr.is_null(idx)).then(|| {
83        arr.as_any()
84            .downcast_ref::<StringViewArray>()
85            .unwrap()
86            .value(idx)
87    })
88}
89
90fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
91    (!arr.is_null(idx)).then(|| {
92        arr.as_any()
93            .downcast_ref::<BinaryViewArray>()
94            .unwrap()
95            .value(idx)
96    })
97}
98
99fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
100    (!arr.is_null(idx)).then(|| {
101        arr.as_any()
102            .downcast_ref::<StringArray>()
103            .unwrap()
104            .value(idx)
105    })
106}
107
108fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
109    (!arr.is_null(idx)).then(|| {
110        arr.as_any()
111            .downcast_ref::<LargeStringArray>()
112            .unwrap()
113            .value(idx)
114    })
115}
116
117fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
118    (!arr.is_null(idx)).then(|| {
119        arr.as_any()
120            .downcast_ref::<BinaryArray>()
121            .unwrap()
122            .value(idx)
123    })
124}
125
126fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
127    (!arr.is_null(idx)).then(|| {
128        arr.as_any()
129            .downcast_ref::<LargeBinaryArray>()
130            .unwrap()
131            .value(idx)
132    })
133}
134
135fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
136    if arr.is_null(idx) {
137        return None;
138    }
139    arr.as_any()
140        .downcast_ref::<Date32Array>()
141        .unwrap()
142        .value_as_date(idx)
143}
144
145fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
146    if arr.is_null(idx) {
147        return None;
148    }
149    arr.as_any()
150        .downcast_ref::<Date64Array>()
151        .unwrap()
152        .value_as_date(idx)
153}
154
155fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
156    if arr.is_null(idx) {
157        return None;
158    }
159    arr.as_any()
160        .downcast_ref::<Time32SecondArray>()
161        .unwrap()
162        .value_as_time(idx)
163}
164
165fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
166    if arr.is_null(idx) {
167        return None;
168    }
169    arr.as_any()
170        .downcast_ref::<Time32MillisecondArray>()
171        .unwrap()
172        .value_as_time(idx)
173}
174
175fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
176    if arr.is_null(idx) {
177        return None;
178    }
179    arr.as_any()
180        .downcast_ref::<Time64MicrosecondArray>()
181        .unwrap()
182        .value_as_time(idx)
183}
184fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
185    if arr.is_null(idx) {
186        return None;
187    }
188    arr.as_any()
189        .downcast_ref::<Time64NanosecondArray>()
190        .unwrap()
191        .value_as_time(idx)
192}
193
194fn get_numeric_128_value(
195    arr: &Arc<dyn Array>,
196    idx: usize,
197    scale: u32,
198) -> PgWireResult<Option<Decimal>> {
199    if arr.is_null(idx) {
200        return Ok(None);
201    }
202
203    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
204    let value = array.value(idx);
205    Decimal::try_from_i128_with_scale(value, scale)
206        .map_err(|e| {
207            let error_code = match e {
208                rust_decimal::Error::ExceedsMaximumPossibleValue => {
209                    "22003" // numeric_value_out_of_range
210                }
211                rust_decimal::Error::LessThanMinimumPossibleValue => {
212                    "22003" // numeric_value_out_of_range
213                }
214                rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
215                    return PgWireError::UserError(Box::new(ErrorInfo::new(
216                        "ERROR".to_string(),
217                        "22003".to_string(),
218                        format!("Scale {scale} exceeds maximum precision for numeric type"),
219                    )));
220                }
221                _ => "22003", // generic numeric_value_out_of_range
222            };
223            PgWireError::UserError(Box::new(ErrorInfo::new(
224                "ERROR".to_string(),
225                error_code.to_string(),
226                format!("Numeric value conversion failed: {e}"),
227            )))
228        })
229        .map(Some)
230}
231
232pub fn encode_value<T: Encoder>(
233    encoder: &mut T,
234    arr: &Arc<dyn Array>,
235    idx: usize,
236    arrow_field: &Field,
237    pg_field: &FieldInfo,
238) -> PgWireResult<()> {
239    let arrow_type = arrow_field.data_type();
240
241    #[cfg(feature = "geo")]
242    if let Some(geoarrow_type) = geoarrow_schema::GeoArrowType::from_extension_field(arrow_field)
243        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
244    {
245        let geoarrow_array: Arc<dyn geoarrow::array::GeoArrowArray> =
246            geoarrow::array::from_arrow_array(arr, arrow_field)
247                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
248
249        return encode_geo(
250            encoder,
251            geoarrow_type,
252            &geoarrow_array,
253            idx,
254            arrow_field,
255            pg_field,
256        );
257    }
258
259    match arrow_type {
260        DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
261        DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
262        DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
263        DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?,
264        DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?,
265        DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?,
266        DataType::UInt8 => {
267            encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i16)), pg_field)?
268        }
269        DataType::UInt16 => {
270            encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i32)), pg_field)?
271        }
272        DataType::UInt32 => {
273            encoder.encode_field(&get_u32_value(arr, idx).map(|x| x as i64), pg_field)?
274        }
275        DataType::UInt64 => encoder.encode_field(&get_u64_as_decimal_value(arr, idx), pg_field)?,
276        DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?,
277        DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?,
278        DataType::Decimal128(_, s) => {
279            encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)?
280        }
281        DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?,
282        DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?,
283        DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?,
284        DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?,
285        DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?,
286        DataType::LargeBinary => {
287            encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)?
288        }
289        DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?,
290        DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?,
291        DataType::Time32(unit) => match unit {
292            TimeUnit::Second => {
293                encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)?
294            }
295            TimeUnit::Millisecond => {
296                encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)?
297            }
298            _ => {}
299        },
300        DataType::Time64(unit) => match unit {
301            TimeUnit::Microsecond => {
302                encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)?
303            }
304            TimeUnit::Nanosecond => {
305                encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)?
306            }
307            _ => {}
308        },
309        DataType::Timestamp(unit, timezone) => match unit {
310            TimeUnit::Second => {
311                if arr.is_null(idx) {
312                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
313                }
314                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
315                if let Some(tz) = timezone {
316                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
317                    let value = ts_array
318                        .value_as_datetime_with_tz(idx, tz)
319                        .map(|d| d.fixed_offset());
320
321                    encoder.encode_field(&value, pg_field)?;
322                } else {
323                    let value = ts_array.value_as_datetime(idx);
324                    encoder.encode_field(&value, pg_field)?;
325                }
326            }
327            TimeUnit::Millisecond => {
328                if arr.is_null(idx) {
329                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
330                }
331                let ts_array = arr
332                    .as_any()
333                    .downcast_ref::<TimestampMillisecondArray>()
334                    .unwrap();
335                if let Some(tz) = timezone {
336                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
337                    let value = ts_array
338                        .value_as_datetime_with_tz(idx, tz)
339                        .map(|d| d.fixed_offset());
340                    encoder.encode_field(&value, pg_field)?;
341                } else {
342                    let value = ts_array.value_as_datetime(idx);
343                    encoder.encode_field(&value, pg_field)?;
344                }
345            }
346            TimeUnit::Microsecond => {
347                if arr.is_null(idx) {
348                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
349                }
350                let ts_array = arr
351                    .as_any()
352                    .downcast_ref::<TimestampMicrosecondArray>()
353                    .unwrap();
354                if let Some(tz) = timezone {
355                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
356                    let value = ts_array
357                        .value_as_datetime_with_tz(idx, tz)
358                        .map(|d| d.fixed_offset());
359                    encoder.encode_field(&value, pg_field)?;
360                } else {
361                    let value = ts_array.value_as_datetime(idx);
362                    encoder.encode_field(&value, pg_field)?;
363                }
364            }
365            TimeUnit::Nanosecond => {
366                if arr.is_null(idx) {
367                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
368                }
369                let ts_array = arr
370                    .as_any()
371                    .downcast_ref::<TimestampNanosecondArray>()
372                    .unwrap();
373                if let Some(tz) = timezone {
374                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
375                    let value = ts_array
376                        .value_as_datetime_with_tz(idx, tz)
377                        .map(|d| d.fixed_offset());
378                    encoder.encode_field(&value, pg_field)?;
379                } else {
380                    let value = ts_array.value_as_datetime(idx);
381                    encoder.encode_field(&value, pg_field)?;
382                }
383            }
384        },
385        DataType::Interval(interval_unit) => match interval_unit {
386            IntervalUnit::YearMonth => {
387                let interval_array = arr
388                    .as_any()
389                    .downcast_ref::<IntervalYearMonthArray>()
390                    .unwrap();
391                let months = IntervalYearMonthType::to_months(interval_array.value(idx));
392                encoder.encode_field(&PgInterval::new(months, 0, 0), pg_field)?;
393            }
394            IntervalUnit::DayTime => {
395                let interval_array = arr.as_any().downcast_ref::<IntervalDayTimeArray>().unwrap();
396                let (days, millis) = IntervalDayTimeType::to_parts(interval_array.value(idx));
397                encoder
398                    .encode_field(&PgInterval::new(0, days, millis as i64 * 1000i64), pg_field)?;
399            }
400            IntervalUnit::MonthDayNano => {
401                let interval_array = arr
402                    .as_any()
403                    .downcast_ref::<IntervalMonthDayNanoArray>()
404                    .unwrap();
405                let (months, days, nanoseconds) =
406                    IntervalMonthDayNanoType::to_parts(interval_array.value(idx));
407
408                encoder.encode_field(
409                    &PgInterval::new(months, days, nanoseconds / 1000i64),
410                    pg_field,
411                )?;
412            }
413        },
414        DataType::Duration(unit) => match unit {
415            TimeUnit::Second => {
416                if arr.is_null(idx) {
417                    return encoder.encode_field(&None::<PgInterval>, pg_field);
418                }
419                let duration_array = arr.as_any().downcast_ref::<DurationSecondArray>().unwrap();
420                let microseconds = duration_array.value(idx) * 1_000_000i64;
421                encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
422            }
423            TimeUnit::Millisecond => {
424                if arr.is_null(idx) {
425                    return encoder.encode_field(&None::<PgInterval>, pg_field);
426                }
427                let duration_array = arr
428                    .as_any()
429                    .downcast_ref::<DurationMillisecondArray>()
430                    .unwrap();
431                let microseconds = duration_array.value(idx) * 1_000i64;
432                encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
433            }
434            TimeUnit::Microsecond => {
435                if arr.is_null(idx) {
436                    return encoder.encode_field(&None::<PgInterval>, pg_field);
437                }
438                let duration_array = arr
439                    .as_any()
440                    .downcast_ref::<DurationMicrosecondArray>()
441                    .unwrap();
442                let microseconds = duration_array.value(idx);
443                encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
444            }
445            TimeUnit::Nanosecond => {
446                if arr.is_null(idx) {
447                    return encoder.encode_field(&None::<PgInterval>, pg_field);
448                }
449                let duration_array = arr
450                    .as_any()
451                    .downcast_ref::<DurationNanosecondArray>()
452                    .unwrap();
453                let microseconds = duration_array.value(idx) / 1_000i64;
454                encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
455            }
456        },
457        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
458            if arr.is_null(idx) {
459                return encoder.encode_field(&None::<&[i8]>, pg_field);
460            }
461            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
462            encode_list(encoder, array, pg_field)?
463        }
464        DataType::Struct(arrow_fields) => encode_struct(encoder, arr, idx, arrow_fields, pg_field)?,
465        DataType::Dictionary(_, value_type) => {
466            if arr.is_null(idx) {
467                return encoder.encode_field(&None::<i8>, pg_field);
468            }
469            // Get the dictionary values and the mapped row index
470            macro_rules! get_dict_values_and_index {
471                ($key_type:ty) => {
472                    arr.as_any()
473                        .downcast_ref::<DictionaryArray<$key_type>>()
474                        .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
475                };
476            }
477
478            // Try to extract values using different key types
479            let (values, idx) = get_dict_values_and_index!(Int8Type)
480                .or_else(|| get_dict_values_and_index!(Int16Type))
481                .or_else(|| get_dict_values_and_index!(Int32Type))
482                .or_else(|| get_dict_values_and_index!(Int64Type))
483                .or_else(|| get_dict_values_and_index!(UInt8Type))
484                .or_else(|| get_dict_values_and_index!(UInt16Type))
485                .or_else(|| get_dict_values_and_index!(UInt32Type))
486                .or_else(|| get_dict_values_and_index!(UInt64Type))
487                .ok_or_else(|| {
488                    ToSqlError::from(format!(
489                        "Unsupported dictionary key type for value type {value_type}"
490                    ))
491                })?;
492
493            let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true);
494
495            encode_value(encoder, values, idx, &inner_arrow_field, pg_field)?
496        }
497        _ => {
498            return Err(PgWireError::ApiError(ToSqlError::from(format!(
499                "Unsupported Datatype {} and array {:?}",
500                arr.data_type(),
501                &arr
502            ))));
503        }
504    }
505
506    Ok(())
507}
508
509#[cfg(test)]
510mod tests {
511    use bytes::BytesMut;
512    use pgwire::{api::results::FieldFormat, types::format::FormatOptions};
513    use postgres_types::Type;
514
515    use super::*;
516
517    #[test]
518    fn encodes_dictionary_array() {
519        #[derive(Default)]
520        struct MockEncoder {
521            encoded_value: String,
522        }
523
524        impl Encoder for MockEncoder {
525            fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
526            where
527                T: ToSql + ToSqlText + Sized,
528            {
529                let mut bytes = BytesMut::new();
530                let _sql_text =
531                    value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default());
532                let string = String::from_utf8(bytes.to_vec());
533                self.encoded_value = string.unwrap();
534                Ok(())
535            }
536        }
537
538        let val = "~!@&$[]()@@!!";
539        let value = StringArray::from_iter_values([val]);
540        let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
541        let dict_arr: Arc<dyn Array> =
542            Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
543
544        let mut encoder = MockEncoder::default();
545
546        let arrow_field = Field::new(
547            "x",
548            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
549            true,
550        );
551        let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text);
552        let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field);
553
554        assert!(result.is_ok());
555
556        assert!(encoder.encoded_value == val);
557    }
558
559    #[test]
560    fn test_get_time32_second_value() {
561        let array = Time32SecondArray::from_iter_values([3723_i32]);
562        let array: Arc<dyn Array> = Arc::new(array);
563        let value = get_time32_second_value(&array, 0);
564        assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap());
565    }
566
567    #[test]
568    fn test_get_time32_millisecond_value() {
569        let array = Time32MillisecondArray::from_iter_values([3723001_i32]);
570        let array: Arc<dyn Array> = Arc::new(array);
571        let value = get_time32_millisecond_value(&array, 0);
572        assert_eq!(
573            value,
574            Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap()
575        );
576    }
577
578    #[test]
579    fn test_get_time64_microsecond_value() {
580        let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]);
581        let array: Arc<dyn Array> = Arc::new(array);
582        let value = get_time64_microsecond_value(&array, 0);
583        assert_eq!(
584            value,
585            Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap()
586        );
587    }
588
589    #[test]
590    fn test_get_time64_nanosecond_value() {
591        let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]);
592        let array: Arc<dyn Array> = Arc::new(array);
593        let value = get_time64_nanosecond_value(&array, 0);
594        assert_eq!(
595            value,
596            Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap()
597        );
598    }
599}