arrow_pg/
struct_encoder.rs1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::array::{Array, StructArray};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::array::{Array, StructArray};
7
8use bytes::{BufMut, BytesMut};
9use pgwire::api::results::{FieldFormat, FieldInfo};
10use pgwire::error::PgWireResult;
11use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
12use postgres_types::{Field, IsNull, ToSql};
13
14use crate::encoder::{encode_value, EncodedValue, Encoder};
15
16pub(crate) fn encode_struct(
17 arr: &Arc<dyn Array>,
18 idx: usize,
19 fields: &[Field],
20 parent_pg_field_info: &FieldInfo,
21) -> PgWireResult<Option<EncodedValue>> {
22 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
23 if arr.is_null(idx) {
24 return Ok(None);
25 }
26 let mut row_encoder = StructEncoder::new(fields.len());
27 for (i, arr) in arr.columns().iter().enumerate() {
28 let field = &fields[i];
29 let type_ = field.type_();
30
31 let mut pg_field = FieldInfo::new(
32 field.name().to_string(),
33 None,
34 None,
35 type_.clone(),
36 parent_pg_field_info.format(),
37 );
38
39 pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone());
40
41 encode_value(&mut row_encoder, arr, idx, &pg_field).unwrap();
42 }
43 Ok(Some(EncodedValue {
44 bytes: row_encoder.row_buffer,
45 }))
46}
47
48pub(crate) struct StructEncoder {
49 num_cols: usize,
50 curr_col: usize,
51 row_buffer: BytesMut,
52}
53
54impl StructEncoder {
55 pub(crate) fn new(num_cols: usize) -> Self {
56 Self {
57 num_cols,
58 curr_col: 0,
59 row_buffer: BytesMut::new(),
60 }
61 }
62}
63
64impl Encoder for StructEncoder {
65 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
66 where
67 T: ToSql + ToSqlText + Sized,
68 {
69 let datatype = pg_field.datatype();
70 let format = pg_field.format();
71
72 if format == FieldFormat::Text {
73 if self.curr_col == 0 {
74 self.row_buffer.put_slice(b"(");
75 }
76 let mut buf = BytesMut::new();
78 value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?;
79 let encoded_value_as_str = String::from_utf8_lossy(&buf);
80 if QUOTE_CHECK.is_match(&encoded_value_as_str) {
81 self.row_buffer.put_u8(b'"');
82 self.row_buffer.put_slice(
83 QUOTE_ESCAPE
84 .replace_all(&encoded_value_as_str, r#"\$1"#)
85 .as_bytes(),
86 );
87 self.row_buffer.put_u8(b'"');
88 } else {
89 self.row_buffer.put_slice(&buf);
90 }
91 if self.curr_col == self.num_cols - 1 {
92 self.row_buffer.put_slice(b")");
93 } else {
94 self.row_buffer.put_slice(b",");
95 }
96 } else {
97 if self.curr_col == 0 && format == FieldFormat::Binary {
98 self.row_buffer.put_i32(self.num_cols as i32);
100 }
101
102 self.row_buffer.put_u32(datatype.oid());
103 let prev_index = self.row_buffer.len();
105 self.row_buffer.put_i32(-1);
107 let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
108 if let IsNull::No = is_null {
109 let value_length = self.row_buffer.len() - prev_index - 4;
110 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
111 length_bytes.put_i32(value_length as i32);
112 }
113 }
114 self.curr_col += 1;
115 Ok(())
116 }
117}