1use std::{str::FromStr, sync::Arc};
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::{
5 array::{
6 timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
7 Decimal256Array, DurationMicrosecondArray, LargeBinaryArray, LargeListArray,
8 LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, Time32MillisecondArray,
9 Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
10 TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
11 TimestampSecondArray,
12 },
13 datatypes::{
14 DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
15 Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
16 Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
17 },
18 temporal_conversions::{as_date, as_time},
19};
20#[cfg(feature = "datafusion")]
21use datafusion::arrow::{
22 array::{
23 timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
24 Decimal256Array, DurationMicrosecondArray, LargeBinaryArray, LargeListArray,
25 LargeStringArray, ListArray, MapArray, PrimitiveArray, StringArray, Time32MillisecondArray,
26 Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
27 TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
28 TimestampSecondArray,
29 },
30 datatypes::{
31 DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
32 Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
33 Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
34 },
35 temporal_conversions::{as_date, as_time},
36};
37
38use bytes::{BufMut, BytesMut};
39use chrono::{DateTime, TimeZone, Utc};
40use pgwire::api::results::FieldFormat;
41use pgwire::error::{PgWireError, PgWireResult};
42use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
43use postgres_types::{ToSql, Type};
44use rust_decimal::Decimal;
45
46use crate::encoder::EncodedValue;
47use crate::error::ToSqlError;
48use crate::struct_encoder::encode_struct;
49
50fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
51 arr.as_any()
52 .downcast_ref::<BooleanArray>()
53 .unwrap()
54 .iter()
55 .collect()
56}
57
58macro_rules! get_primitive_list_value {
59 ($name:ident, $t:ty, $pt:ty) => {
60 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
61 arr.as_any()
62 .downcast_ref::<PrimitiveArray<$t>>()
63 .unwrap()
64 .iter()
65 .collect()
66 }
67 };
68
69 ($name:ident, $t:ty, $pt:ty, $f:expr) => {
70 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
71 arr.as_any()
72 .downcast_ref::<PrimitiveArray<$t>>()
73 .unwrap()
74 .iter()
75 .map(|val| val.map($f))
76 .collect()
77 }
78 };
79}
80
81get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
82get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
83get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
84get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
85get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
86get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
87 val as i16
88});
89get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
90get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
91 val as i64
92});
93get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
94get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
95
96fn encode_field<T: ToSql + ToSqlText>(
97 t: &[T],
98 type_: &Type,
99 format: FieldFormat,
100) -> PgWireResult<EncodedValue> {
101 let mut bytes = BytesMut::new();
102 match format {
103 FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
104 FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
105 };
106 Ok(EncodedValue { bytes })
107}
108
109pub(crate) fn encode_list(
110 arr: Arc<dyn Array>,
111 type_: &Type,
112 format: FieldFormat,
113) -> PgWireResult<EncodedValue> {
114 match arr.data_type() {
115 DataType::Null => {
116 let mut bytes = BytesMut::new();
117 match format {
118 FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
119 FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
120 }?;
121 Ok(EncodedValue { bytes })
122 }
123 DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
124 DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
125 DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
126 DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
127 DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
128 DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
129 DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
130 DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
131 DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
132 DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
133 DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
134 DataType::Decimal128(_, s) => {
135 let value: Vec<_> = arr
136 .as_any()
137 .downcast_ref::<Decimal128Array>()
138 .unwrap()
139 .iter()
140 .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
141 .collect();
142 encode_field(&value, type_, format)
143 }
144 DataType::Utf8 => {
145 let value: Vec<Option<&str>> = arr
146 .as_any()
147 .downcast_ref::<StringArray>()
148 .unwrap()
149 .iter()
150 .collect();
151 encode_field(&value, type_, format)
152 }
153 DataType::Binary => {
154 let value: Vec<Option<_>> = arr
155 .as_any()
156 .downcast_ref::<BinaryArray>()
157 .unwrap()
158 .iter()
159 .collect();
160 encode_field(&value, type_, format)
161 }
162 DataType::LargeBinary => {
163 let value: Vec<Option<_>> = arr
164 .as_any()
165 .downcast_ref::<LargeBinaryArray>()
166 .unwrap()
167 .iter()
168 .collect();
169 encode_field(&value, type_, format)
170 }
171
172 DataType::Date32 => {
173 let value: Vec<Option<_>> = arr
174 .as_any()
175 .downcast_ref::<Date32Array>()
176 .unwrap()
177 .iter()
178 .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
179 .collect();
180 encode_field(&value, type_, format)
181 }
182 DataType::Date64 => {
183 let value: Vec<Option<_>> = arr
184 .as_any()
185 .downcast_ref::<Date64Array>()
186 .unwrap()
187 .iter()
188 .map(|val| val.and_then(as_date::<Date64Type>))
189 .collect();
190 encode_field(&value, type_, format)
191 }
192 DataType::Time32(unit) => match unit {
193 TimeUnit::Second => {
194 let value: Vec<Option<_>> = arr
195 .as_any()
196 .downcast_ref::<Time32SecondArray>()
197 .unwrap()
198 .iter()
199 .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
200 .collect();
201 encode_field(&value, type_, format)
202 }
203 TimeUnit::Millisecond => {
204 let value: Vec<Option<_>> = arr
205 .as_any()
206 .downcast_ref::<Time32MillisecondArray>()
207 .unwrap()
208 .iter()
209 .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
210 .collect();
211 encode_field(&value, type_, format)
212 }
213 _ => {
214 Err(PgWireError::ApiError("Unsupported Time32 unit".into()))
217 }
218 },
219 DataType::Time64(unit) => match unit {
220 TimeUnit::Microsecond => {
221 let value: Vec<Option<_>> = arr
222 .as_any()
223 .downcast_ref::<Time64MicrosecondArray>()
224 .unwrap()
225 .iter()
226 .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
227 .collect();
228 encode_field(&value, type_, format)
229 }
230 TimeUnit::Nanosecond => {
231 let value: Vec<Option<_>> = arr
232 .as_any()
233 .downcast_ref::<Time64NanosecondArray>()
234 .unwrap()
235 .iter()
236 .map(|val| val.and_then(as_time::<Time64NanosecondType>))
237 .collect();
238 encode_field(&value, type_, format)
239 }
240 _ => {
241 Err(PgWireError::ApiError("Unsupported Time64 unit".into()))
244 }
245 },
246 DataType::Timestamp(unit, timezone) => match unit {
247 TimeUnit::Second => {
248 let array_iter = arr
249 .as_any()
250 .downcast_ref::<TimestampSecondArray>()
251 .unwrap()
252 .iter();
253
254 if let Some(tz) = timezone {
255 let tz = Tz::from_str(tz.as_ref())
256 .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
257 let value: Vec<_> = array_iter
258 .map(|i| {
259 i.and_then(|i| {
260 DateTime::from_timestamp(i, 0).map(|dt| {
261 Utc.from_utc_datetime(&dt.naive_utc())
262 .with_timezone(&tz)
263 .fixed_offset()
264 })
265 })
266 })
267 .collect();
268 encode_field(&value, type_, format)
269 } else {
270 let value: Vec<_> = array_iter
271 .map(|i| {
272 i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
273 })
274 .collect();
275 encode_field(&value, type_, format)
276 }
277 }
278 TimeUnit::Millisecond => {
279 let array_iter = arr
280 .as_any()
281 .downcast_ref::<TimestampMillisecondArray>()
282 .unwrap()
283 .iter();
284
285 if let Some(tz) = timezone {
286 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
287 let value: Vec<_> = array_iter
288 .map(|i| {
289 i.and_then(|i| {
290 DateTime::from_timestamp_millis(i).map(|dt| {
291 Utc.from_utc_datetime(&dt.naive_utc())
292 .with_timezone(&tz)
293 .fixed_offset()
294 })
295 })
296 })
297 .collect();
298 encode_field(&value, type_, format)
299 } else {
300 let value: Vec<_> = array_iter
301 .map(|i| {
302 i.and_then(|i| {
303 DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
304 })
305 })
306 .collect();
307 encode_field(&value, type_, format)
308 }
309 }
310 TimeUnit::Microsecond => {
311 let array_iter = arr
312 .as_any()
313 .downcast_ref::<TimestampMicrosecondArray>()
314 .unwrap()
315 .iter();
316
317 if let Some(tz) = timezone {
318 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
319 let value: Vec<_> = array_iter
320 .map(|i| {
321 i.and_then(|i| {
322 DateTime::from_timestamp_micros(i).map(|dt| {
323 Utc.from_utc_datetime(&dt.naive_utc())
324 .with_timezone(&tz)
325 .fixed_offset()
326 })
327 })
328 })
329 .collect();
330 encode_field(&value, type_, format)
331 } else {
332 let value: Vec<_> = array_iter
333 .map(|i| {
334 i.and_then(|i| {
335 DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
336 })
337 })
338 .collect();
339 encode_field(&value, type_, format)
340 }
341 }
342 TimeUnit::Nanosecond => {
343 let array_iter = arr
344 .as_any()
345 .downcast_ref::<TimestampNanosecondArray>()
346 .unwrap()
347 .iter();
348
349 if let Some(tz) = timezone {
350 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
351 let value: Vec<_> = array_iter
352 .map(|i| {
353 i.map(|i| {
354 Utc.from_utc_datetime(
355 &DateTime::from_timestamp_nanos(i).naive_utc(),
356 )
357 .with_timezone(&tz)
358 .fixed_offset()
359 })
360 })
361 .collect();
362 encode_field(&value, type_, format)
363 } else {
364 let value: Vec<_> = array_iter
365 .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
366 .collect();
367 encode_field(&value, type_, format)
368 }
369 }
370 },
371 DataType::Struct(_) => {
372 let fields = match type_.kind() {
373 postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
374 _ => Err(format!(
375 "Expected list type found type {} of kind {:?}",
376 type_,
377 type_.kind()
378 )),
379 }
380 .and_then(|struct_type| match struct_type.kind() {
381 postgres_types::Kind::Composite(fields) => Ok(fields),
382 _ => Err(format!(
383 "Failed to unwrap a composite type inside from type {} kind {:?}",
384 type_,
385 type_.kind()
386 )),
387 })
388 .map_err(ToSqlError::from)?;
389
390 let values: PgWireResult<Vec<_>> = (0..arr.len())
391 .map(|row| encode_struct(&arr, row, fields, format))
392 .map(|x| {
393 if matches!(format, FieldFormat::Text) {
394 x.map(|opt| {
395 opt.map(|value| {
396 let mut w = BytesMut::new();
397 w.put_u8(b'"');
398 w.put_slice(
399 QUOTE_ESCAPE
400 .replace_all(
401 &String::from_utf8_lossy(&value.bytes),
402 r#"\$1"#,
403 )
404 .as_bytes(),
405 );
406 w.put_u8(b'"');
407 EncodedValue { bytes: w }
408 })
409 })
410 } else {
411 x
412 }
413 })
414 .collect();
415 encode_field(&values?, type_, format)
416 }
417 DataType::LargeUtf8 => {
418 let value: Vec<Option<&str>> = arr
419 .as_any()
420 .downcast_ref::<LargeStringArray>()
421 .unwrap()
422 .iter()
423 .collect();
424 encode_field(&value, type_, format)
425 }
426 DataType::Decimal256(_, s) => {
427 let decimal_array = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
430 let value: Vec<Option<String>> = (0..decimal_array.len())
431 .map(|i| {
432 if decimal_array.is_null(i) {
433 None
434 } else {
435 let raw_value = decimal_array.value(i);
437 let scale = *s as u32;
438 let value_str = raw_value.to_string();
440 if scale == 0 {
441 Some(value_str)
442 } else {
443 let mut chars: Vec<char> = value_str.chars().collect();
445 if chars.len() <= scale as usize {
446 let zeros_needed = scale as usize - chars.len() + 1;
448 chars.splice(0..0, std::iter::repeat_n('0', zeros_needed));
449 chars.insert(1, '.');
450 } else {
451 let decimal_pos = chars.len() - scale as usize;
452 chars.insert(decimal_pos, '.');
453 }
454 Some(chars.into_iter().collect())
455 }
456 }
457 })
458 .collect();
459 encode_field(&value, type_, format)
460 }
461 DataType::Duration(_) => {
462 let value: Vec<Option<i64>> = arr
464 .as_any()
465 .downcast_ref::<DurationMicrosecondArray>()
466 .unwrap()
467 .iter()
468 .collect();
469 encode_field(&value, type_, format)
470 }
471 DataType::List(_) => {
472 let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
475 let value: Vec<Option<String>> = (0..list_array.len())
476 .map(|i| {
477 if list_array.is_null(i) {
478 None
479 } else {
480 Some(format!("[nested_list_{i}]"))
482 }
483 })
484 .collect();
485 encode_field(&value, type_, format)
486 }
487 DataType::LargeList(_) => {
488 let list_array = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
490 let value: Vec<Option<String>> = (0..list_array.len())
491 .map(|i| {
492 if list_array.is_null(i) {
493 None
494 } else {
495 Some(format!("[large_list_{i}]"))
496 }
497 })
498 .collect();
499 encode_field(&value, type_, format)
500 }
501 DataType::Map(_, _) => {
502 let map_array = arr.as_any().downcast_ref::<MapArray>().unwrap();
504 let value: Vec<Option<String>> = (0..map_array.len())
505 .map(|i| {
506 if map_array.is_null(i) {
507 None
508 } else {
509 Some(format!("{{map_{i}}}"))
510 }
511 })
512 .collect();
513 encode_field(&value, type_, format)
514 }
515
516 DataType::Union(_, _) => {
517 let value: Vec<Option<String>> = (0..arr.len())
519 .map(|i| {
520 if arr.is_null(i) {
521 None
522 } else {
523 Some(format!("union_{i}"))
524 }
525 })
526 .collect();
527 encode_field(&value, type_, format)
528 }
529 DataType::Dictionary(_, _) => {
530 let value: Vec<Option<String>> = (0..arr.len())
532 .map(|i| {
533 if arr.is_null(i) {
534 None
535 } else {
536 Some(format!("dict_{i}"))
537 }
538 })
539 .collect();
540 encode_field(&value, type_, format)
541 }
542 list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
544 "Unsupported List Datatype {} and array {:?}",
545 list_type, &arr
546 )))),
547 }
548}