1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::ToSqlText;
17use postgres_types::{ToSql, Type};
18use rust_decimal::Decimal;
19use timezone::Tz;
20
21use crate::error::ToSqlError;
22use crate::list_encoder::encode_list;
23use crate::struct_encoder::encode_struct;
24
25pub trait Encoder {
26 fn encode_field_with_type_and_format<T>(
27 &mut self,
28 value: &T,
29 data_type: &Type,
30 format: FieldFormat,
31 ) -> PgWireResult<()>
32 where
33 T: ToSql + ToSqlText + Sized;
34}
35
36impl Encoder for DataRowEncoder {
37 fn encode_field_with_type_and_format<T>(
38 &mut self,
39 value: &T,
40 data_type: &Type,
41 format: FieldFormat,
42 ) -> PgWireResult<()>
43 where
44 T: ToSql + ToSqlText + Sized,
45 {
46 self.encode_field_with_type_and_format(value, data_type, format)
47 }
48}
49
50pub(crate) struct EncodedValue {
51 pub(crate) bytes: BytesMut,
52}
53
54impl std::fmt::Debug for EncodedValue {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("EncodedValue").finish()
57 }
58}
59
60impl ToSql for EncodedValue {
61 fn to_sql(
62 &self,
63 _ty: &Type,
64 out: &mut BytesMut,
65 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
66 where
67 Self: Sized,
68 {
69 out.writer().write_all(&self.bytes)?;
70 Ok(postgres_types::IsNull::No)
71 }
72
73 fn accepts(_ty: &Type) -> bool
74 where
75 Self: Sized,
76 {
77 true
78 }
79
80 fn to_sql_checked(
81 &self,
82 ty: &Type,
83 out: &mut BytesMut,
84 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
85 self.to_sql(ty, out)
86 }
87}
88
89impl ToSqlText for EncodedValue {
90 fn to_sql_text(
91 &self,
92 _ty: &Type,
93 out: &mut BytesMut,
94 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
95 where
96 Self: Sized,
97 {
98 out.writer().write_all(&self.bytes)?;
99 Ok(postgres_types::IsNull::No)
100 }
101}
102
103fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
104 (!arr.is_null(idx)).then(|| {
105 arr.as_any()
106 .downcast_ref::<BooleanArray>()
107 .unwrap()
108 .value(idx)
109 })
110}
111
112macro_rules! get_primitive_value {
113 ($name:ident, $t:ty, $pt:ty) => {
114 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
115 (!arr.is_null(idx)).then(|| {
116 arr.as_any()
117 .downcast_ref::<PrimitiveArray<$t>>()
118 .unwrap()
119 .value(idx)
120 })
121 }
122 };
123}
124
125get_primitive_value!(get_i8_value, Int8Type, i8);
126get_primitive_value!(get_i16_value, Int16Type, i16);
127get_primitive_value!(get_i32_value, Int32Type, i32);
128get_primitive_value!(get_i64_value, Int64Type, i64);
129get_primitive_value!(get_u8_value, UInt8Type, u8);
130get_primitive_value!(get_u16_value, UInt16Type, u16);
131get_primitive_value!(get_u32_value, UInt32Type, u32);
132get_primitive_value!(get_u64_value, UInt64Type, u64);
133get_primitive_value!(get_f32_value, Float32Type, f32);
134get_primitive_value!(get_f64_value, Float64Type, f64);
135
136fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
137 (!arr.is_null(idx)).then(|| {
138 arr.as_any()
139 .downcast_ref::<StringViewArray>()
140 .unwrap()
141 .value(idx)
142 })
143}
144
145fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
146 (!arr.is_null(idx)).then(|| {
147 arr.as_any()
148 .downcast_ref::<BinaryViewArray>()
149 .unwrap()
150 .value(idx)
151 })
152}
153
154fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
155 (!arr.is_null(idx)).then(|| {
156 arr.as_any()
157 .downcast_ref::<StringArray>()
158 .unwrap()
159 .value(idx)
160 })
161}
162
163fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
164 (!arr.is_null(idx)).then(|| {
165 arr.as_any()
166 .downcast_ref::<LargeStringArray>()
167 .unwrap()
168 .value(idx)
169 })
170}
171
172fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
173 (!arr.is_null(idx)).then(|| {
174 arr.as_any()
175 .downcast_ref::<BinaryArray>()
176 .unwrap()
177 .value(idx)
178 })
179}
180
181fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
182 (!arr.is_null(idx)).then(|| {
183 arr.as_any()
184 .downcast_ref::<LargeBinaryArray>()
185 .unwrap()
186 .value(idx)
187 })
188}
189
190fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
191 if arr.is_null(idx) {
192 return None;
193 }
194 arr.as_any()
195 .downcast_ref::<Date32Array>()
196 .unwrap()
197 .value_as_date(idx)
198}
199
200fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
201 if arr.is_null(idx) {
202 return None;
203 }
204 arr.as_any()
205 .downcast_ref::<Date64Array>()
206 .unwrap()
207 .value_as_date(idx)
208}
209
210fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
211 if arr.is_null(idx) {
212 return None;
213 }
214 arr.as_any()
215 .downcast_ref::<Time32SecondArray>()
216 .unwrap()
217 .value_as_datetime(idx)
218}
219
220fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
221 if arr.is_null(idx) {
222 return None;
223 }
224 arr.as_any()
225 .downcast_ref::<Time32MillisecondArray>()
226 .unwrap()
227 .value_as_datetime(idx)
228}
229
230fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
231 if arr.is_null(idx) {
232 return None;
233 }
234 arr.as_any()
235 .downcast_ref::<Time64MicrosecondArray>()
236 .unwrap()
237 .value_as_datetime(idx)
238}
239fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
240 if arr.is_null(idx) {
241 return None;
242 }
243 arr.as_any()
244 .downcast_ref::<Time64NanosecondArray>()
245 .unwrap()
246 .value_as_datetime(idx)
247}
248
249fn get_numeric_128_value(
250 arr: &Arc<dyn Array>,
251 idx: usize,
252 scale: u32,
253) -> PgWireResult<Option<Decimal>> {
254 if arr.is_null(idx) {
255 return Ok(None);
256 }
257
258 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
259 let value = array.value(idx);
260 Decimal::try_from_i128_with_scale(value, scale)
261 .map_err(|e| {
262 let error_code = match e {
263 rust_decimal::Error::ExceedsMaximumPossibleValue => {
264 "22003" }
266 rust_decimal::Error::LessThanMinimumPossibleValue => {
267 "22003" }
269 rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
270 return PgWireError::UserError(Box::new(ErrorInfo::new(
271 "ERROR".to_string(),
272 "22003".to_string(),
273 format!("Scale {scale} exceeds maximum precision for numeric type"),
274 )));
275 }
276 _ => "22003", };
278 PgWireError::UserError(Box::new(ErrorInfo::new(
279 "ERROR".to_string(),
280 error_code.to_string(),
281 format!("Numeric value conversion failed: {e}"),
282 )))
283 })
284 .map(Some)
285}
286
287pub fn encode_value<T: Encoder>(
288 encoder: &mut T,
289 arr: &Arc<dyn Array>,
290 idx: usize,
291 type_: &Type,
292 format: FieldFormat,
293) -> PgWireResult<()> {
294 match arr.data_type() {
295 DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
296 DataType::Boolean => {
297 encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
298 }
299 DataType::Int8 => {
300 encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
301 }
302 DataType::Int16 => {
303 encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
304 }
305 DataType::Int32 => {
306 encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
307 }
308 DataType::Int64 => {
309 encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
310 }
311 DataType::UInt8 => encoder.encode_field_with_type_and_format(
312 &(get_u8_value(arr, idx).map(|x| x as i8)),
313 type_,
314 format,
315 )?,
316 DataType::UInt16 => encoder.encode_field_with_type_and_format(
317 &(get_u16_value(arr, idx).map(|x| x as i16)),
318 type_,
319 format,
320 )?,
321 DataType::UInt32 => {
322 encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
323 }
324 DataType::UInt64 => encoder.encode_field_with_type_and_format(
325 &(get_u64_value(arr, idx).map(|x| x as i64)),
326 type_,
327 format,
328 )?,
329 DataType::Float32 => {
330 encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
331 }
332 DataType::Float64 => {
333 encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
334 }
335 DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
336 &get_numeric_128_value(arr, idx, *s as u32)?,
337 type_,
338 format,
339 )?,
340 DataType::Utf8 => {
341 encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
342 }
343 DataType::Utf8View => encoder.encode_field_with_type_and_format(
344 &get_utf8_view_value(arr, idx),
345 type_,
346 format,
347 )?,
348 DataType::BinaryView => encoder.encode_field_with_type_and_format(
349 &get_binary_view_value(arr, idx),
350 type_,
351 format,
352 )?,
353 DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
354 &get_large_utf8_value(arr, idx),
355 type_,
356 format,
357 )?,
358 DataType::Binary => {
359 encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
360 }
361 DataType::LargeBinary => encoder.encode_field_with_type_and_format(
362 &get_large_binary_value(arr, idx),
363 type_,
364 format,
365 )?,
366 DataType::Date32 => {
367 encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
368 }
369 DataType::Date64 => {
370 encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
371 }
372 DataType::Time32(unit) => match unit {
373 TimeUnit::Second => encoder.encode_field_with_type_and_format(
374 &get_time32_second_value(arr, idx),
375 type_,
376 format,
377 )?,
378 TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
379 &get_time32_millisecond_value(arr, idx),
380 type_,
381 format,
382 )?,
383 _ => {}
384 },
385 DataType::Time64(unit) => match unit {
386 TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
387 &get_time64_microsecond_value(arr, idx),
388 type_,
389 format,
390 )?,
391 TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
392 &get_time64_nanosecond_value(arr, idx),
393 type_,
394 format,
395 )?,
396 _ => {}
397 },
398 DataType::Timestamp(unit, timezone) => match unit {
399 TimeUnit::Second => {
400 if arr.is_null(idx) {
401 return encoder.encode_field_with_type_and_format(
402 &None::<NaiveDateTime>,
403 type_,
404 format,
405 );
406 }
407 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
408 if let Some(tz) = timezone {
409 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
410 let value = ts_array
411 .value_as_datetime_with_tz(idx, tz)
412 .map(|d| d.fixed_offset());
413 encoder.encode_field_with_type_and_format(&value, type_, format)?;
414 } else {
415 let value = ts_array.value_as_datetime(idx);
416 encoder.encode_field_with_type_and_format(&value, type_, format)?;
417 }
418 }
419 TimeUnit::Millisecond => {
420 if arr.is_null(idx) {
421 return encoder.encode_field_with_type_and_format(
422 &None::<NaiveDateTime>,
423 type_,
424 format,
425 );
426 }
427 let ts_array = arr
428 .as_any()
429 .downcast_ref::<TimestampMillisecondArray>()
430 .unwrap();
431 if let Some(tz) = timezone {
432 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
433 let value = ts_array
434 .value_as_datetime_with_tz(idx, tz)
435 .map(|d| d.fixed_offset());
436 encoder.encode_field_with_type_and_format(&value, type_, format)?;
437 } else {
438 let value = ts_array.value_as_datetime(idx);
439 encoder.encode_field_with_type_and_format(&value, type_, format)?;
440 }
441 }
442 TimeUnit::Microsecond => {
443 if arr.is_null(idx) {
444 return encoder.encode_field_with_type_and_format(
445 &None::<NaiveDateTime>,
446 type_,
447 format,
448 );
449 }
450 let ts_array = arr
451 .as_any()
452 .downcast_ref::<TimestampMicrosecondArray>()
453 .unwrap();
454 if let Some(tz) = timezone {
455 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
456 let value = ts_array
457 .value_as_datetime_with_tz(idx, tz)
458 .map(|d| d.fixed_offset());
459 encoder.encode_field_with_type_and_format(&value, type_, format)?;
460 } else {
461 let value = ts_array.value_as_datetime(idx);
462 encoder.encode_field_with_type_and_format(&value, type_, format)?;
463 }
464 }
465 TimeUnit::Nanosecond => {
466 if arr.is_null(idx) {
467 return encoder.encode_field_with_type_and_format(
468 &None::<NaiveDateTime>,
469 type_,
470 format,
471 );
472 }
473 let ts_array = arr
474 .as_any()
475 .downcast_ref::<TimestampNanosecondArray>()
476 .unwrap();
477 if let Some(tz) = timezone {
478 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
479 let value = ts_array
480 .value_as_datetime_with_tz(idx, tz)
481 .map(|d| d.fixed_offset());
482 encoder.encode_field_with_type_and_format(&value, type_, format)?;
483 } else {
484 let value = ts_array.value_as_datetime(idx);
485 encoder.encode_field_with_type_and_format(&value, type_, format)?;
486 }
487 }
488 },
489 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
490 if arr.is_null(idx) {
491 return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
492 }
493 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
494 let value = encode_list(array, type_, format)?;
495 encoder.encode_field_with_type_and_format(&value, type_, format)?
496 }
497 DataType::Struct(_) => {
498 let fields = match type_.kind() {
499 postgres_types::Kind::Composite(fields) => fields,
500 _ => {
501 return Err(PgWireError::ApiError(ToSqlError::from(format!(
502 "Failed to unwrap a composite type from type {type_}"
503 ))));
504 }
505 };
506 let value = encode_struct(arr, idx, fields, format)?;
507 encoder.encode_field_with_type_and_format(&value, type_, format)?
508 }
509 DataType::Dictionary(_, value_type) => {
510 if arr.is_null(idx) {
511 return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
512 }
513 macro_rules! get_dict_values_and_index {
515 ($key_type:ty) => {
516 arr.as_any()
517 .downcast_ref::<DictionaryArray<$key_type>>()
518 .map(|dict| (dict.values(), dict.keys().value(idx) as usize))
519 };
520 }
521
522 let (values, idx) = get_dict_values_and_index!(Int8Type)
524 .or_else(|| get_dict_values_and_index!(Int16Type))
525 .or_else(|| get_dict_values_and_index!(Int32Type))
526 .or_else(|| get_dict_values_and_index!(Int64Type))
527 .or_else(|| get_dict_values_and_index!(UInt8Type))
528 .or_else(|| get_dict_values_and_index!(UInt16Type))
529 .or_else(|| get_dict_values_and_index!(UInt32Type))
530 .or_else(|| get_dict_values_and_index!(UInt64Type))
531 .ok_or_else(|| {
532 ToSqlError::from(format!(
533 "Unsupported dictionary key type for value type {value_type}"
534 ))
535 })?;
536
537 encode_value(encoder, values, idx, type_, format)?
538 }
539 _ => {
540 return Err(PgWireError::ApiError(ToSqlError::from(format!(
541 "Unsupported Datatype {} and array {:?}",
542 arr.data_type(),
543 &arr
544 ))));
545 }
546 }
547
548 Ok(())
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn encodes_dictionary_array() {
557 #[derive(Default)]
558 struct MockEncoder {
559 encoded_value: String,
560 }
561
562 impl Encoder for MockEncoder {
563 fn encode_field_with_type_and_format<T>(
564 &mut self,
565 value: &T,
566 data_type: &Type,
567 _format: FieldFormat,
568 ) -> PgWireResult<()>
569 where
570 T: ToSql + ToSqlText + Sized,
571 {
572 let mut bytes = BytesMut::new();
573 let _sql_text = value.to_sql_text(data_type, &mut bytes);
574 let string = String::from_utf8(bytes.to_vec());
575 self.encoded_value = string.unwrap();
576 Ok(())
577 }
578 }
579
580 let val = "~!@&$[]()@@!!";
581 let value = StringArray::from_iter_values([val]);
582 let keys = Int8Array::from_iter_values([0, 0, 0, 0]);
583 let dict_arr: Arc<dyn Array> =
584 Arc::new(DictionaryArray::<Int8Type>::try_new(keys, Arc::new(value)).unwrap());
585
586 let mut encoder = MockEncoder::default();
587
588 let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text);
589
590 assert!(result.is_ok());
591
592 assert!(encoder.encoded_value == val);
593 }
594}