1use std::{str::FromStr, sync::Arc};
2
3use arrow::array::{
4 timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array,
5 LargeBinaryArray, PrimitiveArray, StringArray, Time32MillisecondArray, Time32SecondArray,
6 Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
7 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
8};
9use arrow::{
10 datatypes::{
11 DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type,
12 Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
13 Time64NanosecondType, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
14 },
15 temporal_conversions::{as_date, as_time},
16};
17use bytes::{BufMut, BytesMut};
18use chrono::{DateTime, TimeZone, Utc};
19use pgwire::api::results::FieldFormat;
20use pgwire::error::{PgWireError, PgWireResult};
21use pgwire::types::{ToSqlText, QUOTE_ESCAPE};
22use postgres_types::{ToSql, Type};
23use rust_decimal::Decimal;
24
25use crate::encoder::EncodedValue;
26use crate::error::ToSqlError;
27use crate::struct_encoder::encode_struct;
28
29fn get_bool_list_value(arr: &Arc<dyn Array>) -> Vec<Option<bool>> {
30 arr.as_any()
31 .downcast_ref::<BooleanArray>()
32 .unwrap()
33 .iter()
34 .collect()
35}
36
37macro_rules! get_primitive_list_value {
38 ($name:ident, $t:ty, $pt:ty) => {
39 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
40 arr.as_any()
41 .downcast_ref::<PrimitiveArray<$t>>()
42 .unwrap()
43 .iter()
44 .collect()
45 }
46 };
47
48 ($name:ident, $t:ty, $pt:ty, $f:expr) => {
49 fn $name(arr: &Arc<dyn Array>) -> Vec<Option<$pt>> {
50 arr.as_any()
51 .downcast_ref::<PrimitiveArray<$t>>()
52 .unwrap()
53 .iter()
54 .map(|val| val.map($f))
55 .collect()
56 }
57 };
58}
59
60get_primitive_list_value!(get_i8_list_value, Int8Type, i8);
61get_primitive_list_value!(get_i16_list_value, Int16Type, i16);
62get_primitive_list_value!(get_i32_list_value, Int32Type, i32);
63get_primitive_list_value!(get_i64_list_value, Int64Type, i64);
64get_primitive_list_value!(get_u8_list_value, UInt8Type, i8, |val: u8| { val as i8 });
65get_primitive_list_value!(get_u16_list_value, UInt16Type, i16, |val: u16| {
66 val as i16
67});
68get_primitive_list_value!(get_u32_list_value, UInt32Type, u32);
69get_primitive_list_value!(get_u64_list_value, UInt64Type, i64, |val: u64| {
70 val as i64
71});
72get_primitive_list_value!(get_f32_list_value, Float32Type, f32);
73get_primitive_list_value!(get_f64_list_value, Float64Type, f64);
74
75fn encode_field<T: ToSql + ToSqlText>(
76 t: &[T],
77 type_: &Type,
78 format: FieldFormat,
79) -> PgWireResult<EncodedValue> {
80 let mut bytes = BytesMut::new();
81 match format {
82 FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?,
83 FieldFormat::Binary => t.to_sql(type_, &mut bytes)?,
84 };
85 Ok(EncodedValue { bytes })
86}
87
88pub(crate) fn encode_list(
89 arr: Arc<dyn Array>,
90 type_: &Type,
91 format: FieldFormat,
92) -> PgWireResult<EncodedValue> {
93 match arr.data_type() {
94 DataType::Null => {
95 let mut bytes = BytesMut::new();
96 match format {
97 FieldFormat::Text => None::<i8>.to_sql_text(type_, &mut bytes),
98 FieldFormat::Binary => None::<i8>.to_sql(type_, &mut bytes),
99 }?;
100 Ok(EncodedValue { bytes })
101 }
102 DataType::Boolean => encode_field(&get_bool_list_value(&arr), type_, format),
103 DataType::Int8 => encode_field(&get_i8_list_value(&arr), type_, format),
104 DataType::Int16 => encode_field(&get_i16_list_value(&arr), type_, format),
105 DataType::Int32 => encode_field(&get_i32_list_value(&arr), type_, format),
106 DataType::Int64 => encode_field(&get_i64_list_value(&arr), type_, format),
107 DataType::UInt8 => encode_field(&get_u8_list_value(&arr), type_, format),
108 DataType::UInt16 => encode_field(&get_u16_list_value(&arr), type_, format),
109 DataType::UInt32 => encode_field(&get_u32_list_value(&arr), type_, format),
110 DataType::UInt64 => encode_field(&get_u64_list_value(&arr), type_, format),
111 DataType::Float32 => encode_field(&get_f32_list_value(&arr), type_, format),
112 DataType::Float64 => encode_field(&get_f64_list_value(&arr), type_, format),
113 DataType::Decimal128(_, s) => {
114 let value: Vec<_> = arr
115 .as_any()
116 .downcast_ref::<Decimal128Array>()
117 .unwrap()
118 .iter()
119 .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
120 .collect();
121 encode_field(&value, type_, format)
122 }
123 DataType::Utf8 => {
124 let value: Vec<Option<&str>> = arr
125 .as_any()
126 .downcast_ref::<StringArray>()
127 .unwrap()
128 .iter()
129 .collect();
130 encode_field(&value, type_, format)
131 }
132 DataType::Binary => {
133 let value: Vec<Option<_>> = arr
134 .as_any()
135 .downcast_ref::<BinaryArray>()
136 .unwrap()
137 .iter()
138 .collect();
139 encode_field(&value, type_, format)
140 }
141 DataType::LargeBinary => {
142 let value: Vec<Option<_>> = arr
143 .as_any()
144 .downcast_ref::<LargeBinaryArray>()
145 .unwrap()
146 .iter()
147 .collect();
148 encode_field(&value, type_, format)
149 }
150
151 DataType::Date32 => {
152 let value: Vec<Option<_>> = arr
153 .as_any()
154 .downcast_ref::<Date32Array>()
155 .unwrap()
156 .iter()
157 .map(|val| val.and_then(|x| as_date::<Date32Type>(x as i64)))
158 .collect();
159 encode_field(&value, type_, format)
160 }
161 DataType::Date64 => {
162 let value: Vec<Option<_>> = arr
163 .as_any()
164 .downcast_ref::<Date64Array>()
165 .unwrap()
166 .iter()
167 .map(|val| val.and_then(as_date::<Date64Type>))
168 .collect();
169 encode_field(&value, type_, format)
170 }
171 DataType::Time32(unit) => match unit {
172 TimeUnit::Second => {
173 let value: Vec<Option<_>> = arr
174 .as_any()
175 .downcast_ref::<Time32SecondArray>()
176 .unwrap()
177 .iter()
178 .map(|val| val.and_then(|x| as_time::<Time32SecondType>(x as i64)))
179 .collect();
180 encode_field(&value, type_, format)
181 }
182 TimeUnit::Millisecond => {
183 let value: Vec<Option<_>> = arr
184 .as_any()
185 .downcast_ref::<Time32MillisecondArray>()
186 .unwrap()
187 .iter()
188 .map(|val| val.and_then(|x| as_time::<Time32MillisecondType>(x as i64)))
189 .collect();
190 encode_field(&value, type_, format)
191 }
192 _ => {
193 unimplemented!()
194 }
195 },
196 DataType::Time64(unit) => match unit {
197 TimeUnit::Microsecond => {
198 let value: Vec<Option<_>> = arr
199 .as_any()
200 .downcast_ref::<Time64MicrosecondArray>()
201 .unwrap()
202 .iter()
203 .map(|val| val.and_then(as_time::<Time64MicrosecondType>))
204 .collect();
205 encode_field(&value, type_, format)
206 }
207 TimeUnit::Nanosecond => {
208 let value: Vec<Option<_>> = arr
209 .as_any()
210 .downcast_ref::<Time64NanosecondArray>()
211 .unwrap()
212 .iter()
213 .map(|val| val.and_then(as_time::<Time64NanosecondType>))
214 .collect();
215 encode_field(&value, type_, format)
216 }
217 _ => {
218 unimplemented!()
219 }
220 },
221 DataType::Timestamp(unit, timezone) => match unit {
222 TimeUnit::Second => {
223 let array_iter = arr
224 .as_any()
225 .downcast_ref::<TimestampSecondArray>()
226 .unwrap()
227 .iter();
228
229 if let Some(tz) = timezone {
230 let tz = Tz::from_str(tz.as_ref())
231 .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?;
232 let value: Vec<_> = array_iter
233 .map(|i| {
234 i.and_then(|i| {
235 DateTime::from_timestamp(i, 0).map(|dt| {
236 Utc.from_utc_datetime(&dt.naive_utc())
237 .with_timezone(&tz)
238 .fixed_offset()
239 })
240 })
241 })
242 .collect();
243 encode_field(&value, type_, format)
244 } else {
245 let value: Vec<_> = array_iter
246 .map(|i| {
247 i.and_then(|i| DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc()))
248 })
249 .collect();
250 encode_field(&value, type_, format)
251 }
252 }
253 TimeUnit::Millisecond => {
254 let array_iter = arr
255 .as_any()
256 .downcast_ref::<TimestampMillisecondArray>()
257 .unwrap()
258 .iter();
259
260 if let Some(tz) = timezone {
261 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
262 let value: Vec<_> = array_iter
263 .map(|i| {
264 i.and_then(|i| {
265 DateTime::from_timestamp_millis(i).map(|dt| {
266 Utc.from_utc_datetime(&dt.naive_utc())
267 .with_timezone(&tz)
268 .fixed_offset()
269 })
270 })
271 })
272 .collect();
273 encode_field(&value, type_, format)
274 } else {
275 let value: Vec<_> = array_iter
276 .map(|i| {
277 i.and_then(|i| {
278 DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
279 })
280 })
281 .collect();
282 encode_field(&value, type_, format)
283 }
284 }
285 TimeUnit::Microsecond => {
286 let array_iter = arr
287 .as_any()
288 .downcast_ref::<TimestampMicrosecondArray>()
289 .unwrap()
290 .iter();
291
292 if let Some(tz) = timezone {
293 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
294 let value: Vec<_> = array_iter
295 .map(|i| {
296 i.and_then(|i| {
297 DateTime::from_timestamp_micros(i).map(|dt| {
298 Utc.from_utc_datetime(&dt.naive_utc())
299 .with_timezone(&tz)
300 .fixed_offset()
301 })
302 })
303 })
304 .collect();
305 encode_field(&value, type_, format)
306 } else {
307 let value: Vec<_> = array_iter
308 .map(|i| {
309 i.and_then(|i| {
310 DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
311 })
312 })
313 .collect();
314 encode_field(&value, type_, format)
315 }
316 }
317 TimeUnit::Nanosecond => {
318 let array_iter = arr
319 .as_any()
320 .downcast_ref::<TimestampNanosecondArray>()
321 .unwrap()
322 .iter();
323
324 if let Some(tz) = timezone {
325 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
326 let value: Vec<_> = array_iter
327 .map(|i| {
328 i.map(|i| {
329 Utc.from_utc_datetime(
330 &DateTime::from_timestamp_nanos(i).naive_utc(),
331 )
332 .with_timezone(&tz)
333 .fixed_offset()
334 })
335 })
336 .collect();
337 encode_field(&value, type_, format)
338 } else {
339 let value: Vec<_> = array_iter
340 .map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
341 .collect();
342 encode_field(&value, type_, format)
343 }
344 }
345 },
346 DataType::Struct(_) => {
347 let fields = match type_.kind() {
348 postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
349 _ => Err(format!(
350 "Expected list type found type {} of kind {:?}",
351 type_,
352 type_.kind()
353 )),
354 }
355 .and_then(|struct_type| match struct_type.kind() {
356 postgres_types::Kind::Composite(fields) => Ok(fields),
357 _ => Err(format!(
358 "Failed to unwrap a composite type inside from type {} kind {:?}",
359 type_,
360 type_.kind()
361 )),
362 })
363 .map_err(ToSqlError::from)?;
364
365 let values: PgWireResult<Vec<_>> = (0..arr.len())
366 .map(|row| encode_struct(&arr, row, fields, format))
367 .map(|x| {
368 if matches!(format, FieldFormat::Text) {
369 x.map(|opt| {
370 opt.map(|value| {
371 let mut w = BytesMut::new();
372 w.put_u8(b'"');
373 w.put_slice(
374 QUOTE_ESCAPE
375 .replace_all(
376 &String::from_utf8_lossy(&value.bytes),
377 r#"\$1"#,
378 )
379 .as_bytes(),
380 );
381 w.put_u8(b'"');
382 EncodedValue { bytes: w }
383 })
384 })
385 } else {
386 x
387 }
388 })
389 .collect();
390 encode_field(&values?, type_, format)
391 }
392 list_type => Err(PgWireError::ApiError(ToSqlError::from(format!(
394 "Unsupported List Datatype {} and array {:?}",
395 list_type, &arr
396 )))),
397 }
398}