arrow_pg/
struct_encoder.rs1use std::error::Error;
2use std::io::Write;
3use std::sync::Arc;
4
5#[cfg(not(feature = "datafusion"))]
6use arrow::array::{Array, StructArray};
7use arrow_schema::Fields;
8#[cfg(feature = "datafusion")]
9use datafusion::arrow::array::{Array, StructArray};
10
11use bytes::{BufMut, BytesMut};
12use pgwire::api::results::{FieldFormat, FieldInfo};
13use pgwire::error::PgWireResult;
14use pgwire::types::format::FormatOptions;
15use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
16use postgres_types::{Field, IsNull, ToSql, Type};
17
18use crate::datatypes::field_into_pg_type;
19use crate::encoder::{encode_value, Encoder};
20
21#[derive(Debug)]
22struct BytesWrapper(BytesMut, bool);
23
24impl ToSql for BytesWrapper {
25 fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Send + Sync>>
26 where
27 Self: Sized,
28 {
29 out.writer().write_all(&self.0)?;
30 Ok(IsNull::No)
31 }
32
33 fn accepts(_ty: &Type) -> bool
34 where
35 Self: Sized,
36 {
37 true
38 }
39
40 fn to_sql_checked(
41 &self,
42 ty: &Type,
43 out: &mut BytesMut,
44 ) -> Result<IsNull, Box<dyn Error + Send + Sync>> {
45 self.to_sql(ty, out)
46 }
47}
48
49impl ToSqlText for BytesWrapper {
50 fn to_sql_text(
51 &self,
52 _ty: &Type,
53 out: &mut BytesMut,
54 _format_options: &FormatOptions,
55 ) -> Result<IsNull, Box<dyn Error + Send + Sync>>
56 where
57 Self: Sized,
58 {
59 if self.1 {
60 out.put_u8(b'"');
61 out.put_slice(
62 QUOTE_ESCAPE
63 .replace_all(&String::from_utf8_lossy(&self.0), r#"\$1"#)
64 .as_bytes(),
65 );
66 out.put_u8(b'"');
67 } else {
68 out.put_slice(&self.0);
69 }
70 Ok(IsNull::No)
71 }
72}
73
74pub(crate) fn encode_structs<T: Encoder>(
75 encoder: &mut T,
76 arr: &Arc<dyn Array>,
77 arrow_fields: &Fields,
78 parent_pg_field_info: &FieldInfo,
79) -> PgWireResult<()> {
80 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
81 let quote_wrapper = matches!(parent_pg_field_info.format(), FieldFormat::Text);
82
83 let fields = arrow_fields
84 .iter()
85 .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
86 .collect::<PgWireResult<Vec<_>>>()?;
87
88 let values: PgWireResult<Vec<_>> = (0..arr.len())
89 .map(|row| {
90 if arr.is_null(row) {
91 Ok(None)
92 } else {
93 let mut row_encoder = StructEncoder::new(arrow_fields.len());
94
95 for (i, arr) in arr.columns().iter().enumerate() {
96 let field = &fields[i];
97 let type_ = field.type_();
98 let arrow_field = &arrow_fields[i];
99
100 let format = parent_pg_field_info.format();
101 let format_options = parent_pg_field_info.format_options().clone();
102 let mut pg_field =
103 FieldInfo::new(field.name().to_string(), None, None, type_.clone(), format);
104 pg_field = pg_field.with_format_options(format_options);
105
106 encode_value(&mut row_encoder, arr, row, arrow_field, &pg_field).unwrap();
107 }
108
109 Ok(Some(BytesWrapper(row_encoder.take_buffer(), quote_wrapper)))
110 }
111 })
112 .collect();
113 encoder.encode_field(&values?, parent_pg_field_info)
114}
115
116pub(crate) fn encode_struct<T: Encoder>(
117 encoder: &mut T,
118 arr: &Arc<dyn Array>,
119 idx: usize,
120 arrow_fields: &Fields,
121 parent_pg_field_info: &FieldInfo,
122) -> PgWireResult<()> {
123 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
124 if arr.is_null(idx) {
125 return encoder.encode_field(&None::<&[i8]>, parent_pg_field_info);
126 }
127
128 let fields = arrow_fields
129 .iter()
130 .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
131 .collect::<PgWireResult<Vec<_>>>()?;
132
133 let mut row_encoder = StructEncoder::new(arrow_fields.len());
134
135 for (i, arr) in arr.columns().iter().enumerate() {
136 let field = &fields[i];
137 let type_ = field.type_();
138
139 let arrow_field = &arrow_fields[i];
140
141 let mut pg_field = FieldInfo::new(
142 field.name().to_string(),
143 None,
144 None,
145 type_.clone(),
146 parent_pg_field_info.format(),
147 );
148 pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone());
149
150 encode_value(&mut row_encoder, arr, idx, arrow_field, &pg_field).unwrap();
151 }
152 let encoded_value = BytesWrapper(row_encoder.row_buffer, false);
153 encoder.encode_field(&encoded_value, parent_pg_field_info)
154}
155
156pub(crate) struct StructEncoder {
157 num_cols: usize,
158 curr_col: usize,
159 row_buffer: BytesMut,
160}
161
162impl StructEncoder {
163 pub(crate) fn new(num_cols: usize) -> Self {
164 Self {
165 num_cols,
166 curr_col: 0,
167 row_buffer: BytesMut::new(),
168 }
169 }
170
171 pub(crate) fn take_buffer(self) -> BytesMut {
172 self.row_buffer
173 }
174}
175
176impl Encoder for StructEncoder {
177 type Item = BytesMut;
178
179 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
180 where
181 T: ToSql + ToSqlText + Sized,
182 {
183 let datatype = pg_field.datatype();
184 let format = pg_field.format();
185
186 if format == FieldFormat::Text {
187 if self.curr_col == 0 {
188 self.row_buffer.put_slice(b"(");
189 }
190 let mut buf = BytesMut::new();
192 value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?;
193 let encoded_value_as_str = String::from_utf8_lossy(&buf);
194 if QUOTE_CHECK.is_match(&encoded_value_as_str) {
195 self.row_buffer.put_u8(b'"');
196 self.row_buffer.put_slice(
197 QUOTE_ESCAPE
198 .replace_all(&encoded_value_as_str, r#"\$1"#)
199 .as_bytes(),
200 );
201 self.row_buffer.put_u8(b'"');
202 } else {
203 self.row_buffer.put_slice(&buf);
204 }
205 if self.curr_col == self.num_cols - 1 {
206 self.row_buffer.put_slice(b")");
207 } else {
208 self.row_buffer.put_slice(b",");
209 }
210 } else {
211 if self.curr_col == 0 && format == FieldFormat::Binary {
212 self.row_buffer.put_i32(self.num_cols as i32);
214 }
215
216 self.row_buffer.put_u32(datatype.oid());
217 let prev_index = self.row_buffer.len();
219 self.row_buffer.put_i32(-1);
221 let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
222 if let IsNull::No = is_null {
223 let value_length = self.row_buffer.len() - prev_index - 4;
224 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
225 length_bytes.put_i32(value_length as i32);
226 }
227 }
228 self.curr_col += 1;
229 Ok(())
230 }
231
232 fn take_row(&mut self) -> Self::Item {
233 std::mem::take(&mut self.row_buffer)
234 }
235}