1use std::error::Error;
2use std::io::Write;
3use std::str::FromStr;
4use std::sync::Arc;
5
6use arrow::array::*;
7use arrow::datatypes::*;
8use bytes::BufMut;
9use bytes::BytesMut;
10use chrono::{NaiveDate, NaiveDateTime};
11use pgwire::api::results::DataRowEncoder;
12use pgwire::api::results::FieldFormat;
13use pgwire::error::PgWireError;
14use pgwire::error::PgWireResult;
15use pgwire::types::ToSqlText;
16use postgres_types::{ToSql, Type};
17use rust_decimal::Decimal;
18use timezone::Tz;
19
20use crate::error::ToSqlError;
21use crate::list_encoder::encode_list;
22use crate::struct_encoder::encode_struct;
23
24pub trait Encoder {
25 fn encode_field_with_type_and_format<T>(
26 &mut self,
27 value: &T,
28 data_type: &Type,
29 format: FieldFormat,
30 ) -> PgWireResult<()>
31 where
32 T: ToSql + ToSqlText + Sized;
33}
34
35impl Encoder for DataRowEncoder {
36 fn encode_field_with_type_and_format<T>(
37 &mut self,
38 value: &T,
39 data_type: &Type,
40 format: FieldFormat,
41 ) -> PgWireResult<()>
42 where
43 T: ToSql + ToSqlText + Sized,
44 {
45 self.encode_field_with_type_and_format(value, data_type, format)
46 }
47}
48
49pub(crate) struct EncodedValue {
50 pub(crate) bytes: BytesMut,
51}
52
53impl std::fmt::Debug for EncodedValue {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("EncodedValue").finish()
56 }
57}
58
59impl ToSql for EncodedValue {
60 fn to_sql(
61 &self,
62 _ty: &Type,
63 out: &mut BytesMut,
64 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
65 where
66 Self: Sized,
67 {
68 out.writer().write_all(&self.bytes)?;
69 Ok(postgres_types::IsNull::No)
70 }
71
72 fn accepts(_ty: &Type) -> bool
73 where
74 Self: Sized,
75 {
76 true
77 }
78
79 fn to_sql_checked(
80 &self,
81 ty: &Type,
82 out: &mut BytesMut,
83 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
84 self.to_sql(ty, out)
85 }
86}
87
88impl ToSqlText for EncodedValue {
89 fn to_sql_text(
90 &self,
91 _ty: &Type,
92 out: &mut BytesMut,
93 ) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
94 where
95 Self: Sized,
96 {
97 out.writer().write_all(&self.bytes)?;
98 Ok(postgres_types::IsNull::No)
99 }
100}
101
102fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> Option<bool> {
103 (!arr.is_null(idx)).then(|| {
104 arr.as_any()
105 .downcast_ref::<BooleanArray>()
106 .unwrap()
107 .value(idx)
108 })
109}
110
111macro_rules! get_primitive_value {
112 ($name:ident, $t:ty, $pt:ty) => {
113 fn $name(arr: &Arc<dyn Array>, idx: usize) -> Option<$pt> {
114 (!arr.is_null(idx)).then(|| {
115 arr.as_any()
116 .downcast_ref::<PrimitiveArray<$t>>()
117 .unwrap()
118 .value(idx)
119 })
120 }
121 };
122}
123
124get_primitive_value!(get_i8_value, Int8Type, i8);
125get_primitive_value!(get_i16_value, Int16Type, i16);
126get_primitive_value!(get_i32_value, Int32Type, i32);
127get_primitive_value!(get_i64_value, Int64Type, i64);
128get_primitive_value!(get_u8_value, UInt8Type, u8);
129get_primitive_value!(get_u16_value, UInt16Type, u16);
130get_primitive_value!(get_u32_value, UInt32Type, u32);
131get_primitive_value!(get_u64_value, UInt64Type, u64);
132get_primitive_value!(get_f32_value, Float32Type, f32);
133get_primitive_value!(get_f64_value, Float64Type, f64);
134
135fn get_utf8_view_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
136 (!arr.is_null(idx)).then(|| {
137 arr.as_any()
138 .downcast_ref::<StringViewArray>()
139 .unwrap()
140 .value(idx)
141 })
142}
143
144fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
145 (!arr.is_null(idx)).then(|| {
146 arr.as_any()
147 .downcast_ref::<StringArray>()
148 .unwrap()
149 .value(idx)
150 })
151}
152
153fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&str> {
154 (!arr.is_null(idx)).then(|| {
155 arr.as_any()
156 .downcast_ref::<LargeStringArray>()
157 .unwrap()
158 .value(idx)
159 })
160}
161
162fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
163 (!arr.is_null(idx)).then(|| {
164 arr.as_any()
165 .downcast_ref::<BinaryArray>()
166 .unwrap()
167 .value(idx)
168 })
169}
170
171fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> Option<&[u8]> {
172 (!arr.is_null(idx)).then(|| {
173 arr.as_any()
174 .downcast_ref::<LargeBinaryArray>()
175 .unwrap()
176 .value(idx)
177 })
178}
179
180fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
181 if arr.is_null(idx) {
182 return None;
183 }
184 arr.as_any()
185 .downcast_ref::<Date32Array>()
186 .unwrap()
187 .value_as_date(idx)
188}
189
190fn get_date64_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
191 if arr.is_null(idx) {
192 return None;
193 }
194 arr.as_any()
195 .downcast_ref::<Date64Array>()
196 .unwrap()
197 .value_as_date(idx)
198}
199
200fn get_time32_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
201 if arr.is_null(idx) {
202 return None;
203 }
204 arr.as_any()
205 .downcast_ref::<Time32SecondArray>()
206 .unwrap()
207 .value_as_datetime(idx)
208}
209
210fn get_time32_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
211 if arr.is_null(idx) {
212 return None;
213 }
214 arr.as_any()
215 .downcast_ref::<Time32MillisecondArray>()
216 .unwrap()
217 .value_as_datetime(idx)
218}
219
220fn get_time64_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
221 if arr.is_null(idx) {
222 return None;
223 }
224 arr.as_any()
225 .downcast_ref::<Time64MicrosecondArray>()
226 .unwrap()
227 .value_as_datetime(idx)
228}
229fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
230 if arr.is_null(idx) {
231 return None;
232 }
233 arr.as_any()
234 .downcast_ref::<Time64NanosecondArray>()
235 .unwrap()
236 .value_as_datetime(idx)
237}
238
239fn get_numeric_128_value(
240 arr: &Arc<dyn Array>,
241 idx: usize,
242 scale: u32,
243) -> PgWireResult<Option<Decimal>> {
244 if arr.is_null(idx) {
245 return Ok(None);
246 }
247
248 let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
249 let value = array.value(idx);
250 Decimal::try_from_i128_with_scale(value, scale)
251 .map_err(|e| {
252 let message = match e {
253 rust_decimal::Error::ExceedsMaximumPossibleValue => {
254 "Exceeds maximum possible value"
255 }
256 rust_decimal::Error::LessThanMinimumPossibleValue => {
257 "Less than minimum possible value"
258 }
259 rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => {
260 "Scale exceeds maximum precision"
261 }
262 _ => unreachable!(),
263 };
264 PgWireError::ApiError(ToSqlError::from(message))
266 })
267 .map(Some)
268}
269
270pub fn encode_value<T: Encoder>(
271 encoder: &mut T,
272 arr: &Arc<dyn Array>,
273 idx: usize,
274 type_: &Type,
275 format: FieldFormat,
276) -> PgWireResult<()> {
277 match arr.data_type() {
278 DataType::Null => encoder.encode_field_with_type_and_format(&None::<i8>, type_, format)?,
279 DataType::Boolean => {
280 encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)?
281 }
282 DataType::Int8 => {
283 encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)?
284 }
285 DataType::Int16 => {
286 encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)?
287 }
288 DataType::Int32 => {
289 encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)?
290 }
291 DataType::Int64 => {
292 encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)?
293 }
294 DataType::UInt8 => encoder.encode_field_with_type_and_format(
295 &(get_u8_value(arr, idx).map(|x| x as i8)),
296 type_,
297 format,
298 )?,
299 DataType::UInt16 => encoder.encode_field_with_type_and_format(
300 &(get_u16_value(arr, idx).map(|x| x as i16)),
301 type_,
302 format,
303 )?,
304 DataType::UInt32 => {
305 encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)?
306 }
307 DataType::UInt64 => encoder.encode_field_with_type_and_format(
308 &(get_u64_value(arr, idx).map(|x| x as i64)),
309 type_,
310 format,
311 )?,
312 DataType::Float32 => {
313 encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)?
314 }
315 DataType::Float64 => {
316 encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)?
317 }
318 DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format(
319 &get_numeric_128_value(arr, idx, *s as u32)?,
320 type_,
321 format,
322 )?,
323 DataType::Utf8 => {
324 encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)?
325 }
326 DataType::Utf8View => encoder.encode_field_with_type_and_format(
327 &get_utf8_view_value(arr, idx),
328 type_,
329 format,
330 )?,
331 DataType::LargeUtf8 => encoder.encode_field_with_type_and_format(
332 &get_large_utf8_value(arr, idx),
333 type_,
334 format,
335 )?,
336 DataType::Binary => {
337 encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)?
338 }
339 DataType::LargeBinary => encoder.encode_field_with_type_and_format(
340 &get_large_binary_value(arr, idx),
341 type_,
342 format,
343 )?,
344 DataType::Date32 => {
345 encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)?
346 }
347 DataType::Date64 => {
348 encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)?
349 }
350 DataType::Time32(unit) => match unit {
351 TimeUnit::Second => encoder.encode_field_with_type_and_format(
352 &get_time32_second_value(arr, idx),
353 type_,
354 format,
355 )?,
356 TimeUnit::Millisecond => encoder.encode_field_with_type_and_format(
357 &get_time32_millisecond_value(arr, idx),
358 type_,
359 format,
360 )?,
361 _ => {}
362 },
363 DataType::Time64(unit) => match unit {
364 TimeUnit::Microsecond => encoder.encode_field_with_type_and_format(
365 &get_time64_microsecond_value(arr, idx),
366 type_,
367 format,
368 )?,
369 TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format(
370 &get_time64_nanosecond_value(arr, idx),
371 type_,
372 format,
373 )?,
374 _ => {}
375 },
376 DataType::Timestamp(unit, timezone) => match unit {
377 TimeUnit::Second => {
378 if arr.is_null(idx) {
379 return encoder.encode_field_with_type_and_format(
380 &None::<NaiveDateTime>,
381 type_,
382 format,
383 );
384 }
385 let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
386 if let Some(tz) = timezone {
387 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
388 let value = ts_array
389 .value_as_datetime_with_tz(idx, tz)
390 .map(|d| d.fixed_offset());
391 encoder.encode_field_with_type_and_format(&value, type_, format)?;
392 } else {
393 let value = ts_array.value_as_datetime(idx);
394 encoder.encode_field_with_type_and_format(&value, type_, format)?;
395 }
396 }
397 TimeUnit::Millisecond => {
398 if arr.is_null(idx) {
399 return encoder.encode_field_with_type_and_format(
400 &None::<NaiveDateTime>,
401 type_,
402 format,
403 );
404 }
405 let ts_array = arr
406 .as_any()
407 .downcast_ref::<TimestampMillisecondArray>()
408 .unwrap();
409 if let Some(tz) = timezone {
410 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
411 let value = ts_array
412 .value_as_datetime_with_tz(idx, tz)
413 .map(|d| d.fixed_offset());
414 encoder.encode_field_with_type_and_format(&value, type_, format)?;
415 } else {
416 let value = ts_array.value_as_datetime(idx);
417 encoder.encode_field_with_type_and_format(&value, type_, format)?;
418 }
419 }
420 TimeUnit::Microsecond => {
421 if arr.is_null(idx) {
422 return encoder.encode_field_with_type_and_format(
423 &None::<NaiveDateTime>,
424 type_,
425 format,
426 );
427 }
428 let ts_array = arr
429 .as_any()
430 .downcast_ref::<TimestampMicrosecondArray>()
431 .unwrap();
432 if let Some(tz) = timezone {
433 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
434 let value = ts_array
435 .value_as_datetime_with_tz(idx, tz)
436 .map(|d| d.fixed_offset());
437 encoder.encode_field_with_type_and_format(&value, type_, format)?;
438 } else {
439 let value = ts_array.value_as_datetime(idx);
440 encoder.encode_field_with_type_and_format(&value, type_, format)?;
441 }
442 }
443 TimeUnit::Nanosecond => {
444 if arr.is_null(idx) {
445 return encoder.encode_field_with_type_and_format(
446 &None::<NaiveDateTime>,
447 type_,
448 format,
449 );
450 }
451 let ts_array = arr
452 .as_any()
453 .downcast_ref::<TimestampNanosecondArray>()
454 .unwrap();
455 if let Some(tz) = timezone {
456 let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
457 let value = ts_array
458 .value_as_datetime_with_tz(idx, tz)
459 .map(|d| d.fixed_offset());
460 encoder.encode_field_with_type_and_format(&value, type_, format)?;
461 } else {
462 let value = ts_array.value_as_datetime(idx);
463 encoder.encode_field_with_type_and_format(&value, type_, format)?;
464 }
465 }
466 },
467 DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => {
468 if arr.is_null(idx) {
469 return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format);
470 }
471 let array = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
472 let value = encode_list(array, type_, format)?;
473 encoder.encode_field_with_type_and_format(&value, type_, format)?
474 }
475 DataType::Struct(_) => {
476 let fields = match type_.kind() {
477 postgres_types::Kind::Composite(fields) => fields,
478 _ => {
479 return Err(PgWireError::ApiError(ToSqlError::from(format!(
480 "Failed to unwrap a composite type from type {}",
481 type_
482 ))));
483 }
484 };
485 let value = encode_struct(arr, idx, fields, format)?;
486 encoder.encode_field_with_type_and_format(&value, type_, format)?
487 }
488 DataType::Dictionary(_, value_type) => {
489 if arr.is_null(idx) {
490 return encoder.encode_field_with_type_and_format(&None::<i8>, type_, format);
491 }
492 macro_rules! get_dict_values {
495 ($key_type:ty) => {
496 arr.as_any()
497 .downcast_ref::<DictionaryArray<$key_type>>()
498 .map(|dict| dict.values())
499 };
500 }
501
502 let values = get_dict_values!(Int8Type)
504 .or_else(|| get_dict_values!(Int16Type))
505 .or_else(|| get_dict_values!(Int32Type))
506 .or_else(|| get_dict_values!(Int64Type))
507 .or_else(|| get_dict_values!(UInt8Type))
508 .or_else(|| get_dict_values!(UInt16Type))
509 .or_else(|| get_dict_values!(UInt32Type))
510 .or_else(|| get_dict_values!(UInt64Type))
511 .ok_or_else(|| {
512 ToSqlError::from(format!(
513 "Unsupported dictionary key type for value type {}",
514 value_type
515 ))
516 })?;
517
518 if values.len() == 1 {
520 encode_value(encoder, values, 0, type_, format)?
521 } else {
522 encode_value(encoder, values, idx, type_, format)?
524 }
525 }
526 _ => {
527 return Err(PgWireError::ApiError(ToSqlError::from(format!(
528 "Unsupported Datatype {} and array {:?}",
529 arr.data_type(),
530 &arr
531 ))));
532 }
533 }
534
535 Ok(())
536}