Skip to main content

arrow_pg/
struct_encoder.rs

1use std::error::Error;
2use std::io::Write;
3use std::sync::Arc;
4
5#[cfg(not(feature = "datafusion"))]
6use arrow::array::{Array, StructArray};
7use arrow_schema::Fields;
8#[cfg(feature = "datafusion")]
9use datafusion::arrow::array::{Array, StructArray};
10
11use bytes::{BufMut, BytesMut};
12use pgwire::api::results::{FieldFormat, FieldInfo};
13use pgwire::error::PgWireResult;
14use pgwire::types::format::FormatOptions;
15use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
16use postgres_types::{Field, IsNull, ToSql, Type};
17
18use crate::datatypes::field_into_pg_type;
19use crate::encoder::{encode_value, Encoder};
20
21#[derive(Debug)]
22struct BytesWrapper(BytesMut, bool);
23
24impl ToSql for BytesWrapper {
25    fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Send + Sync>>
26    where
27        Self: Sized,
28    {
29        out.writer().write_all(&self.0)?;
30        Ok(IsNull::No)
31    }
32
33    fn accepts(_ty: &Type) -> bool
34    where
35        Self: Sized,
36    {
37        true
38    }
39
40    fn to_sql_checked(
41        &self,
42        ty: &Type,
43        out: &mut BytesMut,
44    ) -> Result<IsNull, Box<dyn Error + Send + Sync>> {
45        self.to_sql(ty, out)
46    }
47}
48
49impl ToSqlText for BytesWrapper {
50    fn to_sql_text(
51        &self,
52        _ty: &Type,
53        out: &mut BytesMut,
54        _format_options: &FormatOptions,
55    ) -> Result<IsNull, Box<dyn Error + Send + Sync>>
56    where
57        Self: Sized,
58    {
59        if self.1 {
60            out.put_u8(b'"');
61            out.put_slice(
62                QUOTE_ESCAPE
63                    .replace_all(&String::from_utf8_lossy(&self.0), r#"\$1"#)
64                    .as_bytes(),
65            );
66            out.put_u8(b'"');
67        } else {
68            out.put_slice(&self.0);
69        }
70        Ok(IsNull::No)
71    }
72}
73
74pub(crate) fn encode_structs<T: Encoder>(
75    encoder: &mut T,
76    arr: &Arc<dyn Array>,
77    arrow_fields: &Fields,
78    parent_pg_field_info: &FieldInfo,
79) -> PgWireResult<()> {
80    let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
81    let quote_wrapper = matches!(parent_pg_field_info.format(), FieldFormat::Text);
82
83    let fields = arrow_fields
84        .iter()
85        .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
86        .collect::<PgWireResult<Vec<_>>>()?;
87
88    let values: PgWireResult<Vec<_>> = (0..arr.len())
89        .map(|row| {
90            if arr.is_null(row) {
91                Ok(None)
92            } else {
93                let mut row_encoder = StructEncoder::new(arrow_fields.len());
94
95                for (i, arr) in arr.columns().iter().enumerate() {
96                    let field = &fields[i];
97                    let type_ = field.type_();
98                    let arrow_field = &arrow_fields[i];
99
100                    let format = parent_pg_field_info.format();
101                    let format_options = parent_pg_field_info.format_options().clone();
102                    let mut pg_field =
103                        FieldInfo::new(field.name().to_string(), None, None, type_.clone(), format);
104                    pg_field = pg_field.with_format_options(format_options);
105
106                    encode_value(&mut row_encoder, arr, row, arrow_field, &pg_field).unwrap();
107                }
108
109                Ok(Some(BytesWrapper(row_encoder.take_buffer(), quote_wrapper)))
110            }
111        })
112        .collect();
113    encoder.encode_field(&values?, parent_pg_field_info)
114}
115
116pub(crate) fn encode_struct<T: Encoder>(
117    encoder: &mut T,
118    arr: &Arc<dyn Array>,
119    idx: usize,
120    arrow_fields: &Fields,
121    parent_pg_field_info: &FieldInfo,
122) -> PgWireResult<()> {
123    let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
124    if arr.is_null(idx) {
125        return encoder.encode_field(&None::<&[i8]>, parent_pg_field_info);
126    }
127
128    let fields = arrow_fields
129        .iter()
130        .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
131        .collect::<PgWireResult<Vec<_>>>()?;
132
133    let mut row_encoder = StructEncoder::new(arrow_fields.len());
134
135    for (i, arr) in arr.columns().iter().enumerate() {
136        let field = &fields[i];
137        let type_ = field.type_();
138
139        let arrow_field = &arrow_fields[i];
140
141        let mut pg_field = FieldInfo::new(
142            field.name().to_string(),
143            None,
144            None,
145            type_.clone(),
146            parent_pg_field_info.format(),
147        );
148        pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone());
149
150        encode_value(&mut row_encoder, arr, idx, arrow_field, &pg_field).unwrap();
151    }
152    let encoded_value = BytesWrapper(row_encoder.row_buffer, false);
153    encoder.encode_field(&encoded_value, parent_pg_field_info)
154}
155
156pub(crate) struct StructEncoder {
157    num_cols: usize,
158    curr_col: usize,
159    row_buffer: BytesMut,
160}
161
162impl StructEncoder {
163    pub(crate) fn new(num_cols: usize) -> Self {
164        Self {
165            num_cols,
166            curr_col: 0,
167            row_buffer: BytesMut::new(),
168        }
169    }
170
171    pub(crate) fn take_buffer(self) -> BytesMut {
172        self.row_buffer
173    }
174}
175
176impl Encoder for StructEncoder {
177    type Item = BytesMut;
178
179    fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
180    where
181        T: ToSql + ToSqlText + Sized,
182    {
183        let datatype = pg_field.datatype();
184        let format = pg_field.format();
185
186        if format == FieldFormat::Text {
187            if self.curr_col == 0 {
188                self.row_buffer.put_slice(b"(");
189            }
190            // encode value in an intermediate buf
191            let mut buf = BytesMut::new();
192            value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?;
193            let encoded_value_as_str = String::from_utf8_lossy(&buf);
194            if QUOTE_CHECK.is_match(&encoded_value_as_str) {
195                self.row_buffer.put_u8(b'"');
196                self.row_buffer.put_slice(
197                    QUOTE_ESCAPE
198                        .replace_all(&encoded_value_as_str, r#"\$1"#)
199                        .as_bytes(),
200                );
201                self.row_buffer.put_u8(b'"');
202            } else {
203                self.row_buffer.put_slice(&buf);
204            }
205            if self.curr_col == self.num_cols - 1 {
206                self.row_buffer.put_slice(b")");
207            } else {
208                self.row_buffer.put_slice(b",");
209            }
210        } else {
211            if self.curr_col == 0 && format == FieldFormat::Binary {
212                // Place Number of fields
213                self.row_buffer.put_i32(self.num_cols as i32);
214            }
215
216            self.row_buffer.put_u32(datatype.oid());
217            // remember the position of the 4-byte length field
218            let prev_index = self.row_buffer.len();
219            // write value length as -1 ahead of time
220            self.row_buffer.put_i32(-1);
221            let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
222            if let IsNull::No = is_null {
223                let value_length = self.row_buffer.len() - prev_index - 4;
224                let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
225                length_bytes.put_i32(value_length as i32);
226            }
227        }
228        self.curr_col += 1;
229        Ok(())
230    }
231
232    fn take_row(&mut self) -> Self::Item {
233        std::mem::take(&mut self.row_buffer)
234    }
235}