1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6#[cfg(not(feature = "datafusion"))]
7use arrow::{array::*, datatypes::*};
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11#[cfg(feature = "datafusion")]
12use datafusion::arrow::{array::*, datatypes::*};
13use pgwire::api::results::DataRowEncoder;
14use pgwire::api::results::FieldFormat;
15use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16use pgwire::types::ToSqlText;
17use postgres_types::{ToSql, Type};
18use rust_decimal::Decimal;
19use timezone::Tz;
20
21use crate::error::ToSqlError;
22use crate::list_encoder::encode_list;
23use crate::struct_encoder::encode_struct;
24
25pub trait Encoder {
26 fn encode_field_with_type_and_format<T>(
27 &mut self,
28 value: &T,
29 data_type: &Type,
30 format: FieldFormat,
31 ) -> PgWireResult<()>
32 where
33 T: ToSql + ToSqlText + Sized;
34}
35
36impl Encoder for DataRowEncoder {
37 fn encode_field_with_type_and_format<T>(
38 &mut self,
39 value: &T,
40 data_type: &Type,
41 format: FieldFormat,
42 ) -> PgWireResult<()>
43 where
44 T: ToSql + ToSqlText + Sized,
45 {
46 self.encode_field_with_type_and_format(value, data_type, format)
47 }
48}
49
50pub(crate) struct EncodedValue {
51 pub(crate) bytes: BytesMut,
52}
53
54impl std::fmt::Debug for EncodedValue {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("EncodedValue").finish()
57 }
58}
59
60impl ToSql for EncodedValue {
61 fn to_sql(
62 &self,
63 _ty: &Type,
64 out: &mut BytesMut,
65 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
66 where
67 Self: Sized,
68 {
69 out.writer().write_all(&self.bytes)?;
70 Ok(postgres_types::IsNull::No)
71 }
72
73 fn accepts(_ty: &Type) -> bool
74 where
75 Self: Sized,
76 {
77 true
78 }
79
80 fn to_sql_checked(
81 &self,
82 ty: &Type,
83 out: &mut BytesMut,
84 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
85 self.to_sql(ty, out)
86 }
87}
88
89impl ToSqlText for EncodedValue {
90 fn to_sql_text(
91 &self,
92 _ty: &Type,
93 out: &mut BytesMut,
94 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
95 where
96 Self: Sized,
97 {
98 out.writer().write_all(&self.bytes)?;
99 Ok(postgres_types::IsNull::No)
100 }
101}
102
103fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
104 (!arr.is_null(idx)).then(|| {
105 arr.as_any()
106 .downcast_ref::<BooleanArray>()
107 .unwrap()
108 .value(idx)
109 })
110}
111
112macro_rules! get_primitive_value {
113 ($name:ident, $t:ty, $pt:ty) => {
114 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
115 (!arr.is_null(idx)).then(|| {
116 arr.as_any()
117 .downcast_ref::<PrimitiveArray<$t>>()
118 .unwrap()
119 .value(idx)
120 })
121 }
122 };
123}
124
125get_primitive_value!(get_i8_value, Int8Type, i8);
126get_primitive_value!(get_i16_value, Int16Type, i16);
127get_primitive_value!(get_i32_value, Int32Type, i32);
128get_primitive_value!(get_i64_value, Int64Type, i64);
129get_primitive_value!(get_u8_value, UInt8Type, u8);
130get_primitive_value!(get_u16_value, UInt16Type, u16);
131get_primitive_value!(get_u32_value, UInt32Type, u32);
132get_primitive_value!(get_u64_value, UInt64Type, u64);
133get_primitive_value!(get_f32_value, Float32Type, f32);
134get_primitive_value!(get_f64_value, Float64Type, f64);
135
136fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
137 (!arr.is_null(idx)).then(|| {
138 arr.as_any()
139 .downcast_ref::<StringViewArray>()
140 .unwrap()
141 .value(idx)
142 })
143}
144
145fn get_binary_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
146 (!arr.is_null(idx)).then(|| {
147 arr.as_any()
148 .downcast_ref::<BinaryViewArray>()
149 .unwrap()
150 .value(idx)
151 })
152}
153
154fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
155 (!arr.is_null(idx)).then(|| {
156 arr.as_any()
157 .downcast_ref::<StringArray>()
158 .unwrap()
159 .value(idx)
160 })
161}
162
163fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
164 (!arr.is_null(idx)).then(|| {
165 arr.as_any()
166 .downcast_ref::<LargeStringArray>()
167 .unwrap()
168 .value(idx)
169 })
170}
171
172fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<String> {
173 (!arr.is_null(idx)).then(|| {
174 String::from_utf8_lossy(
175 arr.as_any()
176 .downcast_ref::<BinaryArray>()
177 .unwrap()
178 .value(idx),
179 )
180 .to_string()
181 })
182}
183
184fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
185 (!arr.is_null(idx)).then(|| {
186 arr.as_any()
187 .downcast_ref::<LargeBinaryArray>()
188 .unwrap()
189 .value(idx)
190 })
191}
192
193fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
194 if arr.is_null(idx) {
195 return None;
196 }
197 arr.as_any()
198 .downcast_ref::<Date32Array>()
199 .unwrap()
200 .value_as_date(idx)
201}
202
203fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
204 if arr.is_null(idx) {
205 return None;
206 }
207 arr.as_any()
208 .downcast_ref::<Date64Array>()
209 .unwrap()
210 .value_as_date(idx)
211}
212
213fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
214 if arr.is_null(idx) {
215 return None;
216 }
217 arr.as_any()
218 .downcast_ref::<Time32SecondArray>()
219 .unwrap()
220 .value_as_datetime(idx)
221}
222
223fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
224 if arr.is_null(idx) {
225 return None;
226 }
227 arr.as_any()
228 .downcast_ref::<Time32MillisecondArray>()
229 .unwrap()
230 .value_as_datetime(idx)
231}
232
233fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
234 if arr.is_null(idx) {
235 return None;
236 }
237 arr.as_any()
238 .downcast_ref::<Time64MicrosecondArray>()
239 .unwrap()
240 .value_as_datetime(idx)
241}
242fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
243 if arr.is_null(idx) {
244 return None;
245 }
246 arr.as_any()
247 .downcast_ref::<Time64NanosecondArray>()
248 .unwrap()
249 .value_as_datetime(idx)
250}
251
252fn get_numeric_128_value(
253 arr: &Arc<dyn Array>,
254 idx: usize,
255 scale: u32,
256) -> PgWireResult<Option<Decimal>> {
257 if arr.is_null(idx) {
258 return Ok(None);
259 }
260
261 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
262 let value = array.value(idx);
263 Decimal::try_from_i128_with_scale(value, scale)
264 .map_err(|e| {
265 let error_code = match e {
266 rust_decimal::Error::ExceedsMaximumPossibleValue => {
267 "22003" }
269 rust_decimal::Error::LessThanMinimumPossibleValue => {
270 "22003" }
272 rust_decimal::Error::ScaleExceedsMaximumPrecision(scale) => {
273 return PgWireError::UserError(Box::new(ErrorInfo::new(
274 "ERROR".to_string(),
275 "22003".to_string(),
276 format!("Scale {scale} exceeds maximum precision for numeric type"),
277 )));
278 }
279 _ => "22003", };
281 PgWireError::UserError(Box::new(ErrorInfo::new(
282 "ERROR".to_string(),
283 error_code.to_string(),
284 format!("Numeric value conversion failed: {e}"),
285 )))
286 })
287 .map(Some)
288}
289
290pub fn encode_value<T: Encoder>(
291 encoder: &mut T,
292 arr: &Arc<dyn Array>,
293 idx: usize,
294 type_: &Type,
295 format: FieldFormat,
296) -> PgWireResult<()> {
297 match arr.data_type() {
298 DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
299 DataType::Boolean => {
300 encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
301 }
302 DataType::Int8 => {
303 encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
304 }
305 DataType::Int16 => {
306 encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
307 }
308 DataType::Int32 => {
309 encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
310 }
311 DataType::Int64 => {
312 encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
313 }
314 DataType::UInt8 => encoder.encode_field_with_type_and_format(
315 &(get_u8_value(arr, idx).map(|x| x as i8)),
316 type_,
317 format,
318 )?,
319 DataType::UInt16 => encoder.encode_field_with_type_and_format(
320 &(get_u16_value(arr, idx).map(|x| x as i16)),
321 type_,
322 format,
323 )?,
324 DataType::UInt32 => {
325 encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
326 }
327 DataType::UInt64 => encoder.encode_field_with_type_and_format(
328 &(get_u64_value(arr, idx).map(|x| x as i64)),
329 type_,
330 format,
331 )?,
332 DataType::Float32 => {
333 encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
334 }
335 DataType::Float64 => {
336 encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
337 }
338 DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
339 &get_numeric_128_value(arr, idx, *s as u32)?,
340 type_,
341 format,
342 )?,
343 DataType::Utf8 => {
344 encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
345 }
346 DataType::Utf8View => encoder.encode_field_with_type_and_format(
347 &get_utf8_view_value(arr, idx),
348 type_,
349 format,
350 )?,
351 DataType::BinaryView => encoder.encode_field_with_type_and_format(
352 &get_binary_view_value(arr, idx),
353 type_,
354 format,
355 )?,
356 DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
357 &get_large_utf8_value(arr, idx),
358 type_,
359 format,
360 )?,
361 DataType::Binary => {
362 encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
363 }
364 DataType::LargeBinary => encoder.encode_field_with_type_and_format(
365 &get_large_binary_value(arr, idx),
366 type_,
367 format,
368 )?,
369 DataType::Date32 => {
370 encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
371 }
372 DataType::Date64 => {
373 encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
374 }
375 DataType::Time32(unit) => match unit {
376 TimeUnit::Second => encoder.encode_field_with_type_and_format(
377 &get_time32_second_value(arr, idx),
378 type_,
379 format,
380 )?,
381 TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
382 &get_time32_millisecond_value(arr, idx),
383 type_,
384 format,
385 )?,
386 _ => {}
387 },
388 DataType::Time64(unit) => match unit {
389 TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
390 &get_time64_microsecond_value(arr, idx),
391 type_,
392 format,
393 )?,
394 TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
395 &get_time64_nanosecond_value(arr, idx),
396 type_,
397 format,
398 )?,
399 _ => {}
400 },
401 DataType::Timestamp(unit, timezone) => match unit {
402 TimeUnit::Second => {
403 if arr.is_null(idx) {
404 return encoder.encode_field_with_type_and_format(
405 &None::<NaiveDateTime>,
406 type_,
407 format,
408 );
409 }
410 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
411 if let Some(tz) = timezone {
412 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
413 let value = ts_array
414 .value_as_datetime_with_tz(idx, tz)
415 .map(|d| d.fixed_offset());
416 encoder.encode_field_with_type_and_format(&value, type_, format)?;
417 } else {
418 let value = ts_array.value_as_datetime(idx);
419 encoder.encode_field_with_type_and_format(&value, type_, format)?;
420 }
421 }
422 TimeUnit::Millisecond => {
423 if arr.is_null(idx) {
424 return encoder.encode_field_with_type_and_format(
425 &None::<NaiveDateTime>,
426 type_,
427 format,
428 );
429 }
430 let ts_array = arr
431 .as_any()
432 .downcast_ref::<TimestampMillisecondArray>()
433 .unwrap();
434 if let Some(tz) = timezone {
435 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
436 let value = ts_array
437 .value_as_datetime_with_tz(idx, tz)
438 .map(|d| d.fixed_offset());
439 encoder.encode_field_with_type_and_format(&value, type_, format)?;
440 } else {
441 let value = ts_array.value_as_datetime(idx);
442 encoder.encode_field_with_type_and_format(&value, type_, format)?;
443 }
444 }
445 TimeUnit::Microsecond => {
446 if arr.is_null(idx) {
447 return encoder.encode_field_with_type_and_format(
448 &None::<NaiveDateTime>,
449 type_,
450 format,
451 );
452 }
453 let ts_array = arr
454 .as_any()
455 .downcast_ref::<TimestampMicrosecondArray>()
456 .unwrap();
457 if let Some(tz) = timezone {
458 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
459 let value = ts_array
460 .value_as_datetime_with_tz(idx, tz)
461 .map(|d| d.fixed_offset());
462 encoder.encode_field_with_type_and_format(&value, type_, format)?;
463 } else {
464 let value = ts_array.value_as_datetime(idx);
465 encoder.encode_field_with_type_and_format(&value, type_, format)?;
466 }
467 }
468 TimeUnit::Nanosecond => {
469 if arr.is_null(idx) {
470 return encoder.encode_field_with_type_and_format(
471 &None::<NaiveDateTime>,
472 type_,
473 format,
474 );
475 }
476 let ts_array = arr
477 .as_any()
478 .downcast_ref::<TimestampNanosecondArray>()
479 .unwrap();
480 if let Some(tz) = timezone {
481 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
482 let value = ts_array
483 .value_as_datetime_with_tz(idx, tz)
484 .map(|d| d.fixed_offset());
485 encoder.encode_field_with_type_and_format(&value, type_, format)?;
486 } else {
487 let value = ts_array.value_as_datetime(idx);
488 encoder.encode_field_with_type_and_format(&value, type_, format)?;
489 }
490 }
491 },
492 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
493 if arr.is_null(idx) {
494 return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
495 }
496 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
497 let value = encode_list(array, type_, format)?;
498 encoder.encode_field_with_type_and_format(&value, type_, format)?
499 }
500 DataType::Struct(_) => {
501 let fields = match type_.kind() {
502 postgres_types::Kind::Composite(fields) => fields,
503 _ => {
504 return Err(PgWireError::ApiError(ToSqlError::from(format!(
505 "Failed to unwrap a composite type from type {type_}"
506 ))));
507 }
508 };
509 let value = encode_struct(arr, idx, fields, format)?;
510 encoder.encode_field_with_type_and_format(&value, type_, format)?
511 }
512 DataType::Dictionary(_, value_type) => {
513 if arr.is_null(idx) {
514 return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
515 }
516 macro_rules! get_dict_values {
519 ($key_type:ty) => {
520 arr.as_any()
521 .downcast_ref::<DictionaryArray<$key_type>>()
522 .map(|dict| dict.values())
523 };
524 }
525
526 let values = get_dict_values!(Int8Type)
528 .or_else(|| get_dict_values!(Int16Type))
529 .or_else(|| get_dict_values!(Int32Type))
530 .or_else(|| get_dict_values!(Int64Type))
531 .or_else(|| get_dict_values!(UInt8Type))
532 .or_else(|| get_dict_values!(UInt16Type))
533 .or_else(|| get_dict_values!(UInt32Type))
534 .or_else(|| get_dict_values!(UInt64Type))
535 .ok_or_else(|| {
536 ToSqlError::from(format!(
537 "Unsupported dictionary key type for value type {value_type}"
538 ))
539 })?;
540
541 if values.len() == 1 {
543 encode_value(encoder, values, 0, type_, format)?
544 } else {
545 encode_value(encoder, values, idx, type_, format)?
547 }
548 }
549 _ => {
550 return Err(PgWireError::ApiError(ToSqlError::from(format!(
551 "Unsupported Datatype {} and array {:?}",
552 arr.data_type(),
553 &arr
554 ))));
555 }
556 }
557
558 Ok(())
559}