arrow_pg/
encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::ToSqlText;
17use postgres_types::{ToSql, Type};
18use rust_decimal::Decimal;
19use timezone::Tz;
20
21use crate::error::ToSqlError;
22use crate::list_encoder::encode_list;
23use crate::struct_encoder::encode_struct;
24
25pub trait Encoder {
26    fn encode_field_with_type_and_format<T>(
27        &mut self,
28        value: &T,
29        data_type: &Type,
30        format: FieldFormat,
31    ) -> PgWireResult<()>
32    where
33        T: ToSql + ToSqlText + Sized;
34}
35
36impl Encoder for DataRowEncoder {
37    fn encode_field_with_type_and_format<T>(
38        &mut self,
39        value: &T,
40        data_type: &Type,
41        format: FieldFormat,
42    ) -> PgWireResult<()>
43    where
44        T: ToSql + ToSqlText + Sized,
45    {
46        self.encode_field_with_type_and_format(value, data_type, format)
47    }
48}
49
50pub(crate) struct EncodedValue {
51    pub(crate) bytes: BytesMut,
52}
53
54impl std::fmt::Debug for EncodedValue {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("EncodedValue").finish()
57    }
58}
59
60impl ToSql for EncodedValue {
61    fn to_sql(
62        &self,
63        _ty: &Type,
64        out: &mut BytesMut,
65    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
66    where
67        Self: Sized,
68    {
69        out.writer().write_all(&self.bytes)?;
70        Ok(postgres_types::IsNull::No)
71    }
72
73    fn accepts(_ty: &Type) -> bool
74    where
75        Self: Sized,
76    {
77        true
78    }
79
80    fn to_sql_checked(
81        &self,
82        ty: &Type,
83        out: &mut BytesMut,
84    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
85        self.to_sql(ty, out)
86    }
87}
88
89impl ToSqlText for EncodedValue {
90    fn to_sql_text(
91        &self,
92        _ty: &Type,
93        out: &mut BytesMut,
94    ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
95    where
96        Self: Sized,
97    {
98        out.writer().write_all(&self.bytes)?;
99        Ok(postgres_types::IsNull::No)
100    }
101}
102
103fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
104    (!arr.is_null(idx)).then(|| {
105        arr.as_any()
106            .downcast_ref::<BooleanArray>()
107            .unwrap()
108            .value(idx)
109    })
110}
111
112macro_rules! get_primitive_value {
113    ($name:ident, $t:ty, $pt:ty) => {
114        fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
115            (!arr.is_null(idx)).then(|| {
116                arr.as_any()
117                    .downcast_ref::<PrimitiveArray<$t>>()
118                    .unwrap()
119                    .value(idx)
120            })
121        }
122    };
123}
124
125get_primitive_value!(get_i8_value, Int8Type, i8);
126get_primitive_value!(get_i16_value, Int16Type, i16);
127get_primitive_value!(get_i32_value, Int32Type, i32);
128get_primitive_value!(get_i64_value, Int64Type, i64);
129get_primitive_value!(get_u8_value, UInt8Type, u8);
130get_primitive_value!(get_u16_value, UInt16Type, u16);
131get_primitive_value!(get_u32_value, UInt32Type, u32);
132get_primitive_value!(get_u64_value, UInt64Type, u64);
133get_primitive_value!(get_f32_value, Float32Type, f32);
134get_primitive_value!(get_f64_value, Float64Type, f64);
135
136fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
137    (!arr.is_null(idx)).then(|| {
138        arr.as_any()
139            .downcast_ref::<StringViewArray>()
140            .unwrap()
141            .value(idx)
142    })
143}
144
145fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
146    (!arr.is_null(idx)).then(|| {
147        arr.as_any()
148            .downcast_ref::<BinaryViewArray>()
149            .unwrap()
150            .value(idx)
151    })
152}
153
154fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
155    (!arr.is_null(idx)).then(|| {
156        arr.as_any()
157            .downcast_ref::<StringArray>()
158            .unwrap()
159            .value(idx)
160    })
161}
162
163fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
164    (!arr.is_null(idx)).then(|| {
165        arr.as_any()
166            .downcast_ref::<LargeStringArray>()
167            .unwrap()
168            .value(idx)
169    })
170}
171
172fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<String> {
173    (!arr.is_null(idx)).then(|| {
174        String::from_utf8_lossy(
175            arr.as_any()
176                .downcast_ref::<BinaryArray>()
177                .unwrap()
178                .value(idx),
179        )
180        .to_string()
181    })
182}
183
184fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
185    (!arr.is_null(idx)).then(|| {
186        arr.as_any()
187            .downcast_ref::<LargeBinaryArray>()
188            .unwrap()
189            .value(idx)
190    })
191}
192
193fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
194    if arr.is_null(idx) {
195        return None;
196    }
197    arr.as_any()
198        .downcast_ref::<Date32Array>()
199        .unwrap()
200        .value_as_date(idx)
201}
202
203fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
204    if arr.is_null(idx) {
205        return None;
206    }
207    arr.as_any()
208        .downcast_ref::<Date64Array>()
209        .unwrap()
210        .value_as_date(idx)
211}
212
213fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
214    if arr.is_null(idx) {
215        return None;
216    }
217    arr.as_any()
218        .downcast_ref::<Time32SecondArray>()
219        .unwrap()
220        .value_as_datetime(idx)
221}
222
223fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
224    if arr.is_null(idx) {
225        return None;
226    }
227    arr.as_any()
228        .downcast_ref::<Time32MillisecondArray>()
229        .unwrap()
230        .value_as_datetime(idx)
231}
232
233fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
234    if arr.is_null(idx) {
235        return None;
236    }
237    arr.as_any()
238        .downcast_ref::<Time64MicrosecondArray>()
239        .unwrap()
240        .value_as_datetime(idx)
241}
242fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
243    if arr.is_null(idx) {
244        return None;
245    }
246    arr.as_any()
247        .downcast_ref::<Time64NanosecondArray>()
248        .unwrap()
249        .value_as_datetime(idx)
250}
251
252fn get_numeric_128_value(
253    arr: &Arc<dyn Array>,
254    idx: usize,
255    scale: u32,
256) -> PgWireResult<Option<Decimal>> {
257    if arr.is_null(idx) {
258        return Ok(None);
259    }
260
261    let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
262    let value = array.value(idx);
263    Decimal::try_from_i128_with_scale(value, scale)
264        .map_err(|e| {
265            let error_code = match e {
266                rust_decimal::Error::ExceedsMaximumPossibleValue => {
267                    "22003" // numeric_value_out_of_range
268                }
269                rust_decimal::Error::LessThanMinimumPossibleValue => {
270                    "22003" // numeric_value_out_of_range
271                }
272                rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
273                    return PgWireError::UserError(Box::new(ErrorInfo::new(
274                        "ERROR".to_string(),
275                        "22003".to_string(),
276                        format!("Scale {scale} exceeds maximum precision for numeric type"),
277                    )));
278                }
279                _ => "22003", // generic numeric_value_out_of_range
280            };
281            PgWireError::UserError(Box::new(ErrorInfo::new(
282                "ERROR".to_string(),
283                error_code.to_string(),
284                format!("Numeric value conversion failed: {e}"),
285            )))
286        })
287        .map(Some)
288}
289
290pub fn encode_value<T: Encoder>(
291    encoder: &mut T,
292    arr: &Arc<dyn Array>,
293    idx: usize,
294    type_: &Type,
295    format: FieldFormat,
296) -> PgWireResult<()> {
297    match arr.data_type() {
298        DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
299        DataType::Boolean => {
300            encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
301        }
302        DataType::Int8 => {
303            encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
304        }
305        DataType::Int16 => {
306            encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
307        }
308        DataType::Int32 => {
309            encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
310        }
311        DataType::Int64 => {
312            encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
313        }
314        DataType::UInt8 => encoder.encode_field_with_type_and_format(
315            &(get_u8_value(arr, idx).map(|x| x as i8)),
316            type_,
317            format,
318        )?,
319        DataType::UInt16 => encoder.encode_field_with_type_and_format(
320            &(get_u16_value(arr, idx).map(|x| x as i16)),
321            type_,
322            format,
323        )?,
324        DataType::UInt32 => {
325            encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
326        }
327        DataType::UInt64 => encoder.encode_field_with_type_and_format(
328            &(get_u64_value(arr, idx).map(|x| x as i64)),
329            type_,
330            format,
331        )?,
332        DataType::Float32 => {
333            encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
334        }
335        DataType::Float64 => {
336            encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
337        }
338        DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
339            &get_numeric_128_value(arr, idx, *s as u32)?,
340            type_,
341            format,
342        )?,
343        DataType::Utf8 => {
344            encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
345        }
346        DataType::Utf8View => encoder.encode_field_with_type_and_format(
347            &get_utf8_view_value(arr, idx),
348            type_,
349            format,
350        )?,
351        DataType::BinaryView => encoder.encode_field_with_type_and_format(
352            &get_binary_view_value(arr, idx),
353            type_,
354            format,
355        )?,
356        DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
357            &get_large_utf8_value(arr, idx),
358            type_,
359            format,
360        )?,
361        DataType::Binary => {
362            encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
363        }
364        DataType::LargeBinary => encoder.encode_field_with_type_and_format(
365            &get_large_binary_value(arr, idx),
366            type_,
367            format,
368        )?,
369        DataType::Date32 => {
370            encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
371        }
372        DataType::Date64 => {
373            encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
374        }
375        DataType::Time32(unit) => match unit {
376            TimeUnit::Second => encoder.encode_field_with_type_and_format(
377                &get_time32_second_value(arr, idx),
378                type_,
379                format,
380            )?,
381            TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
382                &get_time32_millisecond_value(arr, idx),
383                type_,
384                format,
385            )?,
386            _ => {}
387        },
388        DataType::Time64(unit) => match unit {
389            TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
390                &get_time64_microsecond_value(arr, idx),
391                type_,
392                format,
393            )?,
394            TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
395                &get_time64_nanosecond_value(arr, idx),
396                type_,
397                format,
398            )?,
399            _ => {}
400        },
401        DataType::Timestamp(unit, timezone) => match unit {
402            TimeUnit::Second => {
403                if arr.is_null(idx) {
404                    return encoder.encode_field_with_type_and_format(
405                        &None::<NaiveDateTime>,
406                        type_,
407                        format,
408                    );
409                }
410                let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
411                if let Some(tz) = timezone {
412                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
413                    let value = ts_array
414                        .value_as_datetime_with_tz(idx, tz)
415                        .map(|d| d.fixed_offset());
416                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
417                } else {
418                    let value = ts_array.value_as_datetime(idx);
419                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
420                }
421            }
422            TimeUnit::Millisecond => {
423                if arr.is_null(idx) {
424                    return encoder.encode_field_with_type_and_format(
425                        &None::<NaiveDateTime>,
426                        type_,
427                        format,
428                    );
429                }
430                let ts_array = arr
431                    .as_any()
432                    .downcast_ref::<TimestampMillisecondArray>()
433                    .unwrap();
434                if let Some(tz) = timezone {
435                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
436                    let value = ts_array
437                        .value_as_datetime_with_tz(idx, tz)
438                        .map(|d| d.fixed_offset());
439                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
440                } else {
441                    let value = ts_array.value_as_datetime(idx);
442                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
443                }
444            }
445            TimeUnit::Microsecond => {
446                if arr.is_null(idx) {
447                    return encoder.encode_field_with_type_and_format(
448                        &None::<NaiveDateTime>,
449                        type_,
450                        format,
451                    );
452                }
453                let ts_array = arr
454                    .as_any()
455                    .downcast_ref::<TimestampMicrosecondArray>()
456                    .unwrap();
457                if let Some(tz) = timezone {
458                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
459                    let value = ts_array
460                        .value_as_datetime_with_tz(idx, tz)
461                        .map(|d| d.fixed_offset());
462                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
463                } else {
464                    let value = ts_array.value_as_datetime(idx);
465                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
466                }
467            }
468            TimeUnit::Nanosecond => {
469                if arr.is_null(idx) {
470                    return encoder.encode_field_with_type_and_format(
471                        &None::<NaiveDateTime>,
472                        type_,
473                        format,
474                    );
475                }
476                let ts_array = arr
477                    .as_any()
478                    .downcast_ref::<TimestampNanosecondArray>()
479                    .unwrap();
480                if let Some(tz) = timezone {
481                    let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
482                    let value = ts_array
483                        .value_as_datetime_with_tz(idx, tz)
484                        .map(|d| d.fixed_offset());
485                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
486                } else {
487                    let value = ts_array.value_as_datetime(idx);
488                    encoder.encode_field_with_type_and_format(&value, type_, format)?;
489                }
490            }
491        },
492        DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
493            if arr.is_null(idx) {
494                return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
495            }
496            let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
497            let value = encode_list(array, type_, format)?;
498            encoder.encode_field_with_type_and_format(&value, type_, format)?
499        }
500        DataType::Struct(_) => {
501            let fields = match type_.kind() {
502                postgres_types::Kind::Composite(fields) => fields,
503                _ => {
504                    return Err(PgWireError::ApiError(ToSqlError::from(format!(
505                        "Failed to unwrap a composite type from type {type_}"
506                    ))));
507                }
508            };
509            let value = encode_struct(arr, idx, fields, format)?;
510            encoder.encode_field_with_type_and_format(&value, type_, format)?
511        }
512        DataType::Dictionary(_, value_type) => {
513            if arr.is_null(idx) {
514                return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
515            }
516            // Get the dictionary values, ignoring keys
517            // We'll use Int32Type as a common key type, but we're only interested in values
518            macro_rules! get_dict_values {
519                ($key_type:ty) => {
520                    arr.as_any()
521                        .downcast_ref::<DictionaryArray<$key_type>>()
522                        .map(|dict| dict.values())
523                };
524            }
525
526            // Try to extract values using different key types
527            let values = get_dict_values!(Int8Type)
528                .or_else(|| get_dict_values!(Int16Type))
529                .or_else(|| get_dict_values!(Int32Type))
530                .or_else(|| get_dict_values!(Int64Type))
531                .or_else(|| get_dict_values!(UInt8Type))
532                .or_else(|| get_dict_values!(UInt16Type))
533                .or_else(|| get_dict_values!(UInt32Type))
534                .or_else(|| get_dict_values!(UInt64Type))
535                .ok_or_else(|| {
536                    ToSqlError::from(format!(
537                        "Unsupported dictionary key type for value type {value_type}"
538                    ))
539                })?;
540
541            // If the dictionary has only one value, treat it as a primitive
542            if values.len() == 1 {
543                encode_value(encoder, values, 0, type_, format)?
544            } else {
545                // Otherwise, use value directly indexed by values array
546                encode_value(encoder, values, idx, type_, format)?
547            }
548        }
549        _ => {
550            return Err(PgWireError::ApiError(ToSqlError::from(format!(
551                "Unsupported Datatype {} and array {:?}",
552                arr.data_type(),
553                &arr
554            ))));
555        }
556    }
557
558    Ok(())
559}