1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::PgWireError;
16use pgwire::error::PgWireResult;
17use pgwire::types::ToSqlText;
18use postgres_types::{ToSql, Type};
19use rust_decimal::Decimal;
20use timezone::Tz;
21
22use crate::error::ToSqlError;
23use crate::list_encoder::encode_list;
24use crate::struct_encoder::encode_struct;
25
26pub trait Encoder {
27 fn encode_field_with_type_and_format<T>(
28 &mut self,
29 value: &T,
30 data_type: &Type,
31 format: FieldFormat,
32 ) -> PgWireResult<()>
33 where
34 T: ToSql + ToSqlText + Sized;
35}
36
37impl Encoder for DataRowEncoder {
38 fn encode_field_with_type_and_format<T>(
39 &mut self,
40 value: &T,
41 data_type: &Type,
42 format: FieldFormat,
43 ) -> PgWireResult<()>
44 where
45 T: ToSql + ToSqlText + Sized,
46 {
47 self.encode_field_with_type_and_format(value, data_type, format)
48 }
49}
50
51pub(crate) struct EncodedValue {
52 pub(crate) bytes: BytesMut,
53}
54
55impl std::fmt::Debug for EncodedValue {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("EncodedValue").finish()
58 }
59}
60
61impl ToSql for EncodedValue {
62 fn to_sql(
63 &self,
64 _ty: &Type,
65 out: &mut BytesMut,
66 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
67 where
68 Self: Sized,
69 {
70 out.writer().write_all(&self.bytes)?;
71 Ok(postgres_types::IsNull::No)
72 }
73
74 fn accepts(_ty: &Type) -> bool
75 where
76 Self: Sized,
77 {
78 true
79 }
80
81 fn to_sql_checked(
82 &self,
83 ty: &Type,
84 out: &mut BytesMut,
85 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
86 self.to_sql(ty, out)
87 }
88}
89
90impl ToSqlText for EncodedValue {
91 fn to_sql_text(
92 &self,
93 _ty: &Type,
94 out: &mut BytesMut,
95 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
96 where
97 Self: Sized,
98 {
99 out.writer().write_all(&self.bytes)?;
100 Ok(postgres_types::IsNull::No)
101 }
102}
103
104fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
105 (!arr.is_null(idx)).then(|| {
106 arr.as_any()
107 .downcast_ref::<BooleanArray>()
108 .unwrap()
109 .value(idx)
110 })
111}
112
113macro_rules! get_primitive_value {
114 ($name:ident, $t:ty, $pt:ty) => {
115 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
116 (!arr.is_null(idx)).then(|| {
117 arr.as_any()
118 .downcast_ref::<PrimitiveArray<$t>>()
119 .unwrap()
120 .value(idx)
121 })
122 }
123 };
124}
125
126get_primitive_value!(get_i8_value, Int8Type, i8);
127get_primitive_value!(get_i16_value, Int16Type, i16);
128get_primitive_value!(get_i32_value, Int32Type, i32);
129get_primitive_value!(get_i64_value, Int64Type, i64);
130get_primitive_value!(get_u8_value, UInt8Type, u8);
131get_primitive_value!(get_u16_value, UInt16Type, u16);
132get_primitive_value!(get_u32_value, UInt32Type, u32);
133get_primitive_value!(get_u64_value, UInt64Type, u64);
134get_primitive_value!(get_f32_value, Float32Type, f32);
135get_primitive_value!(get_f64_value, Float64Type, f64);
136
137fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
138 (!arr.is_null(idx)).then(|| {
139 arr.as_any()
140 .downcast_ref::<StringViewArray>()
141 .unwrap()
142 .value(idx)
143 })
144}
145
146fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
147 (!arr.is_null(idx)).then(|| {
148 arr.as_any()
149 .downcast_ref::<StringArray>()
150 .unwrap()
151 .value(idx)
152 })
153}
154
155fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
156 (!arr.is_null(idx)).then(|| {
157 arr.as_any()
158 .downcast_ref::<LargeStringArray>()
159 .unwrap()
160 .value(idx)
161 })
162}
163
164fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
165 (!arr.is_null(idx)).then(|| {
166 arr.as_any()
167 .downcast_ref::<BinaryArray>()
168 .unwrap()
169 .value(idx)
170 })
171}
172
173fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
174 (!arr.is_null(idx)).then(|| {
175 arr.as_any()
176 .downcast_ref::<LargeBinaryArray>()
177 .unwrap()
178 .value(idx)
179 })
180}
181
182fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
183 if arr.is_null(idx) {
184 return None;
185 }
186 arr.as_any()
187 .downcast_ref::<Date32Array>()
188 .unwrap()
189 .value_as_date(idx)
190}
191
192fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
193 if arr.is_null(idx) {
194 return None;
195 }
196 arr.as_any()
197 .downcast_ref::<Date64Array>()
198 .unwrap()
199 .value_as_date(idx)
200}
201
202fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
203 if arr.is_null(idx) {
204 return None;
205 }
206 arr.as_any()
207 .downcast_ref::<Time32SecondArray>()
208 .unwrap()
209 .value_as_datetime(idx)
210}
211
212fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
213 if arr.is_null(idx) {
214 return None;
215 }
216 arr.as_any()
217 .downcast_ref::<Time32MillisecondArray>()
218 .unwrap()
219 .value_as_datetime(idx)
220}
221
222fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
223 if arr.is_null(idx) {
224 return None;
225 }
226 arr.as_any()
227 .downcast_ref::<Time64MicrosecondArray>()
228 .unwrap()
229 .value_as_datetime(idx)
230}
231fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
232 if arr.is_null(idx) {
233 return None;
234 }
235 arr.as_any()
236 .downcast_ref::<Time64NanosecondArray>()
237 .unwrap()
238 .value_as_datetime(idx)
239}
240
241fn get_numeric_128_value(
242 arr: &Arc<dyn Array>,
243 idx: usize,
244 scale: u32,
245) -> PgWireResult<Option<Decimal>> {
246 if arr.is_null(idx) {
247 return Ok(None);
248 }
249
250 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
251 let value = array.value(idx);
252 Decimal::try_from_i128_with_scale(value, scale)
253 .map_err(|e| {
254 let message = match e {
255 rust_decimal::Error::ExceedsMaximumPossibleValue => {
256 "Exceeds maximum possible value"
257 }
258 rust_decimal::Error::LessThanMinimumPossibleValue => {
259 "Less than minimum possible value"
260 }
261 rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => {
262 "Scale exceeds maximum precision"
263 }
264 _ => unreachable!(),
265 };
266 PgWireError::ApiError(ToSqlError::from(message))
268 })
269 .map(Some)
270}
271
272pub fn encode_value<T: Encoder>(
273 encoder: &mut T,
274 arr: &Arc<dyn Array>,
275 idx: usize,
276 type_: &Type,
277 format: FieldFormat,
278) -> PgWireResult<()> {
279 match arr.data_type() {
280 DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
281 DataType::Boolean => {
282 encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
283 }
284 DataType::Int8 => {
285 encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
286 }
287 DataType::Int16 => {
288 encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
289 }
290 DataType::Int32 => {
291 encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
292 }
293 DataType::Int64 => {
294 encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
295 }
296 DataType::UInt8 => encoder.encode_field_with_type_and_format(
297 &(get_u8_value(arr, idx).map(|x| x as i8)),
298 type_,
299 format,
300 )?,
301 DataType::UInt16 => encoder.encode_field_with_type_and_format(
302 &(get_u16_value(arr, idx).map(|x| x as i16)),
303 type_,
304 format,
305 )?,
306 DataType::UInt32 => {
307 encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
308 }
309 DataType::UInt64 => encoder.encode_field_with_type_and_format(
310 &(get_u64_value(arr, idx).map(|x| x as i64)),
311 type_,
312 format,
313 )?,
314 DataType::Float32 => {
315 encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
316 }
317 DataType::Float64 => {
318 encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
319 }
320 DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
321 &get_numeric_128_value(arr, idx, *s as u32)?,
322 type_,
323 format,
324 )?,
325 DataType::Utf8 => {
326 encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
327 }
328 DataType::Utf8View => encoder.encode_field_with_type_and_format(
329 &get_utf8_view_value(arr, idx),
330 type_,
331 format,
332 )?,
333 DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
334 &get_large_utf8_value(arr, idx),
335 type_,
336 format,
337 )?,
338 DataType::Binary => {
339 encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
340 }
341 DataType::LargeBinary => encoder.encode_field_with_type_and_format(
342 &get_large_binary_value(arr, idx),
343 type_,
344 format,
345 )?,
346 DataType::Date32 => {
347 encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
348 }
349 DataType::Date64 => {
350 encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
351 }
352 DataType::Time32(unit) => match unit {
353 TimeUnit::Second => encoder.encode_field_with_type_and_format(
354 &get_time32_second_value(arr, idx),
355 type_,
356 format,
357 )?,
358 TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
359 &get_time32_millisecond_value(arr, idx),
360 type_,
361 format,
362 )?,
363 _ => {}
364 },
365 DataType::Time64(unit) => match unit {
366 TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
367 &get_time64_microsecond_value(arr, idx),
368 type_,
369 format,
370 )?,
371 TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
372 &get_time64_nanosecond_value(arr, idx),
373 type_,
374 format,
375 )?,
376 _ => {}
377 },
378 DataType::Timestamp(unit, timezone) => match unit {
379 TimeUnit::Second => {
380 if arr.is_null(idx) {
381 return encoder.encode_field_with_type_and_format(
382 &None::<NaiveDateTime>,
383 type_,
384 format,
385 );
386 }
387 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
388 if let Some(tz) = timezone {
389 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
390 let value = ts_array
391 .value_as_datetime_with_tz(idx, tz)
392 .map(|d| d.fixed_offset());
393 encoder.encode_field_with_type_and_format(&value, type_, format)?;
394 } else {
395 let value = ts_array.value_as_datetime(idx);
396 encoder.encode_field_with_type_and_format(&value, type_, format)?;
397 }
398 }
399 TimeUnit::Millisecond => {
400 if arr.is_null(idx) {
401 return encoder.encode_field_with_type_and_format(
402 &None::<NaiveDateTime>,
403 type_,
404 format,
405 );
406 }
407 let ts_array = arr
408 .as_any()
409 .downcast_ref::<TimestampMillisecondArray>()
410 .unwrap();
411 if let Some(tz) = timezone {
412 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
413 let value = ts_array
414 .value_as_datetime_with_tz(idx, tz)
415 .map(|d| d.fixed_offset());
416 encoder.encode_field_with_type_and_format(&value, type_, format)?;
417 } else {
418 let value = ts_array.value_as_datetime(idx);
419 encoder.encode_field_with_type_and_format(&value, type_, format)?;
420 }
421 }
422 TimeUnit::Microsecond => {
423 if arr.is_null(idx) {
424 return encoder.encode_field_with_type_and_format(
425 &None::<NaiveDateTime>,
426 type_,
427 format,
428 );
429 }
430 let ts_array = arr
431 .as_any()
432 .downcast_ref::<TimestampMicrosecondArray>()
433 .unwrap();
434 if let Some(tz) = timezone {
435 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
436 let value = ts_array
437 .value_as_datetime_with_tz(idx, tz)
438 .map(|d| d.fixed_offset());
439 encoder.encode_field_with_type_and_format(&value, type_, format)?;
440 } else {
441 let value = ts_array.value_as_datetime(idx);
442 encoder.encode_field_with_type_and_format(&value, type_, format)?;
443 }
444 }
445 TimeUnit::Nanosecond => {
446 if arr.is_null(idx) {
447 return encoder.encode_field_with_type_and_format(
448 &None::<NaiveDateTime>,
449 type_,
450 format,
451 );
452 }
453 let ts_array = arr
454 .as_any()
455 .downcast_ref::<TimestampNanosecondArray>()
456 .unwrap();
457 if let Some(tz) = timezone {
458 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
459 let value = ts_array
460 .value_as_datetime_with_tz(idx, tz)
461 .map(|d| d.fixed_offset());
462 encoder.encode_field_with_type_and_format(&value, type_, format)?;
463 } else {
464 let value = ts_array.value_as_datetime(idx);
465 encoder.encode_field_with_type_and_format(&value, type_, format)?;
466 }
467 }
468 },
469 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
470 if arr.is_null(idx) {
471 return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
472 }
473 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
474 let value = encode_list(array, type_, format)?;
475 encoder.encode_field_with_type_and_format(&value, type_, format)?
476 }
477 DataType::Struct(_) => {
478 let fields = match type_.kind() {
479 postgres_types::Kind::Composite(fields) => fields,
480 _ => {
481 return Err(PgWireError::ApiError(ToSqlError::from(format!(
482 "Failed to unwrap a composite type from type {}",
483 type_
484 ))));
485 }
486 };
487 let value = encode_struct(arr, idx, fields, format)?;
488 encoder.encode_field_with_type_and_format(&value, type_, format)?
489 }
490 DataType::Dictionary(_, value_type) => {
491 if arr.is_null(idx) {
492 return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
493 }
494 macro_rules! get_dict_values {
497 ($key_type:ty) => {
498 arr.as_any()
499 .downcast_ref::<DictionaryArray<$key_type>>()
500 .map(|dict| dict.values())
501 };
502 }
503
504 let values = get_dict_values!(Int8Type)
506 .or_else(|| get_dict_values!(Int16Type))
507 .or_else(|| get_dict_values!(Int32Type))
508 .or_else(|| get_dict_values!(Int64Type))
509 .or_else(|| get_dict_values!(UInt8Type))
510 .or_else(|| get_dict_values!(UInt16Type))
511 .or_else(|| get_dict_values!(UInt32Type))
512 .or_else(|| get_dict_values!(UInt64Type))
513 .ok_or_else(|| {
514 ToSqlError::from(format!(
515 "Unsupported dictionary key type for value type {}",
516 value_type
517 ))
518 })?;
519
520 if values.len() == 1 {
522 encode_value(encoder, values, 0, type_, format)?
523 } else {
524 encode_value(encoder, values, idx, type_, format)?
526 }
527 }
528 _ => {
529 return Err(PgWireError::ApiError(ToSqlError::from(format!(
530 "Unsupported Datatype {} and array {:?}",
531 arr.data_type(),
532 &arr
533 ))));
534 }
535 }
536
537 Ok(())
538}