arrow_pg/
encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::ToSqlText;
17use postgres_types::{ToSql, Type};
18use rust_decimal::Decimal;
19use timezone::Tz;
20
21use crate::error::ToSqlError;
22use crate::list_encoder::encode_list;
23use crate::struct_encoder::encode_struct;
24
25pub trait Encoder {
26    fn encode_field_with_type_and_format<T>(
27        &mut self,
28        value: &T,
29        data_type: &Type,
30        format: FieldFormat,
31    ) -> PgWireResult<()>
32    where
33        T: ToSql + ToSqlText + Sized;
34}
35
36impl Encoder for DataRowEncoder {
37    fn encode_field_with_type_and_format<T>(
38        &mut self,
39        value: &T,
40        data_type: &Type,
41        format: FieldFormat,
42    ) -> PgWireResult<()>
43    where
44        T: ToSql + ToSqlText + Sized,
45    {
46        self.encode_field_with_type_and_format(value, data_type, format)
47    }
48}
49
50pub(crate) struct EncodedValue {
51    pub(crate) bytes: BytesMut,
52}
53
54impl std::fmt::Debug for EncodedValue {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("EncodedValue").finish()
57    }
58}
59
60impl ToSql for EncodedValue {
61    fn to_sql(
62        &self,
63        _ty: &Type,
64        out: &mut BytesMut,
65    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
66    where
67        Self: Sized,
68    {
69        out.writer().write_all(&self.bytes)?;
70        Ok(postgres_types::IsNull::No)
71    }
72
73    fn accepts(_ty: &Type) -> bool
74    where
75        Self: Sized,
76    {
77        true
78    }
79
80    fn to_sql_checked(
81        &self,
82        ty: &Type,
83        out: &mut BytesMut,
84    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
85        self.to_sql(ty, out)
86    }
87}
88
89impl ToSqlText for EncodedValue {
90    fn to_sql_text(
91        &self,
92        _ty: &Type,
93        out: &mut BytesMut,
94    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
95    where
96        Self: Sized,
97    {
98        out.writer().write_all(&self.bytes)?;
99        Ok(postgres_types::IsNull::No)
100    }
101}
102
103fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
104    (!arr.is_null(idx)).then(|| {
105        arr.as_any()
106            .downcast_ref::<BooleanArray>()
107            .unwrap()
108            .value(idx)
109    })
110}
111
112macro_rules! get_primitive_value {
113    ($name:ident, $t:ty, $pt:ty) => {
114        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
115            (!arr.is_null(idx)).then(|| {
116                arr.as_any()
117                    .downcast_ref::<PrimitiveArray<$t>>()
118                    .unwrap()
119                    .value(idx)
120            })
121        }
122    };
123}
124
125get_primitive_value!(get_i8_value, Int8Type, i8);
126get_primitive_value!(get_i16_value, Int16Type, i16);
127get_primitive_value!(get_i32_value, Int32Type, i32);
128get_primitive_value!(get_i64_value, Int64Type, i64);
129get_primitive_value!(get_u8_value, UInt8Type, u8);
130get_primitive_value!(get_u16_value, UInt16Type, u16);
131get_primitive_value!(get_u32_value, UInt32Type, u32);
132get_primitive_value!(get_u64_value, UInt64Type, u64);
133get_primitive_value!(get_f32_value, Float32Type, f32);
134get_primitive_value!(get_f64_value, Float64Type, f64);
135
136fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
137    (!arr.is_null(idx)).then(|| {
138        arr.as_any()
139            .downcast_ref::<StringViewArray>()
140            .unwrap()
141            .value(idx)
142    })
143}
144
145fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
146    (!arr.is_null(idx)).then(|| {
147        arr.as_any()
148            .downcast_ref::<BinaryViewArray>()
149            .unwrap()
150            .value(idx)
151    })
152}
153
154fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
155    (!arr.is_null(idx)).then(|| {
156        arr.as_any()
157            .downcast_ref::<StringArray>()
158            .unwrap()
159            .value(idx)
160    })
161}
162
163fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
164    (!arr.is_null(idx)).then(|| {
165        arr.as_any()
166            .downcast_ref::<LargeStringArray>()
167            .unwrap()
168            .value(idx)
169    })
170}
171
172fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
173    (!arr.is_null(idx)).then(|| {
174        arr.as_any()
175            .downcast_ref::<BinaryArray>()
176            .unwrap()
177            .value(idx)
178    })
179}
180
181fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
182    (!arr.is_null(idx)).then(|| {
183        arr.as_any()
184            .downcast_ref::<LargeBinaryArray>()
185            .unwrap()
186            .value(idx)
187    })
188}
189
190fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
191    if arr.is_null(idx) {
192        return None;
193    }
194    arr.as_any()
195        .downcast_ref::<Date32Array>()
196        .unwrap()
197        .value_as_date(idx)
198}
199
200fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
201    if arr.is_null(idx) {
202        return None;
203    }
204    arr.as_any()
205        .downcast_ref::<Date64Array>()
206        .unwrap()
207        .value_as_date(idx)
208}
209
210fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
211    if arr.is_null(idx) {
212        return None;
213    }
214    arr.as_any()
215        .downcast_ref::<Time32SecondArray>()
216        .unwrap()
217        .value_as_datetime(idx)
218}
219
220fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
221    if arr.is_null(idx) {
222        return None;
223    }
224    arr.as_any()
225        .downcast_ref::<Time32MillisecondArray>()
226        .unwrap()
227        .value_as_datetime(idx)
228}
229
230fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
231    if arr.is_null(idx) {
232        return None;
233    }
234    arr.as_any()
235        .downcast_ref::<Time64MicrosecondArray>()
236        .unwrap()
237        .value_as_datetime(idx)
238}
239fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
240    if arr.is_null(idx) {
241        return None;
242    }
243    arr.as_any()
244        .downcast_ref::<Time64NanosecondArray>()
245        .unwrap()
246        .value_as_datetime(idx)
247}
248
249fn get_numeric_128_value(
250    arr: &Arc<dyn Array>,
251    idx: usize,
252    scale: u32,
253) -> PgWireResult<Option<Decimal>> {
254    if arr.is_null(idx) {
255        return Ok(None);
256    }
257
258    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
259    let value = array.value(idx);
260    Decimal::try_from_i128_with_scale(value, scale)
261        .map_err(|e| {
262            let error_code = match e {
263                rust_decimal::Error::ExceedsMaximumPossibleValue => {
264                    "22003" // numeric_value_out_of_range
265                }
266                rust_decimal::Error::LessThanMinimumPossibleValue => {
267                    "22003" // numeric_value_out_of_range
268                }
269                rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
270                    return PgWireError::UserError(Box::new(ErrorInfo::new(
271                        "ERROR".to_string(),
272                        "22003".to_string(),
273                        format!("Scale {scale} exceeds maximum precision for numeric type"),
274                    )));
275                }
276                _ => "22003", // generic numeric_value_out_of_range
277            };
278            PgWireError::UserError(Box::new(ErrorInfo::new(
279                "ERROR".to_string(),
280                error_code.to_string(),
281                format!("Numeric value conversion failed: {e}"),
282            )))
283        })
284        .map(Some)
285}
286
287pub fn encode_value<T: Encoder>(
288    encoder: &mut T,
289    arr: &Arc<dyn Array>,
290    idx: usize,
291    type_: &Type,
292    format: FieldFormat,
293) -> PgWireResult<()> {
294    match arr.data_type() {
295        DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
296        DataType::Boolean => {
297            encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
298        }
299        DataType::Int8 => {
300            encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
301        }
302        DataType::Int16 => {
303            encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
304        }
305        DataType::Int32 => {
306            encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
307        }
308        DataType::Int64 => {
309            encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
310        }
311        DataType::UInt8 => encoder.encode_field_with_type_and_format(
312            &(get_u8_value(arr, idx).map(|x| x as i8)),
313            type_,
314            format,
315        )?,
316        DataType::UInt16 => encoder.encode_field_with_type_and_format(
317            &(get_u16_value(arr, idx).map(|x| x as i16)),
318            type_,
319            format,
320        )?,
321        DataType::UInt32 => {
322            encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
323        }
324        DataType::UInt64 => encoder.encode_field_with_type_and_format(
325            &(get_u64_value(arr, idx).map(|x| x as i64)),
326            type_,
327            format,
328        )?,
329        DataType::Float32 => {
330            encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
331        }
332        DataType::Float64 => {
333            encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
334        }
335        DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
336            &get_numeric_128_value(arr, idx, *s as u32)?,
337            type_,
338            format,
339        )?,
340        DataType::Utf8 => {
341            encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
342        }
343        DataType::Utf8View => encoder.encode_field_with_type_and_format(
344            &get_utf8_view_value(arr, idx),
345            type_,
346            format,
347        )?,
348        DataType::BinaryView => encoder.encode_field_with_type_and_format(
349            &get_binary_view_value(arr, idx),
350            type_,
351            format,
352        )?,
353        DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
354            &get_large_utf8_value(arr, idx),
355            type_,
356            format,
357        )?,
358        DataType::Binary => {
359            encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
360        }
361        DataType::LargeBinary => encoder.encode_field_with_type_and_format(
362            &get_large_binary_value(arr, idx),
363            type_,
364            format,
365        )?,
366        DataType::Date32 => {
367            encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
368        }
369        DataType::Date64 => {
370            encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
371        }
372        DataType::Time32(unit) => match unit {
373            TimeUnit::Second => encoder.encode_field_with_type_and_format(
374                &get_time32_second_value(arr, idx),
375                type_,
376                format,
377            )?,
378            TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
379                &get_time32_millisecond_value(arr, idx),
380                type_,
381                format,
382            )?,
383            _ => {}
384        },
385        DataType::Time64(unit) => match unit {
386            TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
387                &get_time64_microsecond_value(arr, idx),
388                type_,
389                format,
390            )?,
391            TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
392                &get_time64_nanosecond_value(arr, idx),
393                type_,
394                format,
395            )?,
396            _ => {}
397        },
398        DataType::Timestamp(unit, timezone) => match unit {
399            TimeUnit::Second => {
400                if arr.is_null(idx) {
401                    return encoder.encode_field_with_type_and_format(
402                        &None::<NaiveDateTime>,
403                        type_,
404                        format,
405                    );
406                }
407                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
408                if let Some(tz) = timezone {
409                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
410                    let value = ts_array
411                        .value_as_datetime_with_tz(idx, tz)
412                        .map(|d| d.fixed_offset());
413                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
414                } else {
415                    let value = ts_array.value_as_datetime(idx);
416                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
417                }
418            }
419            TimeUnit::Millisecond => {
420                if arr.is_null(idx) {
421                    return encoder.encode_field_with_type_and_format(
422                        &None::<NaiveDateTime>,
423                        type_,
424                        format,
425                    );
426                }
427                let ts_array = arr
428                    .as_any()
429                    .downcast_ref::<TimestampMillisecondArray>()
430                    .unwrap();
431                if let Some(tz) = timezone {
432                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
433                    let value = ts_array
434                        .value_as_datetime_with_tz(idx, tz)
435                        .map(|d| d.fixed_offset());
436                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
437                } else {
438                    let value = ts_array.value_as_datetime(idx);
439                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
440                }
441            }
442            TimeUnit::Microsecond => {
443                if arr.is_null(idx) {
444                    return encoder.encode_field_with_type_and_format(
445                        &None::<NaiveDateTime>,
446                        type_,
447                        format,
448                    );
449                }
450                let ts_array = arr
451                    .as_any()
452                    .downcast_ref::<TimestampMicrosecondArray>()
453                    .unwrap();
454                if let Some(tz) = timezone {
455                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
456                    let value = ts_array
457                        .value_as_datetime_with_tz(idx, tz)
458                        .map(|d| d.fixed_offset());
459                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
460                } else {
461                    let value = ts_array.value_as_datetime(idx);
462                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
463                }
464            }
465            TimeUnit::Nanosecond => {
466                if arr.is_null(idx) {
467                    return encoder.encode_field_with_type_and_format(
468                        &None::<NaiveDateTime>,
469                        type_,
470                        format,
471                    );
472                }
473                let ts_array = arr
474                    .as_any()
475                    .downcast_ref::<TimestampNanosecondArray>()
476                    .unwrap();
477                if let Some(tz) = timezone {
478                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
479                    let value = ts_array
480                        .value_as_datetime_with_tz(idx, tz)
481                        .map(|d| d.fixed_offset());
482                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
483                } else {
484                    let value = ts_array.value_as_datetime(idx);
485                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
486                }
487            }
488        },
489        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
490            if arr.is_null(idx) {
491                return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
492            }
493            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
494            let value = encode_list(array, type_, format)?;
495            encoder.encode_field_with_type_and_format(&value, type_, format)?
496        }
497        DataType::Struct(_) => {
498            let fields = match type_.kind() {
499                postgres_types::Kind::Composite(fields) => fields,
500                _ => {
501                    return Err(PgWireError::ApiError(ToSqlError::from(format!(
502                        "Failed to unwrap a composite type from type {type_}"
503                    ))));
504                }
505            };
506            let value = encode_struct(arr, idx, fields, format)?;
507            encoder.encode_field_with_type_and_format(&value, type_, format)?
508        }
509        DataType::Dictionary(_, value_type) => {
510            if arr.is_null(idx) {
511                return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
512            }
513            // Get the dictionary values and the mapped row index
514            macro_rules! get_dict_values_and_index {
515                ($key_type:ty) => {
516                    arr.as_any()
517                        .downcast_ref::<DictionaryArray<$key_type>>()
518                        .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
519                };
520            }
521
522            // Try to extract values using different key types
523            let (values, idx) = get_dict_values_and_index!(Int8Type)
524                .or_else(|| get_dict_values_and_index!(Int16Type))
525                .or_else(|| get_dict_values_and_index!(Int32Type))
526                .or_else(|| get_dict_values_and_index!(Int64Type))
527                .or_else(|| get_dict_values_and_index!(UInt8Type))
528                .or_else(|| get_dict_values_and_index!(UInt16Type))
529                .or_else(|| get_dict_values_and_index!(UInt32Type))
530                .or_else(|| get_dict_values_and_index!(UInt64Type))
531                .ok_or_else(|| {
532                    ToSqlError::from(format!(
533                        "Unsupported dictionary key type for value type {value_type}"
534                    ))
535                })?;
536
537            encode_value(encoder, values, idx, type_, format)?
538        }
539        _ => {
540            return Err(PgWireError::ApiError(ToSqlError::from(format!(
541                "Unsupported Datatype {} and array {:?}",
542                arr.data_type(),
543                &arr
544            ))));
545        }
546    }
547
548    Ok(())
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn encodes_dictionary_array() {
557        #[derive(Default)]
558        struct MockEncoder {
559            encoded_value: String,
560        }
561
562        impl Encoder for MockEncoder {
563            fn encode_field_with_type_and_format<T>(
564                &mut self,
565                value: &T,
566                data_type: &Type,
567                _format: FieldFormat,
568            ) -> PgWireResult<()>
569            where
570                T: ToSql + ToSqlText + Sized,
571            {
572                let mut bytes = BytesMut::new();
573                let _sql_text = value.to_sql_text(data_type, &mut bytes);
574                let string = String::from_utf8(bytes.to_vec());
575                self.encoded_value = string.unwrap();
576                Ok(())
577            }
578        }
579
580        let val = "~!@&$[]()@@!!";
581        let value = StringArray::from_iter_values([val]);
582        let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
583        let dict_arr: Arc<dyn Array> =
584            Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
585
586        let mut encoder = MockEncoder::default();
587
588        let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text);
589
590        assert!(result.is_ok());
591
592        assert!(encoder.encoded_value == val);
593    }
594}