arrow_pg/
struct_encoder.rs

1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::array::{Array, StructArray};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::array::{Array, StructArray};
7
8use bytes::{BufMut, BytesMut};
9use pgwire::api::results::{FieldFormat, FieldInfo};
10use pgwire::error::PgWireResult;
11use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
12use postgres_types::{Field, IsNull, ToSql};
13
14use crate::encoder::{encode_value, EncodedValue, Encoder};
15
16pub(crate) fn encode_struct(
17    arr: &Arc<dyn Array>,
18    idx: usize,
19    fields: &[Field],
20    parent_pg_field_info: &FieldInfo,
21) -> PgWireResult<Option<EncodedValue>> {
22    let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
23    if arr.is_null(idx) {
24        return Ok(None);
25    }
26    let mut row_encoder = StructEncoder::new(fields.len());
27    for (i, arr) in arr.columns().iter().enumerate() {
28        let field = &fields[i];
29        let type_ = field.type_();
30
31        let mut pg_field = FieldInfo::new(
32            field.name().to_string(),
33            None,
34            None,
35            type_.clone(),
36            parent_pg_field_info.format(),
37        );
38
39        pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone());
40
41        encode_value(&mut row_encoder, arr, idx, &pg_field).unwrap();
42    }
43    Ok(Some(EncodedValue {
44        bytes: row_encoder.row_buffer,
45    }))
46}
47
48pub(crate) struct StructEncoder {
49    num_cols: usize,
50    curr_col: usize,
51    row_buffer: BytesMut,
52}
53
54impl StructEncoder {
55    pub(crate) fn new(num_cols: usize) -> Self {
56        Self {
57            num_cols,
58            curr_col: 0,
59            row_buffer: BytesMut::new(),
60        }
61    }
62}
63
64impl Encoder for StructEncoder {
65    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
66    where
67        T: ToSql + ToSqlText + Sized,
68    {
69        let datatype = pg_field.datatype();
70        let format = pg_field.format();
71
72        if format == FieldFormat::Text {
73            if self.curr_col == 0 {
74                self.row_buffer.put_slice(b"(");
75            }
76            // encode value in an intermediate buf
77            let mut buf = BytesMut::new();
78            value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?;
79            let encoded_value_as_str = String::from_utf8_lossy(&buf);
80            if QUOTE_CHECK.is_match(&encoded_value_as_str) {
81                self.row_buffer.put_u8(b'"');
82                self.row_buffer.put_slice(
83                    QUOTE_ESCAPE
84                        .replace_all(&encoded_value_as_str, r#"\$1"#)
85                        .as_bytes(),
86                );
87                self.row_buffer.put_u8(b'"');
88            } else {
89                self.row_buffer.put_slice(&buf);
90            }
91            if self.curr_col == self.num_cols - 1 {
92                self.row_buffer.put_slice(b")");
93            } else {
94                self.row_buffer.put_slice(b",");
95            }
96        } else {
97            if self.curr_col == 0 && format == FieldFormat::Binary {
98                // Place Number of fields
99                self.row_buffer.put_i32(self.num_cols as i32);
100            }
101
102            self.row_buffer.put_u32(datatype.oid());
103            // remember the position of the 4-byte length field
104            let prev_index = self.row_buffer.len();
105            // write value length as -1 ahead of time
106            self.row_buffer.put_i32(-1);
107            let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
108            if let IsNull::No = is_null {
109                let value_length = self.row_buffer.len() - prev_index - 4;
110                let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
111                length_bytes.put_i32(value_length as i32);
112            }
113        }
114        self.curr_col += 1;
115        Ok(())
116    }
117}