arrow_pg/
struct_encoder.rs

1use std::sync::Arc;
2
3use arrow::array::{Array, StructArray};
4use bytes::{BufMut, BytesMut};
5use pgwire::api::results::FieldFormat;
6use pgwire::error::PgWireResult;
7use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
8use postgres_types::{Field, IsNull, ToSql, Type};
9
10use crate::encoder::{encode_value, EncodedValue, Encoder};
11
12pub(crate) fn encode_struct(
13    arr: &Arc<dyn Array>,
14    idx: usize,
15    fields: &[Field],
16    format: FieldFormat,
17) -> PgWireResult<Option<EncodedValue>> {
18    let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
19    if arr.is_null(idx) {
20        return Ok(None);
21    }
22    let mut row_encoder = StructEncoder::new(fields.len());
23    for (i, arr) in arr.columns().iter().enumerate() {
24        let field = &fields[i];
25        let type_ = field.type_();
26        encode_value(&mut row_encoder, arr, idx, type_, format).unwrap();
27    }
28    Ok(Some(EncodedValue {
29        bytes: row_encoder.row_buffer,
30    }))
31}
32
33pub(crate) struct StructEncoder {
34    num_cols: usize,
35    curr_col: usize,
36    row_buffer: BytesMut,
37}
38
39impl StructEncoder {
40    pub(crate) fn new(num_cols: usize) -> Self {
41        Self {
42            num_cols,
43            curr_col: 0,
44            row_buffer: BytesMut::new(),
45        }
46    }
47}
48
49impl Encoder for StructEncoder {
50    fn encode_field_with_type_and_format<T>(
51        &mut self,
52        value: &T,
53        data_type: &Type,
54        format: FieldFormat,
55    ) -> PgWireResult<()>
56    where
57        T: ToSql + ToSqlText + Sized,
58    {
59        if format == FieldFormat::Text {
60            if self.curr_col == 0 {
61                self.row_buffer.put_slice(b"(");
62            }
63            // encode value in an intermediate buf
64            let mut buf = BytesMut::new();
65            value.to_sql_text(data_type, &mut buf)?;
66            let encoded_value_as_str = String::from_utf8_lossy(&buf);
67            if QUOTE_CHECK.is_match(&encoded_value_as_str) {
68                self.row_buffer.put_u8(b'"');
69                self.row_buffer.put_slice(
70                    QUOTE_ESCAPE
71                        .replace_all(&encoded_value_as_str, r#"\$1"#)
72                        .as_bytes(),
73                );
74                self.row_buffer.put_u8(b'"');
75            } else {
76                self.row_buffer.put_slice(&buf);
77            }
78            if self.curr_col == self.num_cols - 1 {
79                self.row_buffer.put_slice(b")");
80            } else {
81                self.row_buffer.put_slice(b",");
82            }
83        } else {
84            if self.curr_col == 0 && format == FieldFormat::Binary {
85                // Place Number of fields
86                self.row_buffer.put_i32(self.num_cols as i32);
87            }
88
89            self.row_buffer.put_u32(data_type.oid());
90            // remember the position of the 4-byte length field
91            let prev_index = self.row_buffer.len();
92            // write value length as -1 ahead of time
93            self.row_buffer.put_i32(-1);
94            let is_null = value.to_sql(data_type, &mut self.row_buffer)?;
95            if let IsNull::No = is_null {
96                let value_length = self.row_buffer.len() - prev_index - 4;
97                let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
98                length_bytes.put_i32(value_length as i32);
99            }
100        }
101        self.curr_col += 1;
102        Ok(())
103    }
104}