arrow_pg/
encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::PgWireError;
16use pgwire::error::PgWireResult;
17use pgwire::types::ToSqlText;
18use postgres_types::{ToSql, Type};
19use rust_decimal::Decimal;
20use timezone::Tz;
21
22use crate::error::ToSqlError;
23use crate::list_encoder::encode_list;
24use crate::struct_encoder::encode_struct;
25
26pub trait Encoder {
27    fn encode_field_with_type_and_format<T>(
28        &mut self,
29        value: &T,
30        data_type: &Type,
31        format: FieldFormat,
32    ) -> PgWireResult<()>
33    where
34        T: ToSql + ToSqlText + Sized;
35}
36
37impl Encoder for DataRowEncoder {
38    fn encode_field_with_type_and_format<T>(
39        &mut self,
40        value: &T,
41        data_type: &Type,
42        format: FieldFormat,
43    ) -> PgWireResult<()>
44    where
45        T: ToSql + ToSqlText + Sized,
46    {
47        self.encode_field_with_type_and_format(value, data_type, format)
48    }
49}
50
51pub(crate) struct EncodedValue {
52    pub(crate) bytes: BytesMut,
53}
54
55impl std::fmt::Debug for EncodedValue {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("EncodedValue").finish()
58    }
59}
60
61impl ToSql for EncodedValue {
62    fn to_sql(
63        &self,
64        _ty: &Type,
65        out: &mut BytesMut,
66    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
67    where
68        Self: Sized,
69    {
70        out.writer().write_all(&self.bytes)?;
71        Ok(postgres_types::IsNull::No)
72    }
73
74    fn accepts(_ty: &Type) -> bool
75    where
76        Self: Sized,
77    {
78        true
79    }
80
81    fn to_sql_checked(
82        &self,
83        ty: &Type,
84        out: &mut BytesMut,
85    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
86        self.to_sql(ty, out)
87    }
88}
89
90impl ToSqlText for EncodedValue {
91    fn to_sql_text(
92        &self,
93        _ty: &Type,
94        out: &mut BytesMut,
95    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
96    where
97        Self: Sized,
98    {
99        out.writer().write_all(&self.bytes)?;
100        Ok(postgres_types::IsNull::No)
101    }
102}
103
104fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
105    (!arr.is_null(idx)).then(|| {
106        arr.as_any()
107            .downcast_ref::<BooleanArray>()
108            .unwrap()
109            .value(idx)
110    })
111}
112
113macro_rules! get_primitive_value {
114    ($name:ident, $t:ty, $pt:ty) => {
115        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
116            (!arr.is_null(idx)).then(|| {
117                arr.as_any()
118                    .downcast_ref::<PrimitiveArray<$t>>()
119                    .unwrap()
120                    .value(idx)
121            })
122        }
123    };
124}
125
126get_primitive_value!(get_i8_value, Int8Type, i8);
127get_primitive_value!(get_i16_value, Int16Type, i16);
128get_primitive_value!(get_i32_value, Int32Type, i32);
129get_primitive_value!(get_i64_value, Int64Type, i64);
130get_primitive_value!(get_u8_value, UInt8Type, u8);
131get_primitive_value!(get_u16_value, UInt16Type, u16);
132get_primitive_value!(get_u32_value, UInt32Type, u32);
133get_primitive_value!(get_u64_value, UInt64Type, u64);
134get_primitive_value!(get_f32_value, Float32Type, f32);
135get_primitive_value!(get_f64_value, Float64Type, f64);
136
137fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
138    (!arr.is_null(idx)).then(|| {
139        arr.as_any()
140            .downcast_ref::<StringViewArray>()
141            .unwrap()
142            .value(idx)
143    })
144}
145
146fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
147    (!arr.is_null(idx)).then(|| {
148        arr.as_any()
149            .downcast_ref::<StringArray>()
150            .unwrap()
151            .value(idx)
152    })
153}
154
155fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
156    (!arr.is_null(idx)).then(|| {
157        arr.as_any()
158            .downcast_ref::<LargeStringArray>()
159            .unwrap()
160            .value(idx)
161    })
162}
163
164fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
165    (!arr.is_null(idx)).then(|| {
166        arr.as_any()
167            .downcast_ref::<BinaryArray>()
168            .unwrap()
169            .value(idx)
170    })
171}
172
173fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
174    (!arr.is_null(idx)).then(|| {
175        arr.as_any()
176            .downcast_ref::<LargeBinaryArray>()
177            .unwrap()
178            .value(idx)
179    })
180}
181
182fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
183    if arr.is_null(idx) {
184        return None;
185    }
186    arr.as_any()
187        .downcast_ref::<Date32Array>()
188        .unwrap()
189        .value_as_date(idx)
190}
191
192fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
193    if arr.is_null(idx) {
194        return None;
195    }
196    arr.as_any()
197        .downcast_ref::<Date64Array>()
198        .unwrap()
199        .value_as_date(idx)
200}
201
202fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
203    if arr.is_null(idx) {
204        return None;
205    }
206    arr.as_any()
207        .downcast_ref::<Time32SecondArray>()
208        .unwrap()
209        .value_as_datetime(idx)
210}
211
212fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
213    if arr.is_null(idx) {
214        return None;
215    }
216    arr.as_any()
217        .downcast_ref::<Time32MillisecondArray>()
218        .unwrap()
219        .value_as_datetime(idx)
220}
221
222fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
223    if arr.is_null(idx) {
224        return None;
225    }
226    arr.as_any()
227        .downcast_ref::<Time64MicrosecondArray>()
228        .unwrap()
229        .value_as_datetime(idx)
230}
231fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
232    if arr.is_null(idx) {
233        return None;
234    }
235    arr.as_any()
236        .downcast_ref::<Time64NanosecondArray>()
237        .unwrap()
238        .value_as_datetime(idx)
239}
240
241fn get_numeric_128_value(
242    arr: &Arc<dyn Array>,
243    idx: usize,
244    scale: u32,
245) -> PgWireResult<Option<Decimal>> {
246    if arr.is_null(idx) {
247        return Ok(None);
248    }
249
250    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
251    let value = array.value(idx);
252    Decimal::try_from_i128_with_scale(value, scale)
253        .map_err(|e| {
254            let message = match e {
255                rust_decimal::Error::ExceedsMaximumPossibleValue => {
256                    "Exceeds maximum possible value"
257                }
258                rust_decimal::Error::LessThanMinimumPossibleValue => {
259                    "Less than minimum possible value"
260                }
261                rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => {
262                    "Scale exceeds maximum precision"
263                }
264                _ => unreachable!(),
265            };
266            // TODO: add error type in PgWireError
267            PgWireError::ApiError(ToSqlError::from(message))
268        })
269        .map(Some)
270}
271
272pub fn encode_value<T: Encoder>(
273    encoder: &mut T,
274    arr: &Arc<dyn Array>,
275    idx: usize,
276    type_: &Type,
277    format: FieldFormat,
278) -> PgWireResult<()> {
279    match arr.data_type() {
280        DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
281        DataType::Boolean => {
282            encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
283        }
284        DataType::Int8 => {
285            encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
286        }
287        DataType::Int16 => {
288            encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
289        }
290        DataType::Int32 => {
291            encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
292        }
293        DataType::Int64 => {
294            encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
295        }
296        DataType::UInt8 => encoder.encode_field_with_type_and_format(
297            &(get_u8_value(arr, idx).map(|x| x as i8)),
298            type_,
299            format,
300        )?,
301        DataType::UInt16 => encoder.encode_field_with_type_and_format(
302            &(get_u16_value(arr, idx).map(|x| x as i16)),
303            type_,
304            format,
305        )?,
306        DataType::UInt32 => {
307            encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
308        }
309        DataType::UInt64 => encoder.encode_field_with_type_and_format(
310            &(get_u64_value(arr, idx).map(|x| x as i64)),
311            type_,
312            format,
313        )?,
314        DataType::Float32 => {
315            encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
316        }
317        DataType::Float64 => {
318            encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
319        }
320        DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
321            &get_numeric_128_value(arr, idx, *s as u32)?,
322            type_,
323            format,
324        )?,
325        DataType::Utf8 => {
326            encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
327        }
328        DataType::Utf8View => encoder.encode_field_with_type_and_format(
329            &get_utf8_view_value(arr, idx),
330            type_,
331            format,
332        )?,
333        DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
334            &get_large_utf8_value(arr, idx),
335            type_,
336            format,
337        )?,
338        DataType::Binary => {
339            encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
340        }
341        DataType::LargeBinary => encoder.encode_field_with_type_and_format(
342            &get_large_binary_value(arr, idx),
343            type_,
344            format,
345        )?,
346        DataType::Date32 => {
347            encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
348        }
349        DataType::Date64 => {
350            encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
351        }
352        DataType::Time32(unit) => match unit {
353            TimeUnit::Second => encoder.encode_field_with_type_and_format(
354                &get_time32_second_value(arr, idx),
355                type_,
356                format,
357            )?,
358            TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
359                &get_time32_millisecond_value(arr, idx),
360                type_,
361                format,
362            )?,
363            _ => {}
364        },
365        DataType::Time64(unit) => match unit {
366            TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
367                &get_time64_microsecond_value(arr, idx),
368                type_,
369                format,
370            )?,
371            TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
372                &get_time64_nanosecond_value(arr, idx),
373                type_,
374                format,
375            )?,
376            _ => {}
377        },
378        DataType::Timestamp(unit, timezone) => match unit {
379            TimeUnit::Second => {
380                if arr.is_null(idx) {
381                    return encoder.encode_field_with_type_and_format(
382                        &None::<NaiveDateTime>,
383                        type_,
384                        format,
385                    );
386                }
387                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
388                if let Some(tz) = timezone {
389                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
390                    let value = ts_array
391                        .value_as_datetime_with_tz(idx, tz)
392                        .map(|d| d.fixed_offset());
393                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
394                } else {
395                    let value = ts_array.value_as_datetime(idx);
396                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
397                }
398            }
399            TimeUnit::Millisecond => {
400                if arr.is_null(idx) {
401                    return encoder.encode_field_with_type_and_format(
402                        &None::<NaiveDateTime>,
403                        type_,
404                        format,
405                    );
406                }
407                let ts_array = arr
408                    .as_any()
409                    .downcast_ref::<TimestampMillisecondArray>()
410                    .unwrap();
411                if let Some(tz) = timezone {
412                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
413                    let value = ts_array
414                        .value_as_datetime_with_tz(idx, tz)
415                        .map(|d| d.fixed_offset());
416                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
417                } else {
418                    let value = ts_array.value_as_datetime(idx);
419                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
420                }
421            }
422            TimeUnit::Microsecond => {
423                if arr.is_null(idx) {
424                    return encoder.encode_field_with_type_and_format(
425                        &None::<NaiveDateTime>,
426                        type_,
427                        format,
428                    );
429                }
430                let ts_array = arr
431                    .as_any()
432                    .downcast_ref::<TimestampMicrosecondArray>()
433                    .unwrap();
434                if let Some(tz) = timezone {
435                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
436                    let value = ts_array
437                        .value_as_datetime_with_tz(idx, tz)
438                        .map(|d| d.fixed_offset());
439                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
440                } else {
441                    let value = ts_array.value_as_datetime(idx);
442                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
443                }
444            }
445            TimeUnit::Nanosecond => {
446                if arr.is_null(idx) {
447                    return encoder.encode_field_with_type_and_format(
448                        &None::<NaiveDateTime>,
449                        type_,
450                        format,
451                    );
452                }
453                let ts_array = arr
454                    .as_any()
455                    .downcast_ref::<TimestampNanosecondArray>()
456                    .unwrap();
457                if let Some(tz) = timezone {
458                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
459                    let value = ts_array
460                        .value_as_datetime_with_tz(idx, tz)
461                        .map(|d| d.fixed_offset());
462                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
463                } else {
464                    let value = ts_array.value_as_datetime(idx);
465                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
466                }
467            }
468        },
469        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
470            if arr.is_null(idx) {
471                return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
472            }
473            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
474            let value = encode_list(array, type_, format)?;
475            encoder.encode_field_with_type_and_format(&value, type_, format)?
476        }
477        DataType::Struct(_) => {
478            let fields = match type_.kind() {
479                postgres_types::Kind::Composite(fields) => fields,
480                _ => {
481                    return Err(PgWireError::ApiError(ToSqlError::from(format!(
482                        "Failed to unwrap a composite type from type {}",
483                        type_
484                    ))));
485                }
486            };
487            let value = encode_struct(arr, idx, fields, format)?;
488            encoder.encode_field_with_type_and_format(&value, type_, format)?
489        }
490        DataType::Dictionary(_, value_type) => {
491            if arr.is_null(idx) {
492                return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
493            }
494            // Get the dictionary values, ignoring keys
495            // We'll use Int32Type as a common key type, but we're only interested in values
496            macro_rules! get_dict_values {
497                ($key_type:ty) => {
498                    arr.as_any()
499                        .downcast_ref::<DictionaryArray<$key_type>>()
500                        .map(|dict| dict.values())
501                };
502            }
503
504            // Try to extract values using different key types
505            let values = get_dict_values!(Int8Type)
506                .or_else(|| get_dict_values!(Int16Type))
507                .or_else(|| get_dict_values!(Int32Type))
508                .or_else(|| get_dict_values!(Int64Type))
509                .or_else(|| get_dict_values!(UInt8Type))
510                .or_else(|| get_dict_values!(UInt16Type))
511                .or_else(|| get_dict_values!(UInt32Type))
512                .or_else(|| get_dict_values!(UInt64Type))
513                .ok_or_else(|| {
514                    ToSqlError::from(format!(
515                        "Unsupported dictionary key type for value type {}",
516                        value_type
517                    ))
518                })?;
519
520            // If the dictionary has only one value, treat it as a primitive
521            if values.len() == 1 {
522                encode_value(encoder, values, 0, type_, format)?
523            } else {
524                // Otherwise, use value directly indexed by values array
525                encode_value(encoder, values, idx, type_, format)?
526            }
527        }
528        _ => {
529            return Err(PgWireError::ApiError(ToSqlError::from(format!(
530                "Unsupported Datatype {} and array {:?}",
531                arr.data_type(),
532                &arr
533            ))));
534        }
535    }
536
537    Ok(())
538}