1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5 array::{
6 timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
7 Decimal128Array, Decimal256Array, DurationMicrosecondArray, LargeBinaryArray,
8 LargeListArray, LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray,
9 StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
10 Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
11 TimestampNanosecondArray, TimestampSecondArray,
12 },
13 datatypes::{
14 DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
15 Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
16 Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
17 },
18 temporal_conversions::{as_date, as_time},
19};
20#[cfg(feature = "datafusion")]
21use datafusion::arrow::{
22 array::{
23 timezone::Tz, Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
24 Decimal128Array, Decimal256Array, DurationMicrosecondArray, LargeBinaryArray,
25 LargeListArray, LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray,
26 StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
27 Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
28 TimestampNanosecondArray, TimestampSecondArray,
29 },
30 datatypes::{
31 DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
32 Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
33 Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
34 },
35 temporal_conversions::{as_date, as_time},
36};
37
38use bytes::{BufMut, BytesMut};
39use chrono::{DateTime, TimeZone, Utc};
40use pgwire::api::results::FieldFormat;
41use pgwire::error::{PgWireError, PgWireResult};
42use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
43use postgres_types::{ToSql, Type};
44use rust_decimal::Decimal;
45
46use crate::encoder::EncodedValue;
47use crate::error::ToSqlError;
48use crate::struct_encoder::encode_struct;
49
50fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
51 arr.as_any()
52 .downcast_ref::<BooleanArray>()
53 .unwrap()
54 .iter()
55 .collect()
56}
57
58macro_rules! get_primitive_list_value {
59 ($name:ident, $t:ty, $pt:ty) => {
60 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
61 arr.as_any()
62 .downcast_ref::<PrimitiveArray<$t>>()
63 .unwrap()
64 .iter()
65 .collect()
66 }
67 };
68
69 ($name:ident, $t:ty, $pt:ty, $f:expr) => {
70 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
71 arr.as_any()
72 .downcast_ref::<PrimitiveArray<$t>>()
73 .unwrap()
74 .iter()
75 .map(|val| val.map($f))
76 .collect()
77 }
78 };
79}
80
81get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
82get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
83get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
84get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
85get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
86get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
87 val as i16
88});
89get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
90get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
91 val as i64
92});
93get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
94get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
95
96fn encode_field<T: ToSql + ToSqlText>(
97 t: &[T],
98 type_: &Type,
99 format: FieldFormat,
100) -> PgWireResult<EncodedValue> {
101 let mut bytes = BytesMut::new();
102 match format {
103 FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
104 FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
105 };
106 Ok(EncodedValue { bytes })
107}
108
109pub(crate) fn encode_list(
110 arr: Arc<dyn Array>,
111 type_: &Type,
112 format: FieldFormat,
113) -> PgWireResult<EncodedValue> {
114 match arr.data_type() {
115 DataType::Null => {
116 let mut bytes = BytesMut::new();
117 match format {
118 FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
119 FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
120 }?;
121 Ok(EncodedValue { bytes })
122 }
123 DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
124 DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
125 DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
126 DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
127 DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
128 DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
129 DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
130 DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
131 DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
132 DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
133 DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
134 DataType::Decimal128(_, s) => {
135 let value: Vec<_> = arr
136 .as_any()
137 .downcast_ref::<Decimal128Array>()
138 .unwrap()
139 .iter()
140 .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
141 .collect();
142 encode_field(&value, type_, format)
143 }
144 DataType::Utf8 => {
145 let value: Vec<Option<&str>> = arr
146 .as_any()
147 .downcast_ref::<StringArray>()
148 .unwrap()
149 .iter()
150 .collect();
151 encode_field(&value, type_, format)
152 }
153 DataType::Utf8View => {
154 let value: Vec<Option<&str>> = arr
155 .as_any()
156 .downcast_ref::<StringViewArray>()
157 .unwrap()
158 .iter()
159 .collect();
160 encode_field(&value, type_, format)
161 }
162 DataType::Binary => {
163 let value: Vec<Option<_>> = arr
164 .as_any()
165 .downcast_ref::<BinaryArray>()
166 .unwrap()
167 .iter()
168 .collect();
169 encode_field(&value, type_, format)
170 }
171 DataType::LargeBinary => {
172 let value: Vec<Option<_>> = arr
173 .as_any()
174 .downcast_ref::<LargeBinaryArray>()
175 .unwrap()
176 .iter()
177 .collect();
178 encode_field(&value, type_, format)
179 }
180 DataType::BinaryView => {
181 let value: Vec<Option<_>> = arr
182 .as_any()
183 .downcast_ref::<BinaryViewArray>()
184 .unwrap()
185 .iter()
186 .collect();
187 encode_field(&value, type_, format)
188 }
189
190 DataType::Date32 => {
191 let value: Vec<Option<_>> = arr
192 .as_any()
193 .downcast_ref::<Date32Array>()
194 .unwrap()
195 .iter()
196 .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
197 .collect();
198 encode_field(&value, type_, format)
199 }
200 DataType::Date64 => {
201 let value: Vec<Option<_>> = arr
202 .as_any()
203 .downcast_ref::<Date64Array>()
204 .unwrap()
205 .iter()
206 .map(|val| val.and_then(as_date::<Date64Type>))
207 .collect();
208 encode_field(&value, type_, format)
209 }
210 DataType::Time32(unit) => match unit {
211 TimeUnit::Second => {
212 let value: Vec<Option<_>> = arr
213 .as_any()
214 .downcast_ref::<Time32SecondArray>()
215 .unwrap()
216 .iter()
217 .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
218 .collect();
219 encode_field(&value, type_, format)
220 }
221 TimeUnit::Millisecond => {
222 let value: Vec<Option<_>> = arr
223 .as_any()
224 .downcast_ref::<Time32MillisecondArray>()
225 .unwrap()
226 .iter()
227 .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
228 .collect();
229 encode_field(&value, type_, format)
230 }
231 _ => {
232 Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
235 }
236 },
237 DataType::Time64(unit) => match unit {
238 TimeUnit::Microsecond => {
239 let value: Vec<Option<_>> = arr
240 .as_any()
241 .downcast_ref::<Time64MicrosecondArray>()
242 .unwrap()
243 .iter()
244 .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
245 .collect();
246 encode_field(&value, type_, format)
247 }
248 TimeUnit::Nanosecond => {
249 let value: Vec<Option<_>> = arr
250 .as_any()
251 .downcast_ref::<Time64NanosecondArray>()
252 .unwrap()
253 .iter()
254 .map(|val| val.and_then(as_time::<Time64NanosecondType>))
255 .collect();
256 encode_field(&value, type_, format)
257 }
258 _ => {
259 Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
262 }
263 },
264 DataType::Timestamp(unit, timezone) => match unit {
265 TimeUnit::Second => {
266 let array_iter = arr
267 .as_any()
268 .downcast_ref::<TimestampSecondArray>()
269 .unwrap()
270 .iter();
271
272 if let Some(tz) = timezone {
273 let tz = Tz::from_str(tz.as_ref())
274 .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
275 let value: Vec<_> = array_iter
276 .map(|i| {
277 i.and_then(|i| {
278 DateTime::from_timestamp(i, 0).map(|dt| {
279 Utc.from_utc_datetime(&dt.naive_utc())
280 .with_timezone(&tz)
281 .fixed_offset()
282 })
283 })
284 })
285 .collect();
286 encode_field(&value, type_, format)
287 } else {
288 let value: Vec<_> = array_iter
289 .map(|i| {
290 i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
291 })
292 .collect();
293 encode_field(&value, type_, format)
294 }
295 }
296 TimeUnit::Millisecond => {
297 let array_iter = arr
298 .as_any()
299 .downcast_ref::<TimestampMillisecondArray>()
300 .unwrap()
301 .iter();
302
303 if let Some(tz) = timezone {
304 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
305 let value: Vec<_> = array_iter
306 .map(|i| {
307 i.and_then(|i| {
308 DateTime::from_timestamp_millis(i).map(|dt| {
309 Utc.from_utc_datetime(&dt.naive_utc())
310 .with_timezone(&tz)
311 .fixed_offset()
312 })
313 })
314 })
315 .collect();
316 encode_field(&value, type_, format)
317 } else {
318 let value: Vec<_> = array_iter
319 .map(|i| {
320 i.and_then(|i| {
321 DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
322 })
323 })
324 .collect();
325 encode_field(&value, type_, format)
326 }
327 }
328 TimeUnit::Microsecond => {
329 let array_iter = arr
330 .as_any()
331 .downcast_ref::<TimestampMicrosecondArray>()
332 .unwrap()
333 .iter();
334
335 if let Some(tz) = timezone {
336 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
337 let value: Vec<_> = array_iter
338 .map(|i| {
339 i.and_then(|i| {
340 DateTime::from_timestamp_micros(i).map(|dt| {
341 Utc.from_utc_datetime(&dt.naive_utc())
342 .with_timezone(&tz)
343 .fixed_offset()
344 })
345 })
346 })
347 .collect();
348 encode_field(&value, type_, format)
349 } else {
350 let value: Vec<_> = array_iter
351 .map(|i| {
352 i.and_then(|i| {
353 DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
354 })
355 })
356 .collect();
357 encode_field(&value, type_, format)
358 }
359 }
360 TimeUnit::Nanosecond => {
361 let array_iter = arr
362 .as_any()
363 .downcast_ref::<TimestampNanosecondArray>()
364 .unwrap()
365 .iter();
366
367 if let Some(tz) = timezone {
368 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
369 let value: Vec<_> = array_iter
370 .map(|i| {
371 i.map(|i| {
372 Utc.from_utc_datetime(
373 &DateTime::from_timestamp_nanos(i).naive_utc(),
374 )
375 .with_timezone(&tz)
376 .fixed_offset()
377 })
378 })
379 .collect();
380 encode_field(&value, type_, format)
381 } else {
382 let value: Vec<_> = array_iter
383 .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
384 .collect();
385 encode_field(&value, type_, format)
386 }
387 }
388 },
389 DataType::Struct(_) => {
390 let fields = match type_.kind() {
391 postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
392 _ => Err(format!(
393 "Expected list type found type {} of kind {:?}",
394 type_,
395 type_.kind()
396 )),
397 }
398 .and_then(|struct_type| match struct_type.kind() {
399 postgres_types::Kind::Composite(fields) => Ok(fields),
400 _ => Err(format!(
401 "Failed to unwrap a composite type inside from type {} kind {:?}",
402 type_,
403 type_.kind()
404 )),
405 })
406 .map_err(ToSqlError::from)?;
407
408 let values: PgWireResult<Vec<_>> = (0..arr.len())
409 .map(|row| encode_struct(&arr, row, fields, format))
410 .map(|x| {
411 if matches!(format, FieldFormat::Text) {
412 x.map(|opt| {
413 opt.map(|value| {
414 let mut w = BytesMut::new();
415 w.put_u8(b'"');
416 w.put_slice(
417 QUOTE_ESCAPE
418 .replace_all(
419 &String::from_utf8_lossy(&value.bytes),
420 r#"\$1"#,
421 )
422 .as_bytes(),
423 );
424 w.put_u8(b'"');
425 EncodedValue { bytes: w }
426 })
427 })
428 } else {
429 x
430 }
431 })
432 .collect();
433 encode_field(&values?, type_, format)
434 }
435 DataType::LargeUtf8 => {
436 let value: Vec<Option<&str>> = arr
437 .as_any()
438 .downcast_ref::<LargeStringArray>()
439 .unwrap()
440 .iter()
441 .collect();
442 encode_field(&value, type_, format)
443 }
444 DataType::Decimal256(_, s) => {
445 let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
448 let value: Vec<Option<String>> = (0..decimal_array.len())
449 .map(|i| {
450 if decimal_array.is_null(i) {
451 None
452 } else {
453 let raw_value = decimal_array.value(i);
455 let scale = *s as u32;
456 let value_str = raw_value.to_string();
458 if scale == 0 {
459 Some(value_str)
460 } else {
461 let mut chars: Vec<char> = value_str.chars().collect();
463 if chars.len() <= scale as usize {
464 let zeros_needed = scale as usize - chars.len() + 1;
466 chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
467 chars.insert(1, '.');
468 } else {
469 let decimal_pos = chars.len() - scale as usize;
470 chars.insert(decimal_pos, '.');
471 }
472 Some(chars.into_iter().collect())
473 }
474 }
475 })
476 .collect();
477 encode_field(&value, type_, format)
478 }
479 DataType::Duration(_) => {
480 let value: Vec<Option<i64>> = arr
482 .as_any()
483 .downcast_ref::<DurationMicrosecondArray>()
484 .unwrap()
485 .iter()
486 .collect();
487 encode_field(&value, type_, format)
488 }
489 DataType::List(_) => {
490 let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
493 let value: Vec<Option<String>> = (0..list_array.len())
494 .map(|i| {
495 if list_array.is_null(i) {
496 None
497 } else {
498 Some(format!("[nested_list_{i}]"))
500 }
501 })
502 .collect();
503 encode_field(&value, type_, format)
504 }
505 DataType::LargeList(_) => {
506 let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
508 let value: Vec<Option<String>> = (0..list_array.len())
509 .map(|i| {
510 if list_array.is_null(i) {
511 None
512 } else {
513 Some(format!("[large_list_{i}]"))
514 }
515 })
516 .collect();
517 encode_field(&value, type_, format)
518 }
519 DataType::Map(_, _) => {
520 let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
522 let value: Vec<Option<String>> = (0..map_array.len())
523 .map(|i| {
524 if map_array.is_null(i) {
525 None
526 } else {
527 Some(format!("{{map_{i}}}"))
528 }
529 })
530 .collect();
531 encode_field(&value, type_, format)
532 }
533
534 DataType::Union(_, _) => {
535 let value: Vec<Option<String>> = (0..arr.len())
537 .map(|i| {
538 if arr.is_null(i) {
539 None
540 } else {
541 Some(format!("union_{i}"))
542 }
543 })
544 .collect();
545 encode_field(&value, type_, format)
546 }
547 DataType::Dictionary(_, _) => {
548 let value: Vec<Option<String>> = (0..arr.len())
550 .map(|i| {
551 if arr.is_null(i) {
552 None
553 } else {
554 Some(format!("dict_{i}"))
555 }
556 })
557 .collect();
558 encode_field(&value, type_, format)
559 }
560 list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
562 "Unsupported List Datatype {} and array {:?}",
563 list_type, &arr
564 )))),
565 }
566}