arrow_pg/
struct_encoder.rs

1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::array::{Array, StructArray};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::array::{Array, StructArray};
7
8use bytes::{BufMut, BytesMut};
9use pgwire::api::results::FieldFormat;
10use pgwire::error::PgWireResult;
11use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
12use postgres_types::{Field, IsNull, ToSql, Type};
13
14use crate::encoder::{encode_value, EncodedValue, Encoder};
15
16pub(crate) fn encode_struct(
17    arr: &Arc<dyn Array>,
18    idx: usize,
19    fields: &[Field],
20    format: FieldFormat,
21) -> PgWireResult<Option<EncodedValue>> {
22    let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
23    if arr.is_null(idx) {
24        return Ok(None);
25    }
26    let mut row_encoder = StructEncoder::new(fields.len());
27    for (i, arr) in arr.columns().iter().enumerate() {
28        let field = &fields[i];
29        let type_ = field.type_();
30        encode_value(&mut row_encoder, arr, idx, type_, format).unwrap();
31    }
32    Ok(Some(EncodedValue {
33        bytes: row_encoder.row_buffer,
34    }))
35}
36
37pub(crate) struct StructEncoder {
38    num_cols: usize,
39    curr_col: usize,
40    row_buffer: BytesMut,
41}
42
43impl StructEncoder {
44    pub(crate) fn new(num_cols: usize) -> Self {
45        Self {
46            num_cols,
47            curr_col: 0,
48            row_buffer: BytesMut::new(),
49        }
50    }
51}
52
53impl Encoder for StructEncoder {
54    fn encode_field_with_type_and_format<T>(
55        &mut self,
56        value: &T,
57        data_type: &Type,
58        format: FieldFormat,
59    ) -> PgWireResult<()>
60    where
61        T: ToSql + ToSqlText + Sized,
62    {
63        if format == FieldFormat::Text {
64            if self.curr_col == 0 {
65                self.row_buffer.put_slice(b"(");
66            }
67            // encode value in an intermediate buf
68            let mut buf = BytesMut::new();
69            value.to_sql_text(data_type, &mut buf)?;
70            let encoded_value_as_str = String::from_utf8_lossy(&buf);
71            if QUOTE_CHECK.is_match(&encoded_value_as_str) {
72                self.row_buffer.put_u8(b'"');
73                self.row_buffer.put_slice(
74                    QUOTE_ESCAPE
75                        .replace_all(&encoded_value_as_str, r#"\$1"#)
76                        .as_bytes(),
77                );
78                self.row_buffer.put_u8(b'"');
79            } else {
80                self.row_buffer.put_slice(&buf);
81            }
82            if self.curr_col == self.num_cols - 1 {
83                self.row_buffer.put_slice(b")");
84            } else {
85                self.row_buffer.put_slice(b",");
86            }
87        } else {
88            if self.curr_col == 0 && format == FieldFormat::Binary {
89                // Place Number of fields
90                self.row_buffer.put_i32(self.num_cols as i32);
91            }
92
93            self.row_buffer.put_u32(data_type.oid());
94            // remember the position of the 4-byte length field
95            let prev_index = self.row_buffer.len();
96            // write value length as -1 ahead of time
97            self.row_buffer.put_i32(-1);
98            let is_null = value.to_sql(data_type, &mut self.row_buffer)?;
99            if let IsNull::No = is_null {
100                let value_length = self.row_buffer.len() - prev_index - 4;
101                let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
102                length_bytes.put_i32(value_length as i32);
103            }
104        }
105        self.curr_col += 1;
106        Ok(())
107    }
108}