1use std::str::FromStr;
2use std::sync::Arc;
3
4#[cfg(not(feature = "datafusion"))]
5use arrow::{array::*, datatypes::*};
6use chrono::NaiveTime;
7use chrono::{NaiveDate, NaiveDateTime};
8#[cfg(feature = "datafusion")]
9use datafusion::arrow::{array::*, datatypes::*};
10use pg_interval::Interval as PgInterval;
11use pgwire::api::results::{DataRowEncoder, FieldInfo};
12use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
13use pgwire::types::ToSqlText;
14use postgres_types::ToSql;
15use rust_decimal::Decimal;
16use timezone::Tz;
17
18use crate::error::ToSqlError;
19#[cfg(feature = "geo")]
20use crate::geo_encoder::encode_geo;
21use crate::list_encoder::encode_list;
22use crate::struct_encoder::encode_struct;
23
24pub trait Encoder {
25 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
26 where
27 T: ToSql + ToSqlText + Sized;
28}
29
30impl Encoder for DataRowEncoder {
31 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
32 where
33 T: ToSql + ToSqlText + Sized,
34 {
35 self.encode_field_with_type_and_format(
36 value,
37 pg_field.datatype(),
38 pg_field.format(),
39 pg_field.format_options(),
40 )
41 }
42}
43
44fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
45 (!arr.is_null(idx)).then(|| {
46 arr.as_any()
47 .downcast_ref::<BooleanArray>()
48 .unwrap()
49 .value(idx)
50 })
51}
52
53macro_rules! get_primitive_value {
54 ($name:ident, $t:ty, $pt:ty) => {
55 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
56 (!arr.is_null(idx)).then(|| {
57 arr.as_any()
58 .downcast_ref::<PrimitiveArray<$t>>()
59 .unwrap()
60 .value(idx)
61 })
62 }
63 };
64}
65
66get_primitive_value!(get_i8_value, Int8Type, i8);
67get_primitive_value!(get_i16_value, Int16Type, i16);
68get_primitive_value!(get_i32_value, Int32Type, i32);
69get_primitive_value!(get_i64_value, Int64Type, i64);
70get_primitive_value!(get_u8_value, UInt8Type, u8);
71get_primitive_value!(get_u16_value, UInt16Type, u16);
72get_primitive_value!(get_u32_value, UInt32Type, u32);
73get_primitive_value!(get_u64_value, UInt64Type, u64);
74
75fn get_u64_as_decimal_value(arr: &Arc<dyn Array>, idx: usize) -> Option<Decimal> {
76 get_u64_value(arr, idx).map(Decimal::from)
77}
78get_primitive_value!(get_f32_value, Float32Type, f32);
79get_primitive_value!(get_f64_value, Float64Type, f64);
80
81fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
82 (!arr.is_null(idx)).then(|| {
83 arr.as_any()
84 .downcast_ref::<StringViewArray>()
85 .unwrap()
86 .value(idx)
87 })
88}
89
90fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
91 (!arr.is_null(idx)).then(|| {
92 arr.as_any()
93 .downcast_ref::<BinaryViewArray>()
94 .unwrap()
95 .value(idx)
96 })
97}
98
99fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
100 (!arr.is_null(idx)).then(|| {
101 arr.as_any()
102 .downcast_ref::<StringArray>()
103 .unwrap()
104 .value(idx)
105 })
106}
107
108fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
109 (!arr.is_null(idx)).then(|| {
110 arr.as_any()
111 .downcast_ref::<LargeStringArray>()
112 .unwrap()
113 .value(idx)
114 })
115}
116
117fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
118 (!arr.is_null(idx)).then(|| {
119 arr.as_any()
120 .downcast_ref::<BinaryArray>()
121 .unwrap()
122 .value(idx)
123 })
124}
125
126fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
127 (!arr.is_null(idx)).then(|| {
128 arr.as_any()
129 .downcast_ref::<LargeBinaryArray>()
130 .unwrap()
131 .value(idx)
132 })
133}
134
135fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
136 if arr.is_null(idx) {
137 return None;
138 }
139 arr.as_any()
140 .downcast_ref::<Date32Array>()
141 .unwrap()
142 .value_as_date(idx)
143}
144
145fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
146 if arr.is_null(idx) {
147 return None;
148 }
149 arr.as_any()
150 .downcast_ref::<Date64Array>()
151 .unwrap()
152 .value_as_date(idx)
153}
154
155fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
156 if arr.is_null(idx) {
157 return None;
158 }
159 arr.as_any()
160 .downcast_ref::<Time32SecondArray>()
161 .unwrap()
162 .value_as_time(idx)
163}
164
165fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
166 if arr.is_null(idx) {
167 return None;
168 }
169 arr.as_any()
170 .downcast_ref::<Time32MillisecondArray>()
171 .unwrap()
172 .value_as_time(idx)
173}
174
175fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
176 if arr.is_null(idx) {
177 return None;
178 }
179 arr.as_any()
180 .downcast_ref::<Time64MicrosecondArray>()
181 .unwrap()
182 .value_as_time(idx)
183}
184fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
185 if arr.is_null(idx) {
186 return None;
187 }
188 arr.as_any()
189 .downcast_ref::<Time64NanosecondArray>()
190 .unwrap()
191 .value_as_time(idx)
192}
193
194fn get_numeric_128_value(
195 arr: &Arc<dyn Array>,
196 idx: usize,
197 scale: u32,
198) -> PgWireResult<Option<Decimal>> {
199 if arr.is_null(idx) {
200 return Ok(None);
201 }
202
203 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
204 let value = array.value(idx);
205 Decimal::try_from_i128_with_scale(value, scale)
206 .map_err(|e| {
207 let error_code = match e {
208 rust_decimal::Error::ExceedsMaximumPossibleValue => {
209 "22003" }
211 rust_decimal::Error::LessThanMinimumPossibleValue => {
212 "22003" }
214 rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
215 return PgWireError::UserError(Box::new(ErrorInfo::new(
216 "ERROR".to_string(),
217 "22003".to_string(),
218 format!("Scale {scale} exceeds maximum precision for numeric type"),
219 )));
220 }
221 _ => "22003", };
223 PgWireError::UserError(Box::new(ErrorInfo::new(
224 "ERROR".to_string(),
225 error_code.to_string(),
226 format!("Numeric value conversion failed: {e}"),
227 )))
228 })
229 .map(Some)
230}
231
232pub fn encode_value<T: Encoder>(
233 encoder: &mut T,
234 arr: &Arc<dyn Array>,
235 idx: usize,
236 arrow_field: &Field,
237 pg_field: &FieldInfo,
238) -> PgWireResult<()> {
239 let arrow_type = arrow_field.data_type();
240
241 #[cfg(feature = "geo")]
242 if let Some(geoarrow_type) = geoarrow_schema::GeoArrowType::from_extension_field(arrow_field)
243 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
244 {
245 let geoarrow_array: Arc<dyn geoarrow::array::GeoArrowArray> =
246 geoarrow::array::from_arrow_array(arr, arrow_field)
247 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
248
249 return encode_geo(
250 encoder,
251 geoarrow_type,
252 &geoarrow_array,
253 idx,
254 arrow_field,
255 pg_field,
256 );
257 }
258
259 match arrow_type {
260 DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
261 DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
262 DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
263 DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?,
264 DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?,
265 DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?,
266 DataType::UInt8 => {
267 encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i16)), pg_field)?
268 }
269 DataType::UInt16 => {
270 encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i32)), pg_field)?
271 }
272 DataType::UInt32 => {
273 encoder.encode_field(&get_u32_value(arr, idx).map(|x| x as i64), pg_field)?
274 }
275 DataType::UInt64 => encoder.encode_field(&get_u64_as_decimal_value(arr, idx), pg_field)?,
276 DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?,
277 DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?,
278 DataType::Decimal128(_, s) => {
279 encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)?
280 }
281 DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?,
282 DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?,
283 DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?,
284 DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?,
285 DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?,
286 DataType::LargeBinary => {
287 encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)?
288 }
289 DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?,
290 DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?,
291 DataType::Time32(unit) => match unit {
292 TimeUnit::Second => {
293 encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)?
294 }
295 TimeUnit::Millisecond => {
296 encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)?
297 }
298 _ => {}
299 },
300 DataType::Time64(unit) => match unit {
301 TimeUnit::Microsecond => {
302 encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)?
303 }
304 TimeUnit::Nanosecond => {
305 encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)?
306 }
307 _ => {}
308 },
309 DataType::Timestamp(unit, timezone) => match unit {
310 TimeUnit::Second => {
311 if arr.is_null(idx) {
312 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
313 }
314 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
315 if let Some(tz) = timezone {
316 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
317 let value = ts_array
318 .value_as_datetime_with_tz(idx, tz)
319 .map(|d| d.fixed_offset());
320
321 encoder.encode_field(&value, pg_field)?;
322 } else {
323 let value = ts_array.value_as_datetime(idx);
324 encoder.encode_field(&value, pg_field)?;
325 }
326 }
327 TimeUnit::Millisecond => {
328 if arr.is_null(idx) {
329 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
330 }
331 let ts_array = arr
332 .as_any()
333 .downcast_ref::<TimestampMillisecondArray>()
334 .unwrap();
335 if let Some(tz) = timezone {
336 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
337 let value = ts_array
338 .value_as_datetime_with_tz(idx, tz)
339 .map(|d| d.fixed_offset());
340 encoder.encode_field(&value, pg_field)?;
341 } else {
342 let value = ts_array.value_as_datetime(idx);
343 encoder.encode_field(&value, pg_field)?;
344 }
345 }
346 TimeUnit::Microsecond => {
347 if arr.is_null(idx) {
348 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
349 }
350 let ts_array = arr
351 .as_any()
352 .downcast_ref::<TimestampMicrosecondArray>()
353 .unwrap();
354 if let Some(tz) = timezone {
355 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
356 let value = ts_array
357 .value_as_datetime_with_tz(idx, tz)
358 .map(|d| d.fixed_offset());
359 encoder.encode_field(&value, pg_field)?;
360 } else {
361 let value = ts_array.value_as_datetime(idx);
362 encoder.encode_field(&value, pg_field)?;
363 }
364 }
365 TimeUnit::Nanosecond => {
366 if arr.is_null(idx) {
367 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
368 }
369 let ts_array = arr
370 .as_any()
371 .downcast_ref::<TimestampNanosecondArray>()
372 .unwrap();
373 if let Some(tz) = timezone {
374 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
375 let value = ts_array
376 .value_as_datetime_with_tz(idx, tz)
377 .map(|d| d.fixed_offset());
378 encoder.encode_field(&value, pg_field)?;
379 } else {
380 let value = ts_array.value_as_datetime(idx);
381 encoder.encode_field(&value, pg_field)?;
382 }
383 }
384 },
385 DataType::Interval(interval_unit) => match interval_unit {
386 IntervalUnit::YearMonth => {
387 let interval_array = arr
388 .as_any()
389 .downcast_ref::<IntervalYearMonthArray>()
390 .unwrap();
391 let months = IntervalYearMonthType::to_months(interval_array.value(idx));
392 encoder.encode_field(&PgInterval::new(months, 0, 0), pg_field)?;
393 }
394 IntervalUnit::DayTime => {
395 let interval_array = arr.as_any().downcast_ref::<IntervalDayTimeArray>().unwrap();
396 let (days, millis) = IntervalDayTimeType::to_parts(interval_array.value(idx));
397 encoder
398 .encode_field(&PgInterval::new(0, days, millis as i64 * 1000i64), pg_field)?;
399 }
400 IntervalUnit::MonthDayNano => {
401 let interval_array = arr
402 .as_any()
403 .downcast_ref::<IntervalMonthDayNanoArray>()
404 .unwrap();
405 let (months, days, nanoseconds) =
406 IntervalMonthDayNanoType::to_parts(interval_array.value(idx));
407
408 encoder.encode_field(
409 &PgInterval::new(months, days, nanoseconds / 1000i64),
410 pg_field,
411 )?;
412 }
413 },
414 DataType::Duration(unit) => match unit {
415 TimeUnit::Second => {
416 if arr.is_null(idx) {
417 return encoder.encode_field(&None::<PgInterval>, pg_field);
418 }
419 let duration_array = arr.as_any().downcast_ref::<DurationSecondArray>().unwrap();
420 let microseconds = duration_array.value(idx) * 1_000_000i64;
421 encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
422 }
423 TimeUnit::Millisecond => {
424 if arr.is_null(idx) {
425 return encoder.encode_field(&None::<PgInterval>, pg_field);
426 }
427 let duration_array = arr
428 .as_any()
429 .downcast_ref::<DurationMillisecondArray>()
430 .unwrap();
431 let microseconds = duration_array.value(idx) * 1_000i64;
432 encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
433 }
434 TimeUnit::Microsecond => {
435 if arr.is_null(idx) {
436 return encoder.encode_field(&None::<PgInterval>, pg_field);
437 }
438 let duration_array = arr
439 .as_any()
440 .downcast_ref::<DurationMicrosecondArray>()
441 .unwrap();
442 let microseconds = duration_array.value(idx);
443 encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
444 }
445 TimeUnit::Nanosecond => {
446 if arr.is_null(idx) {
447 return encoder.encode_field(&None::<PgInterval>, pg_field);
448 }
449 let duration_array = arr
450 .as_any()
451 .downcast_ref::<DurationNanosecondArray>()
452 .unwrap();
453 let microseconds = duration_array.value(idx) / 1_000i64;
454 encoder.encode_field(&PgInterval::new(0, 0, microseconds), pg_field)?;
455 }
456 },
457 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
458 if arr.is_null(idx) {
459 return encoder.encode_field(&None::<&[i8]>, pg_field);
460 }
461 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
462 encode_list(encoder, array, pg_field)?
463 }
464 DataType::Struct(arrow_fields) => encode_struct(encoder, arr, idx, arrow_fields, pg_field)?,
465 DataType::Dictionary(_, value_type) => {
466 if arr.is_null(idx) {
467 return encoder.encode_field(&None::<i8>, pg_field);
468 }
469 macro_rules! get_dict_values_and_index {
471 ($key_type:ty) => {
472 arr.as_any()
473 .downcast_ref::<DictionaryArray<$key_type>>()
474 .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
475 };
476 }
477
478 let (values, idx) = get_dict_values_and_index!(Int8Type)
480 .or_else(|| get_dict_values_and_index!(Int16Type))
481 .or_else(|| get_dict_values_and_index!(Int32Type))
482 .or_else(|| get_dict_values_and_index!(Int64Type))
483 .or_else(|| get_dict_values_and_index!(UInt8Type))
484 .or_else(|| get_dict_values_and_index!(UInt16Type))
485 .or_else(|| get_dict_values_and_index!(UInt32Type))
486 .or_else(|| get_dict_values_and_index!(UInt64Type))
487 .ok_or_else(|| {
488 ToSqlError::from(format!(
489 "Unsupported dictionary key type for value type {value_type}"
490 ))
491 })?;
492
493 let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true);
494
495 encode_value(encoder, values, idx, &inner_arrow_field, pg_field)?
496 }
497 _ => {
498 return Err(PgWireError::ApiError(ToSqlError::from(format!(
499 "Unsupported Datatype {} and array {:?}",
500 arr.data_type(),
501 &arr
502 ))));
503 }
504 }
505
506 Ok(())
507}
508
509#[cfg(test)]
510mod tests {
511 use bytes::BytesMut;
512 use pgwire::{api::results::FieldFormat, types::format::FormatOptions};
513 use postgres_types::Type;
514
515 use super::*;
516
517 #[test]
518 fn encodes_dictionary_array() {
519 #[derive(Default)]
520 struct MockEncoder {
521 encoded_value: String,
522 }
523
524 impl Encoder for MockEncoder {
525 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
526 where
527 T: ToSql + ToSqlText + Sized,
528 {
529 let mut bytes = BytesMut::new();
530 let _sql_text =
531 value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default());
532 let string = String::from_utf8(bytes.to_vec());
533 self.encoded_value = string.unwrap();
534 Ok(())
535 }
536 }
537
538 let val = "~!@&$[]()@@!!";
539 let value = StringArray::from_iter_values([val]);
540 let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
541 let dict_arr: Arc<dyn Array> =
542 Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
543
544 let mut encoder = MockEncoder::default();
545
546 let arrow_field = Field::new(
547 "x",
548 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
549 true,
550 );
551 let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text);
552 let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field);
553
554 assert!(result.is_ok());
555
556 assert!(encoder.encoded_value == val);
557 }
558
559 #[test]
560 fn test_get_time32_second_value() {
561 let array = Time32SecondArray::from_iter_values([3723_i32]);
562 let array: Arc<dyn Array> = Arc::new(array);
563 let value = get_time32_second_value(&array, 0);
564 assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap());
565 }
566
567 #[test]
568 fn test_get_time32_millisecond_value() {
569 let array = Time32MillisecondArray::from_iter_values([3723001_i32]);
570 let array: Arc<dyn Array> = Arc::new(array);
571 let value = get_time32_millisecond_value(&array, 0);
572 assert_eq!(
573 value,
574 Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap()
575 );
576 }
577
578 #[test]
579 fn test_get_time64_microsecond_value() {
580 let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]);
581 let array: Arc<dyn Array> = Arc::new(array);
582 let value = get_time64_microsecond_value(&array, 0);
583 assert_eq!(
584 value,
585 Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap()
586 );
587 }
588
589 #[test]
590 fn test_get_time64_nanosecond_value() {
591 let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]);
592 let array: Arc<dyn Array> = Arc::new(array);
593 let value = get_time64_nanosecond_value(&array, 0);
594 assert_eq!(
595 value,
596 Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap()
597 );
598 }
599}