arrow_pg/
struct_encoder.rs1use std::sync::Arc;
2
3use arrow::array::{Array, StructArray};
4use bytes::{BufMut, BytesMut};
5use pgwire::api::results::FieldFormat;
6use pgwire::error::PgWireResult;
7use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
8use postgres_types::{Field, IsNull, ToSql, Type};
9
10use crate::encoder::{encode_value, EncodedValue, Encoder};
11
12pub(crate) fn encode_struct(
13 arr: &Arc<dyn Array>,
14 idx: usize,
15 fields: &[Field],
16 format: FieldFormat,
17) -> PgWireResult<Option<EncodedValue>> {
18 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
19 if arr.is_null(idx) {
20 return Ok(None);
21 }
22 let mut row_encoder = StructEncoder::new(fields.len());
23 for (i, arr) in arr.columns().iter().enumerate() {
24 let field = &fields[i];
25 let type_ = field.type_();
26 encode_value(&mut row_encoder, arr, idx, type_, format).unwrap();
27 }
28 Ok(Some(EncodedValue {
29 bytes: row_encoder.row_buffer,
30 }))
31}
32
33pub(crate) struct StructEncoder {
34 num_cols: usize,
35 curr_col: usize,
36 row_buffer: BytesMut,
37}
38
39impl StructEncoder {
40 pub(crate) fn new(num_cols: usize) -> Self {
41 Self {
42 num_cols,
43 curr_col: 0,
44 row_buffer: BytesMut::new(),
45 }
46 }
47}
48
49impl Encoder for StructEncoder {
50 fn encode_field_with_type_and_format<T>(
51 &mut self,
52 value: &T,
53 data_type: &Type,
54 format: FieldFormat,
55 ) -> PgWireResult<()>
56 where
57 T: ToSql + ToSqlText + Sized,
58 {
59 if format == FieldFormat::Text {
60 if self.curr_col == 0 {
61 self.row_buffer.put_slice(b"(");
62 }
63 let mut buf = BytesMut::new();
65 value.to_sql_text(data_type, &mut buf)?;
66 let encoded_value_as_str = String::from_utf8_lossy(&buf);
67 if QUOTE_CHECK.is_match(&encoded_value_as_str) {
68 self.row_buffer.put_u8(b'"');
69 self.row_buffer.put_slice(
70 QUOTE_ESCAPE
71 .replace_all(&encoded_value_as_str, r#"\$1"#)
72 .as_bytes(),
73 );
74 self.row_buffer.put_u8(b'"');
75 } else {
76 self.row_buffer.put_slice(&buf);
77 }
78 if self.curr_col == self.num_cols - 1 {
79 self.row_buffer.put_slice(b")");
80 } else {
81 self.row_buffer.put_slice(b",");
82 }
83 } else {
84 if self.curr_col == 0 && format == FieldFormat::Binary {
85 self.row_buffer.put_i32(self.num_cols as i32);
87 }
88
89 self.row_buffer.put_u32(data_type.oid());
90 let prev_index = self.row_buffer.len();
92 self.row_buffer.put_i32(-1);
94 let is_null = value.to_sql(data_type, &mut self.row_buffer)?;
95 if let IsNull::No = is_null {
96 let value_length = self.row_buffer.len() - prev_index - 4;
97 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
98 length_bytes.put_i32(value_length as i32);
99 }
100 }
101 self.curr_col += 1;
102 Ok(())
103 }
104}