1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::NaiveTime;
11use chrono::{NaiveDate, NaiveDateTime};
12#[cfg(feature = "datafusion")]
13use datafusion::arrow::{array::*, datatypes::*};
14use pgwire::api::results::{DataRowEncoder, FieldInfo};
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::format::FormatOptions;
17use pgwire::types::ToSqlText;
18use postgres_types::{ToSql, Type};
19use rust_decimal::Decimal;
20use timezone::Tz;
21
22use crate::error::ToSqlError;
23use crate::list_encoder::encode_list;
24use crate::struct_encoder::encode_struct;
25
26pub trait Encoder {
27 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
28 where
29 T: ToSql + ToSqlText + Sized;
30}
31
32impl Encoder for DataRowEncoder {
33 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
34 where
35 T: ToSql + ToSqlText + Sized,
36 {
37 self.encode_field_with_type_and_format(
38 value,
39 pg_field.datatype(),
40 pg_field.format(),
41 pg_field.format_options(),
42 )
43 }
44}
45
46pub(crate) struct EncodedValue {
47 pub(crate) bytes: BytesMut,
48}
49
50impl std::fmt::Debug for EncodedValue {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("EncodedValue").finish()
53 }
54}
55
56impl ToSql for EncodedValue {
57 fn to_sql(
58 &self,
59 _ty: &Type,
60 out: &mut BytesMut,
61 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
62 where
63 Self: Sized,
64 {
65 out.writer().write_all(&self.bytes)?;
66 Ok(postgres_types::IsNull::No)
67 }
68
69 fn accepts(_ty: &Type) -> bool
70 where
71 Self: Sized,
72 {
73 true
74 }
75
76 fn to_sql_checked(
77 &self,
78 ty: &Type,
79 out: &mut BytesMut,
80 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
81 self.to_sql(ty, out)
82 }
83}
84
85impl ToSqlText for EncodedValue {
86 fn to_sql_text(
87 &self,
88 _ty: &Type,
89 out: &mut BytesMut,
90 _format_options: &FormatOptions,
91 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
92 where
93 Self: Sized,
94 {
95 out.writer().write_all(&self.bytes)?;
96 Ok(postgres_types::IsNull::No)
97 }
98}
99
100fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
101 (!arr.is_null(idx)).then(|| {
102 arr.as_any()
103 .downcast_ref::<BooleanArray>()
104 .unwrap()
105 .value(idx)
106 })
107}
108
109macro_rules! get_primitive_value {
110 ($name:ident, $t:ty, $pt:ty) => {
111 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
112 (!arr.is_null(idx)).then(|| {
113 arr.as_any()
114 .downcast_ref::<PrimitiveArray<$t>>()
115 .unwrap()
116 .value(idx)
117 })
118 }
119 };
120}
121
122get_primitive_value!(get_i8_value, Int8Type, i8);
123get_primitive_value!(get_i16_value, Int16Type, i16);
124get_primitive_value!(get_i32_value, Int32Type, i32);
125get_primitive_value!(get_i64_value, Int64Type, i64);
126get_primitive_value!(get_u8_value, UInt8Type, u8);
127get_primitive_value!(get_u16_value, UInt16Type, u16);
128get_primitive_value!(get_u32_value, UInt32Type, u32);
129get_primitive_value!(get_u64_value, UInt64Type, u64);
130get_primitive_value!(get_f32_value, Float32Type, f32);
131get_primitive_value!(get_f64_value, Float64Type, f64);
132
133fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
134 (!arr.is_null(idx)).then(|| {
135 arr.as_any()
136 .downcast_ref::<StringViewArray>()
137 .unwrap()
138 .value(idx)
139 })
140}
141
142fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
143 (!arr.is_null(idx)).then(|| {
144 arr.as_any()
145 .downcast_ref::<BinaryViewArray>()
146 .unwrap()
147 .value(idx)
148 })
149}
150
151fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
152 (!arr.is_null(idx)).then(|| {
153 arr.as_any()
154 .downcast_ref::<StringArray>()
155 .unwrap()
156 .value(idx)
157 })
158}
159
160fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
161 (!arr.is_null(idx)).then(|| {
162 arr.as_any()
163 .downcast_ref::<LargeStringArray>()
164 .unwrap()
165 .value(idx)
166 })
167}
168
169fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
170 (!arr.is_null(idx)).then(|| {
171 arr.as_any()
172 .downcast_ref::<BinaryArray>()
173 .unwrap()
174 .value(idx)
175 })
176}
177
178fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
179 (!arr.is_null(idx)).then(|| {
180 arr.as_any()
181 .downcast_ref::<LargeBinaryArray>()
182 .unwrap()
183 .value(idx)
184 })
185}
186
187fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
188 if arr.is_null(idx) {
189 return None;
190 }
191 arr.as_any()
192 .downcast_ref::<Date32Array>()
193 .unwrap()
194 .value_as_date(idx)
195}
196
197fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
198 if arr.is_null(idx) {
199 return None;
200 }
201 arr.as_any()
202 .downcast_ref::<Date64Array>()
203 .unwrap()
204 .value_as_date(idx)
205}
206
207fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
208 if arr.is_null(idx) {
209 return None;
210 }
211 arr.as_any()
212 .downcast_ref::<Time32SecondArray>()
213 .unwrap()
214 .value_as_time(idx)
215}
216
217fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
218 if arr.is_null(idx) {
219 return None;
220 }
221 arr.as_any()
222 .downcast_ref::<Time32MillisecondArray>()
223 .unwrap()
224 .value_as_time(idx)
225}
226
227fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
228 if arr.is_null(idx) {
229 return None;
230 }
231 arr.as_any()
232 .downcast_ref::<Time64MicrosecondArray>()
233 .unwrap()
234 .value_as_time(idx)
235}
236fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveTime> {
237 if arr.is_null(idx) {
238 return None;
239 }
240 arr.as_any()
241 .downcast_ref::<Time64NanosecondArray>()
242 .unwrap()
243 .value_as_time(idx)
244}
245
246fn get_numeric_128_value(
247 arr: &Arc<dyn Array>,
248 idx: usize,
249 scale: u32,
250) -> PgWireResult<Option<Decimal>> {
251 if arr.is_null(idx) {
252 return Ok(None);
253 }
254
255 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
256 let value = array.value(idx);
257 Decimal::try_from_i128_with_scale(value, scale)
258 .map_err(|e| {
259 let error_code = match e {
260 rust_decimal::Error::ExceedsMaximumPossibleValue => {
261 "22003" }
263 rust_decimal::Error::LessThanMinimumPossibleValue => {
264 "22003" }
266 rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
267 return PgWireError::UserError(Box::new(ErrorInfo::new(
268 "ERROR".to_string(),
269 "22003".to_string(),
270 format!("Scale {scale} exceeds maximum precision for numeric type"),
271 )));
272 }
273 _ => "22003", };
275 PgWireError::UserError(Box::new(ErrorInfo::new(
276 "ERROR".to_string(),
277 error_code.to_string(),
278 format!("Numeric value conversion failed: {e}"),
279 )))
280 })
281 .map(Some)
282}
283
284pub fn encode_value<T: Encoder>(
285 encoder: &mut T,
286 arr: &Arc<dyn Array>,
287 idx: usize,
288 pg_field: &FieldInfo,
289) -> PgWireResult<()> {
290 let type_ = pg_field.datatype();
291
292 match arr.data_type() {
293 DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
294 DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
295 DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
296 DataType::Int16 => encoder.encode_field(&get_i16_value(arr, idx), pg_field)?,
297 DataType::Int32 => encoder.encode_field(&get_i32_value(arr, idx), pg_field)?,
298 DataType::Int64 => encoder.encode_field(&get_i64_value(arr, idx), pg_field)?,
299 DataType::UInt8 => {
300 encoder.encode_field(&(get_u8_value(arr, idx).map(|x| x as i8)), pg_field)?
301 }
302 DataType::UInt16 => {
303 encoder.encode_field(&(get_u16_value(arr, idx).map(|x| x as i16)), pg_field)?
304 }
305 DataType::UInt32 => encoder.encode_field(&get_u32_value(arr, idx), pg_field)?,
306 DataType::UInt64 => {
307 encoder.encode_field(&(get_u64_value(arr, idx).map(|x| x as i64)), pg_field)?
308 }
309 DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx), pg_field)?,
310 DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx), pg_field)?,
311 DataType::Decimal128(_, s) => {
312 encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?, pg_field)?
313 }
314 DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx), pg_field)?,
315 DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx), pg_field)?,
316 DataType::BinaryView => encoder.encode_field(&get_binary_view_value(arr, idx), pg_field)?,
317 DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx), pg_field)?,
318 DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx), pg_field)?,
319 DataType::LargeBinary => {
320 encoder.encode_field(&get_large_binary_value(arr, idx), pg_field)?
321 }
322 DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx), pg_field)?,
323 DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx), pg_field)?,
324 DataType::Time32(unit) => match unit {
325 TimeUnit::Second => {
326 encoder.encode_field(&get_time32_second_value(arr, idx), pg_field)?
327 }
328 TimeUnit::Millisecond => {
329 encoder.encode_field(&get_time32_millisecond_value(arr, idx), pg_field)?
330 }
331 _ => {}
332 },
333 DataType::Time64(unit) => match unit {
334 TimeUnit::Microsecond => {
335 encoder.encode_field(&get_time64_microsecond_value(arr, idx), pg_field)?
336 }
337 TimeUnit::Nanosecond => {
338 encoder.encode_field(&get_time64_nanosecond_value(arr, idx), pg_field)?
339 }
340 _ => {}
341 },
342 DataType::Timestamp(unit, timezone) => match unit {
343 TimeUnit::Second => {
344 if arr.is_null(idx) {
345 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
346 }
347 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
348 if let Some(tz) = timezone {
349 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
350 let value = ts_array
351 .value_as_datetime_with_tz(idx, tz)
352 .map(|d| d.fixed_offset());
353
354 encoder.encode_field(&value, pg_field)?;
355 } else {
356 let value = ts_array.value_as_datetime(idx);
357 encoder.encode_field(&value, pg_field)?;
358 }
359 }
360 TimeUnit::Millisecond => {
361 if arr.is_null(idx) {
362 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
363 }
364 let ts_array = arr
365 .as_any()
366 .downcast_ref::<TimestampMillisecondArray>()
367 .unwrap();
368 if let Some(tz) = timezone {
369 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
370 let value = ts_array
371 .value_as_datetime_with_tz(idx, tz)
372 .map(|d| d.fixed_offset());
373 encoder.encode_field(&value, pg_field)?;
374 } else {
375 let value = ts_array.value_as_datetime(idx);
376 encoder.encode_field(&value, pg_field)?;
377 }
378 }
379 TimeUnit::Microsecond => {
380 if arr.is_null(idx) {
381 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
382 }
383 let ts_array = arr
384 .as_any()
385 .downcast_ref::<TimestampMicrosecondArray>()
386 .unwrap();
387 if let Some(tz) = timezone {
388 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
389 let value = ts_array
390 .value_as_datetime_with_tz(idx, tz)
391 .map(|d| d.fixed_offset());
392 encoder.encode_field(&value, pg_field)?;
393 } else {
394 let value = ts_array.value_as_datetime(idx);
395 encoder.encode_field(&value, pg_field)?;
396 }
397 }
398 TimeUnit::Nanosecond => {
399 if arr.is_null(idx) {
400 return encoder.encode_field(&None::<NaiveDateTime>, pg_field);
401 }
402 let ts_array = arr
403 .as_any()
404 .downcast_ref::<TimestampNanosecondArray>()
405 .unwrap();
406 if let Some(tz) = timezone {
407 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
408 let value = ts_array
409 .value_as_datetime_with_tz(idx, tz)
410 .map(|d| d.fixed_offset());
411 encoder.encode_field(&value, pg_field)?;
412 } else {
413 let value = ts_array.value_as_datetime(idx);
414 encoder.encode_field(&value, pg_field)?;
415 }
416 }
417 },
418 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
419 if arr.is_null(idx) {
420 return encoder.encode_field(&None::<&[i8]>, pg_field);
421 }
422 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
423 let value = encode_list(array, pg_field)?;
424 encoder.encode_field(&value, pg_field)?
425 }
426 DataType::Struct(_) => {
427 let fields = match type_.kind() {
428 postgres_types::Kind::Composite(fields) => fields,
429 _ => {
430 return Err(PgWireError::ApiError(ToSqlError::from(format!(
431 "Failed to unwrap a composite type from type {type_}"
432 ))));
433 }
434 };
435 let value = encode_struct(arr, idx, fields, pg_field)?;
436 encoder.encode_field(&value, pg_field)?
437 }
438 DataType::Dictionary(_, value_type) => {
439 if arr.is_null(idx) {
440 return encoder.encode_field(&None::<i8>, pg_field);
441 }
442 macro_rules! get_dict_values_and_index {
444 ($key_type:ty) => {
445 arr.as_any()
446 .downcast_ref::<DictionaryArray<$key_type>>()
447 .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
448 };
449 }
450
451 let (values, idx) = get_dict_values_and_index!(Int8Type)
453 .or_else(|| get_dict_values_and_index!(Int16Type))
454 .or_else(|| get_dict_values_and_index!(Int32Type))
455 .or_else(|| get_dict_values_and_index!(Int64Type))
456 .or_else(|| get_dict_values_and_index!(UInt8Type))
457 .or_else(|| get_dict_values_and_index!(UInt16Type))
458 .or_else(|| get_dict_values_and_index!(UInt32Type))
459 .or_else(|| get_dict_values_and_index!(UInt64Type))
460 .ok_or_else(|| {
461 ToSqlError::from(format!(
462 "Unsupported dictionary key type for value type {value_type}"
463 ))
464 })?;
465
466 encode_value(encoder, values, idx, pg_field)?
467 }
468 _ => {
469 return Err(PgWireError::ApiError(ToSqlError::from(format!(
470 "Unsupported Datatype {} and array {:?}",
471 arr.data_type(),
472 &arr
473 ))));
474 }
475 }
476
477 Ok(())
478}
479
480#[cfg(test)]
481mod tests {
482 use pgwire::api::results::FieldFormat;
483
484 use super::*;
485
486 #[test]
487 fn encodes_dictionary_array() {
488 #[derive(Default)]
489 struct MockEncoder {
490 encoded_value: String,
491 }
492
493 impl Encoder for MockEncoder {
494 fn encode_field<T>(&mut self, value: &T, pg_field: &FieldInfo) -> PgWireResult<()>
495 where
496 T: ToSql + ToSqlText + Sized,
497 {
498 let mut bytes = BytesMut::new();
499 let _sql_text =
500 value.to_sql_text(pg_field.datatype(), &mut bytes, &FormatOptions::default());
501 let string = String::from_utf8(bytes.to_vec());
502 self.encoded_value = string.unwrap();
503 Ok(())
504 }
505 }
506
507 let val = "~!@&$[]()@@!!";
508 let value = StringArray::from_iter_values([val]);
509 let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
510 let dict_arr: Arc<dyn Array> =
511 Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
512
513 let mut encoder = MockEncoder::default();
514
515 let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text);
516 let result = encode_value(&mut encoder, &dict_arr, 2, &pg_field);
517
518 assert!(result.is_ok());
519
520 assert!(encoder.encoded_value == val);
521 }
522
523 #[test]
524 fn test_get_time32_second_value() {
525 let array = Time32SecondArray::from_iter_values([3723_i32]);
526 let array: Arc<dyn Array> = Arc::new(array);
527 let value = get_time32_second_value(&array, 0);
528 assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap());
529 }
530
531 #[test]
532 fn test_get_time32_millisecond_value() {
533 let array = Time32MillisecondArray::from_iter_values([3723001_i32]);
534 let array: Arc<dyn Array> = Arc::new(array);
535 let value = get_time32_millisecond_value(&array, 0);
536 assert_eq!(
537 value,
538 Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap()
539 );
540 }
541
542 #[test]
543 fn test_get_time64_microsecond_value() {
544 let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]);
545 let array: Arc<dyn Array> = Arc::new(array);
546 let value = get_time64_microsecond_value(&array, 0);
547 assert_eq!(
548 value,
549 Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap()
550 );
551 }
552
553 #[test]
554 fn test_get_time64_nanosecond_value() {
555 let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]);
556 let array: Arc<dyn Array> = Arc::new(array);
557 let value = get_time64_nanosecond_value(&array, 0);
558 assert_eq!(
559 value,
560 Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap()
561 );
562 }
563}