arrow_pg/
struct_encoder.rs1use std::error::Error;
2use std::io::Write;
3use std::sync::Arc;
4
5#[cfg(not(feature = "datafusion"))]
6use arrow::array::{Array, StructArray};
7use arrow_schema::Fields;
8#[cfg(feature = "datafusion")]
9use datafusion::arrow::array::{Array, StructArray};
10
11use bytes::{BufMut, BytesMut};
12use pgwire::api::results::{FieldFormat, FieldInfo};
13use pgwire::error::PgWireResult;
14use pgwire::types::format::FormatOptions;
15use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE};
16use postgres_types::{Field, IsNull, ToSql, Type};
17
18use crate::datatypes::field_into_pg_type;
19use crate::encoder::{encode_value, Encoder};
20
21#[derive(Debug)]
22struct BytesWrapper(BytesMut, bool);
23
24impl ToSql for BytesWrapper {
25 fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Send + Sync>>
26 where
27 Self: Sized,
28 {
29 out.writer().write_all(&self.0)?;
30 Ok(IsNull::No)
31 }
32
33 fn accepts(_ty: &Type) -> bool
34 where
35 Self: Sized,
36 {
37 true
38 }
39
40 fn to_sql_checked(
41 &self,
42 ty: &Type,
43 out: &mut BytesMut,
44 ) -> Result<IsNull, Box<dyn Error + Send + Sync>> {
45 self.to_sql(ty, out)
46 }
47}
48
49impl ToSqlText for BytesWrapper {
50 fn to_sql_text(
51 &self,
52 _ty: &Type,
53 out: &mut BytesMut,
54 _format_options: &FormatOptions,
55 ) -> Result<IsNull, Box<dyn Error + Send + Sync>>
56 where
57 Self: Sized,
58 {
59 if self.1 {
60 out.put_u8(b'"');
61 out.put_slice(
62 QUOTE_ESCAPE
63 .replace_all(&String::from_utf8_lossy(&self.0), r#"\$1"#)
64 .as_bytes(),
65 );
66 out.put_u8(b'"');
67 } else {
68 out.put_slice(&self.0);
69 }
70 Ok(IsNull::No)
71 }
72}
73
74pub(crate) fn encode_structs<T: Encoder>(
75 encoder: &mut T,
76 arr: &Arc<dyn Array>,
77 arrow_fields: &Fields,
78 parent_pg_field_info: &FieldInfo,
79) -> PgWireResult<()> {
80 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
81 let quote_wrapper = matches!(parent_pg_field_info.format(), FieldFormat::Text);
82
83 let fields = arrow_fields
84 .iter()
85 .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
86 .collect::<PgWireResult<Vec<_>>>()?;
87
88 let values: PgWireResult<Vec<_>> = (0..arr.len())
89 .map(|row| {
90 if arr.is_null(row) {
91 Ok(None)
92 } else {
93 let mut row_encoder = StructEncoder::new(arrow_fields.len());
94
95 for (i, arr) in arr.columns().iter().enumerate() {
96 let field = &fields[i];
97 let type_ = field.type_();
98 let arrow_field = &arrow_fields[i];
99
100 let format = parent_pg_field_info.format();
101 let format_options = parent_pg_field_info.format_options().clone();
102 let mut pg_field =
103 FieldInfo::new(field.name().to_string(), None, None, type_.clone(), format);
104 pg_field = pg_field.with_format_options(format_options);
105
106 encode_value(&mut row_encoder, arr, row, arrow_field, &pg_field).unwrap();
107 }
108
109 Ok(Some(BytesWrapper(row_encoder.take_buffer(), quote_wrapper)))
110 }
111 })
112 .collect();
113 encoder.encode_field(&values?, parent_pg_field_info)
114}
115
116pub(crate) fn encode_struct<T: Encoder>(
117 encoder: &mut T,
118 arr: &Arc<dyn Array>,
119 idx: usize,
120 arrow_fields: &Fields,
121 parent_pg_field_info: &FieldInfo,
122) -> PgWireResult<()> {
123 let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
124 if arr.is_null(idx) {
125 return Ok(());
126 }
127
128 let fields = arrow_fields
129 .iter()
130 .map(|f| field_into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t)))
131 .collect::<PgWireResult<Vec<_>>>()?;
132
133 let mut row_encoder = StructEncoder::new(arrow_fields.len());
134
135 for (i, arr) in arr.columns().iter().enumerate() {
136 let field = &fields[i];
137 let type_ = field.type_();
138
139 let arrow_field = &arrow_fields[i];
140
141 let mut pg_field = FieldInfo::new(
142 field.name().to_string(),
143 None,
144 None,
145 type_.clone(),
146 parent_pg_field_info.format(),
147 );
148 pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone());
149
150 encode_value(&mut row_encoder, arr, idx, arrow_field, &pg_field).unwrap();
151 }
152 let encoded_value = BytesWrapper(row_encoder.row_buffer, false);
153 encoder.encode_field(&encoded_value, parent_pg_field_info)
154}
155
156pub(crate) struct StructEncoder {
157 num_cols: usize,
158 curr_col: usize,
159 row_buffer: BytesMut,
160}
161
162impl StructEncoder {
163 pub(crate) fn new(num_cols: usize) -> Self {
164 Self {
165 num_cols,
166 curr_col: 0,
167 row_buffer: BytesMut::new(),
168 }
169 }
170
171 pub(crate) fn take_buffer(self) -> BytesMut {
172 self.row_buffer
173 }
174}
175
176impl Encoder for StructEncoder {
177 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
178 where
179 T: ToSql + ToSqlText + Sized,
180 {
181 let datatype = pg_field.datatype();
182 let format = pg_field.format();
183
184 if format == FieldFormat::Text {
185 if self.curr_col == 0 {
186 self.row_buffer.put_slice(b"(");
187 }
188 let mut buf = BytesMut::new();
190 value.to_sql_text(datatype, &mut buf, pg_field.format_options().as_ref())?;
191 let encoded_value_as_str = String::from_utf8_lossy(&buf);
192 if QUOTE_CHECK.is_match(&encoded_value_as_str) {
193 self.row_buffer.put_u8(b'"');
194 self.row_buffer.put_slice(
195 QUOTE_ESCAPE
196 .replace_all(&encoded_value_as_str, r#"\$1"#)
197 .as_bytes(),
198 );
199 self.row_buffer.put_u8(b'"');
200 } else {
201 self.row_buffer.put_slice(&buf);
202 }
203 if self.curr_col == self.num_cols - 1 {
204 self.row_buffer.put_slice(b")");
205 } else {
206 self.row_buffer.put_slice(b",");
207 }
208 } else {
209 if self.curr_col == 0 && format == FieldFormat::Binary {
210 self.row_buffer.put_i32(self.num_cols as i32);
212 }
213
214 self.row_buffer.put_u32(datatype.oid());
215 let prev_index = self.row_buffer.len();
217 self.row_buffer.put_i32(-1);
219 let is_null = value.to_sql(datatype, &mut self.row_buffer)?;
220 if let IsNull::No = is_null {
221 let value_length = self.row_buffer.len() - prev_index - 4;
222 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
223 length_bytes.put_i32(value_length as i32);
224 }
225 }
226 self.curr_col += 1;
227 Ok(())
228 }
229}