arrow_pg/
encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6use arrow::array::*;
7use arrow::datatypes::*;
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11use pgwire::api::results::DataRowEncoder;
12use pgwire::api::results::FieldFormat;
13use pgwire::error::PgWireError;
14use pgwire::error::PgWireResult;
15use pgwire::types::ToSqlText;
16use postgres_types::{ToSql, Type};
17use rust_decimal::Decimal;
18use timezone::Tz;
19
20use crate::error::ToSqlError;
21use crate::list_encoder::encode_list;
22use crate::struct_encoder::encode_struct;
23
24pub trait Encoder {
25    fn encode_field_with_type_and_format<T>(
26        &mut self,
27        value: &T,
28        data_type: &Type,
29        format: FieldFormat,
30    ) -> PgWireResult<()>
31    where
32        T: ToSql + ToSqlText + Sized;
33}
34
35impl Encoder for DataRowEncoder {
36    fn encode_field_with_type_and_format<T>(
37        &mut self,
38        value: &T,
39        data_type: &Type,
40        format: FieldFormat,
41    ) -> PgWireResult<()>
42    where
43        T: ToSql + ToSqlText + Sized,
44    {
45        self.encode_field_with_type_and_format(value, data_type, format)
46    }
47}
48
49pub(crate) struct EncodedValue {
50    pub(crate) bytes: BytesMut,
51}
52
53impl std::fmt::Debug for EncodedValue {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("EncodedValue").finish()
56    }
57}
58
59impl ToSql for EncodedValue {
60    fn to_sql(
61        &self,
62        _ty: &Type,
63        out: &mut BytesMut,
64    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
65    where
66        Self: Sized,
67    {
68        out.writer().write_all(&self.bytes)?;
69        Ok(postgres_types::IsNull::No)
70    }
71
72    fn accepts(_ty: &Type) -> bool
73    where
74        Self: Sized,
75    {
76        true
77    }
78
79    fn to_sql_checked(
80        &self,
81        ty: &Type,
82        out: &mut BytesMut,
83    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
84        self.to_sql(ty, out)
85    }
86}
87
88impl ToSqlText for EncodedValue {
89    fn to_sql_text(
90        &self,
91        _ty: &Type,
92        out: &mut BytesMut,
93    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
94    where
95        Self: Sized,
96    {
97        out.writer().write_all(&self.bytes)?;
98        Ok(postgres_types::IsNull::No)
99    }
100}
101
102fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
103    (!arr.is_null(idx)).then(|| {
104        arr.as_any()
105            .downcast_ref::<BooleanArray>()
106            .unwrap()
107            .value(idx)
108    })
109}
110
111macro_rules! get_primitive_value {
112    ($name:ident, $t:ty, $pt:ty) => {
113        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
114            (!arr.is_null(idx)).then(|| {
115                arr.as_any()
116                    .downcast_ref::<PrimitiveArray<$t>>()
117                    .unwrap()
118                    .value(idx)
119            })
120        }
121    };
122}
123
124get_primitive_value!(get_i8_value, Int8Type, i8);
125get_primitive_value!(get_i16_value, Int16Type, i16);
126get_primitive_value!(get_i32_value, Int32Type, i32);
127get_primitive_value!(get_i64_value, Int64Type, i64);
128get_primitive_value!(get_u8_value, UInt8Type, u8);
129get_primitive_value!(get_u16_value, UInt16Type, u16);
130get_primitive_value!(get_u32_value, UInt32Type, u32);
131get_primitive_value!(get_u64_value, UInt64Type, u64);
132get_primitive_value!(get_f32_value, Float32Type, f32);
133get_primitive_value!(get_f64_value, Float64Type, f64);
134
135fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
136    (!arr.is_null(idx)).then(|| {
137        arr.as_any()
138            .downcast_ref::<StringViewArray>()
139            .unwrap()
140            .value(idx)
141    })
142}
143
144fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
145    (!arr.is_null(idx)).then(|| {
146        arr.as_any()
147            .downcast_ref::<StringArray>()
148            .unwrap()
149            .value(idx)
150    })
151}
152
153fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
154    (!arr.is_null(idx)).then(|| {
155        arr.as_any()
156            .downcast_ref::<LargeStringArray>()
157            .unwrap()
158            .value(idx)
159    })
160}
161
162fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
163    (!arr.is_null(idx)).then(|| {
164        arr.as_any()
165            .downcast_ref::<BinaryArray>()
166            .unwrap()
167            .value(idx)
168    })
169}
170
171fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
172    (!arr.is_null(idx)).then(|| {
173        arr.as_any()
174            .downcast_ref::<LargeBinaryArray>()
175            .unwrap()
176            .value(idx)
177    })
178}
179
180fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
181    if arr.is_null(idx) {
182        return None;
183    }
184    arr.as_any()
185        .downcast_ref::<Date32Array>()
186        .unwrap()
187        .value_as_date(idx)
188}
189
190fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
191    if arr.is_null(idx) {
192        return None;
193    }
194    arr.as_any()
195        .downcast_ref::<Date64Array>()
196        .unwrap()
197        .value_as_date(idx)
198}
199
200fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
201    if arr.is_null(idx) {
202        return None;
203    }
204    arr.as_any()
205        .downcast_ref::<Time32SecondArray>()
206        .unwrap()
207        .value_as_datetime(idx)
208}
209
210fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
211    if arr.is_null(idx) {
212        return None;
213    }
214    arr.as_any()
215        .downcast_ref::<Time32MillisecondArray>()
216        .unwrap()
217        .value_as_datetime(idx)
218}
219
220fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
221    if arr.is_null(idx) {
222        return None;
223    }
224    arr.as_any()
225        .downcast_ref::<Time64MicrosecondArray>()
226        .unwrap()
227        .value_as_datetime(idx)
228}
229fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
230    if arr.is_null(idx) {
231        return None;
232    }
233    arr.as_any()
234        .downcast_ref::<Time64NanosecondArray>()
235        .unwrap()
236        .value_as_datetime(idx)
237}
238
239fn get_numeric_128_value(
240    arr: &Arc<dyn Array>,
241    idx: usize,
242    scale: u32,
243) -> PgWireResult<Option<Decimal>> {
244    if arr.is_null(idx) {
245        return Ok(None);
246    }
247
248    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
249    let value = array.value(idx);
250    Decimal::try_from_i128_with_scale(value, scale)
251        .map_err(|e| {
252            let message = match e {
253                rust_decimal::Error::ExceedsMaximumPossibleValue => {
254                    "Exceeds maximum possible value"
255                }
256                rust_decimal::Error::LessThanMinimumPossibleValue => {
257                    "Less than minimum possible value"
258                }
259                rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => {
260                    "Scale exceeds maximum precision"
261                }
262                _ => unreachable!(),
263            };
264            // TODO: add error type in PgWireError
265            PgWireError::ApiError(ToSqlError::from(message))
266        })
267        .map(Some)
268}
269
270pub fn encode_value<T: Encoder>(
271    encoder: &mut T,
272    arr: &Arc<dyn Array>,
273    idx: usize,
274    type_: &Type,
275    format: FieldFormat,
276) -> PgWireResult<()> {
277    match arr.data_type() {
278        DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
279        DataType::Boolean => {
280            encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
281        }
282        DataType::Int8 => {
283            encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
284        }
285        DataType::Int16 => {
286            encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
287        }
288        DataType::Int32 => {
289            encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
290        }
291        DataType::Int64 => {
292            encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
293        }
294        DataType::UInt8 => encoder.encode_field_with_type_and_format(
295            &(get_u8_value(arr, idx).map(|x| x as i8)),
296            type_,
297            format,
298        )?,
299        DataType::UInt16 => encoder.encode_field_with_type_and_format(
300            &(get_u16_value(arr, idx).map(|x| x as i16)),
301            type_,
302            format,
303        )?,
304        DataType::UInt32 => {
305            encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
306        }
307        DataType::UInt64 => encoder.encode_field_with_type_and_format(
308            &(get_u64_value(arr, idx).map(|x| x as i64)),
309            type_,
310            format,
311        )?,
312        DataType::Float32 => {
313            encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
314        }
315        DataType::Float64 => {
316            encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
317        }
318        DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
319            &get_numeric_128_value(arr, idx, *s as u32)?,
320            type_,
321            format,
322        )?,
323        DataType::Utf8 => {
324            encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
325        }
326        DataType::Utf8View => encoder.encode_field_with_type_and_format(
327            &get_utf8_view_value(arr, idx),
328            type_,
329            format,
330        )?,
331        DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
332            &get_large_utf8_value(arr, idx),
333            type_,
334            format,
335        )?,
336        DataType::Binary => {
337            encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
338        }
339        DataType::LargeBinary => encoder.encode_field_with_type_and_format(
340            &get_large_binary_value(arr, idx),
341            type_,
342            format,
343        )?,
344        DataType::Date32 => {
345            encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
346        }
347        DataType::Date64 => {
348            encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
349        }
350        DataType::Time32(unit) => match unit {
351            TimeUnit::Second => encoder.encode_field_with_type_and_format(
352                &get_time32_second_value(arr, idx),
353                type_,
354                format,
355            )?,
356            TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
357                &get_time32_millisecond_value(arr, idx),
358                type_,
359                format,
360            )?,
361            _ => {}
362        },
363        DataType::Time64(unit) => match unit {
364            TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
365                &get_time64_microsecond_value(arr, idx),
366                type_,
367                format,
368            )?,
369            TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
370                &get_time64_nanosecond_value(arr, idx),
371                type_,
372                format,
373            )?,
374            _ => {}
375        },
376        DataType::Timestamp(unit, timezone) => match unit {
377            TimeUnit::Second => {
378                if arr.is_null(idx) {
379                    return encoder.encode_field_with_type_and_format(
380                        &None::<NaiveDateTime>,
381                        type_,
382                        format,
383                    );
384                }
385                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
386                if let Some(tz) = timezone {
387                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
388                    let value = ts_array
389                        .value_as_datetime_with_tz(idx, tz)
390                        .map(|d| d.fixed_offset());
391                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
392                } else {
393                    let value = ts_array.value_as_datetime(idx);
394                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
395                }
396            }
397            TimeUnit::Millisecond => {
398                if arr.is_null(idx) {
399                    return encoder.encode_field_with_type_and_format(
400                        &None::<NaiveDateTime>,
401                        type_,
402                        format,
403                    );
404                }
405                let ts_array = arr
406                    .as_any()
407                    .downcast_ref::<TimestampMillisecondArray>()
408                    .unwrap();
409                if let Some(tz) = timezone {
410                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
411                    let value = ts_array
412                        .value_as_datetime_with_tz(idx, tz)
413                        .map(|d| d.fixed_offset());
414                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
415                } else {
416                    let value = ts_array.value_as_datetime(idx);
417                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
418                }
419            }
420            TimeUnit::Microsecond => {
421                if arr.is_null(idx) {
422                    return encoder.encode_field_with_type_and_format(
423                        &None::<NaiveDateTime>,
424                        type_,
425                        format,
426                    );
427                }
428                let ts_array = arr
429                    .as_any()
430                    .downcast_ref::<TimestampMicrosecondArray>()
431                    .unwrap();
432                if let Some(tz) = timezone {
433                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
434                    let value = ts_array
435                        .value_as_datetime_with_tz(idx, tz)
436                        .map(|d| d.fixed_offset());
437                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
438                } else {
439                    let value = ts_array.value_as_datetime(idx);
440                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
441                }
442            }
443            TimeUnit::Nanosecond => {
444                if arr.is_null(idx) {
445                    return encoder.encode_field_with_type_and_format(
446                        &None::<NaiveDateTime>,
447                        type_,
448                        format,
449                    );
450                }
451                let ts_array = arr
452                    .as_any()
453                    .downcast_ref::<TimestampNanosecondArray>()
454                    .unwrap();
455                if let Some(tz) = timezone {
456                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
457                    let value = ts_array
458                        .value_as_datetime_with_tz(idx, tz)
459                        .map(|d| d.fixed_offset());
460                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
461                } else {
462                    let value = ts_array.value_as_datetime(idx);
463                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
464                }
465            }
466        },
467        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
468            if arr.is_null(idx) {
469                return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
470            }
471            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
472            let value = encode_list(array, type_, format)?;
473            encoder.encode_field_with_type_and_format(&value, type_, format)?
474        }
475        DataType::Struct(_) => {
476            let fields = match type_.kind() {
477                postgres_types::Kind::Composite(fields) => fields,
478                _ => {
479                    return Err(PgWireError::ApiError(ToSqlError::from(format!(
480                        "Failed to unwrap a composite type from type {}",
481                        type_
482                    ))));
483                }
484            };
485            let value = encode_struct(arr, idx, fields, format)?;
486            encoder.encode_field_with_type_and_format(&value, type_, format)?
487        }
488        DataType::Dictionary(_, value_type) => {
489            if arr.is_null(idx) {
490                return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
491            }
492            // Get the dictionary values, ignoring keys
493            // We'll use Int32Type as a common key type, but we're only interested in values
494            macro_rules! get_dict_values {
495                ($key_type:ty) => {
496                    arr.as_any()
497                        .downcast_ref::<DictionaryArray<$key_type>>()
498                        .map(|dict| dict.values())
499                };
500            }
501
502            // Try to extract values using different key types
503            let values = get_dict_values!(Int8Type)
504                .or_else(|| get_dict_values!(Int16Type))
505                .or_else(|| get_dict_values!(Int32Type))
506                .or_else(|| get_dict_values!(Int64Type))
507                .or_else(|| get_dict_values!(UInt8Type))
508                .or_else(|| get_dict_values!(UInt16Type))
509                .or_else(|| get_dict_values!(UInt32Type))
510                .or_else(|| get_dict_values!(UInt64Type))
511                .ok_or_else(|| {
512                    ToSqlError::from(format!(
513                        "Unsupported dictionary key type for value type {}",
514                        value_type
515                    ))
516                })?;
517
518            // If the dictionary has only one value, treat it as a primitive
519            if values.len() == 1 {
520                encode_value(encoder, values, 0, type_, format)?
521            } else {
522                // Otherwise, use value directly indexed by values array
523                encode_value(encoder, values, idx, type_, format)?
524            }
525        }
526        _ => {
527            return Err(PgWireError::ApiError(ToSqlError::from(format!(
528                "Unsupported Datatype {} and array {:?}",
529                arr.data_type(),
530                &arr
531            ))));
532        }
533    }
534
535    Ok(())
536}