arrow_pg/
struct_encoder.rs1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::array::{Array, StructArray};
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::array::{Array, StructArray};
7
8use bytes::{BufMut, BytesMut};
9use pgwire::api::results::FieldFormat;
10use pgwire::error::PgWireResult;
11use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
12use postgres_types::{Field, IsNull, ToSql, Type};
13
14use crate::encoder::{encode_value, EncodedValue, Encoder};
15
16pub(crate) fn encode_struct(
17 arr: &Arc<dyn Array>,
18 idx: usize,
19 fields: &[Field],
20 format: FieldFormat,
21) -> PgWireResult<Option<EncodedValue>> {
22 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
23 if arr.is_null(idx) {
24 return Ok(None);
25 }
26 let mut row_encoder = StructEncoder::new(fields.len());
27 for (i, arr) in arr.columns().iter().enumerate() {
28 let field = &fields[i];
29 let type_ = field.type_();
30 encode_value(&mut row_encoder, arr, idx, type_, format).unwrap();
31 }
32 Ok(Some(EncodedValue {
33 bytes: row_encoder.row_buffer,
34 }))
35}
36
37pub(crate) struct StructEncoder {
38 num_cols: usize,
39 curr_col: usize,
40 row_buffer: BytesMut,
41}
42
43impl StructEncoder {
44 pub(crate) fn new(num_cols: usize) -> Self {
45 Self {
46 num_cols,
47 curr_col: 0,
48 row_buffer: BytesMut::new(),
49 }
50 }
51}
52
53impl Encoder for StructEncoder {
54 fn encode_field_with_type_and_format<T>(
55 &mut self,
56 value: &T,
57 data_type: &Type,
58 format: FieldFormat,
59 ) -> PgWireResult<()>
60 where
61 T: ToSql + ToSqlText + Sized,
62 {
63 if format == FieldFormat::Text {
64 if self.curr_col == 0 {
65 self.row_buffer.put_slice(b"(");
66 }
67 let mut buf = BytesMut::new();
69 value.to_sql_text(data_type, &mut buf)?;
70 let encoded_value_as_str = String::from_utf8_lossy(&buf);
71 if QUOTE_CHECK.is_match(&encoded_value_as_str) {
72 self.row_buffer.put_u8(b'"');
73 self.row_buffer.put_slice(
74 QUOTE_ESCAPE
75 .replace_all(&encoded_value_as_str, r#"\$1"#)
76 .as_bytes(),
77 );
78 self.row_buffer.put_u8(b'"');
79 } else {
80 self.row_buffer.put_slice(&buf);
81 }
82 if self.curr_col == self.num_cols - 1 {
83 self.row_buffer.put_slice(b")");
84 } else {
85 self.row_buffer.put_slice(b",");
86 }
87 } else {
88 if self.curr_col == 0 && format == FieldFormat::Binary {
89 self.row_buffer.put_i32(self.num_cols as i32);
91 }
92
93 self.row_buffer.put_u32(data_type.oid());
94 let prev_index = self.row_buffer.len();
96 self.row_buffer.put_i32(-1);
98 let is_null = value.to_sql(data_type, &mut self.row_buffer)?;
99 if let IsNull::No = is_null {
100 let value_length = self.row_buffer.len() - prev_index - 4;
101 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
102 length_bytes.put_i32(value_length as i32);
103 }
104 }
105 self.curr_col += 1;
106 Ok(())
107 }
108}