arrow_pg/
encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::NaiveTime;
11use chrono::{NaiveDate, NaiveDateTime};
12#[cfg(feature = "datafusion")]
13use datafusion::arrow::{array::*, datatypes::*};
14use pgwire::api::results::{DataRowEncoder, FieldInfo};
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::format::FormatOptions;
17use pgwire::types::ToSqlText;
18use postgres_types::{ToSql, Type};
19use rust_decimal::Decimal;
20use timezone::Tz;
21
22use crate::error::ToSqlError;
23use crate::list_encoder::encode_list;
24use crate::struct_encoder::encode_struct;
25
26pub trait Encoder {
27    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
28    where
29        T: ToSql + ToSqlText + Sized;
30}
31
32impl Encoder for DataRowEncoder {
33    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
34    where
35        T: ToSql + ToSqlText + Sized,
36    {
37        self.encode_field_with_type_and_format(
38            value,
39            pg_field.datatype(),
40            pg_field.format(),
41            pg_field.format_options(),
42        )
43    }
44}
45
46pub(crate) struct EncodedValue {
47    pub(crate) bytes: BytesMut,
48}
49
50impl std::fmt::Debug for EncodedValue {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("EncodedValue").finish()
53    }
54}
55
56impl ToSql for EncodedValue {
57    fn to_sql(
58        &self,
59        _ty: &Type,
60        out: &mut BytesMut,
61    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
62    where
63        Self: Sized,
64    {
65        out.writer().write_all(&self.bytes)?;
66        Ok(postgres_types::IsNull::No)
67    }
68
69    fn accepts(_ty: &Type) -> bool
70    where
71        Self: Sized,
72    {
73        true
74    }
75
76    fn to_sql_checked(
77        &self,
78        ty: &Type,
79        out: &mut BytesMut,
80    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
81        self.to_sql(ty, out)
82    }
83}
84
85impl ToSqlText for EncodedValue {
86    fn to_sql_text(
87        &self,
88        _ty: &Type,
89        out: &mut BytesMut,
90        _format_options: &FormatOptions,
91    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
92    where
93        Self: Sized,
94    {
95        out.writer().write_all(&self.bytes)?;
96        Ok(postgres_types::IsNull::No)
97    }
98}
99
100fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
101    (!arr.is_null(idx)).then(|| {
102        arr.as_any()
103            .downcast_ref::<BooleanArray>()
104            .unwrap()
105            .value(idx)
106    })
107}
108
109macro_rules! get_primitive_value {
110    ($name:ident, $t:ty, $pt:ty) => {
111        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
112            (!arr.is_null(idx)).then(|| {
113                arr.as_any()
114                    .downcast_ref::<PrimitiveArray<$t>>()
115                    .unwrap()
116                    .value(idx)
117            })
118        }
119    };
120}
121
122get_primitive_value!(get_i8_value, Int8Type, i8);
123get_primitive_value!(get_i16_value, Int16Type, i16);
124get_primitive_value!(get_i32_value, Int32Type, i32);
125get_primitive_value!(get_i64_value, Int64Type, i64);
126get_primitive_value!(get_u8_value, UInt8Type, u8);
127get_primitive_value!(get_u16_value, UInt16Type, u16);
128get_primitive_value!(get_u32_value, UInt32Type, u32);
129get_primitive_value!(get_u64_value, UInt64Type, u64);
130get_primitive_value!(get_f32_value, Float32Type, f32);
131get_primitive_value!(get_f64_value, Float64Type, f64);
132
133fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
134    (!arr.is_null(idx)).then(|| {
135        arr.as_any()
136            .downcast_ref::<StringViewArray>()
137            .unwrap()
138            .value(idx)
139    })
140}
141
142fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
143    (!arr.is_null(idx)).then(|| {
144        arr.as_any()
145            .downcast_ref::<BinaryViewArray>()
146            .unwrap()
147            .value(idx)
148    })
149}
150
151fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
152    (!arr.is_null(idx)).then(|| {
153        arr.as_any()
154            .downcast_ref::<StringArray>()
155            .unwrap()
156            .value(idx)
157    })
158}
159
160fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
161    (!arr.is_null(idx)).then(|| {
162        arr.as_any()
163            .downcast_ref::<LargeStringArray>()
164            .unwrap()
165            .value(idx)
166    })
167}
168
169fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
170    (!arr.is_null(idx)).then(|| {
171        arr.as_any()
172            .downcast_ref::<BinaryArray>()
173            .unwrap()
174            .value(idx)
175    })
176}
177
178fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
179    (!arr.is_null(idx)).then(|| {
180        arr.as_any()
181            .downcast_ref::<LargeBinaryArray>()
182            .unwrap()
183            .value(idx)
184    })
185}
186
187fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
188    if arr.is_null(idx) {
189        return None;
190    }
191    arr.as_any()
192        .downcast_ref::<Date32Array>()
193        .unwrap()
194        .value_as_date(idx)
195}
196
197fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
198    if arr.is_null(idx) {
199        return None;
200    }
201    arr.as_any()
202        .downcast_ref::<Date64Array>()
203        .unwrap()
204        .value_as_date(idx)
205}
206
207fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
208    if arr.is_null(idx) {
209        return None;
210    }
211    arr.as_any()
212        .downcast_ref::<Time32SecondArray>()
213        .unwrap()
214        .value_as_time(idx)
215}
216
217fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
218    if arr.is_null(idx) {
219        return None;
220    }
221    arr.as_any()
222        .downcast_ref::<Time32MillisecondArray>()
223        .unwrap()
224        .value_as_time(idx)
225}
226
227fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
228    if arr.is_null(idx) {
229        return None;
230    }
231    arr.as_any()
232        .downcast_ref::<Time64MicrosecondArray>()
233        .unwrap()
234        .value_as_time(idx)
235}
236fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
237    if arr.is_null(idx) {
238        return None;
239    }
240    arr.as_any()
241        .downcast_ref::<Time64NanosecondArray>()
242        .unwrap()
243        .value_as_time(idx)
244}
245
246fn get_numeric_128_value(
247    arr: &Arc<dyn Array>,
248    idx: usize,
249    scale: u32,
250) -> PgWireResult<Option<Decimal>> {
251    if arr.is_null(idx) {
252        return Ok(None);
253    }
254
255    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
256    let value = array.value(idx);
257    Decimal::try_from_i128_with_scale(value, scale)
258        .map_err(|e| {
259            let error_code = match e {
260                rust_decimal::Error::ExceedsMaximumPossibleValue => {
261                    "22003" // numeric_value_out_of_range
262                }
263                rust_decimal::Error::LessThanMinimumPossibleValue => {
264                    "22003" // numeric_value_out_of_range
265                }
266                rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
267                    return PgWireError::UserError(Box::new(ErrorInfo::new(
268                        "ERROR".to_string(),
269                        "22003".to_string(),
270                        format!("Scale {scale} exceeds maximum precision for numeric type"),
271                    )));
272                }
273                _ => "22003", // generic numeric_value_out_of_range
274            };
275            PgWireError::UserError(Box::new(ErrorInfo::new(
276                "ERROR".to_string(),
277                error_code.to_string(),
278                format!("Numeric value conversion failed: {e}"),
279            )))
280        })
281        .map(Some)
282}
283
284pub fn encode_value<T: Encoder>(
285    encoder: &mut T,
286    arr: &Arc<dyn Array>,
287    idx: usize,
288    pg_field: &FieldInfo,
289) -> PgWireResult<()> {
290    let type_ = pg_field.datatype();
291
292    match arr.data_type() {
293        DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
294        DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
295        DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
296        DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?,
297        DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?,
298        DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?,
299        DataType::UInt8 => {
300            encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i8)), pg_field)?
301        }
302        DataType::UInt16 => {
303            encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i16)), pg_field)?
304        }
305        DataType::UInt32 => encoder.encode_field(&get_u32_value(arr, idx), pg_field)?,
306        DataType::UInt64 => {
307            encoder.encode_field(&(get_u64_value(arr, idx).map(|x| x as i64)), pg_field)?
308        }
309        DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?,
310        DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?,
311        DataType::Decimal128(_, s) => {
312            encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)?
313        }
314        DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?,
315        DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?,
316        DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?,
317        DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?,
318        DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?,
319        DataType::LargeBinary => {
320            encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)?
321        }
322        DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?,
323        DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?,
324        DataType::Time32(unit) => match unit {
325            TimeUnit::Second => {
326                encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)?
327            }
328            TimeUnit::Millisecond => {
329                encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)?
330            }
331            _ => {}
332        },
333        DataType::Time64(unit) => match unit {
334            TimeUnit::Microsecond => {
335                encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)?
336            }
337            TimeUnit::Nanosecond => {
338                encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)?
339            }
340            _ => {}
341        },
342        DataType::Timestamp(unit, timezone) => match unit {
343            TimeUnit::Second => {
344                if arr.is_null(idx) {
345                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
346                }
347                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
348                if let Some(tz) = timezone {
349                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
350                    let value = ts_array
351                        .value_as_datetime_with_tz(idx, tz)
352                        .map(|d| d.fixed_offset());
353
354                    encoder.encode_field(&value, pg_field)?;
355                } else {
356                    let value = ts_array.value_as_datetime(idx);
357                    encoder.encode_field(&value, pg_field)?;
358                }
359            }
360            TimeUnit::Millisecond => {
361                if arr.is_null(idx) {
362                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
363                }
364                let ts_array = arr
365                    .as_any()
366                    .downcast_ref::<TimestampMillisecondArray>()
367                    .unwrap();
368                if let Some(tz) = timezone {
369                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
370                    let value = ts_array
371                        .value_as_datetime_with_tz(idx, tz)
372                        .map(|d| d.fixed_offset());
373                    encoder.encode_field(&value, pg_field)?;
374                } else {
375                    let value = ts_array.value_as_datetime(idx);
376                    encoder.encode_field(&value, pg_field)?;
377                }
378            }
379            TimeUnit::Microsecond => {
380                if arr.is_null(idx) {
381                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
382                }
383                let ts_array = arr
384                    .as_any()
385                    .downcast_ref::<TimestampMicrosecondArray>()
386                    .unwrap();
387                if let Some(tz) = timezone {
388                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
389                    let value = ts_array
390                        .value_as_datetime_with_tz(idx, tz)
391                        .map(|d| d.fixed_offset());
392                    encoder.encode_field(&value, pg_field)?;
393                } else {
394                    let value = ts_array.value_as_datetime(idx);
395                    encoder.encode_field(&value, pg_field)?;
396                }
397            }
398            TimeUnit::Nanosecond => {
399                if arr.is_null(idx) {
400                    return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
401                }
402                let ts_array = arr
403                    .as_any()
404                    .downcast_ref::<TimestampNanosecondArray>()
405                    .unwrap();
406                if let Some(tz) = timezone {
407                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
408                    let value = ts_array
409                        .value_as_datetime_with_tz(idx, tz)
410                        .map(|d| d.fixed_offset());
411                    encoder.encode_field(&value, pg_field)?;
412                } else {
413                    let value = ts_array.value_as_datetime(idx);
414                    encoder.encode_field(&value, pg_field)?;
415                }
416            }
417        },
418        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
419            if arr.is_null(idx) {
420                return encoder.encode_field(&None::<&[i8]>, pg_field);
421            }
422            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
423            let value = encode_list(array, pg_field)?;
424            encoder.encode_field(&value, pg_field)?
425        }
426        DataType::Struct(_) => {
427            let fields = match type_.kind() {
428                postgres_types::Kind::Composite(fields) => fields,
429                _ => {
430                    return Err(PgWireError::ApiError(ToSqlError::from(format!(
431                        "Failed to unwrap a composite type from type {type_}"
432                    ))));
433                }
434            };
435            let value = encode_struct(arr, idx, fields, pg_field)?;
436            encoder.encode_field(&value, pg_field)?
437        }
438        DataType::Dictionary(_, value_type) => {
439            if arr.is_null(idx) {
440                return encoder.encode_field(&None::<i8>, pg_field);
441            }
442            // Get the dictionary values and the mapped row index
443            macro_rules! get_dict_values_and_index {
444                ($key_type:ty) => {
445                    arr.as_any()
446                        .downcast_ref::<DictionaryArray<$key_type>>()
447                        .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
448                };
449            }
450
451            // Try to extract values using different key types
452            let (values, idx) = get_dict_values_and_index!(Int8Type)
453                .or_else(|| get_dict_values_and_index!(Int16Type))
454                .or_else(|| get_dict_values_and_index!(Int32Type))
455                .or_else(|| get_dict_values_and_index!(Int64Type))
456                .or_else(|| get_dict_values_and_index!(UInt8Type))
457                .or_else(|| get_dict_values_and_index!(UInt16Type))
458                .or_else(|| get_dict_values_and_index!(UInt32Type))
459                .or_else(|| get_dict_values_and_index!(UInt64Type))
460                .ok_or_else(|| {
461                    ToSqlError::from(format!(
462                        "Unsupported dictionary key type for value type {value_type}"
463                    ))
464                })?;
465
466            encode_value(encoder, values, idx, pg_field)?
467        }
468        _ => {
469            return Err(PgWireError::ApiError(ToSqlError::from(format!(
470                "Unsupported Datatype {} and array {:?}",
471                arr.data_type(),
472                &arr
473            ))));
474        }
475    }
476
477    Ok(())
478}
479
480#[cfg(test)]
481mod tests {
482    use pgwire::api::results::FieldFormat;
483
484    use super::*;
485
486    #[test]
487    fn encodes_dictionary_array() {
488        #[derive(Default)]
489        struct MockEncoder {
490            encoded_value: String,
491        }
492
493        impl Encoder for MockEncoder {
494            fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
495            where
496                T: ToSql + ToSqlText + Sized,
497            {
498                let mut bytes = BytesMut::new();
499                let _sql_text =
500                    value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default());
501                let string = String::from_utf8(bytes.to_vec());
502                self.encoded_value = string.unwrap();
503                Ok(())
504            }
505        }
506
507        let val = "~!@&$[]()@@!!";
508        let value = StringArray::from_iter_values([val]);
509        let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
510        let dict_arr: Arc<dyn Array> =
511            Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
512
513        let mut encoder = MockEncoder::default();
514
515        let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text);
516        let result = encode_value(&mut encoder, &dict_arr, 2, &pg_field);
517
518        assert!(result.is_ok());
519
520        assert!(encoder.encoded_value == val);
521    }
522
523    #[test]
524    fn test_get_time32_second_value() {
525        let array = Time32SecondArray::from_iter_values([3723_i32]);
526        let array: Arc<dyn Array> = Arc::new(array);
527        let value = get_time32_second_value(&array, 0);
528        assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap());
529    }
530
531    #[test]
532    fn test_get_time32_millisecond_value() {
533        let array = Time32MillisecondArray::from_iter_values([3723001_i32]);
534        let array: Arc<dyn Array> = Arc::new(array);
535        let value = get_time32_millisecond_value(&array, 0);
536        assert_eq!(
537            value,
538            Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap()
539        );
540    }
541
542    #[test]
543    fn test_get_time64_microsecond_value() {
544        let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]);
545        let array: Arc<dyn Array> = Arc::new(array);
546        let value = get_time64_microsecond_value(&array, 0);
547        assert_eq!(
548            value,
549            Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap()
550        );
551    }
552
553    #[test]
554    fn test_get_time64_nanosecond_value() {
555        let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]);
556        let array: Arc<dyn Array> = Arc::new(array);
557        let value = get_time64_nanosecond_value(&array, 0);
558        assert_eq!(
559            value,
560            Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap()
561        );
562    }
563}